diff --git a/README.md b/README.md index 8d66ea7..796b863 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,27 @@ pip install llm-dna Use `llm-dna` for install/package naming, and `llm_dna` for Python imports. +Optional extras are available for model families that need additional runtime dependencies: + +```bash +# Apple Silicon / MLX-backed models +pip install "llm-dna[apple]" + +# Quantized HuggingFace models (bitsandbytes, GPTQ, compressed-tensors, optimum) +pip install "llm-dna[quantization]" + +# Architecture-specific model families such as Mamba or TIMM-backed models +pip install "llm-dna[model_families]" + +# Everything above +pip install "llm-dna[full]" +``` + +Extra guidance: +- `apple`: required for MLX and `mlx-community/*` style model families on Apple Silicon. +- `quantization`: required for many GPTQ, bitsandbytes, and compressed-tensors model families. +- `model_families`: required for specific architectures whose modeling code depends on packages like `mamba-ssm` or `timm`. + ## Quick Start ```python diff --git a/pyproject.toml b/pyproject.toml index 16a1cca..d202f2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,11 +32,36 @@ dependencies = [ "wonderwords>=2.2.0", "openai>=1.0.0", "tiktoken>=0.7.0", + "python-dotenv>=1.0.0", ] [project.optional-dependencies] model_scraping = ["requests>=2.31.0"] +apple = [ + "mlx>=0.10.0; sys_platform == 'darwin' and platform_machine == 'arm64'", + "mlx-lm>=0.10.0; sys_platform == 'darwin' and platform_machine == 'arm64'", +] + +quantization = [ + "bitsandbytes>=0.46.1", + "autoawq>=0.2.0", + "auto-gptq>=0.5.0", + "optimum>=1.16.0", + "compressed-tensors>=0.1.0", +] + +model_families = [ + "mamba-ssm>=1.0.0; sys_platform == 'linux'", + "timm>=0.9.0", +] + +full = [ + "llm-dna[apple]", + "llm-dna[quantization]", + "llm-dna[model_families]", +] + vllm = ["vllm>=0.4.0"] dev = [ diff --git a/src/llm_dna/api.py b/src/llm_dna/api.py index 858e9a0..ff6cac1 100644 --- a/src/llm_dna/api.py +++ b/src/llm_dna/api.py @@ -11,7 +11,6 @@ from datetime import datetime from dataclasses import asdict, dataclass, replace from pathlib import Path -from types import SimpleNamespace from typing import TYPE_CHECKING, Any, Dict, Optional import numpy as np @@ -310,6 +309,12 @@ def _load_cached_responses(path: Path, expected_count: int) -> Optional[list[str logging.warning("Failed to parse cached responses from %s: %s", path, exc) return None + if isinstance(payload, dict): + complete = payload.get("complete") + if complete is False: + logging.warning("Ignoring incomplete cached responses at %s", path) + return None + responses: list[str] if isinstance(payload, dict) and isinstance(payload.get("items"), list): responses = [str(item.get("response", "")) for item in payload["items"] if isinstance(item, dict)] @@ -322,6 +327,11 @@ def _load_cached_responses(path: Path, expected_count: int) -> Optional[list[str if not responses: return None + non_empty_count = sum(1 for response in responses if response.strip()) + if non_empty_count == 0: + logging.warning("Ignoring cached responses at %s because all responses are empty.", path) + return None + if len(responses) != expected_count: logging.warning( "Cached responses at %s have probe count mismatch (%s != %s); normalizing by truncating/padding.", @@ -537,11 +547,11 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult: signature: "DNASignature" vector: np.ndarray - is_api_mode = _is_api_parallel_mode(config, [config.model_name]) + if config.extractor_type != "embedding": + raise ValueError(f"Unsupported extractor_type for calc_dna: {config.extractor_type}") + response_path = _response_cache_path(config, config.model_name) - cached_responses: Optional[list[str]] = None - if is_api_mode and config.extractor_type == "embedding": - cached_responses = _load_cached_responses(response_path, expected_count=len(probe_texts)) + cached_responses = _load_cached_responses(response_path, expected_count=len(probe_texts)) if cached_responses is not None: logging.info( @@ -550,21 +560,22 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult: response_path, ) model_meta = _default_model_metadata(config.model_name) - signature, vector, _ = _extract_signature_from_text_responses( - model_name=config.model_name, - responses=cached_responses, - config=config, - model_meta=model_meta, - generation_device=resolved_device, - encoder_device=resolved_device, - ) - elif is_api_mode and config.extractor_type == "embedding": - # API model without cached responses: generate via API, then encode - logging.info( - "Generating responses for API model '%s' via provider API...", - config.model_name, - ) + responses = cached_responses + else: + if _is_api_model_type(config.model_type): + logging.info( + "Generating responses for API model '%s' via provider API...", + config.model_name, + ) + model_meta = _load_model_metadata_for_model(config.model_name, metadata_file, token=resolved_token) + is_generative = model_meta.get("architecture", {}).get("is_generative") + if is_generative is False: + arch_type = model_meta.get("architecture", {}).get("type") + raise ValueError( + f"Model '{config.model_name}' is non-generative (architecture={arch_type})." + ) + responses = _generate_responses_for_model( model_name=config.model_name, config=config, @@ -574,7 +585,6 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult: resolved_token=resolved_token, incremental_save_path=response_path if config.save else None, ) - # Save final response cache if config.save: _save_response_cache( path=response_path, @@ -583,64 +593,15 @@ def calc_dna(config: DNAExtractionConfig) -> DNAExtractionResult: prompts=probe_texts, responses=responses, ) - signature, vector, _ = _extract_signature_from_text_responses( - model_name=config.model_name, - responses=responses, - config=config, - model_meta=model_meta, - generation_device=resolved_device, - encoder_device=resolved_device, - ) - else: - # Non-API model: use hidden-state extraction - model_meta = _load_model_metadata_for_model(config.model_name, metadata_file, token=resolved_token) - - is_generative = model_meta.get("architecture", {}).get("is_generative") - if is_generative is False: - arch_type = model_meta.get("architecture", {}).get("type") - raise ValueError( - f"Model '{config.model_name}' is non-generative (architecture={arch_type})." - ) - resolved_model_path = _resolve_model_path(config.model_path, model_meta) - - args = SimpleNamespace( - model_name=config.model_name, - model_path=resolved_model_path, - model_type=config.model_type, - dataset=config.dataset, - probe_set=config.probe_set, - max_samples=config.max_samples, - data_root=config.data_root, - extractor_type=config.extractor_type, - dna_dim=config.dna_dim, - reduction_method=config.reduction_method, - embedding_merge=config.embedding_merge, - max_length=config.max_length, - save_format="json", - output_dir=Path(config.output_dir), - load_in_8bit=config.load_in_8bit, - load_in_4bit=config.load_in_4bit, - no_quantization=config.no_quantization, - metadata_file=metadata_file, - token=resolved_token, - trust_remote_code=config.trust_remote_code, - device=resolved_device, - log_level=config.log_level, - random_seed=config.random_seed, - use_chat_template=config.use_chat_template, - ) - - signature = core.extract_dna_signature( - model_name=config.model_name, - model_path=resolved_model_path, - model_type=config.model_type, - probe_texts=probe_texts, - extractor_type=config.extractor_type, - model_metadata=model_meta, - args=args, - ) - vector = _validate_signature(signature) + signature, vector, _ = _extract_signature_from_text_responses( + model_name=config.model_name, + responses=responses, + config=config, + model_meta=model_meta, + generation_device=resolved_device, + encoder_device=resolved_device, + ) elapsed_seconds = time.time() - start_time diff --git a/src/llm_dna/cli.py b/src/llm_dna/cli.py index 10d203f..56a4d64 100644 --- a/src/llm_dna/cli.py +++ b/src/llm_dna/cli.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import Iterable, List, Optional +from dotenv import load_dotenv + def _load_models_from_file(path: Path) -> List[str]: """Load model names from a file, one per line.""" @@ -183,6 +185,8 @@ def main(argv: Optional[Iterable[str]] = None) -> int: """Main CLI entrypoint for DNA extraction.""" from .api import DNAExtractionConfig, calc_dna, calc_dna_parallel + load_dotenv(override=False) + args = parse_arguments(argv) # Resolve model names diff --git a/src/llm_dna/core/extraction.py b/src/llm_dna/core/extraction.py index 4560814..183e406 100644 --- a/src/llm_dna/core/extraction.py +++ b/src/llm_dna/core/extraction.py @@ -630,6 +630,13 @@ def main(): logging.info(f"DNA signature saved to: {output_path}") # Save summary + # Create safe args dict without sensitive information + safe_args = vars(args).copy() + # Remove sensitive fields that should not be saved to output files + sensitive_fields = ['token', 'OPENROUTER_API_KEY', 'OPENAI_API_KEY'] + for field in sensitive_fields: + safe_args.pop(field, None) + summary = { "model_name": args.model_name, "dataset": args.dataset, @@ -642,7 +649,7 @@ def main(): "signature_stats": signature.get_statistics(), "metadata": signature.metadata.__dict__, "output_file": str(output_path), - "args": vars(args) + "args": safe_args } # Keep summary filename model-only as well diff --git a/src/llm_dna/models/ModelLoader.py b/src/llm_dna/models/ModelLoader.py index f3c8b65..15ac4ba 100644 --- a/src/llm_dna/models/ModelLoader.py +++ b/src/llm_dna/models/ModelLoader.py @@ -75,7 +75,8 @@ def _detect_model_type(self, model_path_or_name: str) -> str: "openrouter:", "anthropic/claude-", "deepseek/", - "openai/gpt-", + "openai/gpt-3", + "openai/gpt-4", "google/gemini-", "z-ai/", "x-ai/grok-",