Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ pip install llm-dna

Use `llm-dna` for install/package naming, and `llm_dna` for Python imports.

Optional extras are available for model families that need additional runtime dependencies:

```bash
# Apple Silicon / MLX-backed models
pip install "llm-dna[apple]"

# Quantized HuggingFace models (bitsandbytes, GPTQ, compressed-tensors, optimum)
pip install "llm-dna[quantization]"

# Architecture-specific model families such as Mamba or TIMM-backed models
pip install "llm-dna[model_families]"

# Everything above
pip install "llm-dna[full]"
```

Extra guidance:
- `apple`: required for MLX and `mlx-community/*` style model families on Apple Silicon.
- `quantization`: required for many GPTQ, bitsandbytes, and compressed-tensors model families.
- `model_families`: required for specific architectures whose modeling code depends on packages like `mamba-ssm` or `timm`.

## Quick Start

```python
Expand Down
25 changes: 25 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,36 @@ dependencies = [
"wonderwords>=2.2.0",
"openai>=1.0.0",
"tiktoken>=0.7.0",
"python-dotenv>=1.0.0",
]

[project.optional-dependencies]
model_scraping = ["requests>=2.31.0"]

apple = [
"mlx>=0.10.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.10.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
]

quantization = [
"bitsandbytes>=0.46.1",
"autoawq>=0.2.0",
"auto-gptq>=0.5.0",
"optimum>=1.16.0",
"compressed-tensors>=0.1.0",
]

model_families = [
"mamba-ssm>=1.0.0; sys_platform == 'linux'",
"timm>=0.9.0",
]

full = [
"llm-dna[apple]",
"llm-dna[quantization]",
"llm-dna[model_families]",
]

vllm = ["vllm>=0.4.0"]

dev = [
Expand Down
115 changes: 38 additions & 77 deletions src/llm_dna/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from datetime import datetime
from dataclasses import asdict, dataclass, replace
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Dict, Optional

import numpy as np
Expand Down Expand Up @@ -310,6 +309,12 @@ def _load_cached_responses(path: Path, expected_count: int) -> Optional[list[str
logging.warning("Failed to parse cached responses from %s: %s", path, exc)
return None

if isinstance(payload, dict):
complete = payload.get("complete")
if complete is False:
logging.warning("Ignoring incomplete cached responses at %s", path)
return None

responses: list[str]
if isinstance(payload, dict) and isinstance(payload.get("items"), list):
responses = [str(item.get("response", "")) for item in payload["items"] if isinstance(item, dict)]
Expand All @@ -322,6 +327,11 @@ def _load_cached_responses(path: Path, expected_count: int) -> Optional[list[str
if not responses:
return None

non_empty_count = sum(1 for response in responses if response.strip())
if non_empty_count == 0:
logging.warning("Ignoring cached responses at %s because all responses are empty.", path)
return None

if len(responses) != expected_count:
logging.warning(
"Cached responses at %s have probe count mismatch (%s != %s); normalizing by truncating/padding.",
Expand Down Expand Up @@ -537,11 +547,11 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult:
signature: "DNASignature"
vector: np.ndarray

is_api_mode = _is_api_parallel_mode(config, [config.model_name])
if config.extractor_type != "embedding":
raise ValueError(f"Unsupported extractor_type for calc_dna: {config.extractor_type}")

response_path = _response_cache_path(config, config.model_name)
cached_responses: Optional[list[str]] = None
if is_api_mode and config.extractor_type == "embedding":
cached_responses = _load_cached_responses(response_path, expected_count=len(probe_texts))
cached_responses = _load_cached_responses(response_path, expected_count=len(probe_texts))

if cached_responses is not None:
logging.info(
Expand All @@ -550,21 +560,22 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult:
response_path,
)
model_meta = _default_model_metadata(config.model_name)
signature, vector, _ = _extract_signature_from_text_responses(
model_name=config.model_name,
responses=cached_responses,
config=config,
model_meta=model_meta,
generation_device=resolved_device,
encoder_device=resolved_device,
)
elif is_api_mode and config.extractor_type == "embedding":
# API model without cached responses: generate via API, then encode
logging.info(
"Generating responses for API model '%s' via provider API...",
config.model_name,
)
responses = cached_responses
else:
if _is_api_model_type(config.model_type):
logging.info(
"Generating responses for API model '%s' via provider API...",
config.model_name,
)

model_meta = _load_model_metadata_for_model(config.model_name, metadata_file, token=resolved_token)
is_generative = model_meta.get("architecture", {}).get("is_generative")
if is_generative is False:
arch_type = model_meta.get("architecture", {}).get("type")
raise ValueError(
f"Model '{config.model_name}' is non-generative (architecture={arch_type})."
)

responses = _generate_responses_for_model(
model_name=config.model_name,
config=config,
Expand All @@ -574,7 +585,6 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult:
resolved_token=resolved_token,
incremental_save_path=response_path if config.save else None,
)
# Save final response cache
if config.save:
_save_response_cache(
path=response_path,
Expand All @@ -583,64 +593,15 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult:
prompts=probe_texts,
responses=responses,
)
signature, vector, _ = _extract_signature_from_text_responses(
model_name=config.model_name,
responses=responses,
config=config,
model_meta=model_meta,
generation_device=resolved_device,
encoder_device=resolved_device,
)
else:
# Non-API model: use hidden-state extraction
model_meta = _load_model_metadata_for_model(config.model_name, metadata_file, token=resolved_token)

is_generative = model_meta.get("architecture", {}).get("is_generative")
if is_generative is False:
arch_type = model_meta.get("architecture", {}).get("type")
raise ValueError(
f"Model '{config.model_name}' is non-generative (architecture={arch_type})."
)

resolved_model_path = _resolve_model_path(config.model_path, model_meta)

args = SimpleNamespace(
model_name=config.model_name,
model_path=resolved_model_path,
model_type=config.model_type,
dataset=config.dataset,
probe_set=config.probe_set,
max_samples=config.max_samples,
data_root=config.data_root,
extractor_type=config.extractor_type,
dna_dim=config.dna_dim,
reduction_method=config.reduction_method,
embedding_merge=config.embedding_merge,
max_length=config.max_length,
save_format="json",
output_dir=Path(config.output_dir),
load_in_8bit=config.load_in_8bit,
load_in_4bit=config.load_in_4bit,
no_quantization=config.no_quantization,
metadata_file=metadata_file,
token=resolved_token,
trust_remote_code=config.trust_remote_code,
device=resolved_device,
log_level=config.log_level,
random_seed=config.random_seed,
use_chat_template=config.use_chat_template,
)

signature = core.extract_dna_signature(
model_name=config.model_name,
model_path=resolved_model_path,
model_type=config.model_type,
probe_texts=probe_texts,
extractor_type=config.extractor_type,
model_metadata=model_meta,
args=args,
)
vector = _validate_signature(signature)
signature, vector, _ = _extract_signature_from_text_responses(
model_name=config.model_name,
responses=responses,
config=config,
model_meta=model_meta,
generation_device=resolved_device,
encoder_device=resolved_device,
)

elapsed_seconds = time.time() - start_time

Expand Down
4 changes: 4 additions & 0 deletions src/llm_dna/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pathlib import Path
from typing import Iterable, List, Optional

from dotenv import load_dotenv


def _load_models_from_file(path: Path) -> List[str]:
"""Load model names from a file, one per line."""
Expand Down Expand Up @@ -183,6 +185,8 @@ def main(argv: Optional[Iterable[str]] = None) -> int:
"""Main CLI entrypoint for DNA extraction."""
from .api import DNAExtractionConfig, calc_dna, calc_dna_parallel

load_dotenv(override=False)

args = parse_arguments(argv)

# Resolve model names
Expand Down
9 changes: 8 additions & 1 deletion src/llm_dna/core/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,13 @@ def main():
logging.info(f"DNA signature saved to: {output_path}")

# Save summary
# Create safe args dict without sensitive information
safe_args = vars(args).copy()
# Remove sensitive fields that should not be saved to output files
sensitive_fields = ['token', 'OPENROUTER_API_KEY', 'OPENAI_API_KEY']
for field in sensitive_fields:
safe_args.pop(field, None)

summary = {
"model_name": args.model_name,
"dataset": args.dataset,
Expand All @@ -642,7 +649,7 @@ def main():
"signature_stats": signature.get_statistics(),
"metadata": signature.metadata.__dict__,
"output_file": str(output_path),
"args": vars(args)
"args": safe_args
}

# Keep summary filename model-only as well
Expand Down
3 changes: 2 additions & 1 deletion src/llm_dna/models/ModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _detect_model_type(self, model_path_or_name: str) -> str:
"openrouter:",
"anthropic/claude-",
"deepseek/",
"openai/gpt-",
"openai/gpt-3",
"openai/gpt-4",
"google/gemini-",
"z-ai/",
"x-ai/grok-",
Expand Down