Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import time
import uuid
Expand Down Expand Up @@ -32,6 +33,9 @@
convert_messages,
extract_usage,
generate_images,
validate_attachments,
AttachmentValidationError,
_convert_content_part,
)
from tee_gateway.model_registry import get_model_config
from tee_gateway.pricing import compute_session_cost
Expand Down Expand Up @@ -103,6 +107,13 @@ def create_chat_completion(body):
connexion.request.get_json()
)

# Reject attachments the target model can't handle, and enforce the size cap,
# before doing any provider work.
try:
validate_attachments(chat_request.messages, chat_request.model)
except AttachmentValidationError as e:
return {"error": "Invalid attachment", "message": str(e)}, e.status

if chat_request.stream:
return _create_streaming_response(chat_request)
else:
Expand Down Expand Up @@ -901,6 +912,40 @@ def generate():
# ---------------------------------------------------------------------------


def _canonical_user_content(content) -> Any:
"""Canonicalize user-message content for request hashing.

Plain-string content is returned unchanged. For multimodal content (a list of
parts), inline attachment bytes are replaced with a ``sha256`` digest so the
signed request commits to the exact attachment content without bloating the
hashed payload with megabytes of base64. URL / file_id references are kept
verbatim.
"""
if isinstance(content, str):
return content
if not isinstance(content, list):
return str(content)

canonical = []
for part in content:
block = _convert_content_part(part)
if block is None:
continue
if block["type"] == "text":
canonical.append({"type": "text", "text": block.get("text", "")})
continue
entry = {"type": block["type"]}
if "base64" in block:
entry["sha256"] = hashlib.sha256(
block["base64"].encode("utf-8")
).hexdigest()
for key in ("mime_type", "filename", "url", "file_id"):
if block.get(key):
entry[key] = block[key]
canonical.append(entry)
return canonical


def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict:
"""Serialize a CreateChatCompletionRequest to a canonical dict for hashing."""
messages = []
Expand All @@ -911,7 +956,7 @@ def _chat_request_to_dict(chat_request: CreateChatCompletionRequest) -> dict:
messages.append(
{
"role": "user",
"content": msg.content,
"content": _canonical_user_content(msg.content),
}
)
elif isinstance(msg, ChatCompletionRequestAssistantMessage):
Expand Down
207 changes: 203 additions & 4 deletions tee_gateway/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import logging
from typing import List, Dict, Optional, Any
from typing import List, Dict, Optional, Any, Generator
from functools import lru_cache

import httpx
Expand Down Expand Up @@ -50,6 +50,11 @@
# BytePlus ModelArk OpenAI-compatible endpoint (ap-southeast)
BYTEDANCE_BASE_URL = "https://ark.ap-southeast.bytepluses.com/api/v3"

# Hard cap on total inline (base64) attachment bytes per request, enforced
# regardless of model. Inline base64 rides inside the encrypted payload, so this
# bounds the request size the enclave will accept.
MAX_ATTACHMENT_BYTES = 30 * 1024 * 1024 # 30 MB

# Shared synchronous HTTP clients for each provider.
# Initialized to None; built by set_provider_config() after key injection.
openai_http_client: Optional[httpx.Client] = None
Expand Down Expand Up @@ -325,17 +330,207 @@ def generate_images(model: str, prompt: str, n: int = 1) -> tuple[list[str], int
return images, len(images)


def _parse_data_uri(uri: str) -> Optional[tuple[str, str]]:
"""Parse a ``data:<mime>;base64,<data>`` URI into ``(mime_type, base64_data)``.

Returns ``None`` if the string is not a base64 data URI.
"""
if not isinstance(uri, str) or not uri.startswith("data:"):
return None
try:
header, data = uri.split(",", 1)
except ValueError:
return None
if ";base64" not in header:
return None
mime_type = header[len("data:") :].split(";", 1)[0]
return mime_type, data


def _convert_content_part(part: Any) -> Optional[Dict[str, Any]]:
"""Convert one OpenAI-format content part into a LangChain v1 standard content
block (``text`` / ``image`` / ``file``).

The standard blocks (``langchain_core.messages.content``) are translated into
each provider's native API by the respective ``langchain-<provider>`` package,
so a single representation works for Anthropic, OpenAI, Gemini and xAI. Returns
``None`` for empty or unrecognized parts.
"""
if not isinstance(part, dict):
text = str(part)
return {"type": "text", "text": text} if text else None

ptype = part.get("type")

if ptype == "text":
text = part.get("text", "") or ""
return {"type": "text", "text": text} if text else None

if ptype in ("image_url", "image"):
image_url = part.get("image_url", part)
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if not url:
# Already-standard image block carrying base64 directly.
if part.get("base64"):
block: Dict[str, Any] = {"type": "image", "base64": part["base64"]}
if part.get("mime_type"):
block["mime_type"] = part["mime_type"]
return block
return None
parsed = _parse_data_uri(url)
if parsed:
mime_type, data = parsed
return {"type": "image", "base64": data, "mime_type": mime_type}
return {"type": "image", "url": url}

if ptype in ("file", "input_file"):
file_obj = part.get("file", part)
if not isinstance(file_obj, dict):
file_obj = {}
file_id = file_obj.get("file_id") or part.get("file_id")
if file_id:
return {"type": "file", "file_id": file_id}

filename = file_obj.get("filename") or part.get("filename")
file_data = (
file_obj.get("file_data") or file_obj.get("base64") or part.get("base64")
)
if file_data:
file_mime: Optional[str]
parsed_file = _parse_data_uri(file_data)
if parsed_file:
file_mime, file_b64 = parsed_file
else:
file_mime = part.get("mime_type") or file_obj.get("mime_type")
file_b64 = file_data
block = {"type": "file", "base64": file_b64}
if file_mime:
block["mime_type"] = file_mime
# OpenAI requires a filename for file uploads; carry it through so
# langchain-openai doesn't substitute a placeholder.
if filename:
block["filename"] = filename
return block

file_url = file_obj.get("file_url") or file_obj.get("url") or part.get("url")
if file_url:
block = {"type": "file", "url": file_url}
if filename:
block["filename"] = filename
return block
return None

# Unknown part type: best-effort text extraction.
text = part.get("text", "") or ""
return {"type": "text", "text": text} if text else None


def _normalize_user_content_parts(content: list) -> list:
"""Preserve multimodal user content while tolerating primitive text parts."""
normalized = []
"""Pass OpenAI content parts through to LangChain mostly unchanged.

Text and image parts already convert correctly to every provider's native
API in their OpenAI form, so they are forwarded as-is. Only ``file`` /
``input_file`` parts are rewritten into LangChain standard file blocks: the
raw OpenAI ``{"type": "file", "file": {...}}`` shape is passed straight
through to providers like Anthropic, which expect a ``document`` block and
would otherwise reject it. Primitive (non-dict) parts are wrapped as text.
"""
normalized: List[Any] = []
for part in content:
if isinstance(part, dict):
normalized.append(part)
if part.get("type") in ("file", "input_file"):
block = _convert_content_part(part)
normalized.append(block if block is not None else part)
else:
normalized.append(part)
else:
normalized.append({"type": "text", "text": str(part)})
return normalized


class AttachmentValidationError(ValueError):
"""Raised when a request's attachments violate model capabilities or size
limits. Carries the HTTP status the caller should return."""

def __init__(self, message: str, status: int = 400) -> None:
super().__init__(message)
self.status = status


def _decoded_base64_len(b64: str) -> int:
"""Length in bytes of base64-encoded data without decoding it."""
data = b64.split(",", 1)[-1] # tolerate a leftover data: prefix
n = len(data)
padding = data[-2:].count("=") if n >= 2 else 0
return max((n * 3) // 4 - padding, 0)


def get_model_capabilities(model: str) -> Dict[str, Any]:
"""Return the LangChain capability profile for a model (``image_inputs``,
``pdf_inputs``, ...), or ``{}`` when the model has no profile data.

Reads the public ``.profile`` attribute of the instantiated chat model, which
each ``langchain-<provider>`` package populates from maintained model data.
"""
try:
chat = get_chat_model_cached(model, 0.0, 16)
return getattr(chat, "profile", None) or {}
except Exception:
return {}


def _iter_content_parts(messages: list) -> Generator[Dict[str, Any], None, None]:
for msg in messages:
content = (
msg.get("content")
if isinstance(msg, dict)
else getattr(msg, "content", None)
)
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
yield part


def validate_attachments(messages: list, model: str) -> None:
"""Enforce per-model modality support and the inline attachment size cap.

Modality gating fails *open*: a modality is only rejected when the model's
profile explicitly marks it unsupported, so models without profile data are
never wrongly blocked (the provider would still reject a truly unsupported
combination). The size cap is a hard limit. Raises ``AttachmentValidationError``.
"""
caps = get_model_capabilities(model)
image_supported = caps.get("image_inputs")
pdf_supported = caps.get("pdf_inputs")

total_bytes = 0
for part in _iter_content_parts(messages):
block = _convert_content_part(part)
if block is None:
continue
if block["type"] == "image":
if image_supported is False:
raise AttachmentValidationError(
f"Model {model!r} does not support image attachments."
)
if "base64" in block:
total_bytes += _decoded_base64_len(block["base64"])
elif block["type"] == "file":
if pdf_supported is False:
raise AttachmentValidationError(
f"Model {model!r} does not support document attachments."
)
if "base64" in block:
total_bytes += _decoded_base64_len(block["base64"])

if total_bytes > MAX_ATTACHMENT_BYTES:
raise AttachmentValidationError(
f"Attachments exceed the {MAX_ATTACHMENT_BYTES // (1024 * 1024)} MB limit.",
status=413,
)


def convert_messages(messages: list) -> List[Any]:
"""Convert OpenAI-format message objects or dicts to LangChain message objects."""
langchain_messages: List[BaseMessage] = []
Expand Down Expand Up @@ -363,6 +558,10 @@ def convert_messages(messages: list) -> List[Any]:
langchain_messages.append(SystemMessage(content=content))

elif role == "user":
# content may be a string or a list of multimodal content parts
# (text / image / file). Pass parts through as-is (file parts are
# normalized to standard LangChain blocks) so the providers handle
# the native conversion.
if isinstance(content, list):
content = _normalize_user_content_parts(content)
langchain_messages.append(HumanMessage(content=content))
Expand Down
Loading
Loading