Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/trace_server/query_builder/test_agent_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sqlparse
from pydantic import ValidationError

from weave.trace_server.agents.schema import AgentSpanCHInsertable
from weave.trace_server.agents.types import (
AgentConversationChatReq,
AgentCustomAttrsSchemaReq,
Expand All @@ -20,6 +21,7 @@
AgentSortBy,
AgentSpanGroupFilter,
AgentSpanMeasureSpec,
AgentSpanSchema,
AgentSpansQueryReq,
AgentSpanValueRef,
AgentsQueryFilters,
Expand Down Expand Up @@ -188,6 +190,52 @@ def test_basic(self) -> None:
expected_params = {"genai_0": "p1", "genai_1": 100, "genai_2": 0}
assert_sql(expected, expected_params, query, pb.get_params())

def test_selects_cache_token_fields(self) -> None:
assert "cache_creation_input_tokens" in SPANS_LIST_COLS
assert "cache_read_input_tokens" in SPANS_LIST_COLS

def test_agent_span_schema_covers_persisted_columns(self) -> None:
missing = set(AgentSpanCHInsertable.model_fields) - set(
AgentSpanSchema.model_fields
)
assert missing == set()

def test_selects_requested_columns(self) -> None:
pb = ParamBuilder("genai")
query = make_spans_list_query(
pb,
AgentSpansQueryReq(
project_id="p1",
columns=[
"request_temperature",
"custom_attrs_float",
"raw_span_dump",
],
),
)

expected = """
SELECT project_id, trace_id, span_id, request_temperature, custom_attrs_float, raw_span_dump
FROM spans s
WHERE s.project_id = {genai_0:String}
ORDER BY started_at DESC
LIMIT {genai_1:UInt64} OFFSET {genai_2:UInt64}
"""
expected_params = {"genai_0": "p1", "genai_1": 100, "genai_2": 0}
assert_sql(expected, expected_params, query, pb.get_params())

def test_rejects_unknown_requested_columns(self) -> None:
with pytest.raises(ValidationError):
AgentSpansQueryReq(project_id="p1", columns=["does_not_exist"])

def test_requested_columns_rejected_with_group_by(self) -> None:
with pytest.raises(ValidationError):
AgentSpansQueryReq(
project_id="p1",
group_by=[AgentGroupByRef(source="column", key="agent_name")],
columns=["raw_span_dump"],
)

def test_with_custom_sort(self) -> None:
pb = ParamBuilder("genai")
query = make_spans_list_query(
Expand Down Expand Up @@ -740,6 +788,10 @@ def test_basic(self) -> None:
pb.get_params(),
)

def test_selects_cache_token_fields(self) -> None:
assert "cache_creation_input_tokens" in CHAT_VIEW_COLS
assert "cache_read_input_tokens" in CHAT_VIEW_COLS


# ============================================================================
# make_agents_count_query / make_agents_list_query
Expand Down
16 changes: 16 additions & 0 deletions weave/trace_server/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class AgentSpanSchema(BaseModel):
span_kind: SpanKindLiteral | None = None
started_at: datetime.datetime | None = None
ended_at: datetime.datetime | None = None
created_at: datetime.datetime | None = None
status_code: StatusCodeLiteral | None = None
status_message: str | None = None
operation_name: str | None = None
Expand Down Expand Up @@ -465,10 +466,15 @@ class AgentSpanSchema(BaseModel):
custom_attrs_bool: dict[str, bool] = Field(default_factory=dict)
server_address: str | None = None
server_port: int | None = None
raw_span_dump: str | None = None
attributes_dump: str | None = None
events_dump: str | None = None
resource_dump: str | None = None
wb_user_id: str | None = None
wb_run_id: str | None = None
wb_run_step: int | None = None
wb_run_step_end: int | None = None
expire_at: datetime.datetime | None = None


class AgentSortBy(BaseModel):
Expand Down Expand Up @@ -540,6 +546,7 @@ class AgentSpansQueryReq(BaseModel):
group_by: list[AgentGroupByRef] | None = None
measures: list[AgentSpanMeasureSpec] = Field(default_factory=list)
group_filters: list[AgentSpanGroupFilter] = Field(default_factory=list)
columns: list[str] = Field(default_factory=list)
custom_attr_columns: list[AgentSpanValueRef] = Field(default_factory=list)
sort_by: list[AgentSortBy] | None = None
limit: int = Field(
Expand All @@ -553,10 +560,19 @@ class AgentSpansQueryReq(BaseModel):
def validate_spans_query_request(self) -> AgentSpansQueryReq:
if (self.measures or self.group_filters) and not self.group_by:
raise ValueError("grouped measures and group filters require group_by")
if self.group_by and self.columns:
raise ValueError("columns are only supported for ungrouped spans")
if self.group_by and self.custom_attr_columns:
raise ValueError(
"custom_attr_columns are only supported for ungrouped spans"
)
invalid_columns = [
column
for column in self.columns
if column not in AgentSpanSchema.model_fields
]
if invalid_columns:
raise ValueError(f"unknown span columns: {invalid_columns}")
invalid_custom_attr_columns = [
col.source
for col in self.custom_attr_columns
Expand Down
39 changes: 31 additions & 8 deletions weave/trace_server/query_builder/agent_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,17 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str:
return ", ".join(cols)


# Spans list query: lightweight table projection. Custom attrs and raw dumps
# remain queryable/filterable server-side, but the UI does not need to hydrate
# arbitrary Map/blob payloads for every span row.
_SPANS_LIST_FIELD_NAMES = [
# Spans list query default: lightweight table projection. Callers can request
# additional AgentSpanSchema columns with AgentSpansQueryReq.columns when they
# need heavier typed columns such as messages, custom attribute maps, or raw
# span dumps.
_SPANS_LIST_REQUIRED_FIELD_NAMES = [
"project_id",
"trace_id",
"span_id",
]
_SPANS_LIST_DEFAULT_FIELD_NAMES = [
*_SPANS_LIST_REQUIRED_FIELD_NAMES,
"parent_span_id",
"span_name",
"span_kind",
Expand All @@ -283,6 +287,8 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str:
"input_tokens",
"output_tokens",
"reasoning_tokens",
"cache_creation_input_tokens",
"cache_read_input_tokens",
"conversation_id",
"conversation_name",
"tool_name",
Expand All @@ -293,7 +299,17 @@ def _projection(cols: list[str], *, table_alias: str | None = None) -> str:
"wb_user_id",
"wb_run_id",
]
SPANS_LIST_COLS: str = _projection(_SPANS_LIST_FIELD_NAMES)
SPANS_LIST_COLS: str = _projection(_SPANS_LIST_DEFAULT_FIELD_NAMES)


def _spans_list_field_names(req: AgentSpansQueryReq) -> list[str]:
if not req.columns:
return _SPANS_LIST_DEFAULT_FIELD_NAMES
out: list[str] = []
for col in [*_SPANS_LIST_REQUIRED_FIELD_NAMES, *req.columns]:
if col not in out:
out.append(col)
return out


def _custom_attr_field_name(ref: AgentSpanValueRef) -> str:
Expand All @@ -305,13 +321,15 @@ def _custom_attr_map_projection(
refs: list[AgentSpanValueRef],
*,
table_alias: str = "s",
skip_sources: set[str] | None = None,
) -> str:
"""Project only selected custom-attribute Map keys for spans table rows."""
keys_by_source: dict[str, set[str]] = {
source: set() for source in sorted(_CUSTOM_ATTR_SOURCES)
}
skip_sources = skip_sources or set()
for ref in refs:
if ref.source in keys_by_source:
if ref.source in keys_by_source and ref.source not in skip_sources:
keys_by_source[ref.source].add(str(ref.key))

projections: list[str] = []
Expand Down Expand Up @@ -372,6 +390,8 @@ def _custom_attr_sort_exprs(
"input_tokens",
"output_tokens",
"reasoning_tokens",
"cache_creation_input_tokens",
"cache_read_input_tokens",
"reasoning_content",
"conversation_id",
"conversation_name",
Expand Down Expand Up @@ -900,11 +920,14 @@ def make_spans_list_query(pb: ParamBuilder, req: AgentSpansQueryReq) -> str:
"started_at DESC",
column_exprs=custom_sort_exprs,
)
span_list_field_names = _spans_list_field_names(req)
custom_attr_projection = _custom_attr_map_projection(
pb, req.custom_attr_columns
pb,
req.custom_attr_columns,
skip_sources=set(span_list_field_names),
)
return f"""
SELECT {SPANS_LIST_COLS}{custom_attr_projection}
SELECT {_projection(span_list_field_names)}{custom_attr_projection}
FROM spans s
WHERE {span_filters.where}
ORDER BY {order_by}
Expand Down
Loading