diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 06aa50c65..f24e28559 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -321,23 +321,40 @@ def get_activation_config() -> dict[str, Any]: @staticmethod def get_memreader_config() -> dict[str, Any]: - """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model).""" - return { - "backend": "openai", - "config": { - "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), - "temperature": 0.6, - "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), - "top_p": 0.95, - "top_k": 20, - "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), - # Default to OpenAI base URL when env var is not provided to satisfy pydantic - # validation requirements during tests/import. - "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), - "remove_think_prefix": True, - }, + """Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model). + + When MEMREADER_GENERAL_MODEL is configured (i.e. a separate stable LLM exists), + the backup client is automatically enabled so that primary failures (self-deployed + model) fall back to the general LLM. + """ + config = { + "model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"), + "temperature": 0.6, + "max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")), + "top_p": 0.95, + "top_k": 20, + "api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"), + # Default to OpenAI base URL when env var is not provided to satisfy pydantic + # validation requirements during tests/import. + "api_base": os.getenv("MEMRADER_API_BASE", "https://api.openai.com/v1"), + "remove_think_prefix": True, } + general_model = os.getenv("MEMREADER_GENERAL_MODEL") + enable_backup = os.getenv("MEMREADER_ENABLE_BACKUP", "false").lower() == "true" + if general_model and enable_backup: + config["backup_client"] = True + config["backup_model_name_or_path"] = general_model + config["backup_api_key"] = os.getenv( + "MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY") + ) + config["backup_api_base"] = os.getenv( + "MEMREADER_GENERAL_API_BASE", + os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + ) + + return {"backend": "openai", "config": config} + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 5487d117c..11c39b33c 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -28,6 +28,22 @@ class OpenAILLMConfig(BaseLLMConfig): default="https://api.openai.com/v1", description="Base URL for OpenAI API" ) extra_body: Any = Field(default=None, description="extra body") + backup_client: bool = Field( + default=False, + description="Whether to enable backup client for fallback on primary failure", + ) + backup_api_key: str | None = Field( + default=None, description="API key for backup OpenAI-compatible endpoint" + ) + backup_api_base: str | None = Field( + default=None, description="Base URL for backup OpenAI-compatible endpoint" + ) + backup_model_name_or_path: str | None = Field( + default=None, description="Model name for backup endpoint" + ) + backup_headers: dict[str, Any] | None = Field( + default=None, description="Default headers for backup client requests" + ) class OpenAIResponsesLLMConfig(BaseLLMConfig): @@ -42,22 +58,18 @@ class OpenAIResponsesLLMConfig(BaseLLMConfig): ) -class QwenLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DashScope (Qwen)") +class QwenLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", description="Base URL for Qwen OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="extra body") -class DeepSeekLLMConfig(BaseLLMConfig): - api_key: str = Field(..., description="API key for DeepSeek") +class DeepSeekLLMConfig(OpenAILLMConfig): api_base: str = Field( default="https://api.deepseek.com", description="Base URL for DeepSeek OpenAI-compatible API", ) - extra_body: Any = Field(default=None, description="Extra options for API") class AzureLLMConfig(BaseLLMConfig): diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 93dac42fb..f6bb4efc1 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -27,7 +27,39 @@ def __init__(self, config: OpenAILLMConfig): self.client = openai.Client( api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers ) - logger.info("OpenAI LLM instance initialized") + self.use_backup_client = config.backup_client + if self.use_backup_client: + self.backup_client = openai.Client( + api_key=config.backup_api_key, + base_url=config.backup_api_base, + default_headers=config.backup_headers, + ) + logger.info( + f"OpenAI LLM instance initialized with backup " + f"(model={config.backup_model_name_or_path})" + ) + else: + self.backup_client = None + logger.info("OpenAI LLM instance initialized") + + def _parse_response(self, response) -> str: + """Extract text content from a chat completion response.""" + if not response.choices: + logger.warning("OpenAI response has no choices") + return "" + + tool_calls = getattr(response.choices[0].message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) + response_content = response.choices[0].message.content + reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + reasoning_content = f"{reasoning_content}" + if self.config.remove_think_prefix: + return remove_thinking_tags(response_content) + if reasoning_content: + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM", @@ -50,29 +82,32 @@ def generate(self, messages: MessageList, **kwargs) -> str: start_time = time.perf_counter() logger.info(f"OpenAI LLM Request body: {request_body}") - response = self.client.chat.completions.create(**request_body) - - cost_time = time.perf_counter() - start_time - logger.info( - f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}" - ) - - if not response.choices: - logger.warning("OpenAI response has no choices") - return "" - - tool_calls = getattr(response.choices[0].message, "tool_calls", None) - if isinstance(tool_calls, list) and len(tool_calls) > 0: - return self.tool_call_parser(tool_calls) - response_content = response.choices[0].message.content - reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) - if isinstance(reasoning_content, str) and reasoning_content: - reasoning_content = f"{reasoning_content}" - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - if reasoning_content: - return reasoning_content + (response_content or "") - return response_content or "" + try: + response = self.client.chat.completions.create(**request_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Request body: {request_body}, Response from OpenAI: " + f"{response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(response) + except Exception as e: + if not self.use_backup_client: + raise + logger.warning( + f"Primary LLM request failed with {type(e).__name__}: {e}, " + f"falling back to backup client" + ) + backup_body = { + **request_body, + "model": self.config.backup_model_name_or_path or request_body["model"], + } + backup_response = self.backup_client.chat.completions.create(**backup_body) + cost_time = time.perf_counter() - start_time + logger.info( + f"Backup LLM request succeeded, Response: " + f"{backup_response.model_dump_json()}, Cost time: {cost_time}" + ) + return self._parse_response(backup_response) @timed_with_status( log_prefix="OpenAI LLM Stream", diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index 6562c9a95..f3d4549b5 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -56,6 +56,11 @@ def test_openai_llm_config(): "remove_think_prefix", "extra_body", "default_headers", + "backup_client", + "backup_api_key", + "backup_api_base", + "backup_model_name_or_path", + "backup_headers", ], )