diff --git a/desktop/Backend-Rust/src/routes/proxy.rs b/desktop/Backend-Rust/src/routes/proxy.rs index 0c14a5ee8dc..d0eef5a2e63 100644 --- a/desktop/Backend-Rust/src/routes/proxy.rs +++ b/desktop/Backend-Rust/src/routes/proxy.rs @@ -40,6 +40,12 @@ const GEMINI_MAX_BODY_SIZE: usize = 5 * 1024 * 1024; /// App uses 8192 (GeminiClient.swift:553,922,1026). const MAX_OUTPUT_TOKENS_CAP: u64 = 8192; +/// Default thinking budget injected when client omits thinkingConfig. +/// Gemini 2.5 Flash thinking output costs $3.50/M vs $0.60/M regular (5.8x). +/// 1024 tokens caps reasoning for old clients; current Swift client sends budget=0 +/// explicitly on all production paths. +const DEFAULT_THINKING_BUDGET: u64 = 1024; + /// Proxy-specific error type — allows JSON 429 responses alongside bare status codes. enum ProxyError { Status(StatusCode), @@ -399,6 +405,7 @@ fn is_gemini_model_allowed(model: &str) -> bool { /// - Cap generation_config.max_output_tokens to MAX_OUTPUT_TOKENS_CAP /// - Reject candidate_count > 1 /// - Strip safety_settings and cached_content +/// - Inject default thinkingConfig if absent (cost control for Gemini 2.5) /// - Preserve all other fields (contents, system_instruction, tools, etc.) /// /// For embedContent/batchEmbedContents: @@ -502,8 +509,11 @@ fn sanitize_gemini_body(body: &[u8], action: &str) -> Result, String> { // Validate inside generation_config / generationConfig. // Check BOTH casings to prevent dual-key bypass where an attacker // sends an empty generation_config + a real generationConfig. + let mut found_generation_config = false; for gc_key in &["generation_config", "generationConfig"] { if let Some(gc) = obj.get_mut(*gc_key).and_then(|v| v.as_object_mut()) { + found_generation_config = true; + // Reject candidate_count > 1 for cc_key in &["candidate_count", "candidateCount"] { if let Some(v) = gc.get(*cc_key) { @@ -525,8 +535,29 @@ fn sanitize_gemini_body(body: &[u8], action: &str) -> Result, String> { } } } + + // Defense-in-depth: inject default thinking budget if client omits it. + // Gemini 2.5 Flash defaults to unlimited thinking which is 5.8x more + // expensive than regular output tokens. Cap at 1024 when absent. + let has_thinking = gc.contains_key("thinking_config") + || gc.contains_key("thinkingConfig"); + if !has_thinking { + gc.insert( + "thinkingConfig".to_string(), + serde_json::json!({"thinkingBudget": DEFAULT_THINKING_BUDGET}), + ); + } } } + + // If no generation_config exists at all (legacy clients), create one + // with the default thinking budget to prevent unlimited thinking spend. + if !found_generation_config { + obj.insert( + "generationConfig".to_string(), + serde_json::json!({"thinkingConfig": {"thinkingBudget": DEFAULT_THINKING_BUDGET}}), + ); + } } serde_json::to_vec(&json).map_err(|e| format!("failed to re-serialize: {}", e)) @@ -1548,4 +1579,151 @@ mod tests { // Vertex AI uses Bearer auth, not API key in URL — but query params are forwarded as-is assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com")); } + + // --- Thinking budget injection --- + + #[test] + fn sanitize_injects_thinking_budget_when_absent() { + let body = serde_json::json!({ + "contents": [{"role": "user", "parts": [{"text": "hello"}]}], + "generation_config": {"max_output_tokens": 1000} + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + let gc = &parsed["generation_config"]; + assert_eq!( + gc["thinkingConfig"]["thinkingBudget"], + serde_json::json!(DEFAULT_THINKING_BUDGET), + "should inject default thinking budget when absent" + ); + } + + #[test] + fn sanitize_preserves_existing_thinking_config_snake() { + let body = serde_json::json!({ + "contents": [{"role": "user", "parts": [{"text": "hello"}]}], + "generation_config": { + "max_output_tokens": 1000, + "thinking_config": {"thinking_budget": 0} + } + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + let gc = &parsed["generation_config"]; + // Should not overwrite existing thinking_config + assert_eq!(gc["thinking_config"]["thinking_budget"], serde_json::json!(0)); + assert!(gc.get("thinkingConfig").is_none(), "should not inject when thinking_config present"); + } + + #[test] + fn sanitize_preserves_existing_thinking_config_camel() { + let body = serde_json::json!({ + "contents": [{"role": "user", "parts": [{"text": "hello"}]}], + "generationConfig": { + "maxOutputTokens": 1000, + "thinkingConfig": {"thinkingBudget": 4096} + } + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + let gc = &parsed["generationConfig"]; + assert_eq!(gc["thinkingConfig"]["thinkingBudget"], serde_json::json!(4096)); + } + + #[test] + fn sanitize_no_thinking_injection_for_embed() { + let body = serde_json::json!({ + "content": {"parts": [{"text": "hello"}]} + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "embedContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + // Embed requests have no generation_config, so no injection + assert!(parsed.get("thinkingConfig").is_none()); + assert!(parsed.get("generation_config").is_none()); + } + + #[test] + fn sanitize_injects_generation_config_when_absent() { + // Legacy clients may send only contents with no generation_config at all. + // Proxy must create generationConfig with thinking budget to prevent + // unlimited thinking spend. + let body = serde_json::json!({ + "contents": [{"parts": [{"text": "hello"}]}] + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + let gc = parsed.get("generationConfig").expect("generationConfig should be created"); + let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected"); + assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET); + } + + #[test] + fn sanitize_dual_generation_config_both_get_thinking() { + // Attacker sends both casings — both should get thinking budget injected. + let body = serde_json::json!({ + "contents": [{"parts": [{"text": "hello"}]}], + "generation_config": {"max_output_tokens": 100}, + "generationConfig": {"maxOutputTokens": 200} + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + // Both casings should have thinkingConfig injected + let gc_snake = parsed.get("generation_config").unwrap().as_object().unwrap(); + assert!(gc_snake.contains_key("thinkingConfig")); + let gc_camel = parsed.get("generationConfig").unwrap().as_object().unwrap(); + assert!(gc_camel.contains_key("thinkingConfig")); + } + + #[test] + fn sanitize_null_generation_config_gets_new_one() { + // Malformed: generation_config is null — proxy should create a fresh one. + let body = serde_json::json!({ + "contents": [{"parts": [{"text": "hello"}]}], + "generation_config": null + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + // null generation_config is not an object, so proxy creates generationConfig + let gc = parsed.get("generationConfig").expect("generationConfig should be created"); + let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected"); + assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET); + } + + #[test] + fn sanitize_string_generation_config_gets_new_one() { + // Malformed: generation_config is a string — proxy should create a fresh one. + let body = serde_json::json!({ + "contents": [{"parts": [{"text": "hello"}]}], + "generation_config": "invalid" + }); + let result = sanitize_gemini_body( + serde_json::to_vec(&body).unwrap().as_slice(), + "generateContent", + ).unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap(); + let gc = parsed.get("generationConfig").expect("generationConfig should be created"); + let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected"); + assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET); + } } diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json index dc7db3d037d..7824b87d122 100644 --- a/desktop/CHANGELOG.json +++ b/desktop/CHANGELOG.json @@ -1,5 +1,7 @@ { - "unreleased": [], + "unreleased": [ + "Reduced AI processing costs with thinking budget controls" + ], "releases": [ { "version": "0.11.378", diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift index f7a9414f9c8..0497da584d7 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift @@ -590,7 +590,8 @@ actor InsightAssistant: ProactiveAssistant { contents: iterContents, systemPrompt: iterSystemPrompt, tools: iterTools, - forceToolCall: iterForce + forceToolCall: iterForce, + thinkingBudget: 1024 ) } } catch { @@ -739,7 +740,8 @@ actor InsightAssistant: ProactiveAssistant { contents: p2Contents, systemPrompt: p2SystemPrompt, tools: p2Tools, - forceToolCall: p2Force + forceToolCall: p2Force, + thinkingBudget: 1024 ) } } catch { diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 3ced48065e3..61bcd71d5be 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -778,7 +778,8 @@ actor TaskAssistant: ProactiveAssistant { contents: contents, systemPrompt: currentSystemPrompt, tools: [tools], - forceToolCall: iteration == 0 + forceToolCall: iteration == 0, + thinkingBudget: 1024 ) guard let toolCall = result.toolCalls.first else { diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index e67d35a8963..f9d48a6c237 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -1,5 +1,24 @@ import Foundation +// MARK: - Thinking Budget Configuration + +/// Controls how many tokens Gemini 2.5 spends on internal reasoning. +/// Budget 0 disables thinking (cheapest). Budget -1 = dynamic (model decides). +/// Flash range: 0–24576. Pro range: 128–32768. +struct ThinkingConfig: Encodable { + let thinkingBudget: Int + + enum CodingKeys: String, CodingKey { + case thinkingBudget = "thinking_budget" + } + + /// Minimum thinking budget that disables or minimizes reasoning for a given model. + /// Flash supports 0 (fully off). Pro requires at least 128. + static func minimumBudget(for model: String) -> Int { + model.contains("pro") ? 128 : 0 + } +} + // MARK: - Gemini API Request/Response Types struct GeminiRequest: Encodable { @@ -56,12 +75,14 @@ struct GeminiRequest: Encodable { } struct GenerationConfig: Encodable { - let responseMimeType: String + let responseMimeType: String? let responseSchema: ResponseSchema? + let thinkingConfig: ThinkingConfig? enum CodingKeys: String, CodingKey { case responseMimeType = "response_mime_type" case responseSchema = "response_schema" + case thinkingConfig = "thinking_config" } struct ResponseSchema: Encodable { @@ -248,15 +269,6 @@ actor GeminiClient { URL(string: "\(Self.proxyBaseURL)v1/proxy/gemini/models/\(model):\(action)")! } - /// Build streaming proxy URL for a Gemini model action - private func streamProxyURL(action: String, queryParams: [String: String] = [:]) -> URL { - var urlString = "\(Self.proxyBaseURL)v1/proxy/gemini-stream/models/\(model):\(action)" - if !queryParams.isEmpty { - let params = queryParams.map { "\($0.key)=\($0.value)" }.joined(separator: "&") - urlString += "?\(params)" - } - return URL(string: urlString)! - } /// Log the raw API error message for debugging and throw a sanitized error. /// The `errorDescription` on GeminiClientError is user-friendly; this log preserves the raw detail. @@ -334,7 +346,8 @@ actor GeminiClient { prompt: String, imageData: Data, systemPrompt: String, - responseSchema: GeminiRequest.GenerationConfig.ResponseSchema + responseSchema: GeminiRequest.GenerationConfig.ResponseSchema, + thinkingBudget: Int = 0 ) async throws -> String { let maxRetries = 2 var lastError: Error? @@ -359,7 +372,8 @@ actor GeminiClient { ), generationConfig: GeminiRequest.GenerationConfig( responseMimeType: "application/json", - responseSchema: responseSchema + responseSchema: responseSchema, + thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ) ) @@ -417,7 +431,8 @@ actor GeminiClient { prompt: String, systemPrompt: String, maxRetries: Int = 2, - timeout: TimeInterval = 300 + timeout: TimeInterval = 300, + thinkingBudget: Int = 0 ) async throws -> String { var lastError: Error? @@ -432,7 +447,11 @@ actor GeminiClient { systemInstruction: GeminiRequest.SystemInstruction( parts: [GeminiRequest.SystemInstruction.TextPart(text: systemPrompt)] ), - generationConfig: nil + generationConfig: GeminiRequest.GenerationConfig( + responseMimeType: nil, + responseSchema: nil, + thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + ) ) let url = proxyURL(action: "generateContent") @@ -482,7 +501,8 @@ actor GeminiClient { func sendRequest( prompt: String, systemPrompt: String, - responseSchema: GeminiRequest.GenerationConfig.ResponseSchema + responseSchema: GeminiRequest.GenerationConfig.ResponseSchema, + thinkingBudget: Int = 0 ) async throws -> String { let maxRetries = 2 var lastError: Error? @@ -500,7 +520,8 @@ actor GeminiClient { ), generationConfig: GeminiRequest.GenerationConfig( responseMimeType: "application/json", - responseSchema: responseSchema + responseSchema: responseSchema, + thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ) ) @@ -541,153 +562,8 @@ actor GeminiClient { throw lastError! } - /// Send a multi-turn chat request with streaming response - /// - Parameters: - /// - messages: Array of chat messages (role: user/model, text) - /// - systemPrompt: System instructions for the model - /// - onChunk: Callback for each text chunk received - /// - Returns: The complete text response - func sendChatStreamRequest( - messages: [ChatMessage], - systemPrompt: String, - onChunk: @escaping (String) -> Void - ) async throws -> String { - // Build contents from chat messages - let contents = messages.map { message in - GeminiChatRequest.Content( - role: message.role, - parts: [GeminiChatRequest.Part(text: message.text)] - ) - } - - let request = GeminiChatRequest( - contents: contents, - systemInstruction: GeminiChatRequest.SystemInstruction( - parts: [GeminiChatRequest.SystemInstruction.TextPart(text: systemPrompt)] - ), - generationConfig: GeminiChatRequest.GenerationConfig( - temperature: 0.7, - maxOutputTokens: 8192 - ) - ) - - // Use streamGenerateContent endpoint for streaming (via backend proxy) - let url = streamProxyURL(action: "streamGenerateContent", queryParams: ["alt": "sse"]) - var urlRequest = URLRequest(url: url) - urlRequest.httpMethod = "POST" - urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") - urlRequest.timeoutInterval = 300 - urlRequest.httpBody = try JSONEncoder().encode(request) - - var fullText = "" - - // Use URLSession bytes for streaming - let (bytes, response) = try await URLSession.shared.bytes(for: urlRequest) - - guard let httpResponse = response as? HTTPURLResponse else { - throw GeminiClientError.invalidResponse - } - - if httpResponse.statusCode != 200 { - throw GeminiClientError.apiError("HTTP \(httpResponse.statusCode)") - } - - // Parse SSE stream - for try await line in bytes.lines { - // SSE format: "data: {json}" - if line.hasPrefix("data: ") { - let jsonString = String(line.dropFirst(6)) - if let data = jsonString.data(using: .utf8) { - if let chunk = try? JSONDecoder().decode(GeminiStreamChunk.self, from: data) { - if let text = chunk.candidates?.first?.content?.parts?.first?.text { - fullText += text - onChunk(text) - } - } - } - } - } - - return fullText - } - - /// Chat message for multi-turn conversation - struct ChatMessage { - let role: String // "user" or "model" - let text: String - } } -// MARK: - Gemini Chat Request (multi-turn with roles) - -struct GeminiChatRequest: Encodable { - let contents: [Content] - let systemInstruction: SystemInstruction? - let generationConfig: GenerationConfig? - - enum CodingKeys: String, CodingKey { - case contents - case systemInstruction = "system_instruction" - case generationConfig = "generation_config" - } - - struct Content: Encodable { - let role: String // "user" or "model" - let parts: [Part] - } - - struct Part: Encodable { - let text: String - } - - struct SystemInstruction: Encodable { - let parts: [TextPart] - - struct TextPart: Encodable { - let text: String - } - } - - struct GenerationConfig: Encodable { - let temperature: Double? - let maxOutputTokens: Int? - - enum CodingKeys: String, CodingKey { - case temperature - case maxOutputTokens = "max_output_tokens" - } - } -} - -// MARK: - Gemini Stream Chunk Response - -struct GeminiStreamChunk: Decodable { - let candidates: [Candidate]? - - struct Candidate: Decodable { - let content: Content? - - struct Content: Decodable { - let parts: [Part]? - - struct Part: Decodable { - let text: String? - let functionCall: FunctionCall? - - enum CodingKeys: String, CodingKey { - case text - case functionCall = "functionCall" - } - } - } - } - - struct FunctionCall: Decodable { - let name: String - let args: [String: AnyCodable]? - } -} // MARK: - Tool Calling Support @@ -764,100 +640,12 @@ struct GeminiTool: Encodable { } } -/// Chat request with tools -struct GeminiToolChatRequest: Encodable { - let contents: [Content] - let systemInstruction: SystemInstruction? - let generationConfig: GenerationConfig? - let tools: [GeminiTool]? - - enum CodingKeys: String, CodingKey { - case contents - case systemInstruction = "system_instruction" - case generationConfig = "generation_config" - case tools - } - - struct Content: Encodable { - let role: String - let parts: [Part] - } - - struct Part: Encodable { - let text: String? - let functionCall: FunctionCallPart? - let functionResponse: FunctionResponsePart? - let thoughtSignature: String? - - enum CodingKeys: String, CodingKey { - case text - case functionCall = "functionCall" - case functionResponse = "functionResponse" - case thoughtSignature = "thought_signature" - } - - init(text: String) { - self.text = text - self.functionCall = nil - self.functionResponse = nil - self.thoughtSignature = nil - } - - init(functionResponse: FunctionResponsePart) { - self.text = nil - self.functionCall = nil - self.functionResponse = functionResponse - self.thoughtSignature = nil - } - - init(functionCall: FunctionCallPart, thoughtSignature: String? = nil) { - self.text = nil - self.functionCall = functionCall - self.functionResponse = nil - self.thoughtSignature = thoughtSignature - } - } - - struct FunctionCallPart: Encodable { - let name: String - let args: [String: String] - } - - struct FunctionResponsePart: Encodable { - let name: String - let response: ResponseContent - - struct ResponseContent: Encodable { - let result: String - } - } - - struct SystemInstruction: Encodable { - let parts: [TextPart] - - struct TextPart: Encodable { - let text: String - } - } - - struct GenerationConfig: Encodable { - let temperature: Double? - let maxOutputTokens: Int? - - enum CodingKeys: String, CodingKey { - case temperature - case maxOutputTokens = "max_output_tokens" - } - } -} /// Result of a tool-enabled chat (may include tool calls) struct ToolChatResult { let text: String let toolCalls: [ToolCall] let requiresToolExecution: Bool - /// Accumulated conversation contents for multi-turn tool loops - var contents: [GeminiToolChatRequest.Content]? } /// A function call from the model @@ -867,252 +655,32 @@ struct ToolCall { let thoughtSignature: String? } -// MARK: - GeminiClient Tool Extensions - -extension GeminiClient { - - /// Available chat tools - static let chatTools: [GeminiTool] = [ - GeminiTool(functionDeclarations: [ - // Execute SQL on local omi.db - GeminiTool.FunctionDeclaration( - name: "execute_sql", - description: - "Execute a SQL query on the local omi.db database. Supports SELECT, INSERT, UPDATE, DELETE. Use this for any structured data lookup — app usage, screenshots, tasks, conversations, time-based queries, aggregations, etc. The system prompt contains the full database schema.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init( - type: "string", - description: - "SQL query to execute. SELECT queries auto-limit to 200 rows. UPDATE/DELETE require WHERE clause. DROP/ALTER/CREATE are blocked." - ) - ], - required: ["query"] - ) - ), - // Semantic vector search - GeminiTool.FunctionDeclaration( - name: "semantic_search", - description: - "Search screen history using semantic similarity (vector embeddings). Use this for fuzzy conceptual queries where exact keywords won't work — e.g. 'reading about machine learning', 'working on design mockups', 'chatting with friends'. Returns screenshots ranked by semantic similarity.", - parameters: GeminiTool.FunctionDeclaration.Parameters( - type: "object", - properties: [ - "query": .init( - type: "string", description: "Natural language description of what to search for."), - "days": .init( - type: "integer", - description: "Search the last N days (default: 7). Use 1 for today only."), - "app_filter": .init( - type: "string", - description: "Optional: filter by app name (e.g., 'Google Chrome', 'Cursor', 'Slack')" - ), - ], - required: ["query"] - ) - ), - ]) - ] - - /// Send a chat request with tool support (non-streaming) - func sendToolChatRequest( - messages: [ChatMessage], - systemPrompt: String, - tools: [GeminiTool]? = nil - ) async throws -> ToolChatResult { - // Build contents from chat messages - let contents = messages.map { message in - GeminiToolChatRequest.Content( - role: message.role, - parts: [GeminiToolChatRequest.Part(text: message.text)] - ) - } - - let request = GeminiToolChatRequest( - contents: contents, - systemInstruction: GeminiToolChatRequest.SystemInstruction( - parts: [GeminiToolChatRequest.SystemInstruction.TextPart(text: systemPrompt)] - ), - generationConfig: GeminiToolChatRequest.GenerationConfig( - temperature: 0.7, - maxOutputTokens: 8192 - ), - tools: tools ?? Self.chatTools - ) - - let url = proxyURL(action: "generateContent") - var urlRequest = URLRequest(url: url) - urlRequest.httpMethod = "POST" - urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") - urlRequest.timeoutInterval = 300 - urlRequest.httpBody = try JSONEncoder().encode(request) - - let (data, urlResponse) = try await URLSession.shared.data(for: urlRequest) - try checkHTTPStatus(urlResponse, data: data) - - // Parse response - let response = try JSONDecoder().decode(GeminiToolResponse.self, from: data) - - if let error = response.error { - try throwAPIError(error.message) - } - - guard let candidate = response.candidates?.first, - let parts = candidate.content?.parts - else { - try throwBlockedOrInvalidResponse( - blockReason: response.promptFeedback?.blockReason, - finishReason: response.candidates?.first?.finishReason - ) - } - - // Check for function calls - var toolCalls: [ToolCall] = [] - var textResponse = "" - - for part in parts { - if let functionCall = part.functionCall { - let args = functionCall.args?.mapValues { $0.value } ?? [:] - toolCalls.append( - ToolCall( - name: functionCall.name, arguments: args, thoughtSignature: part.thoughtSignature)) - } - if let text = part.text { - textResponse += text - } - } - - return ToolChatResult( - text: textResponse, - toolCalls: toolCalls, - requiresToolExecution: !toolCalls.isEmpty, - contents: contents - ) - } - - /// Continue a conversation after executing tools - /// Returns ToolChatResult so the caller can check if more tool calls are needed (multi-turn loop) - func continueWithToolResults( - previousContents: [GeminiToolChatRequest.Content], - toolCalls: [ToolCall], - toolResults: [String: String], - systemPrompt: String, - tools: [GeminiTool]? = nil - ) async throws -> ToolChatResult { - var contents = previousContents - - // Add the model's function call as a model turn - var functionCallParts: [GeminiToolChatRequest.Part] = [] - for call in toolCalls { - functionCallParts.append( - GeminiToolChatRequest.Part( - functionCall: GeminiToolChatRequest.FunctionCallPart( - name: call.name, - args: call.arguments.mapValues { "\($0)" } - ), - thoughtSignature: call.thoughtSignature - )) - } - contents.append(GeminiToolChatRequest.Content(role: "model", parts: functionCallParts)) - - // Add function responses - for call in toolCalls { - let result = toolResults[call.name] ?? "No result" - contents.append( - GeminiToolChatRequest.Content( - role: "function", - parts: [ - GeminiToolChatRequest.Part( - functionResponse: GeminiToolChatRequest.FunctionResponsePart( - name: call.name, - response: .init(result: result) - ) - ) - ] - )) - } - - let request = GeminiToolChatRequest( - contents: contents, - systemInstruction: GeminiToolChatRequest.SystemInstruction( - parts: [GeminiToolChatRequest.SystemInstruction.TextPart(text: systemPrompt)] - ), - generationConfig: GeminiToolChatRequest.GenerationConfig( - temperature: 0.7, - maxOutputTokens: 8192 - ), - tools: tools - ) - - let url = proxyURL(action: "generateContent") - var urlRequest = URLRequest(url: url) - urlRequest.httpMethod = "POST" - urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") - urlRequest.timeoutInterval = 300 - urlRequest.httpBody = try JSONEncoder().encode(request) - - let (data, urlResponse) = try await URLSession.shared.data(for: urlRequest) - try checkHTTPStatus(urlResponse, data: data) - - let response = try JSONDecoder().decode(GeminiToolResponse.self, from: data) - - if let error = response.error { - try throwAPIError(error.message) - } - - guard let candidate = response.candidates?.first, - let parts = candidate.content?.parts - else { - try throwBlockedOrInvalidResponse( - blockReason: response.promptFeedback?.blockReason, - finishReason: response.candidates?.first?.finishReason - ) - } - - // Check for more function calls or text - var newToolCalls: [ToolCall] = [] - var textResponse = "" - - for part in parts { - if let functionCall = part.functionCall { - let args = functionCall.args?.mapValues { $0.value } ?? [:] - newToolCalls.append( - ToolCall( - name: functionCall.name, arguments: args, thoughtSignature: part.thoughtSignature)) - } - if let text = part.text { - textResponse += text - } - } - - return ToolChatResult( - text: textResponse, - toolCalls: newToolCalls, - requiresToolExecution: !newToolCalls.isEmpty, - contents: contents - ) - } -} - // MARK: - Image + Tool Calling Request /// Request type combining image analysis with tool calling struct GeminiImageToolRequest: Encodable { let contents: [Content] let systemInstruction: SystemInstruction? + let generationConfig: GenerationConfig? let tools: [GeminiTool]? let toolConfig: ToolConfig? enum CodingKeys: String, CodingKey { case contents case systemInstruction = "system_instruction" + case generationConfig = "generation_config" case tools case toolConfig = "tool_config" } + struct GenerationConfig: Encodable { + let thinkingConfig: ThinkingConfig? + + enum CodingKeys: String, CodingKey { + case thinkingConfig = "thinking_config" + } + } + struct Content: Encodable { let role: String let parts: [Part] @@ -1215,116 +783,17 @@ struct GeminiImageToolRequest: Encodable { extension GeminiClient { - /// Send image + text + tools, returns the model's function call - /// Retries up to 2 times for transient errors. - func sendImageToolRequest( - prompt: String, - imageData: Data, - systemPrompt: String, - tools: [GeminiTool], - forceToolCall: Bool = true - ) async throws -> ToolChatResult { - let maxRetries = 2 - var lastError: Error? - - for attempt in 0...maxRetries { - do { - // Wrap base64 encoding + JSON serialization in autoreleasepool. - // See sendRequest() comment for rationale. - let requestBody: Data = try autoreleasepool { - let base64Data = imageData.base64EncodedString() - - let toolConfig = - forceToolCall - ? GeminiImageToolRequest.ToolConfig( - functionCallingConfig: .init(mode: "ANY") - ) : nil - - let request = GeminiImageToolRequest( - contents: [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64Data), - ] - ) - ], - systemInstruction: GeminiImageToolRequest.SystemInstruction( - parts: [.init(text: systemPrompt)] - ), - tools: tools, - toolConfig: toolConfig - ) - - return try JSONEncoder().encode(request) - } - - let url = proxyURL(action: "generateContent") - var urlRequest = URLRequest(url: url) - urlRequest.httpMethod = "POST" - urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") - urlRequest.timeoutInterval = 300 - urlRequest.httpBody = requestBody - - let (data, urlResponse) = try await URLSession.shared.data(for: urlRequest) - try checkHTTPStatus(urlResponse, data: data) - - let response = try JSONDecoder().decode(GeminiToolResponse.self, from: data) - - if let error = response.error { - try throwAPIError(error.message) - } - - guard let candidate = response.candidates?.first, - let parts = candidate.content?.parts - else { - try throwBlockedOrInvalidResponse( - blockReason: response.promptFeedback?.blockReason, - finishReason: response.candidates?.first?.finishReason - ) - } - - var toolCalls: [ToolCall] = [] - var textResponse = "" - - for part in parts { - if let functionCall = part.functionCall { - let args = functionCall.args?.mapValues { $0.value } ?? [:] - toolCalls.append( - ToolCall( - name: functionCall.name, arguments: args, thoughtSignature: part.thoughtSignature)) - } - if let text = part.text { - textResponse += text - } - } - - return ToolChatResult( - text: textResponse, - toolCalls: toolCalls, - requiresToolExecution: !toolCalls.isEmpty - ) - } catch { - lastError = error - guard attempt < maxRetries && isTransientError(error) else { - throw error - } - await retryBackoff(attempt: attempt, error: error) - } - } - - throw lastError! - } - /// Send image + tool loop request: takes pre-built contents array for multi-turn tool calling. /// Retries up to 2 times for transient errors. + /// - Parameter thinkingBudget: Token budget for model reasoning. Tool-calling features that need + /// multi-step reasoning (e.g. InsightAssistant SQL generation, TaskAssistant screen analysis) + /// should pass a reasonable budget (e.g. 1024). Default 0 = minimal thinking. func sendImageToolLoop( contents: [GeminiImageToolRequest.Content], systemPrompt: String, tools: [GeminiTool], - forceToolCall: Bool = false + forceToolCall: Bool = false, + thinkingBudget: Int = 0 ) async throws -> ToolChatResult { let maxRetries = 2 var lastError: Error? @@ -1345,6 +814,9 @@ extension GeminiClient { systemInstruction: GeminiImageToolRequest.SystemInstruction( parts: [.init(text: systemPrompt)] ), + generationConfig: GeminiImageToolRequest.GenerationConfig( + thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + ), tools: tools, toolConfig: toolConfig ) @@ -1410,105 +882,6 @@ extension GeminiClient { throw lastError! } - /// Continue conversation after tool execution: sends full history + tool result, returns text - /// No tools on continuation — model returns plain JSON guided by system prompt. - func continueImageToolRequest( - originalPrompt: String, - originalImageData: Data, - toolCall: ToolCall, - toolResult: String, - systemPrompt: String - ) async throws -> String { - let maxRetries = 2 - var lastError: Error? - - for attempt in 0...maxRetries { - do { - // Wrap base64 encoding + JSON serialization in autoreleasepool. - // See sendRequest() comment for rationale. - let requestBody: Data = try autoreleasepool { - let base64Data = originalImageData.base64EncodedString() - - let contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: originalPrompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64Data), - ] - ), - GeminiImageToolRequest.Content( - role: "model", - parts: [ - GeminiImageToolRequest.Part( - functionCall: .init( - name: toolCall.name, - args: toolCall.arguments.compactMapValues { "\($0)" } - ), - thoughtSignature: toolCall.thoughtSignature - ) - ] - ), - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part( - functionResponse: .init( - name: toolCall.name, - response: .init(result: toolResult) - )) - ] - ), - ] - - let request = GeminiImageToolRequest( - contents: contents, - systemInstruction: GeminiImageToolRequest.SystemInstruction( - parts: [.init(text: systemPrompt)] - ), - tools: nil, - toolConfig: nil - ) - - return try JSONEncoder().encode(request) - } - - let url = proxyURL(action: "generateContent") - var urlRequest = URLRequest(url: url) - urlRequest.httpMethod = "POST" - urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") - urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") - urlRequest.timeoutInterval = 300 - urlRequest.httpBody = requestBody - - let (data, urlResponse) = try await URLSession.shared.data(for: urlRequest) - try checkHTTPStatus(urlResponse, data: data) - - let response = try JSONDecoder().decode(GeminiToolResponse.self, from: data) - - if let error = response.error { - try throwAPIError(error.message) - } - - guard let text = response.candidates?.first?.content?.parts?.first?.text else { - try throwBlockedOrInvalidResponse( - blockReason: response.promptFeedback?.blockReason, - finishReason: response.candidates?.first?.finishReason - ) - } - - return text - } catch { - lastError = error - guard attempt < maxRetries && isTransientError(error) else { - throw error - } - await retryBackoff(attempt: attempt, error: error) - } - } - - throw lastError! - } } /// Response type for tool-enabled requests diff --git a/desktop/Desktop/Tests/ThinkingBudgetTests.swift b/desktop/Desktop/Tests/ThinkingBudgetTests.swift new file mode 100644 index 00000000000..21b6707eda1 --- /dev/null +++ b/desktop/Desktop/Tests/ThinkingBudgetTests.swift @@ -0,0 +1,106 @@ +import XCTest +@testable import Omi_Computer + +final class ThinkingBudgetTests: XCTestCase { + + // MARK: - ThinkingConfig.minimumBudget(for:) + + func testFlashModelMinimumBudgetIsZero() { + XCTAssertEqual(ThinkingConfig.minimumBudget(for: "gemini-2.5-flash"), 0) + } + + func testFlashPreviewModelMinimumBudgetIsZero() { + XCTAssertEqual(ThinkingConfig.minimumBudget(for: "gemini-2.5-flash-preview-04-17"), 0) + } + + func testProModelMinimumBudgetIs128() { + XCTAssertEqual(ThinkingConfig.minimumBudget(for: "gemini-2.5-pro"), 128) + } + + func testProPreviewModelMinimumBudgetIs128() { + XCTAssertEqual(ThinkingConfig.minimumBudget(for: "gemini-2.5-pro-preview-05-06"), 128) + } + + func testUnknownModelDefaultsToZero() { + XCTAssertEqual(ThinkingConfig.minimumBudget(for: "gemini-2.0-flash"), 0) + } + + // MARK: - ThinkingConfig encoding + + func testThinkingConfigEncodesSnakeCase() throws { + let config = ThinkingConfig(thinkingBudget: 1024) + let data = try JSONEncoder().encode(config) + let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] + XCTAssertEqual(json["thinking_budget"] as? Int, 1024) + XCTAssertNil(json["thinkingBudget"], "Should use snake_case key, not camelCase") + } + + func testThinkingConfigEncodesZeroBudget() throws { + let config = ThinkingConfig(thinkingBudget: 0) + let data = try JSONEncoder().encode(config) + let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] + XCTAssertEqual(json["thinking_budget"] as? Int, 0) + } + + // MARK: - Budget floor enforcement via max() + + func testFlashBudgetZeroPassesThroughAsZero() { + let budget = max(0, ThinkingConfig.minimumBudget(for: "gemini-2.5-flash")) + XCTAssertEqual(budget, 0) + } + + func testProBudgetZeroFloorsTo128() { + let budget = max(0, ThinkingConfig.minimumBudget(for: "gemini-2.5-pro")) + XCTAssertEqual(budget, 128) + } + + func testProBudget1024StaysAt1024() { + let budget = max(1024, ThinkingConfig.minimumBudget(for: "gemini-2.5-pro")) + XCTAssertEqual(budget, 1024) + } + + func testFlashBudget1024StaysAt1024() { + let budget = max(1024, ThinkingConfig.minimumBudget(for: "gemini-2.5-flash")) + XCTAssertEqual(budget, 1024) + } + + // MARK: - GeminiRequest includes thinkingConfig in generationConfig + + func testGeminiRequestEncodesThinkingConfig() throws { + let request = GeminiRequest( + contents: [GeminiRequest.Content(parts: [GeminiRequest.Part(text: "test")])], + systemInstruction: nil, + generationConfig: GeminiRequest.GenerationConfig( + responseMimeType: "application/json", + responseSchema: nil, + thinkingConfig: ThinkingConfig(thinkingBudget: 0) + ) + ) + let data = try JSONEncoder().encode(request) + let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] + let genConfig = json["generation_config"] as! [String: Any] + let thinkingConfig = genConfig["thinking_config"] as! [String: Any] + XCTAssertEqual(thinkingConfig["thinking_budget"] as? Int, 0) + } + + // MARK: - GeminiImageToolRequest includes thinkingConfig + + func testImageToolRequestEncodesThinkingBudget() throws { + let request = GeminiImageToolRequest( + contents: [ + GeminiImageToolRequest.Content(role: "user", parts: [.init(text: "test")]) + ], + systemInstruction: GeminiImageToolRequest.SystemInstruction(parts: [.init(text: "sys")]), + generationConfig: GeminiImageToolRequest.GenerationConfig( + thinkingConfig: ThinkingConfig(thinkingBudget: 1024) + ), + tools: [], + toolConfig: nil + ) + let data = try JSONEncoder().encode(request) + let json = try JSONSerialization.jsonObject(with: data) as! [String: Any] + let genConfig = json["generation_config"] as! [String: Any] + let thinkingConfig = genConfig["thinking_config"] as! [String: Any] + XCTAssertEqual(thinkingConfig["thinking_budget"] as? Int, 1024) + } +} diff --git a/desktop/run.sh b/desktop/run.sh index a2ef0b621f3..f4f3d58d3a7 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -399,6 +399,22 @@ if [ -d "$SPARKLE_FRAMEWORK" ]; then cp -R "$SPARKLE_FRAMEWORK" "$APP_BUNDLE/Contents/Frameworks/" fi +# Copy Sentry framework +SENTRY_FRAMEWORK="Desktop/.build/arm64-apple-macosx/debug/Sentry.framework" +if [ -d "$SENTRY_FRAMEWORK" ]; then + substep "Copying Sentry framework" + rm -rf "$APP_BUNDLE/Contents/Frameworks/Sentry.framework" + cp -R "$SENTRY_FRAMEWORK" "$APP_BUNDLE/Contents/Frameworks/" +fi + +# Copy onnxruntime framework +ONNX_FRAMEWORK="Desktop/.build/arm64-apple-macosx/debug/onnxruntime.framework" +if [ -d "$ONNX_FRAMEWORK" ]; then + substep "Copying onnxruntime framework" + rm -rf "$APP_BUNDLE/Contents/Frameworks/onnxruntime.framework" + cp -R "$ONNX_FRAMEWORK" "$APP_BUNDLE/Contents/Frameworks/" +fi + # Copy libwebp dylibs and rewrite load paths WEBP_LIB="$(pkg-config --variable=libdir libwebp 2>/dev/null)/libwebp.7.dylib" if [ -f "$WEBP_LIB" ]; then @@ -553,6 +569,14 @@ if [ -n "$SIGN_IDENTITY" ]; then substep "Signing Sparkle framework" codesign --force --options runtime --sign "$SIGN_IDENTITY" "$APP_BUNDLE/Contents/Frameworks/Sparkle.framework" fi + if [ -d "$APP_BUNDLE/Contents/Frameworks/Sentry.framework" ]; then + substep "Signing Sentry framework" + codesign --force --options runtime --sign "$SIGN_IDENTITY" "$APP_BUNDLE/Contents/Frameworks/Sentry.framework" + fi + if [ -d "$APP_BUNDLE/Contents/Frameworks/onnxruntime.framework" ]; then + substep "Signing onnxruntime framework" + codesign --force --options runtime --sign "$SIGN_IDENTITY" "$APP_BUNDLE/Contents/Frameworks/onnxruntime.framework" + fi if [ -f "$APP_BUNDLE/Contents/Frameworks/libsharpyuv.0.dylib" ]; then substep "Signing libsharpyuv" codesign --force --options runtime --sign "$SIGN_IDENTITY" "$APP_BUNDLE/Contents/Frameworks/libsharpyuv.0.dylib"