Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions e2e/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func TestRuntime_OpenAI_Basic(t *testing.T) {
require.NoError(t, err)

response := sess.GetLastAssistantMessageContent()
assert.Equal(t, "2 + 2 is equal to 4.", response)
assert.Equal(t, "Simple Math Calculation", sess.Title)
assert.Equal(t, "2 + 2 equals 4.", response)
assert.Equal(t, "Basic Math Question", sess.Title)
}

func TestRuntime_Mistral_Basic(t *testing.T) {
Expand All @@ -55,5 +55,5 @@ func TestRuntime_Mistral_Basic(t *testing.T) {

response := sess.GetLastAssistantMessageContent()
assert.Equal(t, "The sum of 2 + 2 is 4.", response)
assert.Equal(t, "Basic Arithmetic: Sum of 2 and 2", sess.Title)
assert.Equal(t, "Math Basics: Simple Addition", sess.Title)
}
168 changes: 88 additions & 80 deletions e2e/testdata/cassettes/TestRuntime_Mistral_Basic.yaml

Large diffs are not rendered by default.

118 changes: 78 additions & 40 deletions e2e/testdata/cassettes/TestRuntime_OpenAI_Basic.yaml

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,11 @@ type SessionTitleEvent struct {
AgentContext
}

func SessionTitle(sessionID, title, agentName string) Event {
func SessionTitle(sessionID, title string) Event {
return &SessionTitleEvent{
Type: "session_title",
SessionID: sessionID,
Title: title,
AgentContext: AgentContext{AgentName: agentName},
Type: "session_title",
SessionID: sessionID,
Title: title,
}
}

Expand Down
83 changes: 10 additions & 73 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ type LocalRuntime struct {
elicitationEventsChannel chan Event // Current events channel for sending elicitation requests
elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel
ragInitialized atomic.Bool
titleGenerationWg sync.WaitGroup // Wait group for title generation
titleGen *titleGenerator
}

type streamResult struct {
Expand Down Expand Up @@ -210,6 +210,13 @@ func New(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
return nil, err
}

model := agents.Model()
if model == nil {
return nil, errors.New("no model found for the team; ensure at least one agent has a valid model")
}

r.titleGen = newTitleGenerator(model)

slog.Debug("Creating new runtime", "agent", r.currentAgent, "available_agents", agents.Size())

return r, nil
Expand Down Expand Up @@ -488,8 +495,7 @@ func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.S

telemetry.RecordSessionEnd(ctx)

// Wait for title generation if it's in progress
r.titleGenerationWg.Wait()
r.titleGen.Wait()
}

// RunStream starts the agent's interaction loop and returns a channel of events
Expand Down Expand Up @@ -543,7 +549,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
return
}

// Emit toolset information
events <- ToolsetInfo(len(agentTools), r.currentAgent)

messages := sess.GetMessages(a)
Expand All @@ -558,9 +563,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
r.registerDefaultTools()

if sess.Title == "" {
r.titleGenerationWg.Go(func() {
r.generateSessionTitle(ctx, sess, events)
})
r.titleGen.Generate(ctx, sess, events)
}

iteration := 0
Expand Down Expand Up @@ -1353,72 +1356,6 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool
}, nil
}

// truncateTitle truncates a title to maxLength characters, adding an ellipsis if needed
func truncateTitle(title string, maxLength int) string {
if len(title) <= maxLength {
return title
}
// Ensure we have room for the ellipsis
if maxLength < 3 {
return "..."
}
return title[:maxLength-3] + "..."
}

// generateSessionTitle generates a title for the session based on the first user message
func (r *LocalRuntime) generateSessionTitle(ctx context.Context, sess *session.Session, events chan Event) {
slog.Debug("Generating title for session", "session_id", sess.ID)

firstUserMessage := sess.GetLastUserMessageContent()
if firstUserMessage == "" {
slog.Error("Failed generating session title: no user message found in session", "session_id", sess.ID)
events <- SessionTitle(sess.ID, "Untitled", r.currentAgent)
return
}

systemPrompt := "You are a helpful AI assistant that generates concise, descriptive titles for conversations. You will be given a conversation history and asked to create a title that captures the main topic."
userPrompt := fmt.Sprintf("Based on the following message a user sent to an AI assistant, generate a short, descriptive title (maximum 50 characters) that captures the main topic or purpose of the conversation. Return ONLY the title text, nothing else.\n\nUser message: %s\n\n", firstUserMessage)

titleModel := provider.CloneWithOptions(
ctx,
r.CurrentAgent().Model(),
options.WithStructuredOutput(nil),
options.WithMaxTokens(100),
options.WithGeneratingTitle(),
)
newTeam := team.New(
team.WithAgents(agent.New("root", systemPrompt, agent.WithModel(titleModel))),
)
titleSession := session.New(
session.WithUserMessage(userPrompt),
session.WithTitle("Generating title..."),
)

titleRuntime, err := New(newTeam, WithSessionCompaction(false))
if err != nil {
slog.Error("Failed to create title generator runtime", "error", err)
return
}

// Run the title generation (this will be a simple back-and-forth)
_, err = titleRuntime.Run(ctx, titleSession)
if err != nil {
slog.Error("Failed to generate session title", "session_id", sess.ID, "error", err)
return
}

// Get the generated title from the last assistant message
title := titleSession.GetLastAssistantMessageContent()
if title == "" {
return
}
// Truncate title to 50 characters with ellipsis if needed
title = truncateTitle(title, 50)
sess.Title = title
slog.Debug("Generated session title", "session_id", sess.ID, "title", title)
events <- SessionTitle(sess.ID, title, r.currentAgent)
}

// Summarize generates a summary for the session based on the conversation history
func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, events chan Event) {
slog.Debug("Generating summary for session", "session_id", sess.ID)
Expand Down
67 changes: 2 additions & 65 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func TestGetTools_WarningHandling(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
root := agent.New("root", "test", agent.WithToolSets(tt.toolsets...))
root := agent.New("root", "test", agent.WithToolSets(tt.toolsets...), agent.WithModel(&mockProvider{}))
tm := team.New(team.WithAgents(root))
rt, err := New(tm, WithModelStore(mockModelStore{}))
require.NoError(t, err)
Expand Down Expand Up @@ -769,7 +769,7 @@ func TestSummarize_EmptySession(t *testing.T) {

func TestProcessToolCalls_UnknownTool_NoToolResultMessage(t *testing.T) {
// Build a runtime with a simple agent but no tools registered matching the call
root := agent.New("root", "You are a test agent")
root := agent.New("root", "You are a test agent", agent.WithModel(&mockProvider{}))
tm := team.New(team.WithAgents(root))

rt, err := New(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}))
Expand Down Expand Up @@ -856,66 +856,3 @@ func TestEmitStartupInfo(t *testing.T) {
// Should be empty due to deduplication
require.Empty(t, collectedEvents2, "EmitStartupInfo should not emit duplicate events")
}

func TestTruncateTitle(t *testing.T) {
tests := []struct {
name string
title string
maxLength int
expected string
}{
{
name: "title shorter than max length",
title: "Short title",
maxLength: 50,
expected: "Short title",
},
{
name: "title exactly at max length",
title: "This is exactly fifty characters in length now.",
maxLength: 50,
expected: "This is exactly fifty characters in length now.",
},
{
name: "title longer than max length",
title: "This is a very long title that exceeds the maximum character limit",
maxLength: 50,
expected: "This is a very long title that exceeds the maxi...",
},
{
name: "very short max length",
title: "Any title",
maxLength: 5,
expected: "An...",
},
{
name: "max length less than 3",
title: "Any title",
maxLength: 2,
expected: "...",
},
{
name: "empty title",
title: "",
maxLength: 50,
expected: "",
},
{
name: "title with unicode characters",
title: "こんにちは、これは日本語のタイトルです。とても長いタイトルなので切り捨てられるはずです。",
maxLength: 50,
expected: "こんにちは、これは日本語のタイトルです。とても長いタイトルなので切り捨てられるはずです。"[:47] + "...",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateTitle(tt.title, tt.maxLength)
require.Equal(t, tt.expected, result)
// Only check length constraint if maxLength >= 3 (otherwise ellipsis alone is 3 chars)
if tt.maxLength >= 3 {
require.LessOrEqual(t, len(result), tt.maxLength)
}
})
}
}
89 changes: 89 additions & 0 deletions pkg/runtime/title_generator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package runtime

import (
"context"
"fmt"
"log/slog"
"sync"

"github.com/docker/cagent/pkg/agent"
"github.com/docker/cagent/pkg/model/provider"
"github.com/docker/cagent/pkg/model/provider/options"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/team"
)

const (
titleSystemPrompt = "You are a helpful AI assistant that generates concise, descriptive titles for conversations. You will be given a conversation history and asked to create a title that captures the main topic."
titleUserPromptFormat = "Based on the following message a user sent to an AI assistant, generate a short, descriptive title (maximum 50 characters) that captures the main topic or purpose of the conversation. Return ONLY the title text, nothing else.\n\nUser message: %s\n\n"
)

type titleGenerator struct {
wg sync.WaitGroup
model provider.Provider
}

func newTitleGenerator(model provider.Provider) *titleGenerator {
return &titleGenerator{
model: model,
}
}

func (t *titleGenerator) Generate(ctx context.Context, sess *session.Session, events chan<- Event) {
t.wg.Go(func() {
t.generate(ctx, sess, events)
})
}

func (t *titleGenerator) Wait() {
t.wg.Wait()
}

func (t *titleGenerator) generate(ctx context.Context, sess *session.Session, events chan<- Event) {
slog.Debug("Generating title for session", "session_id", sess.ID)

firstUserMessage := sess.GetLastUserMessageContent()
if firstUserMessage == "" {
return
}

userPrompt := fmt.Sprintf(titleUserPromptFormat, firstUserMessage)

titleModel := provider.CloneWithOptions(
ctx,
t.model,
options.WithStructuredOutput(nil),
options.WithMaxTokens(20),
options.WithGeneratingTitle(),
)

newTeam := team.New(
team.WithAgents(agent.New("root", titleSystemPrompt, agent.WithModel(titleModel))),
)

titleSession := session.New(
session.WithUserMessage(userPrompt),
session.WithTitle("Generating title..."),
)

titleRuntime, err := New(newTeam, WithSessionCompaction(false))
if err != nil {
slog.Error("Failed to create title generator runtime", "error", err)
return
}

_, err = titleRuntime.Run(ctx, titleSession)
if err != nil {
slog.Error("Failed to generate session title", "session_id", sess.ID, "error", err)
return
}

title := titleSession.GetLastAssistantMessageContent()
if title == "" {
return
}

sess.Title = title
slog.Debug("Generated session title", "session_id", sess.ID, "title", title)
events <- SessionTitle(sess.ID, title)
}
16 changes: 16 additions & 0 deletions pkg/team/team.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/docker/cagent/pkg/agent"
"github.com/docker/cagent/pkg/model/provider"
"github.com/docker/cagent/pkg/rag"
)

Expand Down Expand Up @@ -66,6 +67,21 @@ func (t *Team) Agent(name string) (*agent.Agent, error) {
return found, nil
}

func (t *Team) Model() provider.Provider {
root, err := t.Agent("root")
if err == nil {
return root.Model()
}

for _, agentName := range t.AgentNames() {
a, err := t.Agent(agentName)
if err == nil {
return a.Model()
}
}
return nil
}

func (t *Team) Size() int {
return len(t.agents)
}
Expand Down