Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litellm import acompletion
from litellm.types.utils import ModelResponse
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from rich import print

Expand Down Expand Up @@ -55,6 +56,8 @@ async def ruler(
judge_model: str = "openai/o3",
extra_litellm_params: dict | None = None,
rubric: str = DEFAULT_RUBRIC,
judge_chat_model=None,
context_window: int | None = None,
*,
debug: bool = False,
) -> list[TrajectoryScore]:
Expand All @@ -80,6 +83,11 @@ async def ruler(
- "anthropic/claude-3-opus-20240229" - Alternative judge
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
Can include temperature, max_tokens, etc.
judge_chat_model: Optional LangChain-compatible chat model to use instead of
LiteLLM/OpenAI-style identifiers. Must return JSON matching `Response`.
context_window: Optional context window override (e.g., Ollama `num_ctx`).
If provided, it sets litellm `num_ctx` and `max_input_tokens` unless
already supplied in extra_litellm_params.
rubric: The grading rubric. The default rubric works well for most tasks.
debug: If True, pretty-print the judge's reasoning to help understand scores.

Expand Down Expand Up @@ -172,30 +180,58 @@ async def ruler(
{"role": "user", "content": user_text},
]

response = await acompletion(
model=judge_model,
messages=messages,
response_format=Response,
caching=False,
**extra_litellm_params if extra_litellm_params else {},
)
assert isinstance(response, ModelResponse)
litellm_params = dict(extra_litellm_params) if extra_litellm_params else {}
if context_window is not None:
litellm_params.setdefault("num_ctx", context_window)
litellm_params.setdefault("max_input_tokens", context_window)

if judge_chat_model is not None:
lc_messages = [
SystemMessage(content=judge_prompt),
HumanMessage(content=user_text),
]
result = await judge_chat_model.ainvoke(lc_messages)
content = getattr(result, "content", result)
if isinstance(content, BaseMessage):
content = content.content
if isinstance(content, list):
content = "".join(
part["text"] if isinstance(part, dict) and "text" in part else str(part)
for part in content
)
if debug:
try:
print("\n[RULER] Pretty-printed LLM choice JSON:")
print(json.loads(content))
except json.JSONDecodeError as e:
print(f"[RULER] Could not parse choice content as JSON: {e}")
print(f"[RULER] Raw choice content: {content}")
parsed = Response.model_validate_json(content)
else:
response = await acompletion(
model=judge_model,
messages=messages,
response_format=Response,
caching=False,
**litellm_params,
)
assert isinstance(response, ModelResponse)

if len(response.choices) == 0:
raise ValueError(f"No choices in response: {response}")
first_choice = response.choices[0]
if len(response.choices) == 0:
raise ValueError(f"No choices in response: {response}")
first_choice = response.choices[0]

if debug:
raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined]
try:
print("\n[RULER] Pretty-printed LLM choice JSON:")
print(json.loads(raw_content))
except json.JSONDecodeError as e:
print(f"[RULER] Could not parse choice content as JSON: {e}")
print(f"[RULER] Raw choice content: {raw_content}")
if debug:
raw_content = first_choice.message.content or "{}" # type: ignore[attr-defined]
try:
print("\n[RULER] Pretty-printed LLM choice JSON:")
print(json.loads(raw_content))
except json.JSONDecodeError as e:
print(f"[RULER] Could not parse choice content as JSON: {e}")
print(f"[RULER] Raw choice content: {raw_content}")

content = first_choice.message.content or "{}" # type: ignore[attr-defined]
parsed = Response.model_validate_json(content)
content = first_choice.message.content or "{}" # type: ignore[attr-defined]
parsed = Response.model_validate_json(content)

# If all trajectories were identical, we only sent one to the judge
# Duplicate the score for all trajectories
Expand All @@ -222,6 +258,8 @@ async def ruler_score_group(
judge_model: str = "openai/o3",
extra_litellm_params: dict | None = None,
rubric: str = DEFAULT_RUBRIC,
judge_chat_model=None,
context_window: int | None = None,
*,
swallow_exceptions: bool = False,
debug: bool = False,
Expand All @@ -242,6 +280,10 @@ async def ruler_score_group(
group: A TrajectoryGroup containing trajectories to score.
judge_model: The model to use for judging. See `ruler` for options.
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
judge_chat_model: Optional LangChain-compatible chat model to bypass LiteLLM /
OpenAI identifiers (e.g., ChatOllama, ChatNVIDIA).
context_window: Optional context window override (e.g., Ollama `num_ctx`).
Sets litellm `num_ctx`/`max_input_tokens` if not already set.
rubric: Custom rubric or use the default which works well for most tasks.
swallow_exceptions: If True, returns None on errors instead of raising.
This is recommended for production to handle API failures gracefully.
Expand Down Expand Up @@ -298,6 +340,8 @@ async def ruler_score_group(
message_lists,
judge_model=judge_model,
extra_litellm_params=extra_litellm_params,
judge_chat_model=judge_chat_model,
context_window=context_window,
rubric=rubric,
debug=debug,
)
Expand Down