diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 06aa50c65..f24e28559 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]:
@staticmethod
def get_memreader_config() -> dict[str, Any]:
- """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model)."""
- return {
- "backend": "openai",
- "config": {
- "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
- "temperature": 0.6,
- "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
- "top_p": 0.95,
- "top_k": 20,
- "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
- # Default to OpenAI base URL when env var is not provided to satisfy pydantic
- # validation requirements during tests/import.
- "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
- "remove_think_prefix": True,
- },
+ """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).
+
+ When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists),
+ the backup client is automatically enabled so that primary failures (self-deployed
+ model) fall back to the general LLM.
+ """
+ config = {
+ "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
+ "temperature": 0.6,
+ "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
+ "top_p": 0.95,
+ "top_k": 20,
+ "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
+ # Default to OpenAI base URL when env var is not provided to satisfy pydantic
+ # validation requirements during tests/import.
+ "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"),
+ "remove_think_prefix": True,
}
+ general_model = os.getenv("MEMREADER_GENERAL_MODEL")
+ enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true"
+ if general_model and enable_backup:
+ config["backup_client"] = True
+ config["backup_model_name_or_path"] = general_model
+ config["backup_api_key"] = os.getenv(
+ "MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
+ )
+ config["backup_api_base"] = os.getenv(
+ "MEMREADER_GENERAL_API_BASE",
+ os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ )
+
+ return {"backend": "openai", "config": config}
+
@staticmethod
def get_memreader_general_llm_config() -> dict[str, Any]:
"""Get general LLM configuration for non-chat/doc tasks.
diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py
index 5487d117c..11c39b33c 100644
--- a/src/memos/configs/llm.py
+++ b/src/memos/configs/llm.py
@@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig):
default="https://api.openai.com/v1", description="Base URL for OpenAI API"
)
extra_body: Any = Field(default=None, description="extra body")
+ backup_client: bool = Field(
+ default=False,
+ description="Whether to enable backup client for fallback on primary failure",
+ )
+ backup_api_key: str | None = Field(
+ default=None, description="API key for backup OpenAI-compatible endpoint"
+ )
+ backup_api_base: str | None = Field(
+ default=None, description="Base URL for backup OpenAI-compatible endpoint"
+ )
+ backup_model_name_or_path: str | None = Field(
+ default=None, description="Model name for backup endpoint"
+ )
+ backup_headers: dict[str, Any] | None = Field(
+ default=None, description="Default headers for backup client requests"
+ )
class OpenAIResponsesLLMConfig(BaseLLMConfig):
@@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig):
)
-class QwenLLMConfig(BaseLLMConfig):
- api_key: str = Field(..., description="API key for DashScope (Qwen)")
+class QwenLLMConfig(OpenAILLMConfig):
api_base: str = Field(
default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
description="Base URL for Qwen OpenAI-compatible API",
)
- extra_body: Any = Field(default=None, description="extra body")
-class DeepSeekLLMConfig(BaseLLMConfig):
- api_key: str = Field(..., description="API key for DeepSeek")
+class DeepSeekLLMConfig(OpenAILLMConfig):
api_base: str = Field(
default="https://api.deepseek.com",
description="Base URL for DeepSeek OpenAI-compatible API",
)
- extra_body: Any = Field(default=None, description="Extra options for API")
class AzureLLMConfig(BaseLLMConfig):
diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py
index 93dac42fb..f6bb4efc1 100644
--- a/src/memos/llms/openai.py
+++ b/src/memos/llms/openai.py
@@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig):
self.client = openai.Client(
api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers
)
- logger.info("OpenAI LLM instance initialized")
+ self.use_backup_client = config.backup_client
+ if self.use_backup_client:
+ self.backup_client = openai.Client(
+ api_key=config.backup_api_key,
+ base_url=config.backup_api_base,
+ default_headers=config.backup_headers,
+ )
+ logger.info(
+ f"OpenAI LLM instance initialized with backup "
+ f"(model={config.backup_model_name_or_path})"
+ )
+ else:
+ self.backup_client = None
+ logger.info("OpenAI LLM instance initialized")
+
+ def _parse_response(self, response) -> str:
+ """Extract text content from a chat completion response."""
+ if not response.choices:
+ logger.warning("OpenAI response has no choices")
+ return ""
+
+ tool_calls = getattr(response.choices[0].message, "tool_calls", None)
+ if isinstance(tool_calls, list) and len(tool_calls) > 0:
+ return self.tool_call_parser(tool_calls)
+ response_content = response.choices[0].message.content
+ reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
+ if isinstance(reasoning_content, str) and reasoning_content:
+ reasoning_content = f"{reasoning_content}"
+ if self.config.remove_think_prefix:
+ return remove_thinking_tags(response_content)
+ if reasoning_content:
+ return reasoning_content + (response_content or "")
+ return response_content or ""
@timed_with_status(
log_prefix="OpenAI LLM",
@@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str:
start_time = time.perf_counter()
logger.info(f"OpenAI LLM Request body: {request_body}")
- response = self.client.chat.completions.create(**request_body)
-
- cost_time = time.perf_counter() - start_time
- logger.info(
- f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}"
- )
-
- if not response.choices:
- logger.warning("OpenAI response has no choices")
- return ""
-
- tool_calls = getattr(response.choices[0].message, "tool_calls", None)
- if isinstance(tool_calls, list) and len(tool_calls) > 0:
- return self.tool_call_parser(tool_calls)
- response_content = response.choices[0].message.content
- reasoning_content = getattr(response.choices[0].message, "reasoning_content", None)
- if isinstance(reasoning_content, str) and reasoning_content:
- reasoning_content = f"{reasoning_content}"
- if self.config.remove_think_prefix:
- return remove_thinking_tags(response_content)
- if reasoning_content:
- return reasoning_content + (response_content or "")
- return response_content or ""
+ try:
+ response = self.client.chat.completions.create(**request_body)
+ cost_time = time.perf_counter() - start_time
+ logger.info(
+ f"Request body: {request_body}, Response from OpenAI: "
+ f"{response.model_dump_json()}, Cost time: {cost_time}"
+ )
+ return self._parse_response(response)
+ except Exception as e:
+ if not self.use_backup_client:
+ raise
+ logger.warning(
+ f"Primary LLM request failed with {type(e).__name__}: {e}, "
+ f"falling back to backup client"
+ )
+ backup_body = {
+ **request_body,
+ "model": self.config.backup_model_name_or_path or request_body["model"],
+ }
+ backup_response = self.backup_client.chat.completions.create(**backup_body)
+ cost_time = time.perf_counter() - start_time
+ logger.info(
+ f"Backup LLM request succeeded, Response: "
+ f"{backup_response.model_dump_json()}, Cost time: {cost_time}"
+ )
+ return self._parse_response(backup_response)
@timed_with_status(
log_prefix="OpenAI LLM Stream",
diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py
index 6562c9a95..f3d4549b5 100644
--- a/tests/configs/test_llm.py
+++ b/tests/configs/test_llm.py
@@ -56,6 +56,11 @@ def test_openai_llm_config():
"remove_think_prefix",
"extra_body",
"default_headers",
+ "backup_client",
+ "backup_api_key",
+ "backup_api_base",
+ "backup_model_name_or_path",
+ "backup_headers",
],
)