diff --git a/tests/trace_server/query_builder/test_agent_query_builder.py b/tests/trace_server/query_builder/test_agent_query_builder.py index 68dc4df1bc6f..69f34f642a67 100644 --- a/tests/trace_server/query_builder/test_agent_query_builder.py +++ b/tests/trace_server/query_builder/test_agent_query_builder.py @@ -12,6 +12,7 @@ import sqlparse from pydantic import ValidationError +from weave.trace_server.agents.schema import AgentSpanCHInsertable from weave.trace_server.agents.types import ( AgentConversationChatReq, AgentCustomAttrsSchemaReq, @@ -20,6 +21,7 @@ AgentSortBy, AgentSpanGroupFilter, AgentSpanMeasureSpec, + AgentSpanSchema, AgentSpansQueryReq, AgentSpanValueRef, AgentsQueryFilters, @@ -188,6 +190,52 @@ def test_basic(self) -> None: expected_params = {"genai_0": "p1", "genai_1": 100, "genai_2": 0} assert_sql(expected, expected_params, query, pb.get_params()) + def test_selects_cache_token_fields(self) -> None: + assert "cache_creation_input_tokens" in SPANS_LIST_COLS + assert "cache_read_input_tokens" in SPANS_LIST_COLS + + def test_agent_span_schema_covers_persisted_columns(self) -> None: + missing = set(AgentSpanCHInsertable.model_fields) - set( + AgentSpanSchema.model_fields + ) + assert missing == set() + + def test_selects_requested_columns(self) -> None: + pb = ParamBuilder("genai") + query = make_spans_list_query( + pb, + AgentSpansQueryReq( + project_id="p1", + columns=[ + "request_temperature", + "custom_attrs_float", + "raw_span_dump", + ], + ), + ) + + expected = """ + SELECT project_id, trace_id, span_id, request_temperature, custom_attrs_float, raw_span_dump + FROM spans s + WHERE s.project_id = {genai_0:String} + ORDER BY started_at DESC + LIMIT {genai_1:UInt64} OFFSET {genai_2:UInt64} + """ + expected_params = {"genai_0": "p1", "genai_1": 100, "genai_2": 0} + assert_sql(expected, expected_params, query, pb.get_params()) + + def test_rejects_unknown_requested_columns(self) -> None: + with pytest.raises(ValidationError): + AgentSpansQueryReq(project_id="p1", columns=["does_not_exist"]) + + def test_requested_columns_rejected_with_group_by(self) -> None: + with pytest.raises(ValidationError): + AgentSpansQueryReq( + project_id="p1", + group_by=[AgentGroupByRef(source="column", key="agent_name")], + columns=["raw_span_dump"], + ) + def test_with_custom_sort(self) -> None: pb = ParamBuilder("genai") query = make_spans_list_query( @@ -740,6 +788,10 @@ def test_basic(self) -> None: pb.get_params(), ) + def test_selects_cache_token_fields(self) -> None: + assert "cache_creation_input_tokens" in CHAT_VIEW_COLS + assert "cache_read_input_tokens" in CHAT_VIEW_COLS + # ============================================================================ # make_agents_count_query / make_agents_list_query diff --git a/weave/trace_server/agents/types.py b/weave/trace_server/agents/types.py index ea282c03c3f9..949508856618 100644 --- a/weave/trace_server/agents/types.py +++ b/weave/trace_server/agents/types.py @@ -413,6 +413,7 @@ class AgentSpanSchema(BaseModel): span_kind: SpanKindLiteral | None = None started_at: datetime.datetime | None = None ended_at: datetime.datetime | None = None + created_at: datetime.datetime | None = None status_code: StatusCodeLiteral | None = None status_message: str | None = None operation_name: str | None = None @@ -465,10 +466,15 @@ class AgentSpanSchema(BaseModel): custom_attrs_bool: dict[str, bool] = Field(default_factory=dict) server_address: str | None = None server_port: int | None = None + raw_span_dump: str | None = None + attributes_dump: str | None = None + events_dump: str | None = None + resource_dump: str | None = None wb_user_id: str | None = None wb_run_id: str | None = None wb_run_step: int | None = None wb_run_step_end: int | None = None + expire_at: datetime.datetime | None = None class AgentSortBy(BaseModel): @@ -540,6 +546,7 @@ class AgentSpansQueryReq(BaseModel): group_by: list[AgentGroupByRef] | None = None measures: list[AgentSpanMeasureSpec] = Field(default_factory=list) group_filters: list[AgentSpanGroupFilter] = Field(default_factory=list) + columns: list[str] = Field(default_factory=list) custom_attr_columns: list[AgentSpanValueRef] = Field(default_factory=list) sort_by: list[AgentSortBy] | None = None limit: int = Field( @@ -553,10 +560,19 @@ class AgentSpansQueryReq(BaseModel): def validate_spans_query_request(self) -> AgentSpansQueryReq: if (self.measures or self.group_filters) and not self.group_by: raise ValueError("grouped measures and group filters require group_by") + if self.group_by and self.columns: + raise ValueError("columns are only supported for ungrouped spans") if self.group_by and self.custom_attr_columns: raise ValueError( "custom_attr_columns are only supported for ungrouped spans" ) + invalid_columns = [ + column + for column in self.columns + if column not in AgentSpanSchema.model_fields + ] + if invalid_columns: + raise ValueError(f"unknown span columns: {invalid_columns}") invalid_custom_attr_columns = [ col.source for col in self.custom_attr_columns diff --git a/weave/trace_server/query_builder/agent_query_builder.py b/weave/trace_server/query_builder/agent_query_builder.py index aa13909e01e4..5a82511b5fd6 100644 --- a/weave/trace_server/query_builder/agent_query_builder.py +++ b/weave/trace_server/query_builder/agent_query_builder.py @@ -257,13 +257,17 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str: return ", ".join(cols) -# Spans list query: lightweight table projection. Custom attrs and raw dumps -# remain queryable/filterable server-side, but the UI does not need to hydrate -# arbitrary Map/blob payloads for every span row. -_SPANS_LIST_FIELD_NAMES = [ +# Spans list query default: lightweight table projection. Callers can request +# additional AgentSpanSchema columns with AgentSpansQueryReq.columns when they +# need heavier typed columns such as messages, custom attribute maps, or raw +# span dumps. +_SPANS_LIST_REQUIRED_FIELD_NAMES = [ "project_id", "trace_id", "span_id", +] +_SPANS_LIST_DEFAULT_FIELD_NAMES = [ + *_SPANS_LIST_REQUIRED_FIELD_NAMES, "parent_span_id", "span_name", "span_kind", @@ -283,6 +287,8 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str: "input_tokens", "output_tokens", "reasoning_tokens", + "cache_creation_input_tokens", + "cache_read_input_tokens", "conversation_id", "conversation_name", "tool_name", @@ -293,7 +299,17 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str: "wb_user_id", "wb_run_id", ] -SPANS_LIST_COLS: str = _projection(_SPANS_LIST_FIELD_NAMES) +SPANS_LIST_COLS: str = _projection(_SPANS_LIST_DEFAULT_FIELD_NAMES) + + +def _spans_list_field_names(req: AgentSpansQueryReq) -> list[str]: + if not req.columns: + return _SPANS_LIST_DEFAULT_FIELD_NAMES + out: list[str] = [] + for col in [*_SPANS_LIST_REQUIRED_FIELD_NAMES, *req.columns]: + if col not in out: + out.append(col) + return out def _custom_attr_field_name(ref: AgentSpanValueRef) -> str: @@ -305,13 +321,15 @@ def _custom_attr_map_projection( refs: list[AgentSpanValueRef], *, table_alias: str = "s", + skip_sources: set[str] | None = None, ) -> str: """Project only selected custom-attribute Map keys for spans table rows.""" keys_by_source: dict[str, set[str]] = { source: set() for source in sorted(_CUSTOM_ATTR_SOURCES) } + skip_sources = skip_sources or set() for ref in refs: - if ref.source in keys_by_source: + if ref.source in keys_by_source and ref.source not in skip_sources: keys_by_source[ref.source].add(str(ref.key)) projections: list[str] = [] @@ -372,6 +390,8 @@ def _custom_attr_sort_exprs( "input_tokens", "output_tokens", "reasoning_tokens", + "cache_creation_input_tokens", + "cache_read_input_tokens", "reasoning_content", "conversation_id", "conversation_name", @@ -900,11 +920,14 @@ def make_spans_list_query(pb: ParamBuilder, req: AgentSpansQueryReq) -> str: "started_at DESC", column_exprs=custom_sort_exprs, ) + span_list_field_names = _spans_list_field_names(req) custom_attr_projection = _custom_attr_map_projection( - pb, req.custom_attr_columns + pb, + req.custom_attr_columns, + skip_sources=set(span_list_field_names), ) return f""" - SELECT {SPANS_LIST_COLS}{custom_attr_projection} + SELECT {_projection(span_list_field_names)}{custom_attr_projection} FROM spans s WHERE {span_filters.where} ORDER BY {order_by}