Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
DatasetBuilder,
GenerateMetadata,
GenerateOutputs,
LLMClient,
LLMClientMap,
LogCallback,
Messages,
MessageType,
Expand Down Expand Up @@ -424,7 +426,7 @@ def resolve_optional_args(
MessageType,
]:
"""Resolve optional arguments, fallback to state or class defaults."""
client = client or state["client"]
client = client or state["client"][state["current_client"]]
model = model or state["model"]
assert client is not None and model is not None
oai_tools = oai_tools or state["oai_tools"]
Expand Down Expand Up @@ -622,7 +624,7 @@ async def get_model_response_with_tokens(
async def init_state(
self,
input: RolloutInput,
client: AsyncOpenAI,
client: LLMClientMap | LLMClient,
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
Expand All @@ -639,7 +641,13 @@ async def init_state(
if "task" not in state_input:
state_input["task"] = self.env_id or "default"
state = State(input=RolloutInput(**state_input)) # type: ignore[missing-typed-dict-key]

if isinstance(client, LLMClient):
client = {"default": client}

state["client"] = client
state["current_client"] = list(client.keys())[0]

state["model"] = model
state["sampling_args"] = sampling_args
state["is_completed"] = False
Expand Down Expand Up @@ -669,7 +677,7 @@ async def init_state(
async def rollout(
self,
input: RolloutInput,
client: AsyncOpenAI,
client: LLMClientMap | LLMClient,
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
Expand Down
17 changes: 16 additions & 1 deletion verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import verifiers as vf
from verifiers.types import (
LLMClient,
LLMClientMap,
Messages,
ModelResponse,
RolloutInput,
Expand Down Expand Up @@ -65,6 +67,12 @@ async def has_final_env_response(self, state: State) -> bool:
"""Check if env_response signaled termination via final_env_response."""
return state.get("final_env_response") is not None

def get_next_client(self, state: State) -> str:
"""
Override to implement client switching logic. Must be a key of the dictionary in state["client"].
"""
return state["current_client"]

async def setup_state(self, state: State) -> State:
"""Override to add environment-specific state fields."""
return state
Expand All @@ -78,6 +86,13 @@ async def get_prompt_messages(self, state: State) -> Messages:
prev_turn_completion = state["trajectory"][-1]["completion"]
messages = concat_messages([prev_turn_prompt, prev_turn_completion])
env_response = await self.env_response(messages, state)
next_client = self.get_next_client(state)
state["current_client"] = next_client

# TODO: reconstruct messages in the form compatible with the next client
# System messages rewrite, etc.
# if next_client == "anthropic":
# <logic>
return concat_messages([messages, env_response])

async def render_completion(self, state: State):
Expand Down Expand Up @@ -129,7 +144,7 @@ async def add_model_response(
async def rollout(
self,
input: RolloutInput,
client: AsyncOpenAI,
client: LLMClientMap | LLMClient,
model: str,
sampling_args: SamplingArgs | None = None,
) -> State:
Expand Down
5 changes: 4 additions & 1 deletion verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]]
RewardFunc = IndividualRewardFunc | GroupRewardFunc
DatasetBuilder = Callable[[], Dataset]
LLMClient = AsyncOpenAI # add anthropic support later
LLMClientMap = dict[str, LLMClient]


class TrajectoryStepTokens(TypedDict):
Expand Down Expand Up @@ -99,7 +101,8 @@ class State(dict):
INPUT_FIELDS = ["prompt", "answer", "task", "info", "example_id"]
# rollout inputs
input: RolloutInput
client: AsyncOpenAI
client: LLMClientMap
current_client: str
model: str
sampling_args: SamplingArgs | None
# created during rollout
Expand Down