From b274703b24ccced5d0b8aa7d362291d0011f65bb Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Thu, 11 Dec 2025 09:12:59 +0100 Subject: [PATCH] feat: Provider-agnostic judge support added to RULER Co-authored-by: Apoorva Gupta --- src/art/rewards/ruler.py | 86 ++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea333124..2f3cebdd7 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -16,6 +16,7 @@ from litellm import acompletion from litellm.types.utils import ModelResponse from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from pydantic import BaseModel, Field from rich import print @@ -55,6 +56,8 @@ async def ruler( judge_model: str = "openai/o3", extra_litellm_params: dict | None = None, rubric: str = DEFAULT_RUBRIC, + judge_chat_model=None, + context_window: int | None = None, *, debug: bool = False, ) -> list[TrajectoryScore]: @@ -80,6 +83,11 @@ async def ruler( - "anthropic/claude-3-opus-20240229" - Alternative judge extra_litellm_params: Additional parameters to pass to LiteLLM completion. Can include temperature, max_tokens, etc. + judge_chat_model: Optional LangChain-compatible chat model to use instead of + LiteLLM/OpenAI-style identifiers. Must return JSON matching `Response`. + context_window: Optional context window override (e.g., Ollama `num_ctx`). + If provided, it sets litellm `num_ctx` and `max_input_tokens` unless + already supplied in extra_litellm_params. rubric: The grading rubric. The default rubric works well for most tasks. debug: If True, pretty-print the judge's reasoning to help understand scores. @@ -172,30 +180,58 @@ async def ruler( {"role": "user", "content": user_text}, ] - response = await acompletion( - model=judge_model, - messages=messages, - response_format=Response, - caching=False, - **extra_litellm_params if extra_litellm_params else {}, - ) - assert isinstance(response, ModelResponse) + litellm_params = dict(extra_litellm_params) if extra_litellm_params else {} + if context_window is not None: + litellm_params.setdefault("num_ctx", context_window) + litellm_params.setdefault("max_input_tokens", context_window) + + if judge_chat_model is not None: + lc_messages = [ + SystemMessage(content=judge_prompt), + HumanMessage(content=user_text), + ] + result = await judge_chat_model.ainvoke(lc_messages) + content = getattr(result, "content", result) + if isinstance(content, BaseMessage): + content = content.content + if isinstance(content, list): + content = "".join( + part["text"] if isinstance(part, dict) and "text" in part else str(part) + for part in content + ) + if debug: + try: + print("\n[RULER] Pretty-printed LLM choice JSON:") + print(json.loads(content)) + except json.JSONDecodeError as e: + print(f"[RULER] Could not parse choice content as JSON: {e}") + print(f"[RULER] Raw choice content: {content}") + parsed = Response.model_validate_json(content) + else: + response = await acompletion( + model=judge_model, + messages=messages, + response_format=Response, + caching=False, + **litellm_params, + ) + assert isinstance(response, ModelResponse) - if len(response.choices) == 0: - raise ValueError(f"No choices in response: {response}") - first_choice = response.choices[0] + if len(response.choices) == 0: + raise ValueError(f"No choices in response: {response}") + first_choice = response.choices[0] - if debug: - raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined] - try: - print("\n[RULER] Pretty-printed LLM choice JSON:") - print(json.loads(raw_content)) - except json.JSONDecodeError as e: - print(f"[RULER] Could not parse choice content as JSON: {e}") - print(f"[RULER] Raw choice content: {raw_content}") + if debug: + raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined] + try: + print("\n[RULER] Pretty-printed LLM choice JSON:") + print(json.loads(raw_content)) + except json.JSONDecodeError as e: + print(f"[RULER] Could not parse choice content as JSON: {e}") + print(f"[RULER] Raw choice content: {raw_content}") - content = first_choice.message.content or "{}" # type: ignore[attr-defined] - parsed = Response.model_validate_json(content) + content = first_choice.message.content or "{}" # type: ignore[attr-defined] + parsed = Response.model_validate_json(content) # If all trajectories were identical, we only sent one to the judge # Duplicate the score for all trajectories @@ -222,6 +258,8 @@ async def ruler_score_group( judge_model: str = "openai/o3", extra_litellm_params: dict | None = None, rubric: str = DEFAULT_RUBRIC, + judge_chat_model=None, + context_window: int | None = None, *, swallow_exceptions: bool = False, debug: bool = False, @@ -242,6 +280,10 @@ async def ruler_score_group( group: A TrajectoryGroup containing trajectories to score. judge_model: The model to use for judging. See `ruler` for options. extra_litellm_params: Additional parameters to pass to LiteLLM completion. + judge_chat_model: Optional LangChain-compatible chat model to bypass LiteLLM / + OpenAI identifiers (e.g., ChatOllama, ChatNVIDIA). + context_window: Optional context window override (e.g., Ollama `num_ctx`). + Sets litellm `num_ctx`/`max_input_tokens` if not already set. rubric: Custom rubric or use the default which works well for most tasks. swallow_exceptions: If True, returns None on errors instead of raising. This is recommended for production to handle API failures gracefully. @@ -298,6 +340,8 @@ async def ruler_score_group( message_lists, judge_model=judge_model, extra_litellm_params=extra_litellm_params, + judge_chat_model=judge_chat_model, + context_window=context_window, rubric=rubric, debug=debug, )