diff --git a/agentex-ui/hooks/use-task-messages.ts b/agentex-ui/hooks/use-task-messages.ts index 61b3fb28..0a1da8df 100644 --- a/agentex-ui/hooks/use-task-messages.ts +++ b/agentex-ui/hooks/use-task-messages.ts @@ -116,19 +116,27 @@ export function useSendMessage({ throw new Error(response.error.message); } - queryClient.setQueryData(queryKey, data => ({ - messages: data?.messages || [], - deltaAccumulator: data?.deltaAccumulator || null, - rpcStatus: 'pending', - })); + // Refetch messages and spans now that the agent has finished processing + await queryClient.invalidateQueries({ queryKey: taskMessagesKeys.byTaskId(taskId) }); + queryClient.invalidateQueries({ queryKey: ['spans', taskId] }); - return ( - queryClient.getQueryData(queryKey) || { - messages: [], - deltaAccumulator: null, - rpcStatus: 'pending', - } - ); + const finalMessages = await agentexClient.messages.list({ + task_id: taskId, + }); + + const chronologicalMessages = finalMessages.slice().reverse(); + + queryClient.setQueryData(queryKey, { + messages: chronologicalMessages, + deltaAccumulator: null, + rpcStatus: 'success', + }); + + return { + messages: chronologicalMessages, + deltaAccumulator: null, + rpcStatus: 'success', + }; } case 'sync': { diff --git a/agentex/database/migrations/alembic/versions/2026_02_11_0802_add_langgraph_checkpoint_tables_d1a6cde41b3f.py b/agentex/database/migrations/alembic/versions/2026_02_11_0802_add_langgraph_checkpoint_tables_d1a6cde41b3f.py new file mode 100644 index 00000000..254949de --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_02_11_0802_add_langgraph_checkpoint_tables_d1a6cde41b3f.py @@ -0,0 +1,84 @@ +"""add_langgraph_checkpoint_tables + +Revision ID: d1a6cde41b3f +Revises: d024851e790c +Create Date: 2026-02-11 08:02:10.739927 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'd1a6cde41b3f' +down_revision: Union[str, None] = 'd024851e790c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # checkpoint_migrations + op.create_table('checkpoint_migrations', + sa.Column('v', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('v') + ) + + # checkpoints + op.create_table('checkpoints', + sa.Column('thread_id', sa.Text(), nullable=False), + sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False), + sa.Column('checkpoint_id', sa.Text(), nullable=False), + sa.Column('parent_checkpoint_id', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('checkpoint', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), server_default='{}', nullable=False), + sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id') + ) + op.create_index('checkpoints_thread_id_idx', 'checkpoints', ['thread_id'], unique=False) + + # checkpoint_blobs + op.create_table('checkpoint_blobs', + sa.Column('thread_id', sa.Text(), nullable=False), + sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False), + sa.Column('channel', sa.Text(), nullable=False), + sa.Column('version', sa.Text(), nullable=False), + sa.Column('type', sa.Text(), nullable=False), + sa.Column('blob', sa.LargeBinary(), nullable=True), + sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'channel', 'version') + ) + op.create_index('checkpoint_blobs_thread_id_idx', 'checkpoint_blobs', ['thread_id'], unique=False) + + # checkpoint_writes + op.create_table('checkpoint_writes', + sa.Column('thread_id', sa.Text(), nullable=False), + sa.Column('checkpoint_ns', sa.Text(), server_default='', nullable=False), + sa.Column('checkpoint_id', sa.Text(), nullable=False), + sa.Column('task_id', sa.Text(), nullable=False), + sa.Column('idx', sa.Integer(), nullable=False), + sa.Column('channel', sa.Text(), nullable=False), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('blob', sa.LargeBinary(), nullable=False), + sa.Column('task_path', sa.Text(), server_default='', nullable=False), + sa.PrimaryKeyConstraint('thread_id', 'checkpoint_ns', 'checkpoint_id', 'task_id', 'idx') + ) + op.create_index('checkpoint_writes_thread_id_idx', 'checkpoint_writes', ['thread_id'], unique=False) + + # Pre-populate checkpoint_migrations so LangGraph sees all its + # internal migrations as already applied and skips setup(). + op.execute( + sa.text( + "INSERT INTO checkpoint_migrations (v) VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)" + ) + ) + + +def downgrade() -> None: + op.drop_index('checkpoint_writes_thread_id_idx', table_name='checkpoint_writes') + op.drop_table('checkpoint_writes') + op.drop_index('checkpoint_blobs_thread_id_idx', table_name='checkpoint_blobs') + op.drop_table('checkpoint_blobs') + op.drop_index('checkpoints_thread_id_idx', table_name='checkpoints') + op.drop_table('checkpoints') + op.drop_table('checkpoint_migrations') diff --git a/agentex/database/migrations/migration_history.txt b/agentex/database/migrations/migration_history.txt index 3754d855..2671729c 100644 --- a/agentex/database/migrations/migration_history.txt +++ b/agentex/database/migrations/migration_history.txt @@ -1,4 +1,5 @@ -24429f13b8bd -> d024851e790c (head), add_performance_indexes +d024851e790c -> d1a6cde41b3f (head), add_langgraph_checkpoint_tables +24429f13b8bd -> d024851e790c, add_performance_indexes a5d67f2d7356 -> 24429f13b8bd, add agent input type 329fbafa4ff9 -> a5d67f2d7356, add unhealthy status d7addd4229e8 -> 329fbafa4ff9, change_default_acp_to_async diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index 8d24fed7..b5545bc3 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -5,6 +5,8 @@ DateTime, ForeignKey, Index, + Integer, + LargeBinary, String, Text, func, @@ -213,3 +215,56 @@ class DeploymentHistoryORM(BaseORM): "commit_hash", ), ) + + +# LangGraph checkpoint tables +# These mirror the schema from langgraph.checkpoint.postgres so that +# tables are created via Alembic migrations rather than at agent runtime. + + +class CheckpointMigrationORM(BaseORM): + __tablename__ = "checkpoint_migrations" + v = Column(Integer, primary_key=True) + + +class CheckpointORM(BaseORM): + __tablename__ = "checkpoints" + thread_id = Column(Text, nullable=False, primary_key=True) + checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="") + checkpoint_id = Column(Text, nullable=False, primary_key=True) + parent_checkpoint_id = Column(Text, nullable=True) + type = Column(Text, nullable=True) + checkpoint = Column(JSONB, nullable=False) + metadata_ = Column("metadata", JSONB, nullable=False, server_default="{}") + __table_args__ = ( + Index("checkpoints_thread_id_idx", "thread_id"), + ) + + +class CheckpointBlobORM(BaseORM): + __tablename__ = "checkpoint_blobs" + thread_id = Column(Text, nullable=False, primary_key=True) + checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="") + channel = Column(Text, nullable=False, primary_key=True) + version = Column(Text, nullable=False, primary_key=True) + type = Column(Text, nullable=False) + blob = Column(LargeBinary, nullable=True) + __table_args__ = ( + Index("checkpoint_blobs_thread_id_idx", "thread_id"), + ) + + +class CheckpointWriteORM(BaseORM): + __tablename__ = "checkpoint_writes" + thread_id = Column(Text, nullable=False, primary_key=True) + checkpoint_ns = Column(Text, nullable=False, primary_key=True, server_default="") + checkpoint_id = Column(Text, nullable=False, primary_key=True) + task_id = Column(Text, nullable=False, primary_key=True) + idx = Column(Integer, nullable=False, primary_key=True) + channel = Column(Text, nullable=False) + type = Column(Text, nullable=True) + blob = Column(LargeBinary, nullable=False) + task_path = Column(Text, nullable=False, server_default="") + __table_args__ = ( + Index("checkpoint_writes_thread_id_idx", "thread_id"), + ) diff --git a/agentex/src/api/app.py b/agentex/src/api/app.py index d24ec06e..1090ebe0 100644 --- a/agentex/src/api/app.py +++ b/agentex/src/api/app.py @@ -19,6 +19,7 @@ agent_api_keys, agent_task_tracker, agents, + checkpoints, deployment_history, events, messages, @@ -183,6 +184,7 @@ async def handle_unexpected(request, exc): fastapi_app.include_router(agent_api_keys.router) fastapi_app.include_router(deployment_history.router) fastapi_app.include_router(schedules.router) +fastapi_app.include_router(checkpoints.router) # Wrap FastAPI app with health check interceptor for sub-millisecond K8s probe responses. # This must be the outermost layer to bypass all middleware. diff --git a/agentex/src/api/routes/checkpoints.py b/agentex/src/api/routes/checkpoints.py new file mode 100644 index 00000000..81f3a831 --- /dev/null +++ b/agentex/src/api/routes/checkpoints.py @@ -0,0 +1,206 @@ +import base64 + +from fastapi import APIRouter, Response + +from src.api.schemas.checkpoints import ( + BlobResponse, + CheckpointListItem, + CheckpointTupleResponse, + DeleteThreadRequest, + GetCheckpointTupleRequest, + ListCheckpointsRequest, + PutCheckpointRequest, + PutCheckpointResponse, + PutWritesRequest, + WriteResponse, +) +from src.api.schemas.authorization_types import ( + AgentexResourceType, + AuthorizedOperationType, +) +from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase +from src.utils.authorization_shortcuts import DAuthorizedBodyId +from src.utils.logging import make_logger + +logger = make_logger(__name__) + +router = APIRouter(prefix="/checkpoints", tags=["Checkpoints"]) + + +def _bytes_to_b64(data: bytes | None) -> str | None: + if data is None: + return None + return base64.b64encode(data).decode("ascii") + + +def _b64_to_bytes(data: str | None) -> bytes | None: + if data is None: + return None + return base64.b64decode(data) + + +@router.post( + "/get-tuple", + response_model=CheckpointTupleResponse | None, +) +async def get_checkpoint_tuple( + request: GetCheckpointTupleRequest, + checkpoints_use_case: DCheckpointsUseCase, + _authorized_task_id: DAuthorizedBodyId( + AgentexResourceType.task, AuthorizedOperationType.read, field_name="thread_id" + ), +) -> CheckpointTupleResponse | None: + result = await checkpoints_use_case.get_tuple( + thread_id=request.thread_id, + checkpoint_ns=request.checkpoint_ns, + checkpoint_id=request.checkpoint_id, + ) + if result is None: + return None + + return CheckpointTupleResponse( + thread_id=result["thread_id"], + checkpoint_ns=result["checkpoint_ns"], + checkpoint_id=result["checkpoint_id"], + parent_checkpoint_id=result["parent_checkpoint_id"], + checkpoint=result["checkpoint"], + metadata=result["metadata"], + blobs=[ + BlobResponse( + channel=b["channel"], + version=b["version"], + type=b["type"], + blob=_bytes_to_b64(b["blob"]), + ) + for b in result.get("blobs", []) + ], + pending_writes=[ + WriteResponse( + task_id=w["task_id"], + idx=w["idx"], + channel=w["channel"], + type=w["type"], + blob=_bytes_to_b64(w["blob"]), + ) + for w in result.get("pending_writes", []) + ], + ) + + +@router.post( + "/put", + response_model=PutCheckpointResponse, +) +async def put_checkpoint( + request: PutCheckpointRequest, + checkpoints_use_case: DCheckpointsUseCase, + _authorized_task_id: DAuthorizedBodyId( + AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + ), +) -> PutCheckpointResponse: + blobs = [ + { + "channel": b.channel, + "version": b.version, + "type": b.type, + "blob": _b64_to_bytes(b.blob), + } + for b in request.blobs + ] + + await checkpoints_use_case.put( + thread_id=request.thread_id, + checkpoint_ns=request.checkpoint_ns, + checkpoint_id=request.checkpoint_id, + parent_checkpoint_id=request.parent_checkpoint_id, + checkpoint=request.checkpoint, + metadata=request.metadata, + blobs=blobs, + ) + + return PutCheckpointResponse( + thread_id=request.thread_id, + checkpoint_ns=request.checkpoint_ns, + checkpoint_id=request.checkpoint_id, + ) + + +@router.post( + "/put-writes", + status_code=204, +) +async def put_writes( + request: PutWritesRequest, + checkpoints_use_case: DCheckpointsUseCase, + _authorized_task_id: DAuthorizedBodyId( + AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + ), +) -> Response: + writes = [ + { + "task_id": w.task_id, + "idx": w.idx, + "channel": w.channel, + "type": w.type, + "blob": _b64_to_bytes(w.blob), + "task_path": w.task_path, + } + for w in request.writes + ] + + await checkpoints_use_case.put_writes( + thread_id=request.thread_id, + checkpoint_ns=request.checkpoint_ns, + checkpoint_id=request.checkpoint_id, + writes=writes, + upsert=request.upsert, + ) + + return Response(status_code=204) + + +@router.post( + "/list", + response_model=list[CheckpointListItem], +) +async def list_checkpoints( + request: ListCheckpointsRequest, + checkpoints_use_case: DCheckpointsUseCase, + _authorized_task_id: DAuthorizedBodyId( + AgentexResourceType.task, AuthorizedOperationType.read, field_name="thread_id" + ), +) -> list[CheckpointListItem]: + results = await checkpoints_use_case.list_checkpoints( + thread_id=request.thread_id, + checkpoint_ns=request.checkpoint_ns, + before_checkpoint_id=request.before_checkpoint_id, + filter_metadata=request.filter_metadata, + limit=request.limit, + ) + + return [ + CheckpointListItem( + thread_id=r["thread_id"], + checkpoint_ns=r["checkpoint_ns"], + checkpoint_id=r["checkpoint_id"], + parent_checkpoint_id=r["parent_checkpoint_id"], + checkpoint=r["checkpoint"], + metadata=r["metadata"], + ) + for r in results + ] + + +@router.post( + "/delete-thread", + status_code=204, +) +async def delete_thread( + request: DeleteThreadRequest, + checkpoints_use_case: DCheckpointsUseCase, + _authorized_task_id: DAuthorizedBodyId( + AgentexResourceType.task, AuthorizedOperationType.delete, field_name="thread_id" + ), +) -> Response: + await checkpoints_use_case.delete_thread(thread_id=request.thread_id) + return Response(status_code=204) diff --git a/agentex/src/api/schemas/checkpoints.py b/agentex/src/api/schemas/checkpoints.py new file mode 100644 index 00000000..4b910315 --- /dev/null +++ b/agentex/src/api/schemas/checkpoints.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +# ── Request models ── + + +class GetCheckpointTupleRequest(BaseModel): + thread_id: str = Field(..., title="Thread ID") + checkpoint_ns: str = Field("", title="Checkpoint namespace") + checkpoint_id: str | None = Field(None, title="Checkpoint ID (None = latest)") + + +class PutCheckpointRequest(BaseModel): + thread_id: str = Field(..., title="Thread ID") + checkpoint_ns: str = Field("", title="Checkpoint namespace") + checkpoint_id: str = Field(..., title="Checkpoint ID") + parent_checkpoint_id: str | None = Field(None, title="Parent checkpoint ID") + checkpoint: dict[str, Any] = Field(..., title="Checkpoint JSONB payload") + metadata: dict[str, Any] = Field(default_factory=dict, title="Checkpoint metadata") + blobs: list[BlobData] = Field(default_factory=list, title="Channel blob data") + + +class BlobData(BaseModel): + channel: str = Field(..., title="Channel name") + version: str = Field(..., title="Channel version") + type: str = Field(..., title="Serialization type tag") + blob: str | None = Field(None, title="Base64-encoded binary data") + + +# Rebuild PutCheckpointRequest now that BlobData is defined +PutCheckpointRequest.model_rebuild() + + +class WriteData(BaseModel): + task_id: str = Field(..., title="Task ID") + idx: int = Field(..., title="Write index") + channel: str = Field(..., title="Channel name") + type: str | None = Field(None, title="Serialization type tag") + blob: str = Field(..., title="Base64-encoded binary data") + task_path: str = Field("", title="Task path") + + +class PutWritesRequest(BaseModel): + thread_id: str = Field(..., title="Thread ID") + checkpoint_ns: str = Field("", title="Checkpoint namespace") + checkpoint_id: str = Field(..., title="Checkpoint ID") + writes: list[WriteData] = Field(..., title="Write data") + upsert: bool = Field(False, title="Upsert mode") + + +class ListCheckpointsRequest(BaseModel): + thread_id: str = Field(..., title="Thread ID") + checkpoint_ns: str | None = Field(None, title="Checkpoint namespace") + before_checkpoint_id: str | None = Field(None, title="Before checkpoint ID") + filter_metadata: dict[str, Any] | None = Field( + None, title="Metadata filter (JSONB @>)" + ) + limit: int = Field(100, title="Max results", ge=1, le=1000) + + +class DeleteThreadRequest(BaseModel): + thread_id: str = Field(..., title="Thread ID") + + +# ── Response models ── + + +class BlobResponse(BaseModel): + channel: str + version: str + type: str + blob: str | None = None # base64 + + +class WriteResponse(BaseModel): + task_id: str + idx: int + channel: str + type: str | None = None + blob: str | None = None # base64 + + +class CheckpointTupleResponse(BaseModel): + thread_id: str + checkpoint_ns: str + checkpoint_id: str + parent_checkpoint_id: str | None = None + checkpoint: dict[str, Any] + metadata: dict[str, Any] + blobs: list[BlobResponse] = Field(default_factory=list) + pending_writes: list[WriteResponse] = Field(default_factory=list) + + +class CheckpointListItem(BaseModel): + thread_id: str + checkpoint_ns: str + checkpoint_id: str + parent_checkpoint_id: str | None = None + checkpoint: dict[str, Any] + metadata: dict[str, Any] + + +class PutCheckpointResponse(BaseModel): + thread_id: str + checkpoint_ns: str + checkpoint_id: str diff --git a/agentex/src/domain/repositories/checkpoint_repository.py b/agentex/src/domain/repositories/checkpoint_repository.py new file mode 100644 index 00000000..b320517d --- /dev/null +++ b/agentex/src/domain/repositories/checkpoint_repository.py @@ -0,0 +1,307 @@ +from typing import Annotated, Any + +from fastapi import Depends +from sqlalchemy import delete, or_, select +from sqlalchemy.dialects.postgresql import insert +from src.adapters.crud_store.adapter_postgres import async_sql_exception_handler +from src.adapters.orm import CheckpointBlobORM, CheckpointORM, CheckpointWriteORM +from src.config.dependencies import ( + DDatabaseAsyncReadOnlySessionMaker, + DDatabaseAsyncReadWriteSessionMaker, +) +from src.utils.logging import make_logger + +logger = make_logger(__name__) + + +class CheckpointRepository: + """Repository for LangGraph checkpoint operations. + + Uses raw SQLAlchemy queries because the checkpoint tables have + composite primary keys that don't fit the generic CRUD repository. + """ + + def __init__( + self, + async_read_write_session_maker: DDatabaseAsyncReadWriteSessionMaker, + async_read_only_session_maker: DDatabaseAsyncReadOnlySessionMaker, + ): + self.async_rw_session_maker = async_read_write_session_maker + self.async_ro_session_maker = async_read_only_session_maker + + async def get_tuple( + self, + thread_id: str, + checkpoint_ns: str = "", + checkpoint_id: str | None = None, + ) -> dict[str, Any] | None: + """Fetch a checkpoint along with its blobs and pending writes. + + If checkpoint_id is None, returns the latest checkpoint for the thread/ns. + """ + async with ( + self.async_ro_session_maker() as session, + async_sql_exception_handler(), + ): + # Build checkpoint query + query = select(CheckpointORM).where( + CheckpointORM.thread_id == thread_id, + CheckpointORM.checkpoint_ns == checkpoint_ns, + ) + if checkpoint_id: + query = query.where(CheckpointORM.checkpoint_id == checkpoint_id) + else: + query = query.order_by(CheckpointORM.checkpoint_id.desc()).limit(1) + + result = await session.execute(query) + cp = result.scalar_one_or_none() + if cp is None: + return None + + # Fetch blobs whose (channel, version) appears in checkpoint.channel_versions + channel_versions: dict[str, str] = cp.checkpoint.get("channel_versions", {}) + blobs: list[dict[str, Any]] = [] + if channel_versions: + # Build OR conditions for each (channel, version) pair + blob_query = select(CheckpointBlobORM).where( + CheckpointBlobORM.thread_id == thread_id, + CheckpointBlobORM.checkpoint_ns == checkpoint_ns, + ) + # Filter to only matching channel+version pairs + conditions = [] + for channel, version in channel_versions.items(): + conditions.append( + (CheckpointBlobORM.channel == channel) + & (CheckpointBlobORM.version == str(version)) + ) + if conditions: + blob_query = blob_query.where(or_(*conditions)) + + blob_result = await session.execute(blob_query) + for b in blob_result.scalars().all(): + blobs.append( + { + "channel": b.channel, + "version": b.version, + "type": b.type, + "blob": bytes(b.blob) if b.blob is not None else None, + } + ) + + # Fetch pending writes for this checkpoint + writes_query = ( + select(CheckpointWriteORM) + .where( + CheckpointWriteORM.thread_id == thread_id, + CheckpointWriteORM.checkpoint_ns == checkpoint_ns, + CheckpointWriteORM.checkpoint_id == cp.checkpoint_id, + ) + .order_by(CheckpointWriteORM.task_id, CheckpointWriteORM.idx) + ) + writes_result = await session.execute(writes_query) + writes: list[dict[str, Any]] = [] + for w in writes_result.scalars().all(): + writes.append( + { + "task_id": w.task_id, + "idx": w.idx, + "channel": w.channel, + "type": w.type, + "blob": bytes(w.blob) if w.blob is not None else None, + } + ) + + return { + "thread_id": cp.thread_id, + "checkpoint_ns": cp.checkpoint_ns, + "checkpoint_id": cp.checkpoint_id, + "parent_checkpoint_id": cp.parent_checkpoint_id, + "checkpoint": cp.checkpoint, + "metadata": cp.metadata_, + "blobs": blobs, + "pending_writes": writes, + } + + async def put( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + parent_checkpoint_id: str | None, + checkpoint: dict[str, Any], + metadata: dict[str, Any], + blobs: list[dict[str, Any]], + ) -> None: + """Upsert a checkpoint and its blobs in one transaction.""" + async with ( + self.async_rw_session_maker() as session, + async_sql_exception_handler(), + ): + # Upsert blobs + for blob in blobs: + stmt = ( + insert(CheckpointBlobORM) + .values( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + channel=blob["channel"], + version=blob["version"], + type=blob["type"], + blob=blob.get("blob"), + ) + .on_conflict_do_nothing( + index_elements=[ + "thread_id", + "checkpoint_ns", + "channel", + "version", + ] + ) + ) + await session.execute(stmt) + + # Upsert checkpoint + stmt = ( + insert(CheckpointORM) + .values( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + parent_checkpoint_id=parent_checkpoint_id, + checkpoint=checkpoint, + metadata_=metadata, + ) + .on_conflict_do_update( + index_elements=["thread_id", "checkpoint_ns", "checkpoint_id"], + set_={ + "checkpoint": checkpoint, + "metadata": metadata, # use DB column name, not Python attr + }, + ) + ) + await session.execute(stmt) + await session.commit() + + async def put_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + upsert: bool = False, + ) -> None: + """Batch insert/upsert checkpoint writes.""" + async with ( + self.async_rw_session_maker() as session, + async_sql_exception_handler(), + ): + for w in writes: + stmt = insert(CheckpointWriteORM).values( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + task_id=w["task_id"], + idx=w["idx"], + channel=w["channel"], + type=w.get("type"), + blob=w["blob"], + task_path=w.get("task_path", ""), + ) + if upsert: + stmt = stmt.on_conflict_do_update( + index_elements=[ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + ], + set_={ + "channel": w["channel"], + "type": w.get("type"), + "blob": w["blob"], + }, + ) + else: + stmt = stmt.on_conflict_do_nothing( + index_elements=[ + "thread_id", + "checkpoint_ns", + "checkpoint_id", + "task_id", + "idx", + ], + ) + await session.execute(stmt) + await session.commit() + + async def list_checkpoints( + self, + thread_id: str, + checkpoint_ns: str | None = None, + before_checkpoint_id: str | None = None, + filter_metadata: dict[str, Any] | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: + """List checkpoints matching criteria, ordered newest first.""" + async with ( + self.async_ro_session_maker() as session, + async_sql_exception_handler(), + ): + query = select(CheckpointORM).where( + CheckpointORM.thread_id == thread_id + ) + + if checkpoint_ns is not None: + query = query.where(CheckpointORM.checkpoint_ns == checkpoint_ns) + if before_checkpoint_id is not None: + query = query.where(CheckpointORM.checkpoint_id < before_checkpoint_id) + if filter_metadata: + # JSONB containment operator @> + query = query.where(CheckpointORM.metadata_.op("@>")(filter_metadata)) + + query = query.order_by(CheckpointORM.checkpoint_id.desc()) + query = query.limit(limit) + + result = await session.execute(query) + rows = result.scalars().all() + + checkpoints = [] + for cp in rows: + # For list, include checkpoint + metadata but not full blobs/writes + # to keep the response lightweight. Clients call get_tuple for full data. + checkpoints.append( + { + "thread_id": cp.thread_id, + "checkpoint_ns": cp.checkpoint_ns, + "checkpoint_id": cp.checkpoint_id, + "parent_checkpoint_id": cp.parent_checkpoint_id, + "checkpoint": cp.checkpoint, + "metadata": cp.metadata_, + } + ) + return checkpoints + + async def delete_thread(self, thread_id: str) -> None: + """Delete all checkpoint data for a thread.""" + async with ( + self.async_rw_session_maker() as session, + async_sql_exception_handler(), + ): + await session.execute( + delete(CheckpointWriteORM).where( + CheckpointWriteORM.thread_id == thread_id + ) + ) + await session.execute( + delete(CheckpointBlobORM).where( + CheckpointBlobORM.thread_id == thread_id + ) + ) + await session.execute( + delete(CheckpointORM).where(CheckpointORM.thread_id == thread_id) + ) + await session.commit() + + +DCheckpointRepository = Annotated[CheckpointRepository, Depends(CheckpointRepository)] diff --git a/agentex/src/domain/use_cases/agents_acp_use_case.py b/agentex/src/domain/use_cases/agents_acp_use_case.py index dee4d8cd..1c3b97c8 100644 --- a/agentex/src/domain/use_cases/agents_acp_use_case.py +++ b/agentex/src/domain/use_cases/agents_acp_use_case.py @@ -160,6 +160,7 @@ def convert_to_content(self) -> TaskMessageContentEntity: ) return ReasoningContentEntity( author=MessageAuthor.AGENT, + summary=[], content=[reasoning_content_str], ) elif self._delta_type == DeltaType.REASONING_SUMMARY: diff --git a/agentex/src/domain/use_cases/checkpoints_use_case.py b/agentex/src/domain/use_cases/checkpoints_use_case.py new file mode 100644 index 00000000..115113c3 --- /dev/null +++ b/agentex/src/domain/use_cases/checkpoints_use_case.py @@ -0,0 +1,83 @@ +from typing import Annotated, Any + +from fastapi import Depends + +from src.domain.repositories.checkpoint_repository import DCheckpointRepository +from src.utils.logging import make_logger + +logger = make_logger(__name__) + + +class CheckpointsUseCase: + def __init__(self, checkpoint_repository: DCheckpointRepository): + self.checkpoint_repository = checkpoint_repository + + async def get_tuple( + self, + thread_id: str, + checkpoint_ns: str = "", + checkpoint_id: str | None = None, + ) -> dict[str, Any] | None: + return await self.checkpoint_repository.get_tuple( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + ) + + async def put( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + parent_checkpoint_id: str | None, + checkpoint: dict[str, Any], + metadata: dict[str, Any], + blobs: list[dict[str, Any]], + ) -> None: + await self.checkpoint_repository.put( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + parent_checkpoint_id=parent_checkpoint_id, + checkpoint=checkpoint, + metadata=metadata, + blobs=blobs, + ) + + async def put_writes( + self, + thread_id: str, + checkpoint_ns: str, + checkpoint_id: str, + writes: list[dict[str, Any]], + upsert: bool = False, + ) -> None: + await self.checkpoint_repository.put_writes( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + checkpoint_id=checkpoint_id, + writes=writes, + upsert=upsert, + ) + + async def list_checkpoints( + self, + thread_id: str, + checkpoint_ns: str | None = None, + before_checkpoint_id: str | None = None, + filter_metadata: dict[str, Any] | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: + return await self.checkpoint_repository.list_checkpoints( + thread_id=thread_id, + checkpoint_ns=checkpoint_ns, + before_checkpoint_id=before_checkpoint_id, + filter_metadata=filter_metadata, + limit=limit, + ) + + async def delete_thread(self, thread_id: str) -> None: + await self.checkpoint_repository.delete_thread(thread_id=thread_id) + + +DCheckpointsUseCase = Annotated[CheckpointsUseCase, Depends(CheckpointsUseCase)] diff --git a/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py new file mode 100644 index 00000000..062e8c86 --- /dev/null +++ b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py @@ -0,0 +1,666 @@ +""" +Integration tests for the checkpoint repository. + +Tests the CheckpointRepository against a real PostgreSQL database to validate +that our reimplementation of the LangGraph checkpoint storage operations +(get_tuple, put, put_writes, list_checkpoints, delete_thread) works correctly. +""" + +import pytest + + +@pytest.mark.asyncio +class TestCheckpointRepository: + """Integration tests for CheckpointRepository CRUD operations.""" + + # ── put + get_tuple round-trip ── + + async def test_put_and_get_tuple(self, isolated_repositories): + """Test basic round-trip: put a checkpoint then get it back.""" + repo = isolated_repositories["checkpoint_repository"] + + checkpoint_data = { + "id": "cp-1", + "v": 4, + "channel_values": {"counter": 42}, + "channel_versions": {"messages": "00000001.123"}, + } + metadata = {"source": "input", "step": 1, "writes": {}} + blobs = [ + { + "channel": "messages", + "version": "00000001.123", + "type": "json", + "blob": b'["hello"]', + }, + ] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint=checkpoint_data, + metadata=metadata, + blobs=blobs, + ) + + result = await repo.get_tuple( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + ) + + assert result is not None + assert result["thread_id"] == "thread-1" + assert result["checkpoint_ns"] == "" + assert result["checkpoint_id"] == "cp-1" + assert result["parent_checkpoint_id"] is None + assert result["checkpoint"] == checkpoint_data + assert result["metadata"] == metadata + assert len(result["blobs"]) == 1 + assert result["blobs"][0]["channel"] == "messages" + assert result["blobs"][0]["type"] == "json" + assert bytes(result["blobs"][0]["blob"]) == b'["hello"]' + + async def test_put_updates_existing_checkpoint(self, isolated_repositories): + """Test that putting a checkpoint with same PK upserts (updates).""" + repo = isolated_repositories["checkpoint_repository"] + + original = {"id": "cp-1", "v": 4, "channel_values": {"counter": 1}} + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint=original, + metadata={"step": 1}, + blobs=[], + ) + + updated = {"id": "cp-1", "v": 4, "channel_values": {"counter": 99}} + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint=updated, + metadata={"step": 2}, + blobs=[], + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + assert result["checkpoint"]["channel_values"]["counter"] == 99 + assert result["metadata"]["step"] == 2 + + # ── get_tuple: latest checkpoint ── + + async def test_get_tuple_latest(self, isolated_repositories): + """Test that get_tuple without checkpoint_id returns the latest.""" + repo = isolated_repositories["checkpoint_repository"] + + for cp_id in ["cp-1", "cp-2", "cp-3"]: + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id=cp_id, + parent_checkpoint_id=None, + checkpoint={"id": cp_id}, + metadata={}, + blobs=[], + ) + + result = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") + assert result is not None + # "cp-3" is lexicographically greatest → latest + assert result["checkpoint_id"] == "cp-3" + + async def test_get_tuple_not_found(self, isolated_repositories): + """Test that get_tuple returns None for non-existent checkpoint.""" + repo = isolated_repositories["checkpoint_repository"] + + result = await repo.get_tuple( + thread_id="nonexistent", checkpoint_ns="", checkpoint_id="nope" + ) + assert result is None + + # ── blobs ── + + async def test_blobs_only_matching_versions_returned(self, isolated_repositories): + """Test that get_tuple only returns blobs matching channel_versions.""" + repo = isolated_repositories["checkpoint_repository"] + + # Store blobs for two versions + blobs = [ + {"channel": "messages", "version": "v1", "type": "json", "blob": b"old"}, + {"channel": "messages", "version": "v2", "type": "json", "blob": b"new"}, + ] + checkpoint = { + "id": "cp-1", + "v": 4, + "channel_versions": {"messages": "v2"}, + } + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint=checkpoint, + metadata={}, + blobs=blobs, + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + # Should only return v2 blob (matching channel_versions) + assert len(result["blobs"]) == 1 + assert result["blobs"][0]["version"] == "v2" + assert bytes(result["blobs"][0]["blob"]) == b"new" + + # ── pending writes ── + + async def test_put_writes_and_get(self, isolated_repositories): + """Test that writes stored via put_writes appear in get_tuple.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + + writes = [ + { + "task_id": "task-abc", + "idx": 0, + "channel": "messages", + "type": "json", + "blob": b'{"role": "ai"}', + "task_path": "", + }, + { + "task_id": "task-abc", + "idx": 1, + "channel": "output", + "type": "json", + "blob": b'"done"', + "task_path": "", + }, + ] + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=writes, + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + assert len(result["pending_writes"]) == 2 + assert result["pending_writes"][0]["task_id"] == "task-abc" + assert result["pending_writes"][0]["channel"] == "messages" + assert result["pending_writes"][1]["channel"] == "output" + + async def test_put_writes_upsert(self, isolated_repositories): + """Test that upsert=True updates existing writes.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + + original_write = [ + { + "task_id": "task-1", + "idx": 0, + "channel": "messages", + "type": "json", + "blob": b"original", + "task_path": "", + }, + ] + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=original_write, + ) + + updated_write = [ + { + "task_id": "task-1", + "idx": 0, + "channel": "messages", + "type": "json", + "blob": b"updated", + "task_path": "", + }, + ] + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=updated_write, + upsert=True, + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + assert len(result["pending_writes"]) == 1 + assert bytes(result["pending_writes"][0]["blob"]) == b"updated" + + async def test_put_writes_no_upsert_skips_duplicates(self, isolated_repositories): + """Test that upsert=False (default) skips conflicting writes.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + + write = [ + { + "task_id": "task-1", + "idx": 0, + "channel": "messages", + "type": "json", + "blob": b"first", + "task_path": "", + }, + ] + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=write, + ) + + # Try to write again with same PK but different blob — should be skipped + duplicate_write = [ + { + "task_id": "task-1", + "idx": 0, + "channel": "messages", + "type": "json", + "blob": b"second", + "task_path": "", + }, + ] + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=duplicate_write, + upsert=False, + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + assert len(result["pending_writes"]) == 1 + assert bytes(result["pending_writes"][0]["blob"]) == b"first" + + # ── list_checkpoints ── + + async def test_list_checkpoints_basic(self, isolated_repositories): + """Test listing checkpoints returns them in descending order.""" + repo = isolated_repositories["checkpoint_repository"] + + for cp_id in ["cp-1", "cp-2", "cp-3"]: + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id=cp_id, + parent_checkpoint_id=None, + checkpoint={"id": cp_id}, + metadata={"source": "loop"}, + blobs=[], + ) + + results = await repo.list_checkpoints(thread_id="thread-1") + assert len(results) == 3 + # Descending order + assert results[0]["checkpoint_id"] == "cp-3" + assert results[1]["checkpoint_id"] == "cp-2" + assert results[2]["checkpoint_id"] == "cp-1" + + async def test_list_checkpoints_with_metadata_filter(self, isolated_repositories): + """Test JSONB containment filter (@>) on metadata.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={"source": "input", "step": 1}, + blobs=[], + ) + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-2", + parent_checkpoint_id="cp-1", + checkpoint={"id": "cp-2"}, + metadata={"source": "loop", "step": 2, "writes": {"foo": "bar"}}, + blobs=[], + ) + + # Filter by source=loop + results = await repo.list_checkpoints( + thread_id="thread-1", filter_metadata={"source": "loop"} + ) + assert len(results) == 1 + assert results[0]["checkpoint_id"] == "cp-2" + + # Filter by source=input + results = await repo.list_checkpoints( + thread_id="thread-1", filter_metadata={"source": "input"} + ) + assert len(results) == 1 + assert results[0]["checkpoint_id"] == "cp-1" + + # Filter that matches nothing + results = await repo.list_checkpoints( + thread_id="thread-1", filter_metadata={"source": "nonexistent"} + ) + assert len(results) == 0 + + async def test_list_checkpoints_with_before(self, isolated_repositories): + """Test before_checkpoint_id pagination.""" + repo = isolated_repositories["checkpoint_repository"] + + for cp_id in ["cp-1", "cp-2", "cp-3"]: + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id=cp_id, + parent_checkpoint_id=None, + checkpoint={"id": cp_id}, + metadata={}, + blobs=[], + ) + + results = await repo.list_checkpoints( + thread_id="thread-1", before_checkpoint_id="cp-3" + ) + assert len(results) == 2 + assert results[0]["checkpoint_id"] == "cp-2" + assert results[1]["checkpoint_id"] == "cp-1" + + async def test_list_checkpoints_with_limit(self, isolated_repositories): + """Test limit parameter caps results.""" + repo = isolated_repositories["checkpoint_repository"] + + for cp_id in ["cp-1", "cp-2", "cp-3"]: + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id=cp_id, + parent_checkpoint_id=None, + checkpoint={"id": cp_id}, + metadata={}, + blobs=[], + ) + + results = await repo.list_checkpoints(thread_id="thread-1", limit=2) + assert len(results) == 2 + # Should be the two newest + assert results[0]["checkpoint_id"] == "cp-3" + assert results[1]["checkpoint_id"] == "cp-2" + + # ── delete_thread ── + + async def test_delete_thread(self, isolated_repositories): + """Test that delete_thread removes all data for a thread.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1", "channel_versions": {"ch": "v1"}}, + metadata={}, + blobs=[{"channel": "ch", "version": "v1", "type": "json", "blob": b"data"}], + ) + await repo.put_writes( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + writes=[ + { + "task_id": "t1", + "idx": 0, + "channel": "ch", + "type": "json", + "blob": b"w", + "task_path": "", + } + ], + ) + + # Verify data exists + result = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") + assert result is not None + + # Delete + await repo.delete_thread(thread_id="thread-1") + + # Verify everything is gone + result = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") + assert result is None + + results = await repo.list_checkpoints(thread_id="thread-1") + assert len(results) == 0 + + async def test_delete_thread_does_not_affect_other_threads( + self, isolated_repositories + ): + """Test that deleting one thread doesn't affect another.""" + repo = isolated_repositories["checkpoint_repository"] + + for thread_id in ["thread-1", "thread-2"]: + await repo.put( + thread_id=thread_id, + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + + await repo.delete_thread(thread_id="thread-1") + + assert await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") is None + assert await repo.get_tuple(thread_id="thread-2", checkpoint_ns="") is not None + + # ── isolation ── + + async def test_thread_isolation(self, isolated_repositories): + """Test that different thread_ids are fully isolated.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1", "thread": "1"}, + metadata={}, + blobs=[], + ) + await repo.put( + thread_id="thread-2", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1", "thread": "2"}, + metadata={}, + blobs=[], + ) + + r1 = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") + r2 = await repo.get_tuple(thread_id="thread-2", checkpoint_ns="") + + assert r1 is not None + assert r2 is not None + assert r1["checkpoint"]["thread"] == "1" + assert r2["checkpoint"]["thread"] == "2" + + async def test_namespace_isolation(self, isolated_repositories): + """Test that different checkpoint_ns values are isolated.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1", "ns": "root"}, + metadata={}, + blobs=[], + ) + await repo.put( + thread_id="thread-1", + checkpoint_ns="inner", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1", "ns": "inner"}, + metadata={}, + blobs=[], + ) + + root = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="") + inner = await repo.get_tuple(thread_id="thread-1", checkpoint_ns="inner") + + assert root is not None + assert inner is not None + assert root["checkpoint"]["ns"] == "root" + assert inner["checkpoint"]["ns"] == "inner" + + async def test_list_checkpoints_filters_by_namespace(self, isolated_repositories): + """Test that list_checkpoints respects checkpoint_ns filter.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + await repo.put( + thread_id="thread-1", + checkpoint_ns="subgraph", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + + root_results = await repo.list_checkpoints( + thread_id="thread-1", checkpoint_ns="" + ) + sub_results = await repo.list_checkpoints( + thread_id="thread-1", checkpoint_ns="subgraph" + ) + + assert len(root_results) == 1 + assert len(sub_results) == 1 + assert root_results[0]["checkpoint_ns"] == "" + assert sub_results[0]["checkpoint_ns"] == "subgraph" + + # ── parent checkpoint tracking ── + + async def test_parent_checkpoint_id_tracked(self, isolated_repositories): + """Test that parent_checkpoint_id is stored and returned correctly.""" + repo = isolated_repositories["checkpoint_repository"] + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint={"id": "cp-1"}, + metadata={}, + blobs=[], + ) + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-2", + parent_checkpoint_id="cp-1", + checkpoint={"id": "cp-2"}, + metadata={}, + blobs=[], + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-2" + ) + assert result is not None + assert result["parent_checkpoint_id"] == "cp-1" + + # ── blob edge cases ── + + async def test_null_blob_stored_correctly(self, isolated_repositories): + """Test that a blob with None data is stored and returned.""" + repo = isolated_repositories["checkpoint_repository"] + + blobs = [ + {"channel": "empty_channel", "version": "v1", "type": "empty", "blob": None}, + ] + checkpoint = { + "id": "cp-1", + "channel_versions": {"empty_channel": "v1"}, + } + + await repo.put( + thread_id="thread-1", + checkpoint_ns="", + checkpoint_id="cp-1", + parent_checkpoint_id=None, + checkpoint=checkpoint, + metadata={}, + blobs=blobs, + ) + + result = await repo.get_tuple( + thread_id="thread-1", checkpoint_ns="", checkpoint_id="cp-1" + ) + assert result is not None + assert len(result["blobs"]) == 1 + assert result["blobs"][0]["type"] == "empty" + assert result["blobs"][0]["blob"] is None diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index 15c9dbfc..ade4bdc4 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -266,6 +266,7 @@ async def __aenter__(self): from src.domain.repositories.span_repository import SpanRepository from src.domain.repositories.task_message_repository import TaskMessageRepository from src.domain.repositories.task_repository import TaskRepository + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.task_state_repository import TaskStateRepository # Create Redis repository with mock environment variables @@ -314,6 +315,10 @@ def __init__(self, redis_url): "task_state_repository": TaskStateRepository(mongodb_database), # Redis repositories "redis_stream_repository": redis_stream_repository, + # Checkpoint repository + "checkpoint_repository": CheckpointRepository( + async_rw_session_factory, async_ro_session_factory + ), # Direct access for advanced use cases "postgres_rw_session_factory": async_rw_session_factory, "postgres_ro_session_factory": async_ro_session_factory, @@ -372,6 +377,7 @@ async def isolated_integration_app( from src.domain.use_cases.messages_use_case import MessagesUseCase from src.domain.use_cases.spans_use_case import SpanUseCase from src.domain.use_cases.states_use_case import StatesUseCase + from src.domain.use_cases.checkpoints_use_case import CheckpointsUseCase from src.domain.use_cases.tasks_use_case import TasksUseCase # Create use case factory functions with isolated repositories @@ -439,6 +445,11 @@ async def send_message(self, *args, **kwargs): return TasksUseCase(task_service=task_service) + def create_checkpoints_use_case(): + return CheckpointsUseCase( + checkpoint_repository=isolated_repositories["checkpoint_repository"], + ) + def create_messages_use_case(): """Create MessagesUseCase for comprehensive testing""" from src.domain.services.task_message_service import TaskMessageService @@ -456,6 +467,7 @@ def create_messages_use_case(): DDatabaseAsyncReadWriteSessionMaker, DMongoDBDatabase, ) + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.agent_api_key_repository import AgentAPIKeyRepository from src.domain.repositories.agent_repository import AgentRepository from src.domain.repositories.agent_task_tracker_repository import ( @@ -483,6 +495,7 @@ def create_messages_use_case(): "postgres_ro_session_factory" ], # Use cases + CheckpointsUseCase: create_checkpoints_use_case, AgentsUseCase: create_agents_use_case, EventUseCase: create_events_use_case, SpanUseCase: create_spans_use_case, @@ -493,6 +506,9 @@ def create_messages_use_case(): AgentAPIKeysUseCase: create_agent_api_keys_use_case, DeploymentHistoryUseCase: create_deployment_history_use_case, # Repositories - these ensure consistent isolated instances + CheckpointRepository: lambda: isolated_repositories[ + "checkpoint_repository" + ], TaskStateRepository: lambda: isolated_repositories["task_state_repository"], TaskMessageRepository: lambda: isolated_repositories[ "task_message_repository"