Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions desktop/Backend-Rust/src/routes/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ const GEMINI_MAX_BODY_SIZE: usize = 5 * 1024 * 1024;
/// App uses 8192 (GeminiClient.swift:553,922,1026).
const MAX_OUTPUT_TOKENS_CAP: u64 = 8192;

/// Default thinking budget injected when client omits thinkingConfig.
/// Gemini 2.5 Flash thinking output costs $3.50/M vs $0.60/M regular (5.8x).
/// 1024 tokens caps reasoning for old clients; current Swift client sends budget=0
/// explicitly on all production paths.
const DEFAULT_THINKING_BUDGET: u64 = 1024;

/// Proxy-specific error type — allows JSON 429 responses alongside bare status codes.
enum ProxyError {
Status(StatusCode),
Expand Down Expand Up @@ -399,6 +405,7 @@ fn is_gemini_model_allowed(model: &str) -> bool {
/// - Cap generation_config.max_output_tokens to MAX_OUTPUT_TOKENS_CAP
/// - Reject candidate_count > 1
/// - Strip safety_settings and cached_content
/// - Inject default thinkingConfig if absent (cost control for Gemini 2.5)
/// - Preserve all other fields (contents, system_instruction, tools, etc.)
///
/// For embedContent/batchEmbedContents:
Expand Down Expand Up @@ -502,8 +509,11 @@ fn sanitize_gemini_body(body: &[u8], action: &str) -> Result<Vec<u8>, String> {
// Validate inside generation_config / generationConfig.
// Check BOTH casings to prevent dual-key bypass where an attacker
// sends an empty generation_config + a real generationConfig.
let mut found_generation_config = false;
for gc_key in &["generation_config", "generationConfig"] {
if let Some(gc) = obj.get_mut(*gc_key).and_then(|v| v.as_object_mut()) {
found_generation_config = true;

// Reject candidate_count > 1
for cc_key in &["candidate_count", "candidateCount"] {
if let Some(v) = gc.get(*cc_key) {
Expand All @@ -525,8 +535,29 @@ fn sanitize_gemini_body(body: &[u8], action: &str) -> Result<Vec<u8>, String> {
}
}
}

// Defense-in-depth: inject default thinking budget if client omits it.
// Gemini 2.5 Flash defaults to unlimited thinking which is 5.8x more
// expensive than regular output tokens. Cap at 1024 when absent.
let has_thinking = gc.contains_key("thinking_config")
|| gc.contains_key("thinkingConfig");
if !has_thinking {
gc.insert(
"thinkingConfig".to_string(),
serde_json::json!({"thinkingBudget": DEFAULT_THINKING_BUDGET}),
);
}
Comment on lines +539 to +549
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Defense-in-depth bypass when generation_config is absent

The injection only fires when the request already contains a generation_config/generationConfig object. A request that omits the key entirely (valid Gemini API behavior — model uses defaults) skips this block, leaving thinking unlimited. The PR comment says "inject default budget=1024 when client omits thinkingConfig" but the actual contract is narrower: the budget is injected only when a generation_config exists without a thinkingConfig. Any future client call that forgets to set generationConfig bypasses the proxy's cost cap entirely, defeating the stated defense-in-depth goal.

The fix is to add a fallback after the loop: if neither generation_config nor generationConfig exists in the object, insert a new generation_config containing only the default thinkingConfig.

}
}

// If no generation_config exists at all (legacy clients), create one
// with the default thinking budget to prevent unlimited thinking spend.
if !found_generation_config {
obj.insert(
"generationConfig".to_string(),
serde_json::json!({"thinkingConfig": {"thinkingBudget": DEFAULT_THINKING_BUDGET}}),
);
}
}

serde_json::to_vec(&json).map_err(|e| format!("failed to re-serialize: {}", e))
Expand Down Expand Up @@ -1548,4 +1579,151 @@ mod tests {
// Vertex AI uses Bearer auth, not API key in URL — but query params are forwarded as-is
assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com"));
}

// --- Thinking budget injection ---

#[test]
fn sanitize_injects_thinking_budget_when_absent() {
let body = serde_json::json!({
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"generation_config": {"max_output_tokens": 1000}
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
let gc = &parsed["generation_config"];
assert_eq!(
gc["thinkingConfig"]["thinkingBudget"],
serde_json::json!(DEFAULT_THINKING_BUDGET),
"should inject default thinking budget when absent"
);
}

#[test]
fn sanitize_preserves_existing_thinking_config_snake() {
let body = serde_json::json!({
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"generation_config": {
"max_output_tokens": 1000,
"thinking_config": {"thinking_budget": 0}
}
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
let gc = &parsed["generation_config"];
// Should not overwrite existing thinking_config
assert_eq!(gc["thinking_config"]["thinking_budget"], serde_json::json!(0));
assert!(gc.get("thinkingConfig").is_none(), "should not inject when thinking_config present");
}

#[test]
fn sanitize_preserves_existing_thinking_config_camel() {
let body = serde_json::json!({
"contents": [{"role": "user", "parts": [{"text": "hello"}]}],
"generationConfig": {
"maxOutputTokens": 1000,
"thinkingConfig": {"thinkingBudget": 4096}
}
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
let gc = &parsed["generationConfig"];
assert_eq!(gc["thinkingConfig"]["thinkingBudget"], serde_json::json!(4096));
}

#[test]
fn sanitize_no_thinking_injection_for_embed() {
let body = serde_json::json!({
"content": {"parts": [{"text": "hello"}]}
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"embedContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
// Embed requests have no generation_config, so no injection
assert!(parsed.get("thinkingConfig").is_none());
assert!(parsed.get("generation_config").is_none());
}

#[test]
fn sanitize_injects_generation_config_when_absent() {
// Legacy clients may send only contents with no generation_config at all.
// Proxy must create generationConfig with thinking budget to prevent
// unlimited thinking spend.
let body = serde_json::json!({
"contents": [{"parts": [{"text": "hello"}]}]
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
let gc = parsed.get("generationConfig").expect("generationConfig should be created");
let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected");
assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET);
}

#[test]
fn sanitize_dual_generation_config_both_get_thinking() {
// Attacker sends both casings — both should get thinking budget injected.
let body = serde_json::json!({
"contents": [{"parts": [{"text": "hello"}]}],
"generation_config": {"max_output_tokens": 100},
"generationConfig": {"maxOutputTokens": 200}
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
// Both casings should have thinkingConfig injected
let gc_snake = parsed.get("generation_config").unwrap().as_object().unwrap();
assert!(gc_snake.contains_key("thinkingConfig"));
let gc_camel = parsed.get("generationConfig").unwrap().as_object().unwrap();
assert!(gc_camel.contains_key("thinkingConfig"));
}

#[test]
fn sanitize_null_generation_config_gets_new_one() {
// Malformed: generation_config is null — proxy should create a fresh one.
let body = serde_json::json!({
"contents": [{"parts": [{"text": "hello"}]}],
"generation_config": null
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
// null generation_config is not an object, so proxy creates generationConfig
let gc = parsed.get("generationConfig").expect("generationConfig should be created");
let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected");
assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET);
}

#[test]
fn sanitize_string_generation_config_gets_new_one() {
// Malformed: generation_config is a string — proxy should create a fresh one.
let body = serde_json::json!({
"contents": [{"parts": [{"text": "hello"}]}],
"generation_config": "invalid"
});
let result = sanitize_gemini_body(
serde_json::to_vec(&body).unwrap().as_slice(),
"generateContent",
).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
let gc = parsed.get("generationConfig").expect("generationConfig should be created");
let tc = gc.get("thinkingConfig").expect("thinkingConfig should be injected");
assert_eq!(tc["thinkingBudget"], DEFAULT_THINKING_BUDGET);
}
}
4 changes: 3 additions & 1 deletion desktop/CHANGELOG.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"unreleased": [],
"unreleased": [
"Reduced AI processing costs with thinking budget controls"
],
"releases": [
{
"version": "0.11.378",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ actor InsightAssistant: ProactiveAssistant {
contents: iterContents,
systemPrompt: iterSystemPrompt,
tools: iterTools,
forceToolCall: iterForce
forceToolCall: iterForce,
thinkingBudget: 1024
)
}
} catch {
Expand Down Expand Up @@ -739,7 +740,8 @@ actor InsightAssistant: ProactiveAssistant {
contents: p2Contents,
systemPrompt: p2SystemPrompt,
tools: p2Tools,
forceToolCall: p2Force
forceToolCall: p2Force,
thinkingBudget: 1024
)
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,8 @@ actor TaskAssistant: ProactiveAssistant {
contents: contents,
systemPrompt: currentSystemPrompt,
tools: [tools],
forceToolCall: iteration == 0
forceToolCall: iteration == 0,
thinkingBudget: 1024
)

guard let toolCall = result.toolCalls.first else {
Expand Down
Loading
Loading