Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]:

@staticmethod
def get_memreader_config() -> dict[str, Any]:
"""Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model)."""
return {
"backend": "openai",
"config": {
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
"temperature": 0.6,
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
"top_p": 0.95,
"top_k": 20,
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
# Default to OpenAI base URL when env var is not provided to satisfy pydantic
# validation requirements during tests/import.
"api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
"remove_think_prefix": True,
},
"""Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).

When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists),
the backup client is automatically enabled so that primary failures (self-deployed
model) fall back to the general LLM.
"""
config = {
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
"temperature": 0.6,
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
"top_p": 0.95,
"top_k": 20,
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
# Default to OpenAI base URL when env var is not provided to satisfy pydantic
# validation requirements during tests/import.
"api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
"remove_think_prefix": True,
}

general_model = os.getenv("MEMREADER_GENERAL_MODEL")
enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true"
if general_model and enable_backup:
config["backup_client"] = True
config["backup_model_name_or_path"] = general_model
config["backup_api_key"] = os.getenv(
"MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
)
config["backup_api_base"] = os.getenv(
"MEMREADER_GENERAL_API_BASE",
os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
)

return {"backend": "openai", "config": config}

@staticmethod
def get_memreader_general_llm_config() -> dict[str, Any]:
"""Get general LLM configuration for non-chat/doc tasks.
Expand Down
24 changes: 18 additions & 6 deletions src/memos/configs/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig):
default="https://api.openai.com/v1", description="Base URL for OpenAI API"
)
extra_body: Any = Field(default=None, description="extra body")
backup_client: bool = Field(
default=False,
description="Whether to enable backup client for fallback on primary failure",
)
backup_api_key: str | None = Field(
default=None, description="API key for backup OpenAI-compatible endpoint"
)
backup_api_base: str | None = Field(
default=None, description="Base URL for backup OpenAI-compatible endpoint"
)
backup_model_name_or_path: str | None = Field(
default=None, description="Model name for backup endpoint"
)
backup_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for backup client requests"
)


class OpenAIResponsesLLMConfig(BaseLLMConfig):
Expand All @@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig):
)


class QwenLLMConfig(BaseLLMConfig):
api_key: str = Field(..., description="API key for DashScope (Qwen)")
class QwenLLMConfig(OpenAILLMConfig):
api_base: str = Field(
default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
description="Base URL for Qwen OpenAI-compatible API",
)
extra_body: Any = Field(default=None, description="extra body")


class DeepSeekLLMConfig(BaseLLMConfig):
api_key: str = Field(..., description="API key for DeepSeek")
class DeepSeekLLMConfig(OpenAILLMConfig):
api_base: str = Field(
default="https://api.deepseek.com",
description="Base URL for DeepSeek OpenAI-compatible API",
)
extra_body: Any = Field(default=None, description="Extra options for API")


class AzureLLMConfig(BaseLLMConfig):
Expand Down
83 changes: 59 additions & 24 deletions src/memos/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig):
self.client = openai.Client(
api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers
)
logger.info("OpenAI LLM instance initialized")
self.use_backup_client = config.backup_client
if self.use_backup_client:
self.backup_client = openai.Client(
api_key=config.backup_api_key,
base_url=config.backup_api_base,
default_headers=config.backup_headers,
)
logger.info(
f"OpenAI LLM instance initialized with backup "
f"(model={config.backup_model_name_or_path})"
)
else:
self.backup_client = None
logger.info("OpenAI LLM instance initialized")

def _parse_response(self, response) -> str:
"""Extract text content from a chat completion response."""
if not response.choices:
logger.warning("OpenAI response has no choices")
return ""

tool_calls = getattr(response.choices[0].message, "tool_calls", None)
if isinstance(tool_calls, list) and len(tool_calls) > 0:
return self.tool_call_parser(tool_calls)
response_content = response.choices[0].message.content
reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
if isinstance(reasoning_content, str) and reasoning_content:
reasoning_content = f"<think>{reasoning_content}</think>"
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
if reasoning_content:
return reasoning_content + (response_content or "")
return response_content or ""

@timed_with_status(
log_prefix="OpenAI LLM",
Expand All @@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str:
start_time = time.perf_counter()
logger.info(f"OpenAI LLM Request body: {request_body}")

response = self.client.chat.completions.create(**request_body)

cost_time = time.perf_counter() - start_time
logger.info(
f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}"
)

if not response.choices:
logger.warning("OpenAI response has no choices")
return ""

tool_calls = getattr(response.choices[0].message, "tool_calls", None)
if isinstance(tool_calls, list) and len(tool_calls) > 0:
return self.tool_call_parser(tool_calls)
response_content = response.choices[0].message.content
reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
if isinstance(reasoning_content, str) and reasoning_content:
reasoning_content = f"<think>{reasoning_content}</think>"
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
if reasoning_content:
return reasoning_content + (response_content or "")
return response_content or ""
try:
response = self.client.chat.completions.create(**request_body)
cost_time = time.perf_counter() - start_time
logger.info(
f"Request body: {request_body}, Response from OpenAI: "
f"{response.model_dump_json()}, Cost time: {cost_time}"
)
return self._parse_response(response)
except Exception as e:
if not self.use_backup_client:
raise
logger.warning(
f"Primary LLM request failed with {type(e).__name__}: {e}, "
f"falling back to backup client"
)
backup_body = {
**request_body,
"model": self.config.backup_model_name_or_path or request_body["model"],
}
backup_response = self.backup_client.chat.completions.create(**backup_body)
cost_time = time.perf_counter() - start_time
logger.info(
f"Backup LLM request succeeded, Response: "
f"{backup_response.model_dump_json()}, Cost time: {cost_time}"
)
return self._parse_response(backup_response)

@timed_with_status(
log_prefix="OpenAI LLM Stream",
Expand Down
5 changes: 5 additions & 0 deletions tests/configs/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def test_openai_llm_config():
"remove_think_prefix",
"extra_body",
"default_headers",
"backup_client",
"backup_api_key",
"backup_api_base",
"backup_model_name_or_path",
"backup_headers",
],
)

Expand Down
Loading