Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/arxiv_explorer/cli/daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def show(
arxiv_id, paper.title, paper.abstract, detailed=detailed, force=force
)

if (summary or detailed) and paper_summary is None:
import sys

print("Failed to generate summary (check provider settings)", file=sys.stderr)
raise typer.Exit(1)

paper_translation = None
if translate:
translator = TranslationService()
Expand All @@ -252,6 +258,12 @@ def show(
arxiv_id, paper.title, paper.abstract, force=force
)

if translate and paper_translation is None:
import sys

print("Failed to generate translation (check provider settings)", file=sys.stderr)
raise typer.Exit(1)

print_paper_detail(paper, paper_summary, paper_translation)


Expand Down
9 changes: 9 additions & 0 deletions src/arxiv_explorer/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- Custom AI providers
CREATE TABLE IF NOT EXISTS custom_providers (
name TEXT PRIMARY KEY NOT NULL,
preset TEXT NOT NULL,
command_template TEXT NOT NULL,
default_model TEXT DEFAULT '',
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- Paper review sections (incremental cache)
CREATE TABLE IF NOT EXISTS paper_review_sections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
Expand Down
8 changes: 8 additions & 0 deletions src/arxiv_explorer/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ class Language(str, Enum):
KO = "ko"


@dataclass
class CustomProviderConfig:
name: str
preset: str
command_template: str
default_model: str = ""


class JobType(Enum):
SUMMARIZE = "summarize"
TRANSLATE = "translate"
Expand Down
5 changes: 3 additions & 2 deletions src/arxiv_explorer/services/arxiv_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def _build_query(query: str) -> str:
import re

# Already formatted: contains field prefix or boolean operator
if re.search(r'\b(all|ti|au|abs|cat|co|jr|rn|id):', query) or \
re.search(r'\b(AND|OR|ANDNOT)\b', query):
if re.search(r"\b(all|ti|au|abs|cat|co|jr|rn|id):", query) or re.search(
r"\b(AND|OR|ANDNOT)\b", query
):
return query

words = query.split()
Expand Down
35 changes: 26 additions & 9 deletions src/arxiv_explorer/services/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,29 @@ def build_command(self, prompt: str, model: str = "") -> list[str]:
}


def get_provider(provider_type: AIProviderType) -> AIProvider:
"""Return a provider instance. If custom, load the template from settings."""
provider = PROVIDERS[provider_type]
if provider_type == AIProviderType.CUSTOM:
from .settings_service import SettingsService

template = SettingsService().get("custom_command")
provider.configure(template)
return provider
def get_provider(provider_name: str | AIProviderType) -> AIProvider:
"""Return a provider instance. Checks built-in registry first, then custom_providers table."""
# Normalize to string
name = provider_name.value if isinstance(provider_name, AIProviderType) else provider_name

# Try built-in registry
for ptype, prov in PROVIDERS.items():
if ptype.value == name:
if ptype == AIProviderType.CUSTOM:
from .settings_service import SettingsService

template = SettingsService().get("custom_command")
prov.configure(template)
return prov

# Try custom_providers table
from .settings_service import SettingsService

for cp in SettingsService().get_custom_providers():
if cp.name == name:
provider = CustomProvider()
provider.configure(cp.command_template)
return provider

# Fallback to gemini
return PROVIDERS[AIProviderType.GEMINI]
50 changes: 47 additions & 3 deletions src/arxiv_explorer/services/settings_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def get_all(self) -> dict[str, str]:
settings[row["key"]] = row["value"]
return settings

def get_provider(self) -> AIProviderType:
"""Get the current AI provider."""
return AIProviderType(self.get("ai_provider"))
def get_provider(self) -> str:
"""Return the active provider name as a string."""
return self.get("ai_provider")

def get_model(self) -> str:
"""Get the current AI model override."""
Expand Down Expand Up @@ -104,3 +104,47 @@ def set_weights(self, weights: dict[str, int]) -> None:
def reset_weights(self) -> None:
"""Reset recommendation weights to defaults."""
self.set_weights(DEFAULT_WEIGHTS)

# Reserved names that cannot be used for custom providers
RESERVED_PROVIDERS = {"gemini", "claude", "openai", "ollama", "opencode", "custom"}

def get_custom_providers(self) -> list:
"""Return all custom providers as list of CustomProviderConfig."""
from ..core.models import CustomProviderConfig

with get_connection() as conn:
rows = conn.execute(
"SELECT name, preset, command_template, default_model FROM custom_providers ORDER BY name"
).fetchall()
return [
CustomProviderConfig(
name=r["name"],
preset=r["preset"],
command_template=r["command_template"],
default_model=r["default_model"] or "",
)
for r in rows
]

def add_custom_provider(
self, name: str, preset: str, command_template: str, default_model: str = ""
) -> None:
"""Register a custom provider. Raises ValueError if name is reserved or duplicate."""
if name.lower() in self.RESERVED_PROVIDERS:
raise ValueError(f"'{name}' is a reserved provider name")
with get_connection() as conn:
conn.execute(
"INSERT OR REPLACE INTO custom_providers (name, preset, command_template, default_model) "
"VALUES (?, ?, ?, ?)",
(name, preset, command_template, default_model),
)
conn.commit()

def remove_custom_provider(self, name: str) -> None:
"""Remove a custom provider. If it's the active provider, switch to gemini."""
with get_connection() as conn:
conn.execute("DELETE FROM custom_providers WHERE name = ?", (name,))
conn.commit()
# If active provider was deleted, reset to gemini
if self.get("ai_provider") == name:
self.set("ai_provider", "gemini")
16 changes: 8 additions & 8 deletions src/arxiv_explorer/services/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ def summarize(
settings = SettingsService()
provider = get_provider(settings.get_provider())
if not provider.is_available():
import sys

print("Summary generation failed: provider not available", file=sys.stderr)
return None
output = provider.invoke(
prompt,
model=settings.get_model(),
timeout=settings.get_timeout(),
)
if output is None:
import sys

print("Summary generation failed: provider returned no output", file=sys.stderr)
return None
# Extract JSON block (may be in ```json ... ``` format)
if "```json" in output:
Expand All @@ -82,13 +88,9 @@ def summarize(
try:
data = json.loads(output)
except json.JSONDecodeError as e:
# JSON parse failure - print debug info and return None
import sys

if "--verbose" in sys.argv or "-v" in sys.argv:
print(f"\nSummary generation failed ({arxiv_id}): JSON parse error")
print(f"Error: {e}")
print(f"Output sample: {output[:300]}...")
print(f"Summary generation failed: JSON parse error: {e}", file=sys.stderr)
return None

summary = PaperSummary(
Expand All @@ -106,11 +108,9 @@ def summarize(
return summary

except Exception as e:
# Other error - fail silently
import sys

if "--verbose" in sys.argv or "-v" in sys.argv:
print(f"\nError during summary generation ({arxiv_id}): {e}")
print(f"Summary generation failed: {e}", file=sys.stderr)
return None

def _get_cached(self, arxiv_id: str) -> PaperSummary | None:
Expand Down
14 changes: 8 additions & 6 deletions src/arxiv_explorer/services/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,19 @@ def translate(
settings = SettingsService()
provider = get_provider(settings.get_provider())
if not provider.is_available():
import sys

print("Translation failed: provider not available", file=sys.stderr)
return None
output = provider.invoke(
prompt,
model=settings.get_model(),
timeout=settings.get_timeout(),
)
if output is None:
import sys

print("Translation failed: provider returned no output", file=sys.stderr)
return None

# Extract JSON block
Expand All @@ -94,10 +100,7 @@ def translate(
except json.JSONDecodeError as e:
import sys

if "--verbose" in sys.argv or "-v" in sys.argv:
print(f"\nTranslation failed ({arxiv_id}): JSON parse error")
print(f"Error: {e}")
print(f"Output sample: {output[:300]}...")
print(f"Translation failed: JSON parse error: {e}", file=sys.stderr)
return None

translation = PaperTranslation(
Expand All @@ -115,8 +118,7 @@ def translate(
except Exception as e:
import sys

if "--verbose" in sys.argv or "-v" in sys.argv:
print(f"\nTranslation error ({arxiv_id}): {e}")
print(f"Translation failed: {e}", file=sys.stderr)
return None

def _get_cached(self, arxiv_id: str, target_language: Language) -> PaperTranslation | None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"paper_review_sections",
"preferred_authors",
"daily_fetch_cache",
"custom_providers",
}


Expand Down
19 changes: 19 additions & 0 deletions tui-rs/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pub struct PrefsState {
pub weights: [i64; 4],
pub provider: String,
pub language: String,
pub custom_providers: Vec<crate::db::models::CustomProviderEntry>,
pub custom_provider_selected: usize,
pub selected: usize, // cursor in weights section
pub focus_section: usize, // 0=cats, 1=keywords, 2=authors, 3=weights, 4=config
pub section_selected: [usize; 5], // cursor per section (section 4: 0=provider, 1=language)
Expand All @@ -205,6 +207,8 @@ impl Default for PrefsState {
weights: [60, 20, 15, 5],
provider: "gemini".to_string(),
language: "en".to_string(),
custom_providers: vec![],
custom_provider_selected: 0,
selected: 0,
focus_section: 0,
section_selected: [0; 5],
Expand All @@ -220,6 +224,7 @@ impl Default for PrefsState {
pub enum ConfirmAction {
RegenerateSummary,
RegenerateTranslation,
RemoveCustomProvider(String),
}

// =============================================================================
Expand Down Expand Up @@ -251,6 +256,18 @@ pub enum OverlayMode {
AuthorInput {
text: String,
},
PresetPicker {
selected: usize,
},
ProviderNameInput {
preset: String,
text: String,
},
CommandTemplateInput {
preset: String,
name: String,
text: String,
},
}

// =============================================================================
Expand Down Expand Up @@ -353,6 +370,8 @@ impl App {
weights,
provider,
language,
custom_providers: vec![],
custom_provider_selected: 0,
selected: 0,
focus_section: 0,
section_selected: [0; 5],
Expand Down
54 changes: 53 additions & 1 deletion tui-rs/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ impl Database {
}
let conn = Connection::open(path)?;
conn.execute_batch("PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL;")?;
Ok(Self { conn })
let db = Self { conn };
db.ensure_custom_providers_table()?;
Ok(db)
}

/// Return the default database path.
Expand Down Expand Up @@ -475,6 +477,56 @@ impl Database {
Ok(())
}

// =========================================================================
// Custom Providers
// =========================================================================

pub fn ensure_custom_providers_table(&self) -> Result<()> {
self.conn.execute_batch(
"CREATE TABLE IF NOT EXISTS custom_providers (
name TEXT PRIMARY KEY NOT NULL,
preset TEXT NOT NULL,
command_template TEXT NOT NULL,
default_model TEXT DEFAULT '',
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);"
)?;
Ok(())
}

pub fn get_custom_providers(&self) -> Result<Vec<CustomProviderEntry>> {
let mut stmt = self.conn.prepare(
"SELECT name, preset, command_template, default_model FROM custom_providers ORDER BY name"
)?;
let rows = stmt.query_map([], |row| {
Ok(CustomProviderEntry {
name: row.get(0)?,
preset: row.get(1)?,
command_template: row.get(2)?,
default_model: row.get::<_, Option<String>>(3)?.unwrap_or_default(),
})
})?;
rows.collect()
}

pub fn add_custom_provider(&self, entry: &CustomProviderEntry) -> Result<()> {
self.conn.execute(
"INSERT OR REPLACE INTO custom_providers (name, preset, command_template, default_model) VALUES (?1, ?2, ?3, ?4)",
params![entry.name, entry.preset, entry.command_template, entry.default_model],
)?;
Ok(())
}

pub fn remove_custom_provider(&self, name: &str) -> Result<()> {
self.conn.execute("DELETE FROM custom_providers WHERE name = ?1", params![name])?;
// If active provider was deleted, reset to gemini
let current = self.get_setting("ai_provider", "gemini")?;
if current == name {
self.set_setting("ai_provider", "gemini")?;
}
Ok(())
}

// =========================================================================
// Summaries & Translations
// =========================================================================
Expand Down
Loading
Loading