diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..90bc9f2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Data files +dataset/*.json + +# Cache and generated files +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Logging +*.log +logs/ +nohup.out + +# OS-specific files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# IDE files +.idea/ +.vscode/ +*.swp +*.swo + +# Result files +evaluation/results/ +*.csv + +# Jupyter Notebook +.ipynb_checkpoints + +# Keep dataset README +!dataset/README.md + +# Keep configuration files +!evaluation/my_ragas/data/gt_cache.json \ No newline at end of file diff --git a/README.md b/README.md index 847260c..472a569 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,284 @@ -## My Project +

+ Amazon +        + SenTSR-Bench +

-TODO: Fill this README out! +

SenTSR-Bench: Thinking with Injected Knowledge for Time-Series Reasoning

-Be sure to: +

+ Zelin He, Boran Han, Xiyuan Zhang, Shuai Zhang, Haotian Lin, Qi Zhu, Haoyang Fang,
Danielle C. Maddix, Abdul Fatir Ansari, Akash Chandrayan, Abhinav Pradhan, Bernie Wang, Matthew Reimherr
+

-* Change the title in this README -* Edit your repository description on GitHub +

+ Paper + Dataset + Website + License: Apache 2.0 + Venue +

-## Security +

+ Official implementation of the SenTSR-Bench knowledge injection framework.
+ Inject in-domain knowledge from fine-tuned time-series specialists into frozen general reasoning LMs
+ for robust, context-aware diagnostic reasoning. +

-See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. +--- -## License +## Overview + +
+ SenTSR-Bench Overview +
+ (a) A fine-tuned TSLM captures key time-series patterns but fails on diagnostic reasoning; (b) A general-purpose GRLM reasons well but overlooks domain-specific patterns; (c) Our knowledge injection steers the GRLM's reasoning with in-domain knowledge from the TSLM, producing the correct diagnosis. +
+ +
+ +**General reasoning LMs** (GRLMs) show strong reasoning but lack domain-specific time-series knowledge. **Time-series specialist LMs** (TSLMs) capture signal patterns but struggle with multi-step diagnostic reasoning. Our framework bridges this gap by **injecting TSLM-generated insights directly into the GRLM's reasoning trace** — no weight updates needed. + +
+ Knowledge Injection Paradigm +
+ (a) Knowledge injection paradigm; (b) RL-honed thinking traces via RLVR enable effective injection without human supervision. +
+ +### Key Contributions + +1. **New Paradigm** — A framework that injects in-domain knowledge from a TSLM into a GRLM's reasoning process, steering reasoning with domain knowledge +2. **RL-Based Injection** — Reinforcement learning with verifiable rewards (RLVR) elicits knowledge-rich thinking traces *without manual supervision* +3. **SenTSR-Bench** — A new benchmark of 110 real-world multivariate sensor streams with 330 human-curated diagnostic questions +4. **Strong Results** — Surpasses TSLMs by 9.1%–26.1% and GRLMs by 7.9%–22.4% across benchmarks + +### Main Results + +
+ Main Results +
+ Reasoning performance on SenTSR-Bench. RL-injection consistently outperforms all baselines. +
+ +--- + +## Table of Contents + +- [Overview](#overview) +- [Setup](#setup) +- [Benchmark](#benchmark) + - [SenTSR-Bench Evaluation Benchmark](#sentsr-bench-evaluation-benchmark) + - [Public Benchmark Data Curation](#public-benchmark-data-curation) + - [Synthetic Data Generation Pipeline](#synthetic-data-generation-pipeline) +- [Method: Knowledge Injection](#method-knowledge-injection) + - [Claude (GRLM) + ChatTS (TSLM)](#closed-source-claude-grlm--chatts-tslm) + - [Qwen3 (GRLM) + Qwen-VL (TSLM)](#open-source-qwen3-grlm--qwen-vl-tslm) + - [DeepSeek-R1 (GRLM) + Qwen-VL (TSLM)](#open-source-deepseek-r1-grlm--qwen-vl-tslm) +- [TSLM Training](#tslm-training) +- [Evaluation](#evaluation) +- [Citation](#citation) +- [License](#license) + +--- + +## Setup + +### Prerequisites + +- Python 3.10+ +- Conda for environment management +- AWS account with access to Claude models via Bedrock (for closed-source experiments) +- GPU for running self-hosted model servers (8xA100 recommended) + +### Installation + +```bash +git clone https://github.com/amazon-science/SenTSR-Bench.git +cd SenTSR-Bench + +conda create -n tsr-env python=3.10 +conda activate tsr-env +pip install -r requirements.txt + +# (For Claude experiments) Configure AWS credentials +aws configure +``` + +--- + +## Benchmark + +### SenTSR-Bench Evaluation Benchmark + +SenTSR-Bench is a first-of-its-kind dataset of **110 multivariate sensor streams** with **330 human-curated diagnostic questions**, built from real-world industrial operations. Each time series contains 3 sensor channels (acceleration, velocity, temperature). + +The benchmark evaluates a three-stage diagnostic reasoning chain: + +| Stage | Task | Description | +|-------|------|-------------| +| **What Happened** | Anomaly Characterization | Identify key time-series anomaly patterns | +| **How Happened** | Root-Cause Diagnosis | Determine the most likely causes | +| **Suggested Fix** | Action Recommendation | Propose corrective actions | + +Download the evaluation benchmark from HuggingFace: + +```bash +# Install huggingface_hub if needed +pip install huggingface_hub + +# Download dataset to ./dataset/ +huggingface-cli download ZLHe0/SenTSR-Bench --repo-type dataset --local-dir ./dataset +``` + +See `dataset/README.md` for format specifications. + +### Public Benchmark Data Curation + +We additionally evaluate on two public benchmarks: **TSEvol** (Dataset A) and **TS&Language** (MCQ2 subset). + +```bash +python dataset/preprocess_dataset.py \ + --dataset_a path/to/dataset_a.json \ + --mcq2_source path/to/MCQ_2_TS.jsonl \ + --output_dir ./dataset/processed \ + --mcq2_sample_size 100 +``` + +### Synthetic Data Generation Pipeline + +The synthetic training data pipeline uses VLM-assisted code synthesis to bootstrap realistic simulators from 23 seed signals, producing **6,000 MCQ training entries**. + +| Stage | Script | Description | +|-------|--------|-------------| +| 1. Iterative Code Synthesis | `./scripts/run_iterative_generation.sh` | Claude generates Python simulators from real data | +| 2. Stochastic Diversification | `./scripts/run_stochastic_refinement.sh` | Convert to sampling-based generators | +| 3–4. Benchmark Generation | `./scripts/run_synthetic_benchmark.sh 100` | Generate synthetic time series + QA/MCQ | + +> **Note:** Stage 2 requires manual review to select the best stochastic model per sample before proceeding. + +See `dataset/synthetic/README.md` for full pipeline documentation. + +--- + +## Method: Knowledge Injection -This project is licensed under the Apache-2.0 License. +The knowledge injection framework injects TSLM-generated insights directly into the GRLM's reasoning trace. We provide end-to-end examples with multiple model combinations: + +- **GRLMs** (General Reasoners): Claude 3.7 Sonnet, Qwen3-32B, DeepSeek-R1-Distill-Qwen-32B +- **TSLMs** (Time-Series Specialists): ChatTS-14B, Qwen2.5-VL-3B (SFT/RL fine-tuned) + +### Closed-Source: Claude (GRLM) + ChatTS (TSLM) + +**Claude** (via AWS Bedrock) serves as the general reasoner; **ChatTS** provides injected observations via an instructional proxy (`` tags). + +```bash +# 1. Start the ChatTS server +./src/chatts_utils/start_chatts_server.sh + +# 2. Run standalone baselines +./scripts/run_chatts_inference.sh --dataset ./dataset/dataset_a_with_mcq2.json +./scripts/run_claude_inference.sh --dataset ./dataset/dataset_a_with_mcq2.json + +# 3. Run knowledge injection (generates observations + injects into Claude) +./scripts/run_injection_workflow.sh --dataset ./dataset/dataset_a_with_mcq2.json + +# 4. Stop the server +./src/chatts_utils/stop_chatts_server.sh +``` + +### Open-Source: Qwen3 (GRLM) + Qwen-VL (TSLM) + +**Qwen3-32B** serves as the GRLM; **Qwen2.5-VL-3B** provides injected thoughts via `continue_final_message` assistant prefill. + +```bash +# 1. Start both servers +./src/qwen_utils/start_qwen_vl_server.sh # Qwen2.5-VL-3B on port 5003 +./src/qwen3_utils/start_qwen3_server.sh # Qwen3-32B on port 5001 + +# 2. Run standalone baseline +./scripts/run_qwen_inference.sh --dataset ./dataset/dataset_a_with_mcq2.json + +# 3. Run knowledge injection +./scripts/run_qwen3_injection_workflow.sh --dataset ./dataset/dataset_a_with_mcq2.json + +# 4. Stop servers +./src/qwen_utils/stop_qwen_vl_server.sh +./src/qwen3_utils/stop_qwen3_server.sh +``` + +### Open-Source: DeepSeek-R1 (GRLM) + Qwen-VL (TSLM) + +**DeepSeek-R1-Distill-Qwen-32B** is an alternative open-source GRLM. It shares the same tokenizer and API as Qwen3, so the same injection script supports both via `--model_name`. + +```bash +# 1. Start servers +./src/qwen_utils/start_qwen_vl_server.sh # Qwen2.5-VL-3B on port 5003 +./src/r1_utils/start_r1_server.sh # DeepSeek-R1 on port 5002 + +# 2. Run knowledge injection +./scripts/run_r1_injection_workflow.sh --dataset ./dataset/dataset_a_with_mcq2.json + +# 3. Stop servers +./src/qwen_utils/stop_qwen_vl_server.sh +./src/r1_utils/stop_r1_server.sh +``` + +--- + +## TSLM Training + +The time-series specialist (TSLM) is initialized from the public `Qwen2.5-VL-3B-Instruct` checkpoint and post-trained in two stages: + +1. **Supervised Fine-Tuning (SFT)** using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) +2. **Reinforcement Learning (GRPO)** using [VERL](https://github.com/volcengine/verl) + +The synthetic training data generated by `dataset/synthetic/` can be used directly with these frameworks. See Appendix B of the paper for hyperparameter details. + +--- + +## Evaluation + +Evaluate results with sampling for statistical robustness (mean ± std over 3 runs): + +```bash +python evaluation/evaluate_with_sampling.py \ + --exp experiment_name \ + --dataset ./dataset/dataset_a_with_mcq2.json \ + --generated ./evaluation/results/experiment_name/generated_answer.json +``` + +The evaluation uses custom metrics based on the [RAGAS](https://github.com/explodinggradients/ragas) framework. Results are saved to `evaluation/exp//`. + +--- + +## Citation + +If you find this work useful, please cite: + +```bibtex +@misc{he2026sentsrbenchthinkinginjectedknowledge, + title={SenTSR-Bench: Thinking with Injected Knowledge for Time-Series Reasoning}, + author={Zelin He and Boran Han and Xiyuan Zhang and Shuai Zhang and Haotian Lin and Qi Zhu and Haoyang Fang and Danielle C. Maddix and Abdul Fatir Ansari and Akash Chandrayan and Abhinav Pradhan and Bernie Wang and Matthew Reimherr}, + year={2026}, + eprint={2602.19455}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2602.19455}, +} +``` + +--- + +## Acknowledgements + +This project is inspired by and builds upon several excellent open-source projects: + +- [**ChatTS**](https://github.com/NetManAIOps/ChatTS) +- [**LLaMA-Factory**](https://github.com/hiyouga/LLaMA-Factory) +- [**VERL**](https://github.com/volcengine/verl) +- [**vLLM**](https://github.com/vllm-project/vllm) + +--- + +## License +This project is licensed under the Apache License 2.0. diff --git a/assets/Intro.png b/assets/Intro.png new file mode 100644 index 0000000..7bdb79d Binary files /dev/null and b/assets/Intro.png differ diff --git a/assets/Paradigm.png b/assets/Paradigm.png new file mode 100644 index 0000000..60e37c2 Binary files /dev/null and b/assets/Paradigm.png differ diff --git a/assets/amazon_logo.png b/assets/amazon_logo.png new file mode 100644 index 0000000..6082c1f Binary files /dev/null and b/assets/amazon_logo.png differ diff --git a/assets/logo_new.png b/assets/logo_new.png new file mode 100644 index 0000000..b25fa4c Binary files /dev/null and b/assets/logo_new.png differ diff --git a/assets/result.png b/assets/result.png new file mode 100644 index 0000000..7bc48c2 Binary files /dev/null and b/assets/result.png differ diff --git a/dataset/README.md b/dataset/README.md new file mode 100644 index 0000000..1a90dca --- /dev/null +++ b/dataset/README.md @@ -0,0 +1,57 @@ +# TSR Knowledge Injection Dataset Processing + +This directory contains the scripts and data necessary for preprocessing datasets for time series reasoning (TSR) knowledge injection. + +## Data Requirements + +To run the preprocessing script successfully, you need the following input data files: + +1. **dataset_a.json**: A JSON file containing multi-question time series reasoning entries. + - Format: Array of JSON objects with `timeseries`, `cols`, `question`, `answer`, `attributes`, and `ability_types` fields + - Each entry may contain multiple questions and answers that will be split into individual entries + +2. **MCQ_2_TS.jsonl**: A JSONL (JSON Lines) file containing the raw TS&Language (MCQ2) dataset. + - Format: Each line is a JSON object containing: + - `uuid`: Unique identifier + - `description`: Text description of the time series + - `question`: Question text + - `options`: Array of possible answers + - `answer_index`: Index of the correct answer + - `series`: Original time series data (array of numbers) + - `new_series`: Updated time series data (array of numbers) + +## Output Files + +The script generates the following outputs in the specified output directory: + +1. **dataset_a_split.json**: Dataset A entries split into individual questions +2. **dataset_a_split_filtered.json**: Filtered version of the split dataset, containing only entries with specific ability types +3. **mcq2_qa_eval_100.json**: 100 sampled entries from the MCQ2 dataset +4. **dataset_a_with_mcq2.json**: Final merged dataset containing both dataset_a and MCQ2 entries + +## Running the Script + +```bash +python preprocess_dataset.py [--dataset_a PATH] [--mcq2_source PATH] [--output_dir PATH] [--mcq2_sample_size SIZE] [--mcq2_seed SEED] +``` + +### Parameters: +- `--dataset_a`: Path to the dataset_a.json file (default: ./dataset_a.json) +- `--mcq2_source`: Path to the MCQ2_TS.jsonl source file (default: ./MCQ_2_TS.jsonl) +- `--output_dir`: Directory to save processed datasets (default: ./processed) +- `--mcq2_sample_size`: Number of entries to sample from MCQ2 (default: 100) +- `--mcq2_seed`: Random seed for reproducibility of MCQ2 sampling (default: 42) + +## Processing Steps + +1. **Split dataset_a**: Divides multi-question entries into individual questions +2. **Filter dataset_a**: Keeps only entries with ability types containing 'causal', 'deductive', or 'inductive' +3. **Sample MCQ2**: Samples entries from the MCQ2_TS.jsonl file with a fixed seed for reproducibility +4. **Merge datasets**: Combines the processed dataset_a with sampled MCQ2 entries into a single dataset + +## Dependencies + +- Python 3.6+ +- Required libraries: + - numpy + - tqdm \ No newline at end of file diff --git a/dataset/preprocess_dataset.py b/dataset/preprocess_dataset.py new file mode 100755 index 0000000..63fb12f --- /dev/null +++ b/dataset/preprocess_dataset.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +""" +Script for preprocessing TSR knowledge injection datasets. + +This script combines the functionality of several separate scripts: +1. Split multi-question entries in dataset_a.json into individual questions +2. Filter entries based on ability types (causal, deductive, inductive) +3. Sample 100 entries from MCQ_2_TS.jsonl (with fixed seed for reproducibility) +4. Merge dataset_a with MCQ2 dataset + +Usage: + python preprocess_dataset.py [--dataset_a PATH] [--mcq2_source PATH] [--output_dir PATH] +""" + +import argparse +import json +import os +import re +import copy +import random +import numpy as np +from pathlib import Path +from tqdm import tqdm + + +def split_dataset(dataset): + """ + Split multi-question entries into individual questions. + + Args: + dataset: The original dataset as a list of dictionaries + + Returns: + List of dictionaries with one question per entry + """ + print("Splitting multi-question entries...") + + # New dataset to store individual questions + split_dataset = [] + + for entry in dataset: + # Common elements to be replicated for each subquestion + timeseries = entry.get('timeseries') + cols = entry.get('cols') + + # Extract the question text + question = entry.get('question', '') + + # Split the question into prefix and subquestions + question_parts = question.split("please analyze the time series features and answer the following questions:") + if len(question_parts) != 2: + # If the split doesn't work as expected, try an alternative approach + question_parts = question.split("please analyze the time series features and answer the following question:") + if len(question_parts) != 2: + print(f"Warning: Could not split question properly: {question[:100]}...") + continue + + question_prefix = question_parts[0] + "please analyze the time series features and answer the following question:" + subquestions_text = question_parts[1] + + # Remove the formatting instructions at the end + if "Now, based on the above questions" in subquestions_text: + subquestions_text = subquestions_text.split("Now, based on the above questions")[0] + + # Extract the actual questions (numbered items) + subquestions = re.findall(r'\n\d+\. (.*?)(?=\n\d+\.|$)', subquestions_text, re.DOTALL) + + # Extract answers + answer = entry.get('answer', '') + answers = re.findall(r'\d+\. (.*?)(?=\n\d+\.|\Z)', answer, re.DOTALL) + + # Extract attributes + attributes = entry.get('attributes', []) + + # Extract ability types + ability_types = entry.get('ability_types', []) + + # Create new entries for each subquestion + for i, subq in enumerate(subquestions): + if i >= len(answers): + print(f"Warning: Missing answer for subquestion {i+1} in entry") + continue + + # Create a new entry with a single question + new_entry = { + 'timeseries': copy.deepcopy(timeseries), + 'cols': copy.deepcopy(cols), + 'question': f"{question_prefix} {subq.strip()}", + 'answer': answers[i].strip(), + } + + # Handle attributes - try to match the attributes to the subquestion + if i < len(attributes): + new_entry['attributes'] = [attributes[i]] + else: + # If we can't match attributes directly, use empty list + new_entry['attributes'] = [] + + # Handle ability types - use the corresponding ability type if available + if i < len(ability_types): + new_entry['ability_types'] = [ability_types[i]] + else: + # If we can't match ability types directly, use the whole list + new_entry['ability_types'] = copy.deepcopy(ability_types) + + split_dataset.append(new_entry) + + return split_dataset + + +def filter_dataset(dataset): + """ + Filter the dataset to keep only entries with ability_types containing 'causal', 'deductive', or 'inductive'. + Also adds a standard closing line to all questions. + + Args: + dataset: The dataset to filter + + Returns: + Filtered dataset and statistics + """ + print("Filtering dataset based on ability types...") + + # Initialize counters + total_entries = len(dataset) + filtered_entries = 0 + causal_count = 0 + deductive_count = 0 + inductive_count = 0 + + # Filter the dataset + filtered_dataset = [] + + for entry in dataset: + ability_types = entry.get('ability_types', []) + + # Check if any ability type contains 'causal', 'deductive', or 'inductive' + has_target_ability = False + has_causal = False + has_deductive = False + has_inductive = False + + for ability in ability_types: + if isinstance(ability, str): + if 'causal' in ability: + has_causal = True + if 'deductive' in ability: + has_deductive = True + if 'inductive' in ability: + has_inductive = True + + has_target_ability = has_causal or has_deductive or has_inductive + + if has_target_ability: + # Add instruction to the end of the question for causal type + if has_causal: + question = entry.get('question', '') + if not question.endswith("In your answer, start by stating your chosen option and then provide your explanation in a separate sentence."): + entry['question'] = question + "\nIn your answer, start by stating your chosen option and then provide your explanation in a separate sentence." + + filtered_dataset.append(entry) + filtered_entries += 1 + + # Update counters + if has_causal: + causal_count += 1 + if has_deductive: + deductive_count += 1 + if has_inductive: + inductive_count += 1 + + stats = { + 'total': total_entries, + 'filtered': filtered_entries, + 'causal': causal_count, + 'deductive': deductive_count, + 'inductive': inductive_count + } + + return filtered_dataset, stats + + +def merge_datasets(dataset_a, mcq2_dataset): + """ + Merge dataset_a with MCQ2 dataset. + + Args: + dataset_a: Processed dataset_a + mcq2_dataset: MCQ2 dataset + + Returns: + Merged dataset + """ + print("Merging dataset_a with MCQ2 dataset...") + + # Simple append - MCQ2 entries after dataset_a entries + merged_dataset = dataset_a + mcq2_dataset + + return merged_dataset + + +def load_jsonl_data(file_path, limit=None): + """Load data from a JSONL file.""" + data = [] + with open(file_path, 'r') as f: + for i, line in enumerate(f): + if limit is not None and i >= limit: + break + data.append(json.loads(line)) + return data + + +def create_evaluation_entry(entry, idx): + """ + Create an evaluation entry for MCQ2 data. + """ + uuid = entry.get("uuid", f"mcq2_{idx}") + + # Extract the formatted question parts + description = entry.get("description", "") + question = entry.get("question", "") + options = entry.get("options", []) + + # Extract just the question part from the original question if needed + clean_question = question.split("Options:")[0] if "Options:" in question else question + clean_question = clean_question.split("Now, based on")[0] if "Now, based on" in clean_question else clean_question + + # Format the options as a string + options_str = ", ".join([f"\"{option}\"" for option in options]) + + # Assemble the formatted question (for evaluation, don't include tags instruction) + formatted_question = f"""You are a time series analysis expert. {description} {clean_question.strip()} Choose from: [{options_str}].""" + + # Get correct answer + correct_answer_index = entry.get("answer_index", 0) + selected_option = options[correct_answer_index] if options and 0 <= correct_answer_index < len(options) else "Unknown" + + # Extract time series data + original_series = entry.get("series", []) + new_series = entry.get("new_series", []) + + # Column names + column_names = ["original series", "updated series"] + + # Create evaluation format entry + result = { + "timeseries": [original_series, new_series], + "cols": column_names, + "question": formatted_question, + "answer": selected_option, + "attributes": [selected_option], + "ability_types": ["MCQ2"], + "id": uuid + } + + return result + + +def sample_mcq2_data(mcq2_path, sample_size=100, seed=42): + """ + Sample entries from the MCQ2_TS dataset with a fixed seed. + + Args: + mcq2_path: Path to the MCQ2_TS.jsonl file + sample_size: Number of entries to sample + seed: Random seed for reproducibility + + Returns: + List of sampled entries in the proper format for evaluation + """ + print(f"Sampling {sample_size} entries from MCQ2 dataset with seed {seed}") + + # Set random seed for reproducibility + random.seed(seed) + np.random.seed(seed) + + # Load original dataset + print(f"Loading MCQ2 data from {mcq2_path}") + data = load_jsonl_data(mcq2_path) + print(f"Loaded {len(data)} entries from the original MCQ2 dataset") + + # Sample entries with the fixed random seed + sample_size = min(sample_size, len(data)) + print(f"Sampling {sample_size} entries with seed {seed}") + sample_indices = random.sample(range(len(data)), sample_size) + + # Create formatted entries from sampled data + print(f"Creating formatted entries...") + sampled_entries = [] + for i, idx in enumerate(tqdm(sample_indices, desc="Processing MCQ2 entries")): + entry = data[idx] + eval_entry = create_evaluation_entry(entry, idx) + sampled_entries.append(eval_entry) + + print(f"Sampled {len(sampled_entries)} entries from {len(data)} total entries") + + # Print some statistics + if sampled_entries: + num_cols = [len(entry.get("cols", [])) for entry in sampled_entries] + avg_cols = sum(num_cols) / len(num_cols) + print(f"\nMCQ2 Sample Statistics:") + print(f"- Average number of columns: {avg_cols:.2f}") + print(f"- Range of columns: {min(num_cols)} to {max(num_cols)}") + print(f"- First sampled index: {sample_indices[0]}") + print(f"- Last sampled index: {sample_indices[-1]}") + + return sampled_entries + + +def process_datasets(dataset_a_path, mcq2_source_path, output_dir, mcq2_sample_size=100, mcq2_seed=42): + """ + Main function to process the datasets. + + Args: + dataset_a_path: Path to dataset_a.json + mcq2_path: Path to MCQ2 dataset + output_dir: Directory to save processed datasets + """ + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Step 1: Load dataset_a + print(f"\n=== Processing dataset_a ===\n") + print(f"Loading dataset_a from {dataset_a_path}") + with open(dataset_a_path, 'r') as f: + dataset_a = json.load(f) + print(f"Loaded {len(dataset_a)} entries from dataset_a") + + # Step 2: Split dataset_a + split_a = split_dataset(dataset_a) + print(f"Split dataset_a into {len(split_a)} entries") + + # Save split dataset + split_path = os.path.join(output_dir, "dataset_a_split.json") + with open(split_path, 'w') as f: + json.dump(split_a, f, indent=4) + print(f"Split dataset saved to {split_path}") + + # Step 3: Filter split dataset + filtered_a, filter_stats = filter_dataset(split_a) + print(f"Filtered dataset statistics:") + print(f" Original: {filter_stats['total']} entries") + print(f" Filtered: {filter_stats['filtered']} entries") + print(f" Causal: {filter_stats['causal']} entries") + print(f" Deductive: {filter_stats['deductive']} entries") + print(f" Inductive: {filter_stats['inductive']} entries") + + # Save filtered dataset + filtered_path = os.path.join(output_dir, "dataset_a_split_filtered.json") + with open(filtered_path, 'w') as f: + json.dump(filtered_a, f, indent=4) + print(f"Filtered dataset saved to {filtered_path}") + + # Step 4: Sample MCQ2 dataset + print(f"\n=== Processing MCQ2 dataset ===\n") + mcq2_dataset = sample_mcq2_data(mcq2_source_path, mcq2_sample_size, mcq2_seed) + + # Save sampled MCQ2 dataset + mcq2_path = os.path.join(output_dir, "mcq2_qa_eval_100.json") + with open(mcq2_path, 'w') as f: + json.dump(mcq2_dataset, f, indent=2) + print(f"Sampled MCQ2 dataset saved to {mcq2_path}") + + # Step 5: Merge datasets + print(f"\n=== Merging datasets ===\n") + merged_dataset = merge_datasets(filtered_a, mcq2_dataset) + print(f"Merged dataset contains {len(merged_dataset)} entries") + + # Save merged dataset + merged_path = os.path.join(output_dir, "dataset_a_with_mcq2.json") + with open(merged_path, 'w') as f: + json.dump(merged_dataset, f, indent=2) + print(f"Merged dataset saved to {merged_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Process TSR knowledge injection datasets") + parser.add_argument("--dataset_a", type=str, default="./dataset_a.json", + help="Path to the dataset_a.json file") + parser.add_argument("--mcq2_source", type=str, default="./MCQ_2_TS.jsonl", + help="Path to the MCQ2_TS.jsonl source file") + parser.add_argument("--output_dir", type=str, default="./processed", + help="Directory to save processed datasets") + parser.add_argument("--mcq2_sample_size", type=int, default=100, + help="Number of entries to sample from MCQ2") + parser.add_argument("--mcq2_seed", type=int, default=42, + help="Random seed for reproducibility of MCQ2 sampling") + + args = parser.parse_args() + + process_datasets( + args.dataset_a, + args.mcq2_source, + args.output_dir, + args.mcq2_sample_size, + args.mcq2_seed + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dataset/synthetic/README.md b/dataset/synthetic/README.md new file mode 100644 index 0000000..1a2cbce --- /dev/null +++ b/dataset/synthetic/README.md @@ -0,0 +1,122 @@ +# Synthetic Time Series Benchmark Generation Pipeline + +This module generates synthetic time series data and benchmarks for training time series reasoning models. The pipeline transforms real-world time series data into diverse synthetic datasets that preserve key anomaly patterns. + +## Pipeline Overview + +The pipeline consists of 4 stages: + +### Stage 1: Iterative Model Generation (`iterative_ts_generation.py`) + +Uses Claude (via AWS Bedrock) to iteratively generate Python code that models anomaly patterns from real training data. + +- Input: `qa_benchmark_base_train.json` (only `what_happened` samples) +- For each sample: sends time series visualization + anomaly description to Claude +- Claude generates a `generate_synthetic_anomaly()` Python function +- 3 iterations per sample: execute code, visually compare, send feedback +- Output: `results/iterative_results/Sample_{id}/function_{1,2,3}.py` + +### Stage 2: Stochastic Refinement (`stochastic_ts_generation.py`) + +Simplifies Stage 1 models into sampling-based generators that produce diverse data. + +- Input: Stage 1 results + original dataset +- Asks Claude to replace hardcoded parameters with probabilistic sampling +- Generates 3 different versions per sample for selection +- Output: `results/stochastic_results/Sample_{id}/stochastic_function{1,2,3}.py` +- **Manual step**: review and select the best version for each sample + +### Stage 3: Dataset Generation (`generate_synthetic_dataset.py`) + +Generates synthetic time series data at scale using the selected stochastic models. + +- Input: Stochastic functions from Stage 2 + original dataset +- Dynamically loads Python functions via `importlib` +- Generates N samples per source with different random seeds +- Output: `results/synthetic_training_data/data_ts.json` + +### Stage 4: Benchmark Generation (3 scripts, orchestrated by shell script) + +Converts synthetic time series into a structured benchmark dataset: + +- **4a** `generate_qa_benchmark.py`: Uses LLM to diversify the original anomaly descriptions into varied observations, root causes, and corrective actions, then generates QA pairs +- **4b** `generate_mcq_benchmark.py`: Converts to multiple-choice format with distractors from other source samples +- **4c** `filter_mcq_benchmark.py`: Filters to keep only MCQ_obs and MCQ_cause questions + +Output: `results/synthetic_training_data/rme_synthetic_easy.json` + +## Quick Start + +### Prerequisites + +- Python 3.10+ +- AWS credentials configured for Bedrock access +- Required packages: `boto3`, `numpy`, `matplotlib`, `tenacity`, `scipy` + +### 1. Prepare Input Data + +Place your training data in `sample_data/`: +- `qa_benchmark_base_train.json`: Training samples with time series and anomaly descriptions + +See `sample_data/README.md` for format details. + +### 2. Run the Pipeline + +```bash +# Stage 1: Generate initial models (requires Claude API access) +./scripts/run_iterative_generation.sh + +# Stage 2: Refine into stochastic generators +./scripts/run_stochastic_refinement.sh + +# Manual step: review results/stochastic_results/ and keep best models + +# Stages 3-4: Generate full synthetic benchmark +./scripts/run_synthetic_benchmark.sh 100 # 100 samples per source +``` + +### 3. Outputs + +All outputs are saved under `results/synthetic_training_data/`: + +| File | Description | +|------|-------------| +| `data_ts.json` | Synthetic time series data | +| `dataset_summary.json` | Per-source sample statistics | +| `diversified_answers.json` | LLM-generated answer variations (cached) | +| `qa_synthetic_base.json` | QA benchmark (3 question types) | +| `rme_synthetic_easy.json` | Filtered MCQ benchmark | + +## File Structure + +``` +dataset/synthetic/ +├── README.md # This file +├── iterative_ts_generation.py # Stage 1: Iterative model generation +├── stochastic_ts_generation.py # Stage 2: Stochastic refinement +├── generate_synthetic_dataset.py # Stage 3: Dataset generation at scale +├── generate_qa_benchmark.py # Stage 4a: QA benchmark generation (LLM-based) +├── generate_mcq_benchmark.py # Stage 4b: MCQ benchmark generation +├── filter_mcq_benchmark.py # Stage 4c: Filter by ability types +├── sample_data/ # Input data +│ └── README.md # Input format documentation +└── results/ # Generated outputs (created by pipeline) + ├── iterative_results/ # Stage 1 outputs + ├── stochastic_results/ # Stage 2 outputs + └── synthetic_training_data/ # Stages 3-4 outputs +``` + +## Data Formats + +### Time Series Structure +- 3 channels: Acceleration, Velocity, Temperature +- Standardized using Median Absolute Deviation (MAD) +- Variable length (preserved from original data) + +### Question Types +- `what_happened` / `MCQ_obs`: Identify the anomaly pattern +- `how_happened` / `MCQ_cause`: Identify the root cause +- `suggested_fix` / `MCQ_fix`: Recommend corrective action + +### MCQ Format +Each MCQ sample includes 4 options (1 correct + 3 distractors from different source samples), with shuffled order and a `correct_index` field. diff --git a/dataset/synthetic/filter_mcq_benchmark.py b/dataset/synthetic/filter_mcq_benchmark.py new file mode 100644 index 0000000..421f554 --- /dev/null +++ b/dataset/synthetic/filter_mcq_benchmark.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Stage 4c: Filter MCQ Benchmark to Keep Only MCQ_obs and MCQ_cause Questions + +This script filters the generated MCQ benchmark to keep only questions with +ability_types of MCQ_obs and MCQ_cause, excluding MCQ_fix questions. +""" + +import os +import json +import argparse +from typing import List, Dict, Any + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +def filter_mcq_benchmark( + input_file: str, + output_file: str, + keep_ability_types: List[str] = None +): + """Filter MCQ benchmark to keep only questions with specified ability types.""" + if keep_ability_types is None: + keep_ability_types = ["MCQ_obs", "MCQ_cause"] + + print(f"Filtering {input_file} to keep only {keep_ability_types}...") + + with open(input_file, 'r') as f: + benchmark = json.load(f) + + total_before = len(benchmark) + filtered_benchmark = [] + counts_by_type = {} + + for item in benchmark: + ability_types = item.get('ability_types', []) + for ability_type in ability_types: + counts_by_type[ability_type] = counts_by_type.get(ability_type, 0) + 1 + if any(ability_type in keep_ability_types for ability_type in ability_types): + filtered_benchmark.append(item) + + total_after = len(filtered_benchmark) + + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w') as f: + json.dump(filtered_benchmark, f, indent=2) + + print(f"Original benchmark: {total_before} questions") + print(f"Filtered benchmark: {total_after} questions") + print(f"Removed: {total_before - total_after} questions") + print("Counts by ability type:") + for ability_type, count in counts_by_type.items(): + status = "KEPT" if ability_type in keep_ability_types else "REMOVED" + print(f" {ability_type}: {count} questions ({status})") + print(f"Saved to {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Stage 4c: Filter MCQ Benchmark by Ability Types") + parser.add_argument("--input_file", type=str, + default="./results/synthetic_training_data/rme_synthetic_easy_unfiltered.json", + help="Path to the input MCQ benchmark file") + parser.add_argument("--output_file", type=str, + default="./results/synthetic_training_data/rme_synthetic_easy.json", + help="Path to the output filtered benchmark file") + parser.add_argument("--keep_ability_types", type=str, nargs="+", + default=["MCQ_obs", "MCQ_cause"], + help="List of ability types to keep") + + args = parser.parse_args() + + filter_mcq_benchmark( + resolve_path(args.input_file), + resolve_path(args.output_file), + args.keep_ability_types + ) diff --git a/dataset/synthetic/generate_mcq_benchmark.py b/dataset/synthetic/generate_mcq_benchmark.py new file mode 100644 index 0000000..d468a58 --- /dev/null +++ b/dataset/synthetic/generate_mcq_benchmark.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +""" +Stage 4b: Generate Multiple Choice Questions (MCQs) for Synthetic Time Series Data + +This script: +1. Takes a QA benchmark dataset (Stage 4a output) as input +2. Converts each question to a multiple-choice format +3. Groups data by original source sample and question type to create distractors +4. Outputs an MCQ benchmark dataset in JSON format +""" + +import os +import json +import random +import argparse +import sys +from typing import Dict, List, Any +from collections import defaultdict + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +# Constants +NUM_CHOICES = 4 +MCQ_INSTRUCTION = ("\nIn your answer, start by stating your chosen option " + "and then provide your explanation in a separate sentence.") + + +def load_json(file_path: str) -> List[Dict]: + """Load JSON data from file.""" + with open(file_path, 'r') as f: + return json.load(f) + + +def group_data_by_source_and_type( + data: List[Dict] +) -> Dict[str, Dict[str, List[Dict]]]: + """Group data by original source sample ID and question type.""" + grouped_data = defaultdict(lambda: defaultdict(list)) + for sample in data: + orig_id = sample.get("original_id", "") + question_type = sample.get("question_type", "") + if orig_id and question_type: + grouped_data[orig_id][question_type].append(sample) + return grouped_data + + +def get_distractor_pool( + grouped_data: Dict[str, Dict[str, List[Dict]]], + target_orig_id: str, question_type: str +) -> Dict[str, str]: + """Create a pool of distractor options from other source samples.""" + options_pool = {} + for orig_id, type_samples in grouped_data.items(): + if orig_id == target_orig_id: + continue + samples = type_samples.get(question_type, []) + if samples: + sample = random.choice(samples) + answer = sample.get("answer", "") + if answer: + options_pool[orig_id] = answer + return options_pool + + +def create_mcq_for_sample( + sample: Dict, grouped_data: Dict[str, Dict[str, List[Dict]]] +) -> Dict: + """Create a multiple-choice question version of a sample.""" + orig_id = sample.get("original_id", "") + question_type = sample.get("question_type", "") + correct_answer = sample.get("answer", "") + + mcq_sample = sample.copy() + + options_pool = get_distractor_pool(grouped_data, orig_id, question_type) + confusion_options = random.sample( + list(options_pool.values()), + min(NUM_CHOICES - 1, len(options_pool)) + ) + + while len(confusion_options) < NUM_CHOICES - 1: + if options_pool: + confusion_options.append(random.choice(list(options_pool.values()))) + else: + confusion_options.append("No clear pattern observed") + + all_options = [correct_answer] + confusion_options + random.shuffle(all_options) + correct_index = all_options.index(correct_answer) + + options_str = json.dumps(all_options) + mcq_sample["question"] = ( + f"{sample['question']} Choose from: {options_str}{MCQ_INSTRUCTION}" + ) + mcq_sample["options"] = all_options + mcq_sample["correct_index"] = correct_index + + if "time_series" in mcq_sample: + mcq_sample["timeseries"] = mcq_sample.pop("time_series") + + return mcq_sample + + +def main(qa_benchmark_path: str, output_path: str): + """Generate MCQ benchmark from QA benchmark data.""" + print(f"Loading data from {qa_benchmark_path}...") + qa_data = load_json(qa_benchmark_path) + + grouped_data = group_data_by_source_and_type(qa_data) + + print("Generating MCQ questions...") + mcq_data = [] + for sample in qa_data: + mcq_sample = create_mcq_for_sample(sample, grouped_data) + mcq_data.append(mcq_sample) + + print(f"Generated {len(mcq_data)} MCQ questions") + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w') as f: + json.dump(mcq_data, f, indent=2) + print(f"Saved to {output_path}") + + +if __name__ == "__main__": + random.seed(42) + + parser = argparse.ArgumentParser(description="Stage 4b: Generate MCQ Benchmark") + parser.add_argument("--qa_benchmark_path", type=str, + default="./results/synthetic_training_data/qa_synthetic_base.json", + help="Path to QA benchmark file from Stage 4a") + parser.add_argument("--output_path", type=str, + default="./results/synthetic_training_data/rme_synthetic_easy_unfiltered.json", + help="Path to output file") + + args = parser.parse_args() + + main( + resolve_path(args.qa_benchmark_path), + resolve_path(args.output_path) + ) diff --git a/dataset/synthetic/generate_qa_benchmark.py b/dataset/synthetic/generate_qa_benchmark.py new file mode 100644 index 0000000..c8858c6 --- /dev/null +++ b/dataset/synthetic/generate_qa_benchmark.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +""" +Stage 4a: Generate QA Benchmark from Synthetic Time Series Data + +This script: +1. Loads synthetic time series data from data_ts.json (Stage 3 output) +2. Loads the original training data to retrieve anomaly descriptions +3. Calls Claude (via AWS Bedrock) to diversify each anomaly description into + multiple observation, root cause, and corrective action variations +4. Generates QA benchmark entries by sampling from the diversified answers +5. No external metadata files are required beyond the original training data +""" + +import os +import re +import json +import random +import argparse +import numpy as np +from typing import Dict, List, Any, Optional +import boto3 +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception +from botocore.exceptions import ClientError, ReadTimeoutError, ConnectTimeoutError + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# === CONFIGURATION === +MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +MAX_TOKENS = 2048 +# ====================== + +TS_COLUMNS = ["Acceleration", "Velocity", "Temperature"] + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +def load_json_data(file_path: str) -> Any: + """Load JSON data from file.""" + with open(file_path, 'r') as f: + return json.load(f) + + +def should_retry(exc): + if isinstance(exc, ClientError) and exc.response.get("Error", {}).get("Code") == "ThrottlingException": + return True + if isinstance(exc, (ReadTimeoutError, ConnectTimeoutError)): + return True + return False + + +@retry( + retry=retry_if_exception(should_retry), + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=2, max=10) +) +def invoke_claude(client, model_id, messages, system_prompt): + payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "temperature": 0.7, + "system": system_prompt, + "messages": messages + } + resp = client.invoke_model(body=json.dumps(payload), modelId=model_id) + return json.loads(resp['body'].read()) + + +def diversify_answers( + client, model_id: str, original_answer: str, num_variations: int = 5 +) -> Dict[str, List[str]]: + """Use Claude to generate diverse answer variations for all 3 question types.""" + system_prompt = ( + "You are an industrial machinery expert specializing in vibration analysis " + "and predictive maintenance. Respond ONLY with valid JSON, no markdown." + ) + prompt = f"""Given this anomaly observation from industrial vibration and temperature sensors: +"{original_answer}" + +Generate {num_variations} diverse textual variations for each of the following categories. + +1. OBSERVATIONS: Rephrase the anomaly pattern description in {num_variations} different ways. + Each should describe the same core pattern but with different wording and emphasis. + Keep descriptions concise (1-2 sentences). + +2. ROOT CAUSES: Provide {num_variations} plausible root causes that could lead to this + anomaly pattern in industrial rotating machinery with vibration and temperature sensors. + Each should be a concise statement (1-2 sentences). + +3. CORRECTIVE ACTIONS: Provide {num_variations} appropriate corrective actions to address + the anomaly. Each should be a concise recommendation (1-2 sentences). + +Return ONLY valid JSON with this exact structure: +{{ + "observations": ["...", "...", ...], + "root_causes": ["...", "...", ...], + "corrective_actions": ["...", "...", ...] +}}""" + + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + response = invoke_claude(client, model_id, messages, system_prompt) + response_text = response['content'][0]['text'] + + try: + json_match = re.search(r'\{[\s\S]*\}', response_text) + if json_match: + result = json.loads(json_match.group()) + # Validate structure + if all(k in result for k in ("observations", "root_causes", "corrective_actions")): + return result + except (json.JSONDecodeError, AttributeError): + pass + + # Fallback if parsing fails + print(f"Warning: Failed to parse LLM response, using fallback for: {original_answer[:60]}...") + return { + "observations": [original_answer], + "root_causes": [f"Degraded component condition leading to: {original_answer}"], + "corrective_actions": [f"Inspect and service the affected component"] + } + + +def median_absolute_deviation(data): + """Calculate the Median Absolute Deviation (MAD) of a dataset.""" + med = np.median(data) + return np.median(np.abs(np.array(data) - med)) + + +def standardize_timeseries(timeseries: List[List[float]]) -> List[List[float]]: + """ + Standardize time series data using median and MAD. + Formula: (x - median) / (1.4826 * MAD) + """ + standardized_ts = [] + for ts in timeseries: + ts_array = np.array(ts) + med = np.median(ts_array) + mad = median_absolute_deviation(ts_array) + if mad == 0: + mad = 1.0 + std_ts = (ts_array - med) / (1.4826 * mad) + standardized_ts.append(std_ts.tolist()) + return standardized_ts + + +def get_question_prompt(question_type: str) -> str: + """Get the appropriate prompt for the question type.""" + prompts = { + "what_happened": "What is the key anomalous pattern observed in these time series?", + "how_happened": "What is the most likely cause of the anomalous pattern in these time series?", + "suggested_fix": "What is the best corrective action for the event implied by the anomalous pattern in these time series?" + } + return prompts.get(question_type, "") + + +def generate_question_template(timeseries: List[List[float]], question_type: str) -> str: + """Generate the complete question template with dynamic length information.""" + lengths = [len(ts) for ts in timeseries] + template = ( + "You are a time series analysis expert. In a sensor monitoring system, the vibration " + "(measured in velocity and acceleration) and temperature of machines are collected for monitoring. " + "The time series data has been standardized using median and MAD (Median Absolute Deviation).\n" + ) + for col, length in zip(TS_COLUMNS, lengths): + template += f"\"{col}\" is a standardized time series with length of {length}: \n" + template += ( + "Please analyze the time series features and answer the following question: " + f"{get_question_prompt(question_type)}" + ) + return template + + +# Mapping from question type to diversified answer key +QUESTION_TYPE_TO_ANSWER_KEY = { + "what_happened": "observations", + "how_happened": "root_causes", + "suggested_fix": "corrective_actions" +} + +QUESTION_TYPE_TO_ABILITY = { + "what_happened": "MCQ_obs", + "how_happened": "MCQ_cause", + "suggested_fix": "MCQ_fix" +} + + +def main(data_ts_path: str, dataset_path: str, output_path: str, + region: str, num_variations: int): + """Generate QA benchmark using LLM-diversified answers.""" + print(f"Loading synthetic data from {data_ts_path}...") + data_ts = load_json_data(data_ts_path) + + print(f"Loading original training data from {dataset_path}...") + training_data = load_json_data(dataset_path) + + # Build lookup: original sample ID -> original answer (anomaly description) + original_answers = {} + for sample in training_data: + if sample.get('question_type') == 'what_happened': + original_answers[sample['id']] = sample['answer'] + + # Find unique original IDs in synthetic data + unique_originals = sorted(set( + entry['original_id'] for entry in data_ts if entry.get('original_id') + )) + print(f"Found {len(unique_originals)} unique source samples to diversify") + + # Call Claude to diversify answers for each unique source + client = boto3.client('bedrock-runtime', region_name=region) + diversified = {} + + for i, orig_id in enumerate(unique_originals): + original_answer = original_answers.get(orig_id) + if not original_answer: + print(f"Warning: No original answer found for source {orig_id}, skipping") + continue + + print(f"Diversifying answers for source {orig_id} ({i+1}/{len(unique_originals)})...") + result = diversify_answers(client, MODEL_ID, original_answer, num_variations) + diversified[orig_id] = result + + # Generate QA benchmark entries + benchmark_data = [] + for entry in data_ts: + ts_id = entry['id'] + raw_ts = entry.get('timeseries', []) + orig_id = entry.get('original_id') + + if not raw_ts or not orig_id or orig_id not in diversified: + continue + + std_ts = standardize_timeseries(raw_ts) + answers = diversified[orig_id] + + for q_type in ["what_happened", "how_happened", "suggested_fix"]: + answer_key = QUESTION_TYPE_TO_ANSWER_KEY[q_type] + answer_pool = answers.get(answer_key, []) + answer = random.choice(answer_pool) if answer_pool else "" + + sample = { + "id": ts_id, + "timeseries": std_ts, + "cols": TS_COLUMNS, + "question": generate_question_template(std_ts, q_type), + "question_type": q_type, + "answer": answer, + "attributes": [answer] if answer else [], + "ability_types": [QUESTION_TYPE_TO_ABILITY[q_type]], + "original_id": orig_id + } + benchmark_data.append(sample) + + print(f"Generated {len(benchmark_data)} benchmark samples") + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w') as f: + json.dump(benchmark_data, f, indent=2) + print(f"Saved to {output_path}") + + # Save diversified answers cache for reproducibility + cache_path = os.path.join(os.path.dirname(output_path), "diversified_answers.json") + with open(cache_path, 'w') as f: + json.dump(diversified, f, indent=2) + print(f"Saved diversified answers cache to {cache_path}") + + +if __name__ == "__main__": + random.seed(42) + + parser = argparse.ArgumentParser(description="Stage 4a: Generate QA Benchmark") + parser.add_argument("--data_ts_path", type=str, + default="./results/synthetic_training_data/data_ts.json", + help="Path to data_ts.json file from Stage 3") + parser.add_argument("--dataset_path", type=str, + default="./sample_data/qa_benchmark_base_train.json", + help="Path to original training dataset") + parser.add_argument("--output_path", type=str, + default="./results/synthetic_training_data/qa_synthetic_base.json", + help="Path to output file") + parser.add_argument("--region", type=str, default="us-west-2", + help="AWS region for Bedrock") + parser.add_argument("--num_variations", type=int, default=10, + help="Number of answer variations to generate per source") + + args = parser.parse_args() + + main( + resolve_path(args.data_ts_path), + resolve_path(args.dataset_path), + resolve_path(args.output_path), + args.region, + args.num_variations + ) diff --git a/dataset/synthetic/generate_synthetic_dataset.py b/dataset/synthetic/generate_synthetic_dataset.py new file mode 100644 index 0000000..8cbfabe --- /dev/null +++ b/dataset/synthetic/generate_synthetic_dataset.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Stage 3: Generate Synthetic Dataset at Scale + +This script: +1. Loads all stochastic generation functions from Stage 2 results +2. Extracts the sample ID and original timeseries for each function +3. Generates multiple synthetic time series examples per sample +4. Saves the result as a structured JSON file for downstream processing +""" + +import os +import sys +import json +import glob +import argparse +import importlib.util +import numpy as np +from typing import Dict, List, Tuple, Any, Optional, Callable +import traceback + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +def extract_id_and_timeseries( + sample_dir: str, dataset_path: str +) -> Tuple[str, Optional[List[List[float]]]]: + """Extract the sample ID and original timeseries from the directory name and dataset.""" + sample_id = os.path.basename(sample_dir).replace('Sample_', '') + try: + with open(dataset_path, 'r') as f: + dataset = json.load(f) + for sample in dataset: + if sample['id'] == sample_id: + timeseries = sample.get('timeseries') + return sample_id, timeseries + except Exception as e: + print(f"Error loading dataset: {e}") + return sample_id, None + + +def load_generation_function(function_path: str) -> Optional[Callable]: + """Dynamically load a synthetic data generation function from a Python file.""" + try: + module_name = f"synthetic_function_{os.path.basename(function_path).replace('.', '_')}" + spec = importlib.util.spec_from_file_location(module_name, function_path) + if spec is None or spec.loader is None: + print(f"Error: Failed to create module spec from {function_path}") + return None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + if hasattr(module, 'generate_synthetic_anomaly'): + return module.generate_synthetic_anomaly + + for attr_name in dir(module): + if attr_name.startswith('generate_'): + attr = getattr(module, attr_name) + if callable(attr): + print(f"Found alternative generation function: {attr_name}") + return attr + + print(f"Error: No suitable generation function found in {function_path}") + return None + except Exception as e: + print(f"Error loading function from {function_path}: {e}") + traceback.print_exc() + return None + + +def generate_synthetic_samples( + function: Callable, + n_samples: int = 100, + count: int = 100, + base_seed: int = 42 +) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: + """Generate multiple synthetic time series samples using the provided function.""" + samples = [] + for i in range(count): + try: + seed = base_seed + i + try: + result = function(n_samples=n_samples, seed=seed) + except TypeError: + try: + result = function(n_samples) + except TypeError: + try: + result = function(seed=seed) + except TypeError: + result = function() + + if result and isinstance(result, tuple) and len(result) == 3: + samples.append(result) + if (i + 1) % 10 == 0: + print(f"Generated {i+1}/{count} samples") + else: + print(f"Warning: Function returned invalid data for sample {i+1}") + except Exception as e: + print(f"Error generating sample {i+1}: {e}") + return samples + + +def process_sample_dir( + sample_dir: str, dataset_path: str, samples_per_source: int = 100 +) -> Optional[Dict[str, Any]]: + """Process a sample directory to extract ID and generate synthetic data.""" + sample_id, timeseries = extract_id_and_timeseries(sample_dir, dataset_path) + if not timeseries: + print(f"Warning: No timeseries data found for sample {sample_id}") + return None + + n_samples = len(timeseries[0]) if timeseries and len(timeseries) > 0 else 100 + print(f"Using {n_samples} time points from original timeseries for sample {sample_id}") + + function_paths = glob.glob(os.path.join(sample_dir, "stochastic_function*.py")) + if not function_paths: + print(f"Warning: No function files found for sample {sample_id}") + return None + + if len(function_paths) > 1: + print(f"Warning: Multiple function files found for sample {sample_id}: " + f"{[os.path.basename(p) for p in function_paths]}") + print(f"Using the first one: {os.path.basename(function_paths[0])}") + + function_path = function_paths[0] + function = load_generation_function(function_path) + if not function: + print(f"Warning: Could not load generation function for sample {sample_id}") + return None + + print(f"Generating {samples_per_source} samples for source {sample_id}...") + samples = generate_synthetic_samples(function, n_samples, samples_per_source) + if not samples: + print(f"Warning: No samples generated for source {sample_id}") + return None + + print(f"Successfully generated {len(samples)} samples for source {sample_id}") + return { + 'id': sample_id, + 'samples': samples + } + + +def create_data_ts_entries(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Create a list of synthetic time series entries from generation results.""" + data_ts = [] + sample_counter = 0 + + for result in results: + samples = result['samples'] + + for i, sample in enumerate(samples): + acceleration, velocity, temperature = sample + synthetic_id = f"synthetic_{sample_counter:06d}" + sample_counter += 1 + + entry = { + 'id': synthetic_id, + 'timeseries': [ + acceleration.tolist(), + velocity.tolist(), + temperature.tolist() + ], + 'cols': ['Acceleration', 'Velocity', 'Temperature'], + 'original_id': result['id'], + 'synthetic': True + } + data_ts.append(entry) + + return data_ts + + +def generate_synthetic_dataset( + stochastic_results_dir: str, + dataset_path: str, + output_dir: str, + samples_per_source: int = 100 +): + """Generate a synthetic dataset by processing all sample directories.""" + os.makedirs(output_dir, exist_ok=True) + + sample_dirs = glob.glob(os.path.join(stochastic_results_dir, "Sample_*")) + print(f"Found {len(sample_dirs)} sample directories") + + results = [] + for sample_dir in sample_dirs: + result = process_sample_dir(sample_dir, dataset_path, samples_per_source) + if result: + results.append(result) + + data_ts = create_data_ts_entries(results) + + data_ts_path = os.path.join(output_dir, "data_ts.json") + with open(data_ts_path, 'w') as f: + json.dump(data_ts, f, indent=2) + print(f"Saved {len(data_ts)} synthetic samples to {data_ts_path}") + + # Summary grouped by original source sample + source_counts = {} + for entry in data_ts: + orig_id = entry['original_id'] + source_counts[orig_id] = source_counts.get(orig_id, 0) + 1 + + summary = { + 'total_samples': len(data_ts), + 'sources': source_counts + } + summary_path = os.path.join(output_dir, "dataset_summary.json") + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + + print("Dataset summary:") + for orig_id, count in source_counts.items(): + print(f" Source {orig_id}: {count} samples") + print(f"Total: {len(data_ts)} samples") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Stage 3: Generate Synthetic Dataset at Scale") + parser.add_argument("--stochastic_results_dir", type=str, + default="./results/stochastic_results", + help="Directory containing stochastic results from Stage 2") + parser.add_argument("--dataset_path", type=str, + default="./sample_data/qa_benchmark_base_train.json", + help="Path to the original dataset JSON file") + parser.add_argument("--output_dir", type=str, + default="./results/synthetic_training_data", + help="Directory to save outputs") + parser.add_argument("--samples_per_source", type=int, default=100, + help="Number of synthetic samples to generate per source") + + args = parser.parse_args() + + generate_synthetic_dataset( + resolve_path(args.stochastic_results_dir), + resolve_path(args.dataset_path), + resolve_path(args.output_dir), + args.samples_per_source + ) diff --git a/dataset/synthetic/iterative_ts_generation.py b/dataset/synthetic/iterative_ts_generation.py new file mode 100644 index 0000000..3a20b3e --- /dev/null +++ b/dataset/synthetic/iterative_ts_generation.py @@ -0,0 +1,512 @@ +#!/usr/bin/env python3 +""" +Stage 1: Iterative Time Series Generation With Characteristics + +This script is Stage 1 of the synthetic data generation pipeline. It implements +an iterative procedure to generate synthetic time series data using Claude: +1. Load real data samples from qa_benchmark_base_train.json by IDs +2. Extract time series values and "what_happened" characteristics +3. Generate Python code with Claude for synthetic data generation +4. Execute the code and compare with original data +5. Provide feedback to Claude for improvement +6. Repeat for a specified number of iterations +7. Support parallel processing for multiple IDs +""" + +import os +import re +import json +import base64 +import argparse +import time +import numpy as np +import matplotlib.pyplot as plt +from typing import Dict, List, Tuple, Optional +import boto3 +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception +from botocore.exceptions import ClientError, ReadTimeoutError, ConnectTimeoutError +import concurrent.futures +from multiprocessing import cpu_count + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +# === CONFIGURATION === +MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +MAX_TOKENS = 10240 +THINKING_BUDGET = 4096 +# ====================== + + +def load_dataset(dataset_path: str) -> List[Dict]: + """Load the benchmark dataset.""" + with open(dataset_path, 'r') as f: + return json.load(f) + + +def filter_samples_by_ids(dataset: List[Dict], ids: List[str]) -> List[Dict]: + """Filter dataset samples by specified IDs.""" + if not ids: + return dataset + return [sample for sample in dataset if sample['id'] in ids] + + +def filter_what_happened_samples(dataset: List[Dict]) -> List[Dict]: + """Filter dataset to keep only 'what_happened' question types.""" + return [sample for sample in dataset if sample.get('question_type') == 'what_happened'] + + +def should_retry(exc): + """Exception checker for retries.""" + if isinstance(exc, ClientError) and exc.response.get("Error", {}).get("Code") == "ThrottlingException": + return True + if isinstance(exc, (ReadTimeoutError, ConnectTimeoutError)) or "ReadTimeoutError" in str(exc) or "ConnectTimeoutError" in str(exc): + print(f"Encountered timeout error: {str(exc)}. Retrying...") + return True + return False + + +@retry( + retry=retry_if_exception(should_retry), + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=2, max=10) +) +def invoke_claude(client, model_id, messages, system_prompt, enable_thinking=False): + """Invoke Claude with the specified messages and system prompt.""" + payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "temperature": 1.0 if enable_thinking else 0.2, + "system": system_prompt, + "messages": messages + } + if enable_thinking: + payload["thinking"] = {"type": "enabled", "budget_tokens": THINKING_BUDGET} + resp = client.invoke_model(body=json.dumps(payload), modelId=model_id) + return json.loads(resp['body'].read()) + + +def parse_response(resp_body): + """Parse the response from Claude, extracting both thinking and text content.""" + thought_chunks, text_chunks = [], [] + for chunk in resp_body.get("content", []): + if chunk.get("type") == "thinking": + thought_chunks.append(chunk.get("thinking", "").strip()) + elif chunk.get("type") == "text": + text_chunks.append(chunk.get("text", "").strip()) + thought = "\n".join(thought_chunks) + response_text = "".join(text_chunks) + return thought, response_text + + +def extract_python_code(response_text: str) -> Tuple[str, List[str]]: + """Extract Python code blocks from Claude's response.""" + code_pattern = r"```(?:python)?\s*([\s\S]*?)```" + code_blocks = re.findall(code_pattern, response_text) + analysis_text = re.sub(code_pattern, "", response_text).strip() + return analysis_text, code_blocks + + +def generate_original_image(original_data: np.ndarray, feature_names: List[str], + output_dir: str, sample_id: str) -> str: + """Generate an image showing only the original data.""" + fig, axes = plt.subplots(3, 1, figsize=(10, 8)) + sample_dir = os.path.join(output_dir, f"Sample_{sample_id}") + os.makedirs(sample_dir, exist_ok=True) + for i in range(3): + ax = axes[i] + ax.plot(np.arange(original_data.shape[1]), original_data[i], 'b-', linewidth=2) + ax.set_title(feature_names[i], fontsize=12) + ax.set_ylabel('Value', fontsize=10) + ax.grid(True, alpha=0.3) + axes[-1].set_xlabel('Time Step', fontsize=11) + plt.suptitle(f'Original Time Series (Sample ID: {sample_id})', fontsize=14, fontweight='bold') + plt.tight_layout() + img_path = os.path.join(sample_dir, "original.png") + plt.savefig(img_path) + plt.close(fig) + return img_path + + +def generate_comparison_image(original_data: np.ndarray, synthetic_data: np.ndarray, + feature_names: List[str], output_dir: str, sample_id: str, + iteration: int) -> str: + """Generate a comparison image showing original and synthetic data.""" + fig, axes = plt.subplots(3, 2, figsize=(12, 10), sharey='row', sharex=True) + sample_dir = os.path.join(output_dir, f"Sample_{sample_id}") + os.makedirs(sample_dir, exist_ok=True) + axes[0, 0].set_title('Original Data', fontsize=12, fontweight='bold') + axes[0, 1].set_title('Synthetic Data', fontsize=12, fontweight='bold') + for i in range(3): + axes[i, 0].plot(np.arange(original_data.shape[1]), original_data[i], 'b-', linewidth=1.5) + axes[i, 0].set_ylabel(feature_names[i], fontsize=10) + axes[i, 0].grid(True, alpha=0.3) + synth = synthetic_data[i] + if len(synth) > original_data.shape[1]: + synth = synth[:original_data.shape[1]] + elif len(synth) < original_data.shape[1]: + synth = np.pad(synth, (0, original_data.shape[1] - len(synth))) + axes[i, 1].plot(np.arange(len(synth)), synth, 'r-', linewidth=1.5) + axes[i, 1].grid(True, alpha=0.3) + axes[2, 0].set_xlabel('Time Step') + axes[2, 1].set_xlabel('Time Step') + plt.suptitle(f'Original vs Synthetic (Sample ID: {sample_id}, Iteration {iteration})', + fontsize=14, fontweight='bold', y=0.98) + plt.tight_layout() + plt.subplots_adjust(top=0.92) + img_path = os.path.join(sample_dir, f"iteration_{iteration}.png") + plt.savefig(img_path) + plt.close(fig) + return img_path + + +def execute_function_code(function_code: str, n_samples: int = 100, seed: int = 42): + """Execute the function code directly and return the generated data.""" + try: + globals_dict = { + 'np': np, + 'plt': plt, + '__builtins__': __builtins__ + } + try: + import scipy + import scipy.signal as signal + import scipy.stats + globals_dict['scipy'] = scipy + globals_dict['signal'] = signal + globals_dict['stats'] = scipy.stats + except ImportError: + print("Warning: SciPy not available") + import_lines = [] + code_lines = [] + for line in function_code.splitlines(): + if line.strip().startswith(('import ', 'from ')): + import_lines.append(line) + else: + code_lines.append(line) + if import_lines: + import_code = '\n'.join(import_lines) + try: + exec(import_code, globals_dict) + except Exception as e: + print(f"Warning: Error executing imports: {e}") + function_only_code = '\n'.join(code_lines) + exec(function_only_code, globals_dict) + function_names = [ + 'generate_synthetic_anomaly', + 'generate_synthetic_data', + 'generate_time_series_data', + 'generate_ts_data' + ] + for function_name in function_names: + if function_name in globals_dict: + result = globals_dict[function_name](n_samples=n_samples, seed=seed) + return result + print("Error: No suitable generation function found") + return None + except Exception as e: + print(f"Error executing function: {e}") + return None + + +def create_initial_prompt(sample_characteristics: str) -> str: + """Create the initial prompt for Claude with the sample characteristics.""" + prompt = f""" +# Industrial Time Series Generation + +I'm analyzing these time series data that show three metrics from an industrial sensor monitoring system: +1. Acceleration vibration +2. Velocity vibration +3. Temperature + +The KEY ANOMALY PATTERN observed in these time series is: +"{sample_characteristics}" + +Based on this characteristic and the (normalized) time series visualization shown, please create a generative model that: +1. Reproduces the baseline pattern of the multivariate time series +2. Accurately replicates the specific anomaly pattern described above +3. Maintains the synchronous changes in acceleration and velocity (if any) +4. Properly captures ambient temperature changes over the day (24 time points) (if any) +5. Takes into consideration the consistent decreases in velocity and acceleration when the system stops working, and increases when it starts working again +6. Incorporates multiple layers of deterministic and/or random processes drawn from reasonable distributions to model various patterns (fluctuations, stops/starts, rises/decreases, sporadic spikes, etc.) + +The synthetic data should closely match the statistical properties and patterns of the original time series. +Please use only NumPy and basic SciPy functions in your implementation. + +Please implement a Python function with this signature: + +```python +import numpy as np +import scipy.signal as signal + +def generate_synthetic_anomaly(n_samples=100, seed=None): + if seed is not None: + np.random.seed(seed) + + # Your implementation here + + # Return all three time series + return acceleration, velocity, temperature +``` + +Analyze the image of the original time series carefully and develop a model that generates patterns as close as possible to the real data. +""" + return prompt + + +def create_improvement_prompt(sample_characteristics: str, function_code: str, + error_message: Optional[str] = None) -> str: + """Create a prompt for improving the generated code based on the comparison.""" + if error_message: + prompt = f""" +I tried to execute your time series generation function, but encountered this error: + +{error_message} + +Please fix the issues and provide an improved version of your function that correctly generates synthetic data. + +Remember, the KEY ANOMALY PATTERN we're trying to reproduce is: +"{sample_characteristics}" + +Here's your original code: + +```python +{function_code} +``` + +Make sure your revised function: +1. Uses correct import statements +2. Has proper error handling +3. Returns exactly three arrays in this order: (acceleration, velocity, temperature) +4. Maintains the statistical properties of each time series +5. Correctly implements the anomaly pattern described above +""" + else: + prompt = f""" +I've compared your generated synthetic data with the original time series in the attached image. + +Please carefully examine the visual comparison and refine your model to better match the real data patterns. The KEY ANOMALY PATTERN we need to capture accurately is: +"{sample_characteristics}" + +Please focus on improving: +1. The timing, magnitude, and shape of the specific anomaly pattern +2. The synchronous relationship between acceleration and velocity metrics +3. The proper representation of temperature patterns and their relationship to vibration metrics +4. The baseline behavior of the time series (including stops/starts of the system) +5. The statistical properties (variance, range, distribution) of each time series +6. The balance between deterministic patterns and stochastic variations + +Here's your current code: + +```python +{function_code} +``` + +Based on the visual comparison, provide an improved version that generates synthetic data more closely matching the patterns, correlations, and anomalies seen in the original data. +""" + return prompt + + +def process_sample(args): + """Process a single sample through the iterative generation process.""" + sample, output_dir, iterations, enable_thinking, region = args + sample_id = sample['id'] + ts_data = np.array(sample['timeseries']) + cols = sample['cols'] + characteristic = sample['answer'] + print(f"Processing sample {sample_id}...") + sample_dir = os.path.join(output_dir, f"Sample_{sample_id}") + os.makedirs(sample_dir, exist_ok=True) + client = boto3.client('bedrock-runtime', region_name=region) + + initial_system_prompt = """ +You are a time series expert specialized in industrial equipment sensor data analysis and modeling. + +Your task is to develop a generative model that can produce synthetic multivariate time series data that matches a specific anomaly pattern described to you. +""" + improvement_system_prompt = """ +You are a time series expert specialized in industrial equipment sensor data analysis and modeling. +Your task is to improve your previously generated synthetic data model based on the comparison with the original data. +Make sure all imports are correctly specified at the top of your code. +Wrap all code in ```python code blocks. +""" + + current_code = None + error_message = None + sample_start_time = time.time() + img_path = None + + for iter_num in range(1, iterations + 1): + print(f" Iteration {iter_num}/{iterations}") + if iter_num == 1: + prompt = create_initial_prompt(characteristic) + system_prompt = initial_system_prompt + img_path = generate_original_image(ts_data, cols, output_dir, sample_id) + with open(img_path, "rb") as img_f: + img_b64 = base64.b64encode(img_f.read()).decode("utf8") + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}} + ] + }] + else: + prompt = create_improvement_prompt(characteristic, current_code, error_message) + if error_message: + print(f"\n=== ERROR MESSAGE SENT TO CLAUDE IN ITERATION {iter_num} ===") + print(f"{error_message}") + print("================================================\n") + system_prompt = improvement_system_prompt + error_message = None + with open(img_path, "rb") as img_f: + img_b64 = base64.b64encode(img_f.read()).decode("utf8") + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}} + ] + }] + + resp = invoke_claude(client, MODEL_ID, messages, system_prompt, enable_thinking) + if enable_thinking: + thought, response_text = parse_response(resp) + if thought: + with open(os.path.join(sample_dir, f"thinking_{iter_num}.txt"), 'w') as f: + f.write(thought) + else: + response_text = resp['content'][0]['text'] + + _, code_blocks = extract_python_code(response_text) + with open(os.path.join(sample_dir, f"response_{iter_num}.txt"), 'w') as f: + f.write(response_text) + + function_code = None + if code_blocks: + for block in code_blocks: + if "def generate_synthetic" in block: + function_code = block + break + if not function_code: + function_code = code_blocks[0] + + if not function_code: + print(f" Warning: No function code found in iteration {iter_num}") + continue + + with open(os.path.join(sample_dir, f"function_{iter_num}.py"), 'w') as f: + f.write(function_code) + + error_message = None + try: + synthetic_data = execute_function_code(function_code, n_samples=ts_data.shape[1], seed=42) + if synthetic_data is None or len(synthetic_data) != 3: + error_message = f"Invalid output: expected tuple of 3 arrays, got {type(synthetic_data)}" + print(f" Error: {error_message}") + else: + img_path = generate_comparison_image( + ts_data, synthetic_data, cols, output_dir, sample_id, iter_num + ) + print(f" Saved comparison image to {img_path}") + with open(os.path.join(sample_dir, f"synthetic_data_{iter_num}.json"), 'w') as f: + json.dump({ + "synthetic_data": [arr.tolist() for arr in synthetic_data], + "cols": cols + }, f, indent=2) + except Exception as e: + error_message = str(e) + print(f" Error executing function: {error_message}") + + current_code = function_code + + sample_end_time = time.time() + sample_execution_time = sample_end_time - sample_start_time + print(f"Completed sample {sample_id} in {sample_execution_time:.2f}s ({sample_execution_time/60:.2f}min)") + return { + 'id': sample_id, + 'characteristic': characteristic, + 'execution_time': sample_execution_time, + 'iterations': iterations + } + + +def run_iterative_generation(dataset_path: str, output_dir: str, iterations: int = 3, + sample_ids: Optional[List[str]] = None, enable_thinking: bool = False, + max_workers: Optional[int] = None, region: str = 'us-west-2'): + """Run the iterative generation process for the specified samples.""" + start_time = time.time() + os.makedirs(output_dir, exist_ok=True) + print(f"Loading dataset from {dataset_path}...") + dataset = load_dataset(dataset_path) + print(f"Loaded dataset with {len(dataset)} samples") + dataset = filter_what_happened_samples(dataset) + print(f"Filtered to {len(dataset)} 'what_happened' samples") + if sample_ids: + dataset = filter_samples_by_ids(dataset, sample_ids) + print(f"Further filtered to {len(dataset)} samples based on provided IDs") + if not dataset: + print("Error: No samples to process after filtering.") + return + if max_workers is None: + max_workers = max(1, cpu_count() // 2) + print(f"Processing {len(dataset)} samples with up to {max_workers} workers...") + results = [] + args_list = [(sample, output_dir, iterations, enable_thinking, region) for sample in dataset] + with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: + for result in executor.map(process_sample, args_list): + results.append(result) + summary = { + 'total_samples': len(dataset), + 'iterations_per_sample': iterations, + 'samples': results + } + summary_path = os.path.join(output_dir, "generation_summary.json") + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + end_time = time.time() + execution_time = end_time - start_time + print(f"All samples processed in {execution_time:.2f}s ({execution_time/60:.2f}min)") + + +if __name__ == "__main__": + script_start_time = time.time() + parser = argparse.ArgumentParser(description="Stage 1: Iterative Time Series Generation") + parser.add_argument("--dataset_path", type=str, + default="./sample_data/qa_benchmark_base_train.json", + help="Path to the benchmark dataset JSON file") + parser.add_argument("--output_dir", type=str, default="./results/iterative_results", + help="Directory to save results") + parser.add_argument("--iterations", type=int, default=3, + help="Number of improvement iterations") + parser.add_argument("--sample_ids", type=str, nargs="*", + help="Specific sample IDs to process") + parser.add_argument("--thinking", action="store_true", + help="Enable thinking mode for Claude") + parser.add_argument("--max_workers", type=int, default=None, + help="Maximum number of worker processes (default: half of CPU cores)") + parser.add_argument("--region", type=str, default="us-west-2", + help="AWS region for Bedrock (default: us-west-2)") + args = parser.parse_args() + + run_iterative_generation( + resolve_path(args.dataset_path), + resolve_path(args.output_dir), + args.iterations, + args.sample_ids, + args.thinking, + args.max_workers, + args.region + ) + + script_end_time = time.time() + total_execution_time = script_end_time - script_start_time + print(f"\nTotal execution time: {total_execution_time:.2f}s ({total_execution_time/60:.2f}min)") diff --git a/dataset/synthetic/sample_data/README.md b/dataset/synthetic/sample_data/README.md new file mode 100644 index 0000000..2183b61 --- /dev/null +++ b/dataset/synthetic/sample_data/README.md @@ -0,0 +1,40 @@ +# Sample Data for Synthetic Generation + +This directory contains the input data for the synthetic data generation pipeline. + +## Required Input File + +### `qa_benchmark_base_train.json` + +A JSON array of training samples, each with the following fields: + +```json +{ + "id": "sample_001", + "timeseries": [[...], [...], [...]], + "cols": ["Acceleration", "Velocity", "Temperature"], + "question": "...", + "question_type": "what_happened", + "answer": "vibration amplitude increases gradually over time", + "attributes": ["vibration amplitude increases gradually over time"], + "ability_types": ["MCQ_obs"] +} +``` + +Fields: +- **`id`**: Unique sample identifier +- **`timeseries`**: List of 3 channels `[acceleration, velocity, temperature]`, each a list of floats +- **`cols`**: Column names (always `["Acceleration", "Velocity", "Temperature"]`) +- **`question`**: The question text with `` placeholders for time series +- **`question_type`**: One of `what_happened`, `how_happened`, `suggested_fix` +- **`answer`**: The correct answer text (anomaly description) +- **`attributes`**: List of ground-truth attributes +- **`ability_types`**: One of `["MCQ_obs"]`, `["MCQ_cause"]`, `["MCQ_fix"]` + +The synthetic generation pipeline (Stages 1-2) uses only `what_happened` samples from this file. In Stage 4a, the original `answer` fields are used as seeds for LLM-based diversification to generate varied observations, root causes, and corrective actions. + +## How to Prepare Your Data + +1. Prepare your real time series data with anomaly descriptions +2. Format as `qa_benchmark_base_train.json` following the schema above +3. Place the file in this directory (or specify the path via CLI arguments) diff --git a/dataset/synthetic/stochastic_ts_generation.py b/dataset/synthetic/stochastic_ts_generation.py new file mode 100644 index 0000000..cdcb1a9 --- /dev/null +++ b/dataset/synthetic/stochastic_ts_generation.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 +""" +Stage 2: Stochastic Time Series Generation (Sampling-Based Refinement) + +This script is Stage 2 of the synthetic data generation pipeline. It builds on +the iterative generation results (Stage 1) by: +1. Loading the last iteration function code from Stage 1 +2. Asking Claude to simplify into sampling-based generators +3. Replacing hardcoded parameters with probabilistic distributions +4. Generating multiple samples from each model to validate diversity +5. Visualizing the diversity of generated samples +""" + +import os +import re +import json +import base64 +import argparse +import time +import numpy as np +import matplotlib.pyplot as plt +from typing import Dict, List, Tuple, Optional, Any +import boto3 +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception +from botocore.exceptions import ClientError, ReadTimeoutError, ConnectTimeoutError +import glob +from matplotlib.gridspec import GridSpec +import concurrent.futures +from multiprocessing import cpu_count + +# === PATH HANDLING === +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_path(path: str) -> str: + """Resolve a path relative to the script directory if not absolute.""" + if os.path.isabs(path): + return path + return os.path.join(SCRIPT_DIR, path) + + +# === CONFIGURATION === +MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +MAX_TOKENS = 4096 +# ====================== + + +def load_dataset(dataset_path: str) -> List[Dict]: + with open(dataset_path, 'r') as f: + return json.load(f) + + +def find_last_iteration_artifacts(sample_id: str, results_dir: str): + """Find the last iteration artifacts for a given sample ID from Stage 1.""" + sample_dir = os.path.join(results_dir, f"Sample_{sample_id}") + if not os.path.exists(sample_dir): + print(f"Warning: No previous results found for sample {sample_id}") + return None, None, None + function_files = sorted(glob.glob(os.path.join(sample_dir, "function_*.py"))) + if not function_files: + print(f"Warning: No function code found for sample {sample_id}") + return None, None, None + last_function = function_files[-1] + image_files = sorted(glob.glob(os.path.join(sample_dir, "iteration_*.png"))) + last_image = image_files[-1] if image_files else None + data_files = sorted(glob.glob(os.path.join(sample_dir, "synthetic_data_*.json"))) + sample_data = None + if data_files: + with open(data_files[-1], 'r') as f: + sample_data = json.load(f) + return last_function, last_image, sample_data + + +def load_function_code(function_path: str) -> str: + with open(function_path, 'r') as f: + return f.read() + + +def load_sample_info(dataset: List[Dict], sample_id: str) -> Optional[Dict]: + for sample in dataset: + if sample['id'] == sample_id: + return sample + return None + + +def should_retry(exc): + if isinstance(exc, ClientError) and exc.response.get("Error", {}).get("Code") == "ThrottlingException": + return True + if isinstance(exc, (ReadTimeoutError, ConnectTimeoutError)) or "ReadTimeoutError" in str(exc) or "ConnectTimeoutError" in str(exc): + print(f"Encountered timeout error: {str(exc)}. Retrying...") + return True + return False + + +@retry( + retry=retry_if_exception(should_retry), + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=2, max=10) +) +def invoke_claude(client, model_id, messages, system_prompt): + payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "temperature": 0.5, + "system": system_prompt, + "messages": messages + } + resp = client.invoke_model(body=json.dumps(payload), modelId=model_id) + return json.loads(resp['body'].read()) + + +def extract_python_code(response_text: str) -> Tuple[str, List[str]]: + code_pattern = r"```(?:python)?\s*([\s\S]*?)```" + code_blocks = re.findall(code_pattern, response_text) + analysis_text = re.sub(code_pattern, "", response_text).strip() + return analysis_text, code_blocks + + +def create_simplified_prompt(function_code: str, sample_characteristics: str) -> str: + """Create a prompt to encourage simplified, sampling-based generation.""" + prompt = f""" +# Simplified Industrial Time Series Generation + +I have a synthetic data generation function that currently produces time series data with the following key anomaly pattern: +"{sample_characteristics}" + +However, the current implementation is too complex and relies on hardcoded parameters. I need a simplified version that: + +1. Captures only the essential components of the pattern +2. Uses sampling instead of hardcoded values +3. Models machine operation periods as a hidden state from sampling rather than fixed time points +4. Is simpler and more generalizable than the current implementation for a much more diverse time series generation + +Here is the current implementation: + +```python +{function_code} +``` + +Please revise this code to create a more simplified, sampling-based generator that: +- Uses the EXACT function name 'generate_synthetic_anomaly' (this is required for compatibility with our system) +- Focuses on simplicity - the current implementation is overfitting to the original data +- Uses proper sampling for event (spike, rise, etc.) timing and magnitude rather than hardcoding specific time points +- Treats machine operations (on/off patterns) as hidden states that can be sampled +- Is concise, interpretable, and well-commented +- Preserves only the essential characteristics that fully characterize the anomaly pattern. + +NOTE: +1. For cases in `"both vibration and temperature rise sharply","a sudden parallel jump is observed in vibration and temperature","vibration and temperature increase abruptly at the same time"`, +Please properly model the jump sharply (almost vertically instead of in a sharp slope) and guarantee it is at the very end of the time series (155+). + +The function signature must be: +```python +def generate_synthetic_anomaly(n_samples=100, seed=None): + # Your code here + return acceleration, velocity, temperature +``` + +The goal is to create a simplified generator that captures the core pattern while enabling generation of diverse examples through sampling. +""" + return prompt + + +def execute_function_code(function_code: str, n_samples: int = 100, seed: Optional[int] = None): + """Execute the function code directly and return the generated data.""" + try: + globals_dict = { + 'np': np, + 'plt': plt, + '__builtins__': __builtins__ + } + try: + import scipy + import scipy.signal as signal + import scipy.stats + globals_dict['scipy'] = scipy + globals_dict['signal'] = signal + globals_dict['stats'] = scipy.stats + except ImportError: + print("Warning: SciPy not available") + import_lines = [] + code_lines = [] + for line in function_code.splitlines(): + if line.strip().startswith(('import ', 'from ')): + import_lines.append(line) + else: + code_lines.append(line) + if import_lines: + import_code = '\n'.join(import_lines) + try: + exec(import_code, globals_dict) + except Exception as e: + print(f"Warning: Error executing imports: {e}") + function_only_code = '\n'.join(code_lines) + exec(function_only_code, globals_dict) + function_names = [ + 'generate_synthetic_anomaly', + 'generate_synthetic_data', + 'generate_time_series_data', + 'generate_ts_data' + ] + for function_name in function_names: + if function_name in globals_dict: + kwargs = {'n_samples': n_samples} + if seed is not None: + kwargs['seed'] = seed + result = globals_dict[function_name](**kwargs) + return result + print("Error: No suitable generation function found") + return None + except Exception as e: + print(f"Error executing function: {e}") + return None + + +def generate_stochastic_model(sample_id: str, dataset_path: str, results_dir: str, + output_dir: str, num_claude_calls: int = 3, + region: str = 'us-west-2') -> Optional[Dict]: + """Generate a stochastic model for a given sample ID.""" + print(f"Processing sample {sample_id}...") + dataset = load_dataset(dataset_path) + sample_info = load_sample_info(dataset, sample_id) + if sample_info is None: + print(f"Error: Sample {sample_id} not found in dataset") + return None + sample_characteristic = sample_info['answer'] + function_path, image_path, _ = find_last_iteration_artifacts(sample_id, results_dir) + if function_path is None: + print(f"Error: No function code found for sample {sample_id}") + return None + function_code = load_function_code(function_path) + client = boto3.client('bedrock-runtime', region_name=region) + + system_prompt = """ +You are a time series expert specializing in industrial sensor data analysis and modeling. + +Your task is to improve the synthetic data generation code provided to you. + +Make sure all imports are correctly specified at the top of your code. +Wrap all code in ```python code blocks. +""" + + prompt = create_simplified_prompt(function_code, sample_characteristic) + img_b64 = None + if image_path: + with open(image_path, "rb") as img_f: + img_b64 = base64.b64encode(img_f.read()).decode("utf8") + sample_dir = os.path.join(output_dir, f"Sample_{sample_id}") + os.makedirs(sample_dir, exist_ok=True) + + stochastic_functions = [] + for call_index in range(num_claude_calls): + print(f"Invoking Claude (call {call_index+1}/{num_claude_calls})...") + if img_b64: + messages = [{"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}} + ]}] + else: + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + + response = invoke_claude(client, MODEL_ID, messages, system_prompt) + response_text = response['content'][0]['text'] + with open(os.path.join(sample_dir, f"stochastic_response_{call_index+1}.txt"), 'w') as f: + f.write(response_text) + _, code_blocks = extract_python_code(response_text) + if not code_blocks: + print(f"Error: No code blocks found for call {call_index+1}") + continue + stochastic_function = None + for block in code_blocks: + if "def generate_synthetic_anomaly" in block: + stochastic_function = block + break + if not stochastic_function: + stochastic_function = code_blocks[0] + stochastic_function_path = os.path.join(sample_dir, f"stochastic_function{call_index+1}.py") + with open(stochastic_function_path, 'w') as f: + f.write(stochastic_function) + print(f"Saved stochastic function {call_index+1}") + stochastic_functions.append((stochastic_function, stochastic_function_path)) + + # Generate test samples and visualizations + ts_data = np.array(sample_info['timeseries']) + n_samples = ts_data.shape[1] + cols = sample_info['cols'] + samples_per_function = 3 + all_samples = [] + + for func_index, (stochastic_function, _) in enumerate(stochastic_functions): + print(f"Generating test samples for function {func_index+1}/{len(stochastic_functions)}...") + samples = [] + for i in range(samples_per_function): + result = execute_function_code(stochastic_function, n_samples, seed=42 + i) + if result is not None: + samples.append(result) + if samples: + n_rows = 3 + n_cols = len(samples) + 1 + fig = plt.figure(figsize=(4 * n_cols, 10), constrained_layout=True) + gs = GridSpec(n_rows, n_cols, figure=fig) + for i in range(3): + ax = fig.add_subplot(gs[i, 0]) + ax.plot(np.arange(ts_data.shape[1]), ts_data[i], 'b-', linewidth=1.5) + ax.set_ylabel(cols[i], fontsize=12) + ax.grid(True, alpha=0.3) + if i == 0: + ax.set_title('Original Data', fontsize=14, fontweight='bold') + for j, sample in enumerate(samples): + for i in range(3): + ax = fig.add_subplot(gs[i, j + 1]) + synth = sample[i] + if len(synth) > ts_data.shape[1]: + synth = synth[:ts_data.shape[1]] + elif len(synth) < ts_data.shape[1]: + synth = np.pad(synth, (0, ts_data.shape[1] - len(synth))) + ax.plot(np.arange(len(synth)), synth, 'r-', linewidth=1.5) + ax.grid(True, alpha=0.3) + if i == 0: + ax.set_title(f'Sample {j+1}', fontsize=14, fontweight='bold') + img_path = os.path.join(sample_dir, f"multiple_samples_func{func_index+1}.png") + plt.savefig(img_path, dpi=150, bbox_inches='tight') + plt.close(fig) + for i, sample in enumerate(samples): + with open(os.path.join(sample_dir, f"stochastic_sample_func{func_index+1}_sample{i+1}.json"), 'w') as f: + json.dump({"synthetic_data": [arr.tolist() for arr in sample], "cols": cols}, f, indent=2) + all_samples.extend(samples) + + return { + 'id': sample_id, + 'characteristic': sample_characteristic, + 'num_functions_generated': len(stochastic_functions), + 'num_samples_generated': len(all_samples) + } + + +def process_sample_wrapper(args): + sample_id, dataset_path, results_dir, output_dir, num_claude_calls, region = args + return generate_stochastic_model(sample_id, dataset_path, results_dir, output_dir, num_claude_calls, region) + + +def run_stochastic_generation(dataset_path: str, results_dir: str, output_dir: str, + sample_ids: Optional[List[str]] = None, + max_workers: Optional[int] = None, + num_claude_calls: int = 3, region: str = 'us-west-2'): + """Run stochastic generation for multiple samples in parallel.""" + os.makedirs(output_dir, exist_ok=True) + if not sample_ids: + sample_dirs = [os.path.basename(d) for d in glob.glob(os.path.join(results_dir, "Sample_*"))] + sample_ids = [d.replace("Sample_", "") for d in sample_dirs] + print(f"Processing {len(sample_ids)} samples") + if max_workers is None: + max_workers = max(1, cpu_count() // 2) + max_workers = min(len(sample_ids), max_workers) if sample_ids else 1 + print(f"Using {max_workers} parallel workers") + args_list = [(sid, dataset_path, results_dir, output_dir, num_claude_calls, region) for sid in sample_ids] + start_time = time.time() + results = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: + for result in executor.map(process_sample_wrapper, args_list): + if result: + results.append(result) + end_time = time.time() + print(f"All samples processed in {end_time - start_time:.2f}s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Stage 2: Stochastic Time Series Generation") + parser.add_argument("--dataset_path", type=str, + default="./sample_data/qa_benchmark_base_train.json") + parser.add_argument("--results_dir", type=str, + default="./results/iterative_results") + parser.add_argument("--output_dir", type=str, + default="./results/stochastic_results") + parser.add_argument("--sample_ids", type=str, nargs="*") + parser.add_argument("--max_workers", type=int, default=None) + parser.add_argument("--num_claude_calls", type=int, default=3) + parser.add_argument("--region", type=str, default="us-west-2") + args = parser.parse_args() + + start_time = time.time() + run_stochastic_generation( + resolve_path(args.dataset_path), + resolve_path(args.results_dir), + resolve_path(args.output_dir), + args.sample_ids, + args.max_workers, + args.num_claude_calls, + args.region + ) + print(f"Total execution time: {time.time() - start_time:.2f}s") diff --git a/evaluation/evaluate_qa.py b/evaluation/evaluate_qa.py new file mode 100644 index 0000000..0f654fe --- /dev/null +++ b/evaluation/evaluate_qa.py @@ -0,0 +1,1191 @@ +import json +import os +import numpy as np +import re +from typing import * +from loguru import logger +from tqdm import tqdm +import traceback +from multiprocessing import Pool +from evaluation.my_ragas.score import calculate_ragas_score + + +def split_sentences(text): + """ + Split text into sentences while preserving decimal points and abbreviations. + """ + abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.'] + + for abbr in abbreviations: + escaped_abbr = re.escape(abbr) + text = re.sub(escaped_abbr, abbr.replace('.', ''), text) + + # Protect decimal points in numbers like 2.75 by temporarily replacing them + text = re.sub(r'(\d+)\.(\d+)', r'\1\2', text) + + pattern = r'[.!?。!?,;,;](?!\d)' + sentences = re.split(pattern, text) + + # Restore decimal points and abbreviation dots + sentences = [s.strip().replace('', '.').replace('', '.') for s in sentences if s.strip()] + + return sentences + + +def split_period_sentences(text): + """ + Split text into sentences using only periods as delimiters, while preserving decimal points and abbreviations. + """ + abbreviations = ['max.', 'eg.', 'Mrs.', 'Dr.', 'Mr.'] + + for abbr in abbreviations: + escaped_abbr = re.escape(abbr) + text = re.sub(escaped_abbr, abbr.replace('.', ''), text) + + # Protect decimal points in numbers like 2.75 by temporarily replacing them + text = re.sub(r'(\d+)\.(\d+)', r'\1\2', text) + + pattern = r'[.。](?!\d)' + sentences = re.split(pattern, text) + + # Restore decimal points and abbreviation dots + sentences = [s.strip().replace('', '.').replace('', '.') for s in sentences if s.strip()] + + return sentences + + +def match_metric_name(metric: str, sentence: str) -> bool: + """ + Check if a metric name appears in a sentence after normalizing both strings. + Preserves numbers for MCQ numeric options. + """ + pattern = r'[^\u4e00-\u9fa5a-zA-Z0-9]' + sentence = re.sub(pattern, '', sentence).lower() + metric = re.sub(pattern, '', metric).lower() + + return metric in sentence + + +def evaluate_trend(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for trend-related questions. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = False + sentences = split_sentences(answer) + + if len(sentences) == 0: + return [0.0], [0.0], [], [] + + if 'steady' in attribute['type']: + if 'steady' in sentences[0]: + cate_correct = True + elif 'decrease' in attribute['type']: + if 'decreas' in sentences[0].lower(): + cate_correct = True + elif 'increase' in attribute['type']: + if 'increas' in sentences[0].lower(): + cate_correct = True + + num_correct = [] + + # Check start point + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'start' in sentence: + if abs(attribute['start']) < 0.5: + if abs(float_numbers[0]) < 0.5: + num_correct.append(1.0) + else: + num_correct.append(0.0) + else: + num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['start']) / abs(attribute['start'])))) + break + else: + num_correct.append(0.0) + + # Check amplitude + if attribute['type'] != 'keep steady': + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'change value' in sentence or 'from left to right' in sentence: + if abs(attribute['amplitude']) < 0.5: + if abs(float_numbers[0]) < 0.5: + num_correct.append(1.0) + else: + num_correct.append(0.0) + else: + num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude'])))) + break + else: + num_correct.append(0.0) + + return [cate_correct], num_correct, [], [] + + +def evaluate_season(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for seasonality-related questions. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = False + sentences = split_sentences(answer) + + if len(sentences) == 0: + return [0.0], [0.0], [], [] + + if 'no' in attribute['type']: + if 'no periodic' in sentences[0].lower(): + cate_correct = True + else: + if 'no' not in sentences[0].lower() and 'periodic' in sentences[0].lower(): + cate_correct = True + + num_correct = [] + + if attribute['type'] != 'no periodic fluctuation': + # Check period + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'each period' in sentence: + num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['period']) / abs(attribute['period'])))) + break + else: + num_correct.append(0.0) + + # Check amplitude + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'amplitude' in sentence: + num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['amplitude']) / abs(attribute['amplitude'])))) + break + else: + num_correct.append(0.0) + else: + num_correct = [] + + return [cate_correct], num_correct, [], [] + + +def evaluate_noise(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for noise-related questions. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = False + sentences = split_sentences(answer) + + if len(sentences) == 0: + return [0.0], [0.0], [], [] + + if 'almost no' in attribute['type']: + if 'no noise' in sentences[0].lower(): + cate_correct = True + else: + if 'noisy' in sentences[0].lower(): + cate_correct = True + + num_correct = [] + + # Check noise standard deviation + if 'noisy' in attribute['type']: + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'standard' in sentence.lower() or 'std' in sentence.lower(): + num_correct.append(max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - attribute['std']) / abs(attribute['std'])))) + break + else: + num_correct.append(0.0) + + return [cate_correct], num_correct, [], [] + + +def evaluate_local(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for local feature questions. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = [] + num_correct = [] + + # Split into facts + for feat in attribute: + matched_flag = False + pos_numerical = 0.0 + amp_numerical = 0.0 + for fact in re.split(r'[;;]', answer): + sentences = re.split(r'[,。,;;]', fact) + if type(feat['type']) == str: + feat['type'] = [feat['type']] + if any(i in sentences[0].lower() for i in feat['type']): + # Check period and amplitude + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'position' in sentence.lower() or 'around point' in sentence.lower(): + if abs(float_numbers[0] - feat['position']) > 64: + continue + pos_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['position']) / abs(feat['position']))) + matched_flag = True + if matched_flag and 'amplitude' in sentence.lower(): + amp_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['amplitude']) / abs(feat['amplitude']))) + if matched_flag: + break + cate_correct.append(matched_flag) + num_correct.append(pos_numerical) + num_correct.append(amp_numerical) + + return cate_correct, num_correct, [], [] + + +def evaluate_local_inductive(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for local feature questions with inductive reasoning. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = [] + num_correct = [] + reason_correct = [] + reason_details = [] + + # Split into facts + for feat in attribute: + matched_flag = False + pos_numerical = 0.0 + amp_numerical = 0.0 + reason_score = 0.0 + cur_detail = {} + for fact in re.split(r'[;;]', answer): + sentences = re.split(r'[,。,;;]', fact) + if type(feat['type']) == str: + feat['type'] = [feat['type']] + if any(i in sentences[0].lower() for i in feat['type']): + # Check period and amplitude + for sentence in sentences: + float_numbers = list(map(float, re.findall(r'-?\d+\.?\d*', sentence))) + if float_numbers is None or len(float_numbers) == 0: + continue + if 'position' in sentence.lower() or 'around point' in sentence.lower(): + if abs(float_numbers[0] - feat['position']) > 64: + continue + pos_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['position']) / abs(feat['position']))) + matched_flag = True + if matched_flag and 'amplitude' in sentence.lower(): + amp_numerical = max(0.0, min(1.0, 1.0 - abs(float_numbers[0] - feat['amplitude']) / abs(feat['amplitude']))) + if matched_flag: + # Evaluate the inductive reasoning + reason_score, cur_detail = calculate_ragas_score( + question='Please analyze the physical meaning of this local fluctuation in one sentence.', + response=split_period_sentences(fact)[-1], + label=feat['explain'] + ) + cur_detail.update({ + 'label': feat['explain'], + 'response': split_period_sentences(fact)[-1] + }) + break + cate_correct.append(matched_flag) + num_correct.append(pos_numerical) + num_correct.append(amp_numerical) + reason_correct.append(reason_score) + reason_details.append(cur_detail) + + return cate_correct, num_correct, reason_correct, reason_details + + +def evaluate_shape_correlation_inductive(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for shape correlation questions with inductive reasoning. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = False + sentences = split_sentences(answer) + + if len(sentences) == 0: + return [False], [], [0.0], [{}] + + if attribute['label']: + if 'yes' in sentences[0].lower(): + cate_correct = True + else: + if 'no' in sentences[0].lower(): + cate_correct = True + + num_correct = [] + reason_correct, reason_detail = calculate_ragas_score( + question='Explain why they are correlated/no correlated considering their physical meaning in one sentence.', + response=sentences[-1], + label=attribute['explain'] + ) + + return [cate_correct], num_correct, [reason_correct], [reason_detail] + + +def evaluate_local_correlation_inductive(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for local correlation questions with inductive reasoning. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = False + sentences = split_period_sentences(answer) + + # If there's nothing at all, return early + if not sentences: + logger.debug("evaluate_local_correlation_inductive: no sentences parsed from answer") + return [False], [], [0.0], [{}] + + # Prepare for the case where we need sentences[1] + has_second = len(sentences) > 1 + + if attribute.get('label', False): + # Expect a "yes" in the first sentence to proceed + if 'yes' in sentences[0].lower(): + if not has_second: + logger.debug("evaluate_local_correlation_inductive: expected a second sentence for fact extraction but got only one") + else: + # Check correlation type only when we have a second sentence + label_cols = set(map(tuple, attribute.get('pair', []))) + answer_cols = set() + + # Split into facts safely + for fact in sentences[1].split(';'): + items = [s.strip() for s in fact.split(',')] + if len(items) == 2: + metric, corr_type = items + for col in cols: + if match_metric_name(col, metric): + answer_cols.add((col, corr_type)) + + if label_cols == answer_cols: + cate_correct = True + + else: + # Negative case: first sentence should contain "no" + if 'no' in sentences[0].lower(): + cate_correct = True + + # For the RAG-as-a-service score, always pick the *last* sentence we have + explanation = sentences[-1] if sentences else "" + try: + reason_correct, reason_detail = calculate_ragas_score( + question="Explain why they are correlated/not correlated considering their physical meaning in one sentence.", + response=explanation, + label=attribute.get('explain', "") + ) + except Exception as e: + logger.error(f"evaluate_local_correlation_inductive: calculate_ragas_score failed: {e}") + reason_correct, reason_detail = 0.0, {} + + return [cate_correct], [], [reason_correct], [reason_detail] + + +def evaluate_shape_cluster_inductive(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for shape cluster questions with inductive reasoning. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = 0.0 + num_correct = [] + + label_cols = set(attribute['cols']) + answer_cols = set() + + sentences = split_period_sentences(answer) + + if len(sentences) == 0: + return [0.0], [], [0.0], [{}] + + # Split into facts + for fact in split_period_sentences(answer)[0].split(','): + fact = fact.strip() + for col in cols: + if match_metric_name(col, fact): + answer_cols.add(col) + + # Calculate f1-score for label and answer + tp = len(label_cols & answer_cols) + fp = len(answer_cols - label_cols) + fn = len(label_cols - answer_cols) + if tp + fp + fn > 0: + cate_correct = 2 * tp / (2 * tp + fp + fn) + + num_correct = [] + reason_correct, reason_detail = calculate_ragas_score( + question='Explain why they have similar overall trend considering their physical meaning in one sentence.', + response=split_period_sentences(answer)[-1], + label=attribute['explain'] + ) + + return [cate_correct], num_correct, [reason_correct], [reason_detail] + + +def evaluate_local_cluster_inductive(answer: str, attribute: dict, cols: List[str]): + """ + Evaluate answers for local cluster questions with inductive reasoning. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + cate_correct = 0.0 + num_correct = [] + + label_cols = set(zip(attribute['cols'], [i[1] for i in attribute['col_idx']])) + answer_cols = set() + + sentences = split_period_sentences(answer) + + if len(sentences) == 0: + return [0.0], [], [0.0], [{}] + + # Split into facts + for fact in split_period_sentences(answer)[0].split(';'): + items = fact.strip().split(',') + if len(items) == 2: + for col in cols: + if match_metric_name(col, items[0].strip()): + answer_cols.add((col, items[1].strip())) + + # Calculate f1-score for label and answer + tp = len(label_cols & answer_cols) + fp = len(answer_cols - label_cols) + fn = len(label_cols - answer_cols) + if tp + fp + fn > 0: + cate_correct = 2 * tp / (2 * tp + fp + fn) + + num_correct = [] + reason_correct, reason_detail = calculate_ragas_score( + question='Explain why they have similar local fluctuations considering their physical meaning in one sentence.', + response=split_period_sentences(answer)[-1], + label=attribute['explain'] + ) + + return [cate_correct], num_correct, [reason_correct], [reason_detail] + + +def evaluate_deductive(answer, attribute, cols): + """ + Evaluate a yes/no (True/False) deductive question, falling back to RAGAS scoring otherwise. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + try: + labels = split_sentences(attribute) + except Exception as e: + logger.error(f"evaluate_deductive: Error splitting attribute {attribute!r} into labels: {e}") + labels = [] + + try: + sentences = split_sentences(answer) + except Exception as e: + logger.error(f"evaluate_deductive: Error splitting answer {answer!r} into sentences: {e}") + sentences = [] + + cur_reason_correct = 0.0 + ragas_detail = {} + + # Normalize set of boolean labels + bool_labels = {'yes', 'no', 'true', 'false'} + + if labels and labels[0].lower().strip().rstrip('.,') in bool_labels: + label0 = labels[0].lower().strip().rstrip('.,') + if sentences: + resp0 = sentences[0].lower().strip().rstrip('.,') + cur_reason_correct = 1.0 if resp0 == label0 else 0.0 + ragas_detail = {"label": label0, "response": resp0} + else: + logger.warning(f"evaluate_deductive: No sentences to compare against label '{label0}', unparsable answer: {answer!r}") + ragas_detail = {"label": label0, "response": None} + else: + # fallback to RAGAS + logger.info(f"evaluate_deductive: Falling back to RAGAS for answer {answer!r}") + try: + ragas_score, detail = calculate_ragas_score( + question="According to the previous information, please answer True or False and explain it in detail.", + response=answer, + label=attribute + ) + cur_reason_correct = ragas_score + ragas_detail = detail + except Exception as e: + logger.error(f"evaluate_deductive: Error in calculate_ragas_score: {e}") + cur_reason_correct = 0.0 + ragas_detail = {} + + return [], [], [cur_reason_correct], [ragas_detail] + + +def evaluate_inductive(answer: str, attribute: str, cols: List[str]): + """ + Evaluate an open‐ended (inductive) question by running RAGAS on the + model's full answer. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + try: + # Prompt the RAGAS scorer to compare the full answer against the label + ragas_score, ragas_detail = calculate_ragas_score( + question="According to the data and the question, please provide a correct answer and explanation.", + response=answer, + label=attribute + ) + except Exception as e: + logger.error(f"evaluate_inductive: RAGAS scoring failed: {e}") + ragas_score, ragas_detail = 0.0, {} + + # Return only a "reason" score + detail, leaving the categorical/numerical slots empty + return [], [], [ragas_score], [ragas_detail] + + +def evaluate_inductive_rme(answer: str, attribute: str, cols: List[str]): + """ + Evaluate an open‐ended RME (inductive) question by running RAGAS on the + model's full answer. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + try: + # Prompt the RAGAS scorer to compare the full answer against the label + ragas_score, ragas_detail = calculate_ragas_score( + question="According to the data and the question, please provide a correct answer and explanation.", + response=answer, + label=attribute + ) + except Exception as e: + logger.error(f"evaluate_inductive_rme: RAGAS scoring failed: {e}") + ragas_score, ragas_detail = 0.0, {} + + # Return only a "reason" score + detail, leaving the categorical/numerical slots empty + return [], [], [ragas_score], [ragas_detail] + + +def evaluate_causal(answer: str, attribute: str, cols: List[str]): + """ + Evaluate answers for causal questions. + + Args: + answer: Model's response + attribute: Ground truth attribute + cols: Column names + + Returns: + Tuple of (categorical scores, numerical scores, reason scores, reason details) + """ + # 1) Basic input validation + if not answer: + logger.warning(f"evaluate_causal: empty answer for attribute={attribute!r}, cols={cols}") + return [], [], [0.0], [{'label': attribute, 'response': ''}] + + if not attribute: + logger.warning(f"evaluate_causal: empty attribute") + return [], [], [0.0], [{'label': '', 'response': answer}] + + # 2) Extract first part of both strings for comparison + # Use regex to extract the content before the first period that's not part of a number + def extract_first_part(text): + # Find first period not followed by a digit (to avoid splitting decimal numbers) + match = re.search(r'(?= 0: + ga_item = generated_answer[pos] + answer = ga_item.get('response', "") + thought = ga_item.get('thought', None) + else: + answer = "" + thought = None + + # Ground-truth label and question + label = sample['answer'] + question = sample['question'] + + # Do the per-ability evaluation + evaluation_result = evaluate_qa(answer, sample) + + return { + 'idx': idx, + 'application_domain': app_domain, + 'task_type': task_type, + 'question': question, + 'label': label, + 'thought': thought, + 'response': answer, + 'evaluation': evaluation_result + } + + +def evaluate_batch_qa(dataset, generated_answer, EXP, num_workers=8): + """ + Evaluate a batch of QA results. + + Args: + dataset: Dataset containing ground truth + generated_answer: Model's answers + EXP: Experiment name (for saving results) + num_workers: Number of parallel workers + + Returns: + Evaluation results + """ + detailed_result = [] + ability_result = {'categorical': {}, 'numerical': {}, 'reason': {}} + overall_result = {'categorical': [], 'numerical': [], 'reason': []} + + total = len(dataset) + # 1) Identify which indices had errors + error_count = sum(1 for item in generated_answer if '<50 samples + per_task_final = {} + for t, blk in summary_by_task.items(): + if blk['sample_count'] <= 50: + continue + + # Per-ability details + detail_categorical = { + ability: round(float(np.nanmean(vals['categorical'])), 4) + for ability, vals in blk['by_ability'].items() + } + detail_numerical = { + ability: round(float(np.nanmean(vals['numerical'])), 4) + for ability, vals in blk['by_ability'].items() + } + detail_reason = { + ability: round(float(np.nanmean(vals['reason'])), 4) + for ability, vals in blk['by_ability'].items() + } + + # Overall aggregates + all_cate = np.concatenate([v['categorical'] for v in blk['by_ability'].values()]) if blk['by_ability'] else np.array([]) + all_num = np.concatenate([v['numerical'] for v in blk['by_ability'].values()]) if blk['by_ability'] else np.array([]) + all_rea = np.concatenate([v['reason'] for v in blk['by_ability'].values()]) if blk['by_ability'] else np.array([]) + + per_task_final[t] = { + 'detail_categorical': detail_categorical, + 'detail_numerical': detail_numerical, + 'detail_reason': detail_reason, + 'overall_categorical': round(float(np.nanmean(all_cate)), 4) if all_cate.size else 0.0, + 'overall_numerical': round(float(np.nanmean(all_num)), 4) if all_num.size else 0.0, + 'overall_reason': round(float(np.nanmean(all_rea)), 4) if all_rea.size else 0.0, + 'avg_tokens': round(float(np.nanmean(blk['token_counts'])), 1) if blk['token_counts'] else 0.0, + 'error_ratio': round(blk['error_count']/blk['sample_count'], 4), + 'sample_count': blk['sample_count'] + } + + if per_task_final: + with open(f"exp/{EXP}/result_by_task_type.json", "w") as f: + json.dump(per_task_final, f, ensure_ascii=False, indent=4) + logger.info(f"Wrote per-task-type summary to exp/{EXP}/result_by_task_type.json") + else: + logger.info("No task_type groups exceeding 50 samples; skipping per-task summary.") \ No newline at end of file diff --git a/evaluation/evaluate_with_sampling.py b/evaluation/evaluate_with_sampling.py new file mode 100644 index 0000000..fddf550 --- /dev/null +++ b/evaluation/evaluate_with_sampling.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Evaluate generated answers against ground-truth data, handling sampling metadata. + +This script enhances the standard evaluation script by checking for sampling metadata +and ensuring that evaluations are only performed on the sampled subset of the dataset. +If no metadata is found, it falls back to processing the full dataset. + +Usage: + python evaluate_with_sampling.py --exp EXP_NAME --dataset DATASET_PATH --generated OUTPUT_PATH +""" + +import os +import json +import argparse +import logging +from tqdm import tqdm + +from evaluate_qa import evaluate_batch_qa + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + +def load_dataset_with_sampling(dataset_path, generated_path): + """ + Load the dataset and generated answers, handling sampling metadata if available. + Also matches generated answers to dataset entries based on question text. + + Args: + dataset_path: Path to the full dataset + generated_path: Path to generated answers + + Returns: + tuple: (dataset, generated) + """ + # Check if dataset exists + if not os.path.isfile(dataset_path): + raise FileNotFoundError(f"Dataset file not found: {dataset_path}") + + # Load the full dataset + with open(dataset_path, "r") as f: + full_dataset = json.load(f) + + # Load generated answers + if os.path.exists(generated_path): + with open(generated_path, "r") as f: + generated = json.load(f) + else: + generated = [] + logger.warning(f"No generated answers found at {generated_path}; proceeding with empty list.") + + # Check for sampling metadata + metadata_path = generated_path.replace('.json', '_sampling_metadata.json') + sampled_indices = None + + if os.path.exists(metadata_path): + try: + with open(metadata_path, "r") as f: + metadata = json.load(f) + if "sampled_indices" in metadata: + sampled_indices = metadata["sampled_indices"] + logger.info(f"Found sampling metadata: {len(sampled_indices)} samples out of {metadata['original_size']}") + except Exception as e: + logger.warning(f"Error loading sampling metadata: {e}") + + # If we have sampling metadata, filter the dataset + if sampled_indices is not None: + try: + # Create a subset of the dataset with only the sampled indices + filtered_dataset = [full_dataset[idx] for idx in sampled_indices] + logger.info(f"Using sampled dataset with {len(filtered_dataset)} entries (out of {len(full_dataset)} total)") + + # Create a mapping of questions to reindex the generated answers + aligned_generated = [] + for idx, sample in enumerate(filtered_dataset): + sample_question = sample.get('question', '') + # Find matching generated answer based on question text + found = False + for gen_answer in generated: + if gen_answer.get('question', '') == sample_question: + # Create a copy and set the idx to match the new dataset position + modified_answer = dict(gen_answer) + modified_answer['idx'] = idx + aligned_generated.append(modified_answer) + found = True + break + if not found: + logger.warning(f"No matching generated answer found for question: {sample_question[:50]}...") + + logger.info(f"Aligned {len(aligned_generated)} generated answers with dataset questions") + generated = aligned_generated + dataset = filtered_dataset + + except Exception as e: + logger.error(f"Error applying sampling indices: {e}, falling back to full dataset") + dataset = full_dataset + else: + # No sampling metadata found, use the full dataset + dataset = full_dataset + logger.info(f"No sampling metadata found, using full dataset with {len(dataset)} entries") + + return dataset, generated + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate generated QA pairs against a ground-truth dataset, with sampling support." + ) + parser.add_argument( + "--exp", + required=True, + help="Experiment name (used to create exp//fig folder)" + ) + parser.add_argument( + "--dataset", + required=True, + help="Path to the ground-truth dataset JSON (e.g. evaluation/dataset/dataset_a.json)" + ) + parser.add_argument( + "--generated", + required=True, + help="Path to the generated answers JSON (e.g. evaluation/exp//generated_answer.json)" + ) + parser.add_argument( + "--num_workers", + type=int, + default=2, + help="Number of parallel workers for evaluate_batch_qa (default: 2)" + ) + parser.add_argument( + "--ignore_sampling", + action="store_true", + help="Ignore sampling metadata and use the full dataset" + ) + + args = parser.parse_args() + + EXP = args.exp + DATASET_PATH = args.dataset + OUTPUT_JSON = args.generated + FIG_DIR = os.path.join("exp", EXP, "fig") + os.makedirs(FIG_DIR, exist_ok=True) + + # Load dataset and generated answers + if args.ignore_sampling: + # Standard loading without sampling + if not os.path.isfile(DATASET_PATH): + raise FileNotFoundError(f"Dataset file not found: {DATASET_PATH}") + + with open(DATASET_PATH, "r") as f: + dataset = json.load(f) + + if os.path.exists(OUTPUT_JSON): + with open(OUTPUT_JSON, "r") as f: + generated = json.load(f) + else: + generated = [] + logger.warning(f"No generated answers found at {OUTPUT_JSON}; proceeding with empty list.") + else: + # Enhanced loading with sampling support + dataset, generated = load_dataset_with_sampling(DATASET_PATH, OUTPUT_JSON) + + logger.info(f"Loaded {len(dataset)} examples from dataset") + logger.info(f"Loaded {len(generated)} generated answers from: {OUTPUT_JSON}") + logger.info(f"Evaluation figures will be saved under: {FIG_DIR}") + + # Run the batch QA evaluation + evaluate_batch_qa(dataset, generated, EXP, num_workers=args.num_workers) + logger.info("Evaluation complete.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/evaluation/my_ragas/config.py b/evaluation/my_ragas/config.py new file mode 100644 index 0000000..d55de2f --- /dev/null +++ b/evaluation/my_ragas/config.py @@ -0,0 +1,121 @@ +import os +import sys +import logging +from pip._vendor import tomli +from langchain_core.language_models import BaseLanguageModel +from langchain_core.embeddings import Embeddings +from langchain_core.messages.human import HumanMessage +from langchain_aws.chat_models import ChatBedrock + +CONFIG_PATH = os.getenv('CONFIG_PATH', os.path.join(os.path.dirname(os.path.abspath(__file__)), './config/config.toml')) + +logger = logging.getLogger(__name__) + + +def load_config() -> dict: + if not os.path.exists(CONFIG_PATH): + logger.error(f'Config file does not exist: {CONFIG_PATH}') + sys.exit(1) + + with open(CONFIG_PATH, 'r') as f: + cfg = tomli.loads(f.read()) + + return cfg + + +config = load_config() + + +#--- +class BedrockChatWithText(ChatBedrock): + def generate_text(self, prompts, **kwargs): + # Drop any OpenAI‐style args Bedrock doesn't like + kwargs.pop("n", None) + kwargs.pop("max_tokens", None) + + # Wrap each prompt in a single HumanMessage turn + messages = [[ HumanMessage(content=p) ] for p in prompts] + + # Call the normal chat interface, which returns a ChatResult + # that has `.generations`: a List[List[Generation]] + result = super().generate(messages, **kwargs) + return result +#--- + +def load_llm() -> BaseLanguageModel: + models_config = config.get('models') + llm_type = models_config.get('llm_type', 'openai') + if llm_type == 'openai': + os.environ["OPENAI_API_BASE"] = models_config.get('openai_api_base', '') + os.environ["OPENAI_API_KEY"] = models_config.get('openai_api_key', '') + + from langchain_openai.chat_models import ChatOpenAI + + return ChatOpenAI(model=models_config.get('llm_model', 'gpt-3.5-turbo-16k')) + + elif llm_type == 'tongyi': + os.environ["DASHSCOPE_API_KEY"] = models_config.get('dashscope_api_key', '') + + from langchain_community.chat_models.tongyi import ChatTongyi + + return ChatTongyi(model=models_config.get('llm_model', 'qwen1.5-72b-chat')) + + elif llm_type == 'glm': + os.environ["OPENAI_API_BASE"] = models_config.get('openai_api_base', '') + os.environ["OPENAI_API_KEY"] = models_config.get('openai_api_key', '') + from langchain_community.chat_models import ChatZhipuAI + + model = ChatZhipuAI( + temperature=models_config.get('temperature', 1), + api_key=models_config.get('openai_api_key', ''), + model=models_config.get('llm_model', 'gpt-3.5-turbo-16k') + ) + return model + + elif llm_type == 'bedrock': + + return BedrockChatWithText( + # credentials_profile_name=models_config['bedrock_profile'], + region_name=models_config['bedrock_region'], + endpoint_url=( + f"https://bedrock-runtime.{models_config['bedrock_region']}" + ".amazonaws.com" + ), + model_id=models_config['llm_model'], + temperature=0.0, + max_tokens=4095, + model_kwargs=models_config.get('bedrock_kwargs', {}), + ) + + logger.error(f'Unsupported LLM model: {llm_type}') + sys.exit(1) + + +def load_embeddings() -> Embeddings: + embedding_config = config.get('embedding') + emb_type = embedding_config.get('emb_type', 'openai') + if emb_type == 'openai': + os.environ["OPENAI_API_BASE"] = embedding_config.get('openai_api_base', '') + os.environ["OPENAI_API_KEY"] = embedding_config.get('openai_api_key', '') + + from langchain_openai.embeddings import OpenAIEmbeddings + + return OpenAIEmbeddings(model=embedding_config.get('embeddings_model', 'text-embedding-ada-002')) + + elif emb_type == 'dashscope': + os.environ["DASHSCOPE_API_KEY"] = embedding_config.get('dashscope_api_key', '') + + from langchain_community.embeddings.dashscope import DashScopeEmbeddings + + return DashScopeEmbeddings(model=embedding_config.get('embeddings_model', 'text-embedding-v2')) + + elif emb_type == 'bedrock': + from langchain_aws import BedrockEmbeddings + + return BedrockEmbeddings( + # credentials_profile_name=embedding_config['bedrock_profile'], + region_name=embedding_config['bedrock_region'], + ) + + logger.error(f'Unsupported Embeddings model: {emb_type}') + sys.exit(1) \ No newline at end of file diff --git a/evaluation/my_ragas/config/config.toml b/evaluation/my_ragas/config/config.toml new file mode 100644 index 0000000..fd97c1f --- /dev/null +++ b/evaluation/my_ragas/config/config.toml @@ -0,0 +1,52 @@ +# ──────────────────────────────────────────────────────────────────────────────── +# (unchanged) global settings +data_dir = "./evaluation/my_ragas/data" +default_quota = 150 +submission_interval = 10 + +[judge] +enabled = true +max_workers = 8 +correct_threshold = 0.6 +max_retries = 3 +max_execution_time = 60 +report_metric = "score" + +[langsmith] +enabled = false +api_key = "" +project_name = "" +endpoint = "https://api.smith.langchain.com" + +# ──────────────────────────────────────────────────────────────────────────────── +# Use AWS Bedrock +[models] +llm_type = "bedrock" +# 0818 - switch to 3.7 - inductive drop +# llm_model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +llm_model = "us.anthropic.claude-3-5-haiku-20241022-v1:0" +# AWS credentials/profile + region for Bedrock +bedrock_profile = "default" # ~/.aws/credentials profile +bedrock_region = "us-west-2" # the AWS region where Bedrock is served + +# Any extra kwargs +[models.bedrock_kwargs] +temperature = 0.4 + +# ──────────────────────────────────────────────────────────────────────────────── +# Embeddings: also via Bedrock +[embedding] +emb_type = "bedrock" +embeddings_model = "amazon.titan-embed-text" # or whichever Bedrock embed model you prefer + +# AWS credentials/profile + region for embeddings +bedrock_profile = "default" +bedrock_region = "us-west-2" + +# ──────────────────────────────────────────────────────────────────────────────── +[cms] +enabled = false +interval = 10 +contest_id = "" +username = "" +password = "" diff --git a/evaluation/my_ragas/data/gt_cache.json b/evaluation/my_ragas/data/gt_cache.json new file mode 100644 index 0000000..96e9055 --- /dev/null +++ b/evaluation/my_ragas/data/gt_cache.json @@ -0,0 +1,210 @@ +{ + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"The database is experiencing significant contention issues related to enqueue resources.\"": [ + "database contention" + ], + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"The ASM file metadata operation indicates a temporary increase in I/O workload or resource demand.\"": [ + "I/O workload", + "resource demand" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all related to I/O contention and resource waiting times in the Oracle database system. The system may be experiencing high I/O load or contention, leading to delays in SQL execution and increased wait times.\"": [ + "I/O contention", + "resource waiting times" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all solar radiation-related.\"": [ + "solar radiation" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all related to solar radiation.\"": [ + "solar radiation" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"wv (m/s) represents wind speed which varies rapidly with atmospheric conditions, while rho (g/m**3) represents air density which changes slowly with temperature and pressure, leading to different trends and shapes.\"": [ + "wind speed", + "air density" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"log file sync is I/O-related, so its fluctuations indicate changes in writing to redo log files. But enq: CF - contention is related to control file access, which may not be affected by I/O writing fluctuations.\"": [ + "I/O", + "control file access" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Raining (s) indicates the duration of precipitation, while VPact (mbar) measures the actual vapor pressure in the air; they are not similar because rainfall duration does not directly affect the vapor pressure, which can vary independently based on humidity and temperature conditions.\"": [ + "rainfall duration", + "vapor pressure" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all related to the intensity and duration of precipitation events.\"": [ + "precipitation intensity", + "precipitation duration" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all I/O and contention-related. The system may be experiencing high contention and I/O bottlenecks, leading to increased latency and performance degradation.\"": [ + "I/O bottlenecks", + "contention" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"SQL Time Per Second is workload-related, reflecting the overall SQL execution time influenced by many factors like concurrency and resource usage, while gc buffer busy release is cluster-related, indicating contention in the buffer cache for shared resources, which may not vary with SQL execution time.\"": [ + "workload vs cluster", + "different resource contexts" + ], + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"There is a temporary but significant increase in table-level lock contention within the Oracle Database System.\"": [ + "lock contention" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all precipitation-related.\"": [ + "precipitation" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Buffer busy waits are related to contention for data blocks in memory, whereas log file sync is related to committing transactions to disk, so they may not fluctuate together.\"": [ + "buffer busy waits", + "log file sync" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"enq: TM - contention is transaction-related, so it fluctuates with table-level locks, but ARCH wait for archivelog lock is archive-related, which may not be affected by table lock contention.\"": [ + "transaction lock", + "archive lock" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"SQL Time Per Second is related to overall database performance and workload, while enq: TM - contention is related to specific lock contention scenarios, so they may not fluctuate together.\"": [ + "database performance", + "lock contention" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Max wind velocity (m/s) and temperature (degC) are not similar because wind velocity reflects dynamic air movement, while temperature measures thermal energy, leading to different trends and shapes in a weather system.\"": [ + "dynamic air movement", + "thermal energy" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Tpot (K) represents the potential temperature reflecting the thermal state of the air, while VPact (mbar) indicates the actual vapor pressure representing humidity levels, which often vary independently in a weather system.\"": [ + "thermal state", + "humidity levels" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both rho (g/m**3) and rh (%) are moisture-related metrics in a weather system.\"": [ + "moisture metrics" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"SQL Time Per Second is execution-related, so it reflects processing time fluctuations, but Log archive I/O is storage-related, which may not be directly impacted by execution time changes.\"": [ + "execution vs storage" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all I/O and contention-related. The system may be experiencing significant data access bottlenecks, leading to waiting times and contention around I/O operations and resource management.\"": [ + "I/O", + "contention" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all contention and IO-related. The system may be experiencing heavy disk I/O operations and locking contention, leading to potential performance degradation or delays.\"": [ + "contention", + "IO-related" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Sh (g/kg) and Tdew (°C) are both measures related to humidity in a weather system, where sh represents the specific humidity indicating the amount of moisture in the air, while Tdew is the temperature at which air becomes saturated and moisture begins to condense.\"": [ + "humidity", + "moisture" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all temperature and moisture-related, as T (degC), Tpot (K), and Tdew (degC) reflect different aspects of air temperature, while VPmax (mbar), VPact (mbar), sh (g/kg), and H2OC (mmol/mol) measure various components of atmospheric moisture content and saturation.\"": [ + "temperature and moisture" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all I/O and transaction-related. The system may be experiencing heavy disk I/O operations and lock contention, leading to performance bottlenecks.\"": [ + "I/O", + "transaction" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Wait Time Per Second is related to overall system wait times, so it can fluctuate due to various system-wide issues. But log_file_sync平均等待时间 is specifically related to committing transactions to the redo log, which may not be affected by broader system fluctuations.\"": [ + "system wait times", + "redo log transactions" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all related to database I/O performance and contention. The system may be experiencing heavy load, leading to contention for resources and resulting in delays and inefficiencies in read and write operations.\"": [ + "database I/O performance", + "resource contention" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both ##log file sync and ##log_file_sync Average Waiting Time (average wait time) are related to the writing of redo log data to disk, and high I/O contention or slow I/O performance can cause their fluctuations.\"": [ + "log file sync", + "disk I/O" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"##gc buffer busy release is related to RAC inter-instance contention, while ##LGWR wait for redo copy is related to redo log writing, thus they may fluctuate independently.\"": [ + "RAC contention", + "redo log writing" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Rho (g/m**3) represents air density while VPact (mbar) measures actual vapor pressure; their trends differ because air density mainly depends on temperature and pressure, whereas actual vapor pressure depends on the humidity level.\"": [ + "air density", + "vapor pressure" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all related to the commit operation in Oracle. The system may be experiencing issues with writing redo log entries to disk, causing delays in commit processing.\"": [ + "commit operation", + "redo log" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are both humidity-related.\"": [ + "humidity" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both \\\"log file sync\\\" and \\\"log_file_sync平均等待时间\\\" (average wait time) are related to redo log writes, and contention or delays in writing redo logs to disk can cause their fluctuations.\"": [ + "redo log writes", + "disk writes" + ], + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"A sudden increase in competition or demand for ad impressions, likely due to a high-traffic event, targeted campaign, or spike in advertiser bidding.\"": [ + "ad demand spike" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are both related to wind speed, indicating the average and peak velocities of the wind within a specific time frame.\"": [ + "wind speed" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both rain (mm) and raining (s) are precipitation-related metrics in a weather system.\"": [ + "precipitation metrics" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all wind-related, with wv (m/s) representing the average wind speed and max. wv (m/s) representing the maximum wind speed observed during a specific period.\"": [ + "wind-related" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both ##Wait Time Per Second (wait time per second) and ##SQL Time Per Second (SQL time per second) are workload-related metrics, and high workload on the system may cause their fluctuations.\"": [ + "workload-related metrics" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all wind velocity-related.\"": [ + "wind velocity" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both VPact (mbar) and rh (%) are humidity-related metrics in a weather system.\"": [ + "humidity metrics" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Rain (mm) measures the total amount of precipitation, while rho (g/m³) represents air density; they are not similar because precipitation accumulates over time and is influenced by different factors than air density, which varies with temperature and pressure regardless of rainfall.\"": [ + "precipitation", + "air density" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"##SQL Time Per Second is execution-related, so it fluctuates with the volume of SQL activity, but ##buffer busy waits are contention-related, which may not be affected by changes in SQL execution volume.\"": [ + "execution vs contention" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all related to I/O and contention issues within the Oracle database. The system may be experiencing heavy load or resource contention, leading to performance degradation or potential failure.\"": [ + "I/O and contention issues" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both ##Disk file operations I/O and ##row cache lock fluctuations can occur together because heavy disk I/O can increase contention for row cache locks, which are used to manage access to cached dictionary information in Oracle.\"": [ + "disk I/O", + "row cache locks" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"SWDR (W/m) represents solar radiation energy received on the surface, while wind velocity (wv in m/s) indicates the speed of moving air, and these two can differ in trend and shape because solar radiation is influenced by sunlight and atmospheric conditions, whereas wind speed is driven by air pressure differences and temperature variations.\"": [ + "solar radiation", + "wind velocity" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Disk file operations I/O are related to physical read/write operations on disk, while enq: DL - contention is related to distributed lock management, hence fluctuations in disk I/O do not necessarily affect lock contention.\"": [ + "disk I/O", + "distributed lock" + ], + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"The Oracle Database System is experiencing synchronization issues with a temporary delay in processing.\"": [ + "synchronization issues" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Row cache lock is metadata-related, so it fluctuates with metadata activity, but enq: RO - fast object reuse is related to object reuse, which may not be impacted by metadata changes.\"": [ + "metadata", + "object reuse" + ], + "\"Please analyze the physical meaning of this local fluctuation in one sentence.\"|||\"A brief decrease in workload or computational demand, possibly due to the completion of a task or temporary reduction in user activity.\"": [ + "workload decrease" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Both raining (s) and rain (mm) are precipitation-related metrics in a weather system.\"": [ + "precipitation metrics" + ], + "\"Explain why they have similar local fluctuations considering their physical meaning in one sentence.\"|||\"These metrics are all related to database resource contention and I/O wait events. The system may be experiencing significant performance bottlenecks due to locking issues and high I/O latency, possibly indicative of resource saturation or contention during the failure period.\"": [ + "database resource contention", + "I/O wait events" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"##gc buffer busy release is related to contention in the Global Cache Service of an Oracle RAC environment, while ##Disk file operations I/O pertains to physical disk reads and writes, so contention in cache does not necessarily lead to fluctuations in disk I/O.\"": [ + "cache contention", + "disk I/O" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"H2OC (mmol/mol) measures water vapor concentration, while rho (g/m³) measures air density, and their trends differ because they are influenced by distinct factors like temperature and humidity.\"": [ + "water vapor", + "air density" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all temperature-related and influence atmospheric thermodynamics.\"": [ + "temperature", + "thermodynamics" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Max. wv (m/s) represents wind speed variations while rho (g/m³) represents air density, which are influenced by different atmospheric factors, hence their trends and shapes do not align.\"": [ + "different atmospheric factors" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"##gc buffer busy acquire is related to global cache buffer busy waits in RAC environments, so it fluctuates with cluster-wide data sharing issues. But ##enq: TC - contention2 is related to tablespace contention in single instance environments, which may not be affected by cluster-related problems.\"": [ + "cluster-wide data sharing", + "single instance environments" + ], + "\"Explain why they have similar overall trend considering their physical meaning in one sentence.\"|||\"These metrics are all related to temperature and humidity dynamics in the atmosphere, with T representing air temperature, Tpot indicating potential temperature, VPmax reflecting the maximum vapor pressure, and VPdef signifying the vapor pressure deficit.\"": [ + "temperature and humidity dynamics" + ], + "\"Explain why they are correlated/no correlated considering their physical meaning in one sentence.\"|||\"Relative humidity (rh %) measures the amount of moisture in the air compared to its maximum capacity, while dew point temperature (Tdew in °C) indicates the temperature at which air becomes saturated and water vapor begins to condense, leading to differences in their trends and shapes in a weather system.\"": [ + "moisture saturation" + ] +} \ No newline at end of file diff --git a/evaluation/my_ragas/data/gt_cache.lock b/evaluation/my_ragas/data/gt_cache.lock new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/my_ragas/metric.py b/evaluation/my_ragas/metric.py new file mode 100644 index 0000000..ceb7dc8 --- /dev/null +++ b/evaluation/my_ragas/metric.py @@ -0,0 +1,338 @@ +import logging +import difflib +import typing as t +from dataclasses import dataclass, field +import multiprocessing +import numpy as np +from ragas.llms.json_load import json_loader +from ragas.llms.prompt import Prompt +from ragas.metrics._answer_similarity import AnswerSimilarity +from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM +from filelock import FileLock +from ragas.run_config import RunConfig +from evaluation.my_ragas.config import config +import json +import os +from langchain_core.messages.human import HumanMessage + +logger = multiprocessing.get_logger() + +# Path to cache file +DATA_DIR = config.get('data_dir', 'data').rstrip('/') +CACHE_PATH = os.path.join(DATA_DIR, 'gt_cache.json') +CACHE_LOCK_PATH = CACHE_PATH.replace('.json', '.lock') + +def load_cache(): + # logger.info("Loading cache") + if os.path.exists(CACHE_PATH): + with FileLock(CACHE_LOCK_PATH): + with open(CACHE_PATH, 'r') as f: + return json.load(f) + return {} + +def save_cache(cache): + # logger.info("Saving cache") + with FileLock(CACHE_LOCK_PATH): + if os.path.exists(CACHE_PATH): + with open(CACHE_PATH, 'r') as f: + cache.update(json.load(f)) + with open(CACHE_PATH, 'w') as f: + json.dump(cache, f, ensure_ascii=False, indent=4) + +async def get_gt_keywords(llm, question, groundtruth, callbacks, is_async): + cache = load_cache() + cache_key = json.dumps(question, ensure_ascii=False) + '|||' + json.dumps(groundtruth, ensure_ascii=False) + + if cache_key in cache: + return cache[cache_key] + + # Use the LLM to generate gt_keywords + gt_keywords = await generate_gt_keywords_with_llm(llm, question, groundtruth, callbacks, is_async) + + cache[cache_key] = gt_keywords + save_cache(cache) + + return gt_keywords + +async def generate_gt_keywords_with_llm(llm, question, groundtruth, callbacks, is_async): + # 1) build the PromptValue + pv = GT_PROMPT.format(**{ + "question": question, + "ground_truth": groundtruth + }) + + # 2) extract its underlying string + prompt_str = pv.prompt_str # or: prompt_str = str(pv) + + # 3) wrap *that* in a HumanMessage + messages = [[ HumanMessage(content=prompt_str) ]] + + # 4) call generate + response = llm.generate(messages, callbacks=callbacks) + response = await json_loader.safe_load( + response.generations[0][0].text, llm, is_async=is_async + ) + if 'gt_keywords' in response and type(response['gt_keywords']) == list: + result = [] + # Remove keywords from questions + for key in response['gt_keywords']: + if key.lower() not in question.lower() or key.lower() in groundtruth.lower(): + result.append(key) + return result + else: + return [] + +GT_PROMPT = Prompt( + name="gt_keywords", + instruction="""Given a question and the ground truth, extract the following information: + "gt_keywords": Identify and return a list of keywords or phrases contained in the ground_truth. This list should only include the minimal combination of key points that directly answer the question, avoiding words unrelated to the question as much as possible. Only **1 or 2** keywords are needed in each list, and the keywords should be as concise as possible and **as short as possible**. + """, + examples=[ + { + "question": """What protocol does AMF use to ensure correct time?""", + "ground_truth": """The role of NTP is to synchronize the time of all clocked devices within the network, ensuring that the clock times of all devices in the network remain basically consistent, so that the devices can provide various applications based on unified time. Possible applicable network elements: AMF, MME, SGSN""", + "Extracted statements": { + "gt_keywords": ["NTP"] + }, + }, + { + "question": """What is the purpose of MME assigning a TA List to UE?""", + "ground_truth": """The purpose of MME assigning a TA List to UE is to manage the location of idle UEs. After MME assigns a TA List to UE, the UE will not initiate a TA update to MME when moving within this TA List. When the UE moves out of the TA List range, it will initiate a TA update, and the MME will know that the location of the idle UE is within the TA List range. When paging is needed, it can be done within the TA List range.""", + "Extracted statements": { + "gt_keywords": ["manage the location of idle UEs"] + }, + } + ], + input_keys=["question", "ground_truth"], + output_key="Extracted statements", + output_type="json", +) + + +ANSWER_PROMPT = Prompt( + name="answer_correctness_step_2", + instruction="""Given a question, an answer generated by the model, and the list of ground truth keywords (gt_keywords), extract the following information: + "overlapping_keywords": From the list of "gt_keywords", identify any terms or phrases that also appear in the model's answer. These overlapping keywords indicate the points of agreement or coverage between the model's answer and the ground truth.""", + examples=[ + { + "question": "What powers the sun and what is its primary function?", + "gt_keywords": ["nuclear fusion", "energy", "light", "essential for life", "climate system", "weather", "ocean currents"], + "answer": "The sun is powered by nuclear fission, similar to nuclear reactors on Earth, and its primary function is to provide light to the solar system.", + "Extracted statements": { + "overlapping_keywords": ["light"] + } + }, + { + "question": "What is the boiling point of water?", + "gt_keywords": ["100 degrees Celsius", "212 degrees Fahrenheit", "sea level", "change with altitude"], + "answer": "The boiling point of water is 100 degrees Celsius at sea level.", + "Extracted statements": { + "overlapping_keywords": ["100 degrees Celsius", "sea level"] + } + }, + { + "question": "What information should be submitted when contacting technical support for a communication technology company?", + "gt_keywords": ["fault details", "log files and alarm query results", "steps taken to address the issue", "commands executed", "results of those actions", "remote access details", "contact information for relevant personnel"], + "answer": "When contacting technical support for a communication technology company, the following information should be provided: 1. Fault details: Time, location, and event description.", + "Extracted statements": { + "overlapping_keywords": ["fault details"] + } + }, + { + "question": "What are the benefits of a balanced diet?", + "gt_keywords": ["provides essential nutrients", "maintains a healthy weight", "reduces risk of chronic diseases", "supports overall health"], + "answer": "A balanced diet helps maintain a healthy weight and supports overall health.", + "Extracted statements": { + "overlapping_keywords": ["maintain a healthy weight", "supports overall health"] + } + }, + { + "question": "What is the capital of France?", + "gt_keywords": ["Paris"], + "answer": "The capital of France is Paris.", + "Extracted statements": { + "overlapping_keywords": ["Paris"] + } + }, + { + "question": "How does photosynthesis work?", + "gt_keywords": ["process by which plants convert sunlight into energy", "involves chlorophyll", "produces oxygen", "occurs in the chloroplasts"], + "answer": "Photosynthesis is the process by which plants use sunlight to produce energy and oxygen.", + "Extracted statements": { + "overlapping_keywords": ["process by which plants convert sunlight into energy", "produces oxygen"] + } + }, + ], + input_keys=["question", "gt_keywords", "answer"], + output_key="Extracted statements", + output_type="json", +) + + +@dataclass +class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings): + name: str = "answer_correctness" # type: ignore[reportIncompatibleMethodOverride] + evaluation_mode: EvaluationMode = EvaluationMode.qga # type: ignore[reportIncompatibleMethodOverride] + answer_prompt: Prompt = field(default_factory=lambda: ANSWER_PROMPT) + weights: list[float] = field(default_factory=lambda: [1.0, 0.0]) + keyword_matching_threshold: float = field(default_factory=lambda: 0.6) + answer_similarity: AnswerSimilarity | None = None + + def __post_init__(self): + if len(self.weights) != 2: + raise ValueError( + "Expects a list of two weights. First for factuality, second for semantic similarity" + ) + if all([w == 0 for w in self.weights]): + raise ValueError("At least one weight must be non-zero") + if not all([w >= 0 for w in self.weights]): + raise ValueError("Weights must be non-negative") + + def init(self, run_config: RunConfig): + super().init(run_config) + if self.answer_similarity is None and self.weights[1] != 0: + self.answer_similarity = AnswerSimilarity( + llm=self.llm, embeddings=self.embeddings + ) + self.answer_detail = {} + + def _compute_statement_presence(self, gt_keywords, prediction: t.Any, question: str) -> float: + """ + Compute an F1‐style overlap score: + - gt_keywords: list[str] + - prediction: either a dict or a list of dicts from the LLM + - question: the prompt (for logging) + """ + assert self.llm is not None, "LLM must be set" + key = "overlapping_keywords" + + # 1) Normalize prediction → a single dict containing our key, or empty dict + if isinstance(prediction, dict): + pred_dict = prediction + elif isinstance(prediction, list): + pred_dict = next( + (p for p in prediction if isinstance(p, dict) and key in p), + {} + ) + else: + pred_dict = {} + + try: + # 2) Extract the overlapping_keywords list (or default to empty) + overlapping = pred_dict.get(key, []) + if not isinstance(overlapping, list): + overlapping = [] + + # 3) Normalize GT keywords + if not gt_keywords or (isinstance(gt_keywords, float) and np.isnan(gt_keywords)): + gt_list = [] + else: + gt_list = [k.lower() for k in gt_keywords] + + # 4) Lowercase the predictions + ow_list = [k.lower() for k in overlapping] + + # 5) Record in answer_detail + self.answer_detail[question] = { + "answer_keywords": "|".join(ow_list), + "gt_keywords": "|".join(gt_list), + "overlapping_keywords": "|".join([k for k in ow_list if self.match(gt_list, k)]) + } + + # 6) Compute counts + num_ok = sum(1 for k in ow_list if self.match(gt_list, k)) + num_all = len(gt_list) + + # 7) F1‐style score (0 when no GT keywords) + if num_all > 0: + return min(num_ok / num_all, 1.0) + else: + return 0.0 + + except Exception: + logger.error(f"_compute_statement_presence FAILED for question={question!r}") + logger.error(f" gt_keywords= {gt_keywords!r}") + logger.error(f" raw prediction={prediction!r}") + logger.error(" traceback:\n" + traceback.format_exc()) + return 0.0 + + + + def match(self, arr, k): + for item in arr: + if difflib.SequenceMatcher(None, item, k).ratio() >= self.keyword_matching_threshold: + return True + return False + + async def _ascore(self, row, callbacks, is_async): + assert self.llm is not None, "LLM must be set" + + q, a, g = row["question"], row["answer"], row["ground_truth"] + + if len(row["answer"].strip('"')) == 0: + return 0.0 + + gt_keywords = await get_gt_keywords(self.llm, q, g, callbacks, is_async=is_async) + + prompt_input = { + "question": q, + "gt_keywords": gt_keywords, + "answer": a + } + + # 1) Format the PromptValue and extract its raw string + prompt_val = self.answer_prompt.format(**prompt_input) + prompt_str = ( + prompt_val.prompt_str + if hasattr(prompt_val, "prompt_str") + else str(prompt_val) + ) + + # 2) Wrap in a HumanMessage and supply as a list-of-lists + from langchain.schema import HumanMessage + messages = [[ HumanMessage(content=prompt_str) ]] + + # 3) Call generate with the properly‐shaped messages + is_statement_present = self.llm.generate( + messages, + callbacks=callbacks + ) + + prediction = await json_loader.safe_load( + is_statement_present.generations[0][0].text, self.llm, is_async=is_async + ) + f1_score = self._compute_statement_presence(gt_keywords, prediction, q) + + if self.weights[1] == 0: + similarity_score = 0 + else: + assert self.answer_similarity is not None, "AnswerSimilarity must be set" + + callbacks = [] + similarity_score = await self.answer_similarity.ascore( + row, callbacks=callbacks, is_async=is_async + ) + if q not in self.answer_detail: + self.answer_detail[q] = {} + self.answer_detail[q]['answer_similarity'] = float(similarity_score) + self.answer_detail[q]['keywords_prediction'] = str(is_statement_present.generations[0][0].text) + score = np.average( + [f1_score, similarity_score], + weights=self.weights, + ) + + return float(score) + + def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: + assert self.llm is not None, "llm must be set to compute score" + + logger.info(f"Adapting AnswerCorrectness metric to {language}") + self.correctness_prompt = self.answer_prompt.adapt( + language, self.llm, cache_dir + ) + + def save(self, cache_dir: t.Optional[str] = None) -> None: + self.answer_prompt.save(cache_dir) + + +answer_correctness = AnswerCorrectness() diff --git a/evaluation/my_ragas/requirements.txt b/evaluation/my_ragas/requirements.txt new file mode 100644 index 0000000..e77766e --- /dev/null +++ b/evaluation/my_ragas/requirements.txt @@ -0,0 +1,9 @@ +ragas==0.1.9 +langchain==0.2.15 +langchain-chroma==0.1.1 +langchain-community==0.2.15 +langchain-core==0.2.37 +langchain-experimental==0.0.64 +langchain-openai==0.1.23 +langchain-text-splitters==0.2.0 +langchainhub==0.1.16 diff --git a/evaluation/my_ragas/score.py b/evaluation/my_ragas/score.py new file mode 100644 index 0000000..2fc1960 --- /dev/null +++ b/evaluation/my_ragas/score.py @@ -0,0 +1,64 @@ +import copy +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type +) +from botocore.exceptions import ClientError +from evaluation.my_ragas.metric import AnswerCorrectness +from evaluation.my_ragas.config import load_llm, load_embeddings, config +from ragas import RunConfig + +#–– retry wrappers for the I/O‐heavy loaders –– +@retry( + stop=stop_after_attempt(15), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(Exception) # you can narrow this to your specific I/O exceptions +) +def load_embeddings_with_retry(): + return load_embeddings() + +@retry( + stop=stop_after_attempt(15), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(Exception) +) +def load_llm_with_retry(): + return load_llm() + +@retry( + stop=stop_after_attempt(15), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(Exception) +) +def calculate_ragas_score(question: str, response: str, label: str): + """ + Compute the RAGAS score for a single QA pair, retrying resource loads on failure. + Returns: + - score (float) + - detail dict (deepcopied) + """ + # 1) load resources with retry guarantees + embeddings = load_embeddings_with_retry() + llm = load_llm_with_retry() + + # 2) instantiate the scorer + answer_correctness = AnswerCorrectness( + embeddings=embeddings, + llm=llm, + weights=[1.0, 0.0] + ) + answer_correctness.answer_detail = {} + + # 3) compute the score + score = answer_correctness.score( + row={ + 'question': question, + 'answer': response, + 'ground_truth': label + } + ) + + # 4) return both the numeric score and a snapshot of the detail dict + return float(score), copy.deepcopy(answer_correctness.answer_detail) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cb678b3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +# AWS SDK for Claude API access +boto3>=1.28.0 + +# Core libraries +numpy>=1.23.0 +pandas>=2.0.0 +tqdm>=4.65.0 + +# Retry logic +tenacity>=8.2.0 + +# Scientific computing (used by synthetic data generation) +scipy>=1.10.0 + +# Plotting and visualization +matplotlib>=3.7.0 +seaborn>=0.12.0 + +# JSON and data processing +jsonlines>=3.1.0 + +# Parallel processing +concurrent-futures-extractor>=1.0.0 + +# API utilities +requests>=2.31.0 + +# HTML generation +jinja2>=3.1.0 \ No newline at end of file diff --git a/scripts/run_chatts_inference.sh b/scripts/run_chatts_inference.sh new file mode 100755 index 0000000..f418335 --- /dev/null +++ b/scripts/run_chatts_inference.sh @@ -0,0 +1,191 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_chatts_inference.sh +# +# Script to run ChatTS inference on a dataset. +# This script: +# 1. Checks if ChatTS server is running, and starts it if not +# 2. Runs the ChatTS inference script on the specified dataset +# 3. Creates output directories and saves results +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Defaults (can be overridden with command line arguments) +DATASET_PATH="${PROJECT_ROOT}/dataset/processed/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +OUTPUT_NAME="chatts_results.json" + +# Server configuration +CHATTS_PORT=5000 +CHATTS_PID_FILE="/tmp/chatts_server_${CHATTS_PORT}.pid" +CHATTS_SERVER_URL="http://localhost:${CHATTS_PORT}" + +# Create required directories +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${PROJECT_ROOT}/logs" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output) + OUTPUT_NAME="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --workers) + WORKERS="$2" + shift 2 + ;; + --max-samples) + MAX_SAMPLES="$2" + shift 2 + ;; + --help) + echo "Usage: $0 [--dataset PATH] [--output FILENAME] [--output-dir DIR] [--workers N] [--max-samples N]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output FILENAME Name of the output JSON file (default: chatts_results.json)" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --workers N Number of parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo "" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Full paths for output +OUTPUT_PATH="${OUTPUT_DIR}/${OUTPUT_NAME}" +LOG_PATH="${PROJECT_ROOT}/logs/chatts_inference_$(date +%Y-%m-%d-%H-%M-%S).log" + +# Optional parameters for the inference script +OPTIONAL_PARAMS="" +if [[ -v WORKERS ]]; then + OPTIONAL_PARAMS="${OPTIONAL_PARAMS} --workers ${WORKERS}" +fi + +if [[ -v MAX_SAMPLES ]]; then + OPTIONAL_PARAMS="${OPTIONAL_PARAMS} --max_samples ${MAX_SAMPLES}" +fi + +echo "========================================" +echo "Configuration:" +echo "DATASET_PATH = ${DATASET_PATH}" +echo "OUTPUT_PATH = ${OUTPUT_PATH}" +echo "LOG_PATH = ${LOG_PATH}" +echo "SERVER_URL = ${CHATTS_SERVER_URL}" +echo "OPTIONAL_PARAMS = ${OPTIONAL_PARAMS}" +echo "========================================" +echo "" + +# ── Check if ChatTS server is running ──────────────────────────────────────── +check_server() { + echo "Checking if ChatTS server is running..." + + # Check if PID file exists + if [ -f "$CHATTS_PID_FILE" ]; then + PID=$(cat "$CHATTS_PID_FILE") + echo "Found PID file with PID: $PID" + + # Check if process is actually running + if kill -0 "$PID" 2>/dev/null; then + echo "ChatTS server is running with PID: $PID" + return 0 + else + echo "PID file exists but process is not running" + rm -f "$CHATTS_PID_FILE" + fi + fi + + # Check if port is in use (server might be running without PID file) + if nc -z localhost "$CHATTS_PORT" 2>/dev/null; then + echo "Port $CHATTS_PORT is in use, assuming ChatTS server is running" + return 0 + fi + + echo "ChatTS server is not running" + return 1 +} + +# ── Start ChatTS server if needed ────────────────────────────────────────────── +if check_server; then + echo "Using existing ChatTS server" +else + echo "Starting ChatTS server..." + CHATTS_UTILS_DIR="${PROJECT_ROOT}/src/chatts_utils" + + if [ ! -x "${CHATTS_UTILS_DIR}/start_chatts_server.sh" ]; then + echo "Error: ChatTS server start script not found or not executable: ${CHATTS_UTILS_DIR}/start_chatts_server.sh" + exit 1 + fi + + # Start the server + echo "Running: ${CHATTS_UTILS_DIR}/start_chatts_server.sh" + "${CHATTS_UTILS_DIR}/start_chatts_server.sh" + + # Check if server started successfully + if ! check_server; then + echo "Error: Failed to start ChatTS server" + exit 1 + fi +fi + +# ── Run ChatTS inference script ─────────────────────────────────────────────── +echo "Running ChatTS inference..." +echo "Dataset: ${DATASET_PATH}" +echo "Output: ${OUTPUT_PATH}" +echo "Log: ${LOG_PATH}" + +# Ensure inference script is executable +INFERENCE_SCRIPT="${PROJECT_ROOT}/src/chatts_inference.py" +if [ ! -x "$INFERENCE_SCRIPT" ]; then + chmod +x "$INFERENCE_SCRIPT" +fi + +# Run the inference script +echo "Command: python $INFERENCE_SCRIPT --dataset_path $DATASET_PATH --output_path $OUTPUT_PATH --server_url $CHATTS_SERVER_URL $OPTIONAL_PARAMS" +python "$INFERENCE_SCRIPT" \ + --dataset_path "$DATASET_PATH" \ + --output_path "$OUTPUT_PATH" \ + --server_url "$CHATTS_SERVER_URL" \ + $OPTIONAL_PARAMS | tee -a "$LOG_PATH" + +# Check if inference completed successfully +if [ $? -eq 0 ]; then + echo "========================================" + echo "ChatTS inference completed successfully!" + echo "Results saved to: ${OUTPUT_PATH}" + echo "Log saved to: ${LOG_PATH}" + echo "========================================" +else + echo "========================================" + echo "ChatTS inference failed with an error" + echo "Check the log for details: ${LOG_PATH}" + echo "========================================" + exit 1 +fi \ No newline at end of file diff --git a/scripts/run_claude_inference.sh b/scripts/run_claude_inference.sh new file mode 100755 index 0000000..4850091 --- /dev/null +++ b/scripts/run_claude_inference.sh @@ -0,0 +1,144 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_claude_inference.sh +# +# Runs Claude inference on a dataset. +# Handles setting up proper directories and configurations. +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Defaults (can be overridden with command line arguments) +DATASET_PATH="${PROJECT_ROOT}/dataset/processed/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +TEXT_ONLY=false +MAX_SAMPLES=200 +WORKERS=4 + +# Create required directories +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${PROJECT_ROOT}/logs" + +# Display usage information +show_help() { + echo "Usage: $0 [--dataset PATH] [--output-dir DIR] [--output-name NAME] [--workers N] [--max-samples N] [--text-only]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --output-name NAME Name for the output JSON file (default: based on mode)" + echo " --workers N Number of parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo " --text-only Run inference with text-only mode (for VLM)" + echo "" + exit 0 +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --output-name) + OUTPUT_NAME="$2" + shift 2 + ;; + --workers) + WORKERS="$2" + shift 2 + ;; + --max-samples) + MAX_SAMPLES="$2" + shift 2 + ;; + --text-only) + TEXT_ONLY=true + shift + ;; + --help) + show_help + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Determine output filename based on mode and other settings +if [[ -z "${OUTPUT_NAME-}" ]]; then + BASE_NAME="claude" + + # Add text-only suffix if used + if [[ "$TEXT_ONLY" == true ]]; then + BASE_NAME="${BASE_NAME}_text_only" + fi + + OUTPUT_NAME="${BASE_NAME}_results.json" +fi + +# Full paths for output +OUTPUT_PATH="${OUTPUT_DIR}/${OUTPUT_NAME}" +LOG_PATH="${PROJECT_ROOT}/logs/${BASE_NAME}_$(date +%Y-%m-%d-%H-%M-%S).log" + +echo "========================================" +echo "Configuration:" +echo "DATASET_PATH = ${DATASET_PATH}" +echo "OUTPUT_PATH = ${OUTPUT_PATH}" +echo "LOG_PATH = ${LOG_PATH}" +echo "TEXT_ONLY = ${TEXT_ONLY}" +echo "MAX_SAMPLES = ${MAX_SAMPLES}" +echo "WORKERS = ${WORKERS}" +echo "========================================" +echo "" + +# ── Run Claude inference script ─────────────────────────────────────────────── +# Construct the command +SCRIPT="${PROJECT_ROOT}/src/claude_thinking_inference.py" + +COMMAND="python ${SCRIPT} \ + --dataset_path ${DATASET_PATH} \ + --output_path ${OUTPUT_PATH} \ + --max_samples ${MAX_SAMPLES} \ + --workers ${WORKERS}" + +# Add optional arguments +if [[ "$TEXT_ONLY" == true ]]; then + COMMAND="${COMMAND} --text_only" +fi + +echo "Running command: ${COMMAND}" +eval "${COMMAND}" | tee -a "${LOG_PATH}" + +# Check if inference completed successfully +if [ $? -eq 0 ]; then + echo "========================================" + echo "Claude inference completed successfully!" + echo "Results saved to: ${OUTPUT_PATH}" + echo "Log saved to: ${LOG_PATH}" + echo "========================================" +else + echo "========================================" + echo "Claude inference failed with an error" + echo "Check the log for details: ${LOG_PATH}" + echo "========================================" + exit 1 +fi \ No newline at end of file diff --git a/scripts/run_injection_workflow.sh b/scripts/run_injection_workflow.sh new file mode 100755 index 0000000..38e02d0 --- /dev/null +++ b/scripts/run_injection_workflow.sh @@ -0,0 +1,302 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_injection_workflow.sh +# +# A script to run the Claude with ChatTS injection pipeline for time +# series analysis. +# +# This pipeline: +# 1. Starts the ChatTS server if not already running +# 2. Runs Claude with ChatTS injection, where ChatTS injects knowledge at +# the beginning of Claude's thinking process +# 3. Evaluates the results using the evaluation script +# 4. Optionally stops the server when complete +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Defaults (can be overridden with command line arguments) +DATASET_PATH="${PROJECT_ROOT}/dataset/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +NUM_RUNS=1 +MAX_SAMPLES=200 +CLAUDE_WORKERS=10 +CHATTS_WORKERS=10 +EVAL_WORKERS=4 +CHATTS_PORT=5000 +CHATTS_SERVER_URL="http://localhost:${CHATTS_PORT}" +STOP_SERVER=false + +# Create required directories +mkdir -p "${OUTPUT_DIR}" + +# Display usage information +show_help() { + echo "Usage: $0 [--dataset PATH] [--output-dir DIR] [--num-runs N] [--workers N] [--max-samples N] [--stop-server]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --num-runs N Number of runs to perform (default: 1)" + echo " --claude-workers N Number of Claude API parallel workers (default: 10)" + echo " --chatts-workers N Number of ChatTS API parallel workers (default: 10)" + echo " --eval-workers N Number of evaluation parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo " --chatts-port N Port for ChatTS server (default: 5000)" + echo " --stop-server Stop the ChatTS server when done" + echo "" + exit 0 +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --num-runs) + NUM_RUNS="$2" + shift 2 + ;; + --claude-workers) + CLAUDE_WORKERS="$2" + shift 2 + ;; + --chatts-workers) + CHATTS_WORKERS="$2" + shift 2 + ;; + --eval-workers) + EVAL_WORKERS="$2" + shift 2 + ;; + --max-samples) + MAX_SAMPLES="$2" + shift 2 + ;; + --chatts-port) + CHATTS_PORT="$2" + CHATTS_SERVER_URL="http://localhost:${CHATTS_PORT}" + shift 2 + ;; + --stop-server) + STOP_SERVER=true + shift + ;; + --help) + show_help + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Extract dataset name without extension and path +DATASET_NAME=$(basename "$DATASET_PATH" .json) + +# Create log directory +mkdir -p "${PROJECT_ROOT}/logs" +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_PATH="${PROJECT_ROOT}/logs/injection_workflow_${TIMESTAMP}.log" + +echo "========================================" | tee -a "$LOG_PATH" +echo "Configuration:" | tee -a "$LOG_PATH" +echo "DATASET_PATH = ${DATASET_PATH}" | tee -a "$LOG_PATH" +echo "OUTPUT_DIR = ${OUTPUT_DIR}" | tee -a "$LOG_PATH" +echo "NUM_RUNS = ${NUM_RUNS}" | tee -a "$LOG_PATH" +echo "MAX_SAMPLES = ${MAX_SAMPLES}" | tee -a "$LOG_PATH" +echo "CLAUDE_WORKERS = ${CLAUDE_WORKERS}" | tee -a "$LOG_PATH" +echo "CHATTS_WORKERS = ${CHATTS_WORKERS}" | tee -a "$LOG_PATH" +echo "EVAL_WORKERS = ${EVAL_WORKERS}" | tee -a "$LOG_PATH" +echo "CHATTS_SERVER_URL = ${CHATTS_SERVER_URL}" | tee -a "$LOG_PATH" +echo "LOG_PATH = ${LOG_PATH}" | tee -a "$LOG_PATH" +echo "========================================" | tee -a "$LOG_PATH" +echo "" | tee -a "$LOG_PATH" + +# ── Initialize Conda in this shell ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh. Do you need to run 'conda init'?" | tee -a "$LOG_PATH" + exit 1 +fi + +# ------------------------------------------------------------------------------ +# 1) Record start time and verify ChatTS server is running +# ------------------------------------------------------------------------------ +# Record start time +STARTTIME=$(date +%s) + +echo "Checking if ChatTS server is running..." | tee -a "$LOG_PATH" + +if nc -z localhost $CHATTS_PORT 2>/dev/null; then + echo "ChatTS server running on port $CHATTS_PORT" | tee -a "$LOG_PATH" +else + echo "ERROR: ChatTS server is not running on port $CHATTS_PORT" | tee -a "$LOG_PATH" + echo "Please start it first with: src/chatts_utils/start_chatts_server.sh" | tee -a "$LOG_PATH" + exit 1 +fi + +# Run the pipeline multiple times with different output paths +for RUN_NUM in $(seq 1 $NUM_RUNS); do + echo "" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + echo "Starting Injection Run #$RUN_NUM of $NUM_RUNS" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + # Define output paths for this run + RUN_SUFFIX="run${RUN_NUM}" + + # Output paths with run number suffix + INJECTION_OUT="${OUTPUT_DIR}/claude-injection-${DATASET_NAME}-${RUN_SUFFIX}/generated_answer.json" + + # Create output directory + mkdir -p "$(dirname "$INJECTION_OUT")" + + # ------------------------------------------------------------------------------ + # 2) Generate ChatTS observations (shared across all runs if cached) + # ------------------------------------------------------------------------------ + CHATTS_OBS_DIR="${OUTPUT_DIR}/chatts-observations-${DATASET_NAME}" + mkdir -p "$CHATTS_OBS_DIR" + CHATTS_OBS="${CHATTS_OBS_DIR}/generated_answer.json" + + if [ -f "$CHATTS_OBS" ]; then + echo "Found existing ChatTS observations at $CHATTS_OBS" | tee -a "$LOG_PATH" + echo "Reusing existing observations for injection" | tee -a "$LOG_PATH" + else + echo "Generating ChatTS observations..." | tee -a "$LOG_PATH" + + conda activate evaluation + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + python "${PROJECT_ROOT}/src/chatts_injection.py" \ + --dataset_path "$DATASET_PATH" \ + --output_path "$CHATTS_OBS" \ + --server_url "$CHATTS_SERVER_URL" \ + --workers $CHATTS_WORKERS | tee -a "$LOG_PATH" + + echo "ChatTS observations saved to $CHATTS_OBS" | tee -a "$LOG_PATH" + fi + + # ------------------------------------------------------------------------------ + # 3) Run Claude with ChatTS injection + # ------------------------------------------------------------------------------ + echo "Running Claude with ChatTS injection (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + echo " Dataset : $DATASET_PATH" | tee -a "$LOG_PATH" + echo " Injection : $CHATTS_OBS" | tee -a "$LOG_PATH" + echo " Output : $INJECTION_OUT" | tee -a "$LOG_PATH" + + conda activate evaluation + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + python "${PROJECT_ROOT}/src/claude_thinking_with_injection.py" \ + --dataset_path "$DATASET_PATH" \ + --injection_path "$CHATTS_OBS" \ + --output_path "$INJECTION_OUT" \ + --workers $CLAUDE_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM: Injection results saved to $INJECTION_OUT" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + # ------------------------------------------------------------------------------ + # 4) Evaluate injection results + # ------------------------------------------------------------------------------ + echo "Evaluating injection results (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + + # Extract experiment name + INJECTION_EXP_NAME=$(basename "$(dirname "$INJECTION_OUT")") + + # Ensure we're in the evaluation environment + conda activate evaluation + + # Set PYTHONPATH to include the project root directory + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + # Evaluate injection results + echo " Evaluating injection results..." | tee -a "$LOG_PATH" + python "${PROJECT_ROOT}/evaluation/evaluate_with_sampling.py" \ + --exp "$INJECTION_EXP_NAME" \ + --dataset "$DATASET_PATH" \ + --generated "$INJECTION_OUT" \ + --num_workers $EVAL_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM: Injection evaluation complete" | tee -a "$LOG_PATH" + echo " Results in ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM Complete!" | tee -a "$LOG_PATH" + echo "Results:" | tee -a "$LOG_PATH" + echo " - Injection results: $INJECTION_OUT" | tee -a "$LOG_PATH" + echo " - Injection evaluation: ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +done + +# ------------------------------------------------------------------------------ +# 4) Runtime Summary +# ------------------------------------------------------------------------------ +ENDTIME=$(date +%s) +RUNTIME=$((ENDTIME - STARTTIME)) +echo "=======================================" | tee -a "$LOG_PATH" +echo "Total runtime: $RUNTIME seconds ($(($RUNTIME / 60)) minutes)" | tee -a "$LOG_PATH" +echo "=======================================" | tee -a "$LOG_PATH" + +# ------------------------------------------------------------------------------ +# 5) Summary of all runs +# ------------------------------------------------------------------------------ +echo "========================================" | tee -a "$LOG_PATH" +echo "All $NUM_RUNS Runs Completed Successfully" | tee -a "$LOG_PATH" +echo "========================================" | tee -a "$LOG_PATH" + +echo "Summary of result locations:" | tee -a "$LOG_PATH" +for RUN_NUM in $(seq 1 $NUM_RUNS); do + RUN_SUFFIX="run${RUN_NUM}" + + INJECTION_OUT="${OUTPUT_DIR}/claude-injection-${DATASET_NAME}-${RUN_SUFFIX}/generated_answer.json" + INJECTION_EXP_NAME=$(basename "$(dirname "$INJECTION_OUT")") + + echo "Run #$RUN_NUM:" | tee -a "$LOG_PATH" + echo " - Injection results: $INJECTION_OUT" | tee -a "$LOG_PATH" + echo " - Injection evaluation: ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +done + +# ------------------------------------------------------------------------------ +# 6) Stop server if requested +# ------------------------------------------------------------------------------ +if [ "$STOP_SERVER" = true ]; then + echo "Stopping ChatTS server..." | tee -a "$LOG_PATH" + "${PROJECT_ROOT}/src/chatts_utils/stop_chatts_server.sh" || true + echo "ChatTS server stopped" | tee -a "$LOG_PATH" +else + echo "" | tee -a "$LOG_PATH" + echo "NOTE: The ChatTS server is still running." | tee -a "$LOG_PATH" + echo "When you are done with all evaluations, stop it using:" | tee -a "$LOG_PATH" + echo " ${PROJECT_ROOT}/src/chatts_utils/stop_chatts_server.sh" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +fi + +echo "Log saved to: ${LOG_PATH}" | tee -a "$LOG_PATH" +echo "Workflow complete!" | tee -a "$LOG_PATH" diff --git a/scripts/run_iterative_generation.sh b/scripts/run_iterative_generation.sh new file mode 100755 index 0000000..6b96dd9 --- /dev/null +++ b/scripts/run_iterative_generation.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_iterative_generation.sh +# +# Stage 1: Iterative synthetic data generation using Claude. +# Uses Claude (via AWS Bedrock) to iteratively generate Python code that models +# anomaly patterns from real training data. +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root +SYNTHETIC_DIR="${PROJECT_ROOT}/dataset/synthetic" + +echo "========================================" +echo "Stage 1: Iterative Synthetic Data Generation" +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "SYNTHETIC_DIR = $SYNTHETIC_DIR" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +DATASET_PATH="${SYNTHETIC_DIR}/sample_data/qa_benchmark_base_train.json" +OUTPUT_DIR="${SYNTHETIC_DIR}/results/iterative_results" +ITERATIONS=3 +MAX_WORKERS=10 +REGION="us-west-2" + +# ── Parse command line arguments ───────────────────────────────────────────── +SAMPLE_IDS=() +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output_dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --iterations) + ITERATIONS="$2" + shift 2 + ;; + --max_workers) + MAX_WORKERS="$2" + shift 2 + ;; + --region) + REGION="$2" + shift 2 + ;; + --sample_ids) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + SAMPLE_IDS+=("$1") + shift + done + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +echo "Dataset: $DATASET_PATH" +echo "Output: $OUTPUT_DIR" +echo "Iterations: $ITERATIONS" +echo "Max workers: $MAX_WORKERS" +echo "Region: $REGION" +if [ ${#SAMPLE_IDS[@]} -gt 0 ]; then + echo "Sample IDs: ${SAMPLE_IDS[*]}" +else + echo "Sample IDs: (all what_happened samples)" +fi +echo "Start time: $(date)" +echo "" + +# ── Create output directory ────────────────────────────────────────────────── +mkdir -p "${OUTPUT_DIR}" + +# ── Build Python command ───────────────────────────────────────────────────── +CMD=( + python "${SYNTHETIC_DIR}/iterative_ts_generation.py" + --dataset_path="${DATASET_PATH}" + --output_dir="${OUTPUT_DIR}" + --iterations="${ITERATIONS}" + --max_workers="${MAX_WORKERS}" + --region="${REGION}" + --thinking +) + +if [ ${#SAMPLE_IDS[@]} -gt 0 ]; then + CMD+=(--sample_ids "${SAMPLE_IDS[@]}") +fi + +# ── Run ────────────────────────────────────────────────────────────────────── +"${CMD[@]}" + +echo "" +echo "========================================" +echo "Stage 1 Complete" +echo "Results: ${OUTPUT_DIR}" +echo "End time: $(date)" +echo "========================================" diff --git a/scripts/run_qwen3_injection_workflow.sh b/scripts/run_qwen3_injection_workflow.sh new file mode 100755 index 0000000..86fa2ac --- /dev/null +++ b/scripts/run_qwen3_injection_workflow.sh @@ -0,0 +1,330 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_qwen3_injection_workflow.sh +# +# A script to run the Qwen3 with Qwen-VL injection pipeline for time +# series analysis. +# +# This pipeline: +# 1. Ensures Qwen-VL server is running and generates VL observations +# 2. Ensures Qwen3 server is running +# 3. Runs Qwen3 with Qwen-VL injection (thoughts + answer injected into +# Qwen3's thinking via continue_final_message) +# 4. Evaluates the results using the evaluation script +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Defaults (can be overridden with command line arguments) +DATASET_PATH="${PROJECT_ROOT}/dataset/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +NUM_RUNS=1 +MAX_SAMPLES=200 +QWEN3_WORKERS=10 +QWEN_VL_WORKERS=4 +EVAL_WORKERS=4 + +# Server configuration +QWEN_VL_PORT=5003 +QWEN_VL_SERVER_URL="http://localhost:${QWEN_VL_PORT}" +QWEN3_PORT=5001 +QWEN3_SERVER_URL="http://localhost:${QWEN3_PORT}" + +STOP_SERVER=false + +# Create required directories +mkdir -p "${OUTPUT_DIR}" + +# Display usage information +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --num-runs N Number of runs to perform (default: 1)" + echo " --qwen3-workers N Number of Qwen3 API parallel workers (default: 10)" + echo " --qwen-vl-workers N Number of Qwen-VL parallel workers (default: 4)" + echo " --eval-workers N Number of evaluation parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo " --qwen-vl-port N Port for Qwen-VL server (default: 5003)" + echo " --qwen3-port N Port for Qwen3 server (default: 5001)" + echo " --stop-server Stop servers when done" + echo "" + exit 0 +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --num-runs) + NUM_RUNS="$2" + shift 2 + ;; + --qwen3-workers) + QWEN3_WORKERS="$2" + shift 2 + ;; + --qwen-vl-workers) + QWEN_VL_WORKERS="$2" + shift 2 + ;; + --eval-workers) + EVAL_WORKERS="$2" + shift 2 + ;; + --max-samples) + MAX_SAMPLES="$2" + shift 2 + ;; + --qwen-vl-port) + QWEN_VL_PORT="$2" + QWEN_VL_SERVER_URL="http://localhost:${QWEN_VL_PORT}" + shift 2 + ;; + --qwen3-port) + QWEN3_PORT="$2" + QWEN3_SERVER_URL="http://localhost:${QWEN3_PORT}" + shift 2 + ;; + --stop-server) + STOP_SERVER=true + shift + ;; + --help) + show_help + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Extract dataset name without extension and path +DATASET_NAME=$(basename "$DATASET_PATH" .json) + +# Create log directory +mkdir -p "${PROJECT_ROOT}/logs" +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_PATH="${PROJECT_ROOT}/logs/qwen3_injection_workflow_${TIMESTAMP}.log" + +echo "========================================" | tee -a "$LOG_PATH" +echo "Configuration:" | tee -a "$LOG_PATH" +echo "DATASET_PATH = ${DATASET_PATH}" | tee -a "$LOG_PATH" +echo "OUTPUT_DIR = ${OUTPUT_DIR}" | tee -a "$LOG_PATH" +echo "NUM_RUNS = ${NUM_RUNS}" | tee -a "$LOG_PATH" +echo "MAX_SAMPLES = ${MAX_SAMPLES}" | tee -a "$LOG_PATH" +echo "QWEN3_WORKERS = ${QWEN3_WORKERS}" | tee -a "$LOG_PATH" +echo "QWEN_VL_WORKERS = ${QWEN_VL_WORKERS}" | tee -a "$LOG_PATH" +echo "EVAL_WORKERS = ${EVAL_WORKERS}" | tee -a "$LOG_PATH" +echo "QWEN_VL_SERVER_URL = ${QWEN_VL_SERVER_URL}" | tee -a "$LOG_PATH" +echo "QWEN3_SERVER_URL = ${QWEN3_SERVER_URL}" | tee -a "$LOG_PATH" +echo "LOG_PATH = ${LOG_PATH}" | tee -a "$LOG_PATH" +echo "========================================" | tee -a "$LOG_PATH" +echo "" | tee -a "$LOG_PATH" + +# ── Initialize Conda in this shell ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh. Do you need to run 'conda init'?" | tee -a "$LOG_PATH" + exit 1 +fi + +# ------------------------------------------------------------------------------ +# 1) Record start time and verify servers are running +# ------------------------------------------------------------------------------ +STARTTIME=$(date +%s) + +# Check Qwen-VL server +echo "Checking if Qwen-VL server is running..." | tee -a "$LOG_PATH" +if nc -z localhost $QWEN_VL_PORT 2>/dev/null; then + echo "Qwen-VL server running on port $QWEN_VL_PORT" | tee -a "$LOG_PATH" +else + echo "ERROR: Qwen-VL server is not running on port $QWEN_VL_PORT" | tee -a "$LOG_PATH" + echo "Please start it first with: src/qwen_utils/start_qwen_vl_server.sh" | tee -a "$LOG_PATH" + exit 1 +fi + +# Check Qwen3 server +echo "Checking if Qwen3 server is running..." | tee -a "$LOG_PATH" +if nc -z localhost $QWEN3_PORT 2>/dev/null; then + echo "Qwen3 server running on port $QWEN3_PORT" | tee -a "$LOG_PATH" +else + echo "ERROR: Qwen3 server is not running on port $QWEN3_PORT" | tee -a "$LOG_PATH" + echo "Please start it first with: src/qwen3_utils/start_qwen3_server.sh" | tee -a "$LOG_PATH" + exit 1 +fi + +# ------------------------------------------------------------------------------ +# 2) Generate Qwen-VL observations (shared across all runs) +# ------------------------------------------------------------------------------ +QWEN_VL_OUT_DIR="${OUTPUT_DIR}/qwen-vl-thinking-${DATASET_NAME}" +mkdir -p "$QWEN_VL_OUT_DIR" +QWEN_VL_OUT="${QWEN_VL_OUT_DIR}/generated_answer.json" + +if [ -f "$QWEN_VL_OUT" ]; then + echo "Found existing Qwen-VL observations at $QWEN_VL_OUT" | tee -a "$LOG_PATH" + echo "Reusing existing observations for injection" | tee -a "$LOG_PATH" +else + echo "Generating Qwen-VL observations..." | tee -a "$LOG_PATH" + + # Activate environment + conda activate evaluation + + # Set PYTHONPATH + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + python "${PROJECT_ROOT}/src/qwen_inference.py" \ + --dataset_path "$DATASET_PATH" \ + --output_path "$QWEN_VL_OUT" \ + --server_url "$QWEN_VL_SERVER_URL" \ + --workers $QWEN_VL_WORKERS \ + --max_samples $MAX_SAMPLES | tee -a "$LOG_PATH" + + echo "Qwen-VL observations saved to $QWEN_VL_OUT" | tee -a "$LOG_PATH" +fi + +# ------------------------------------------------------------------------------ +# 3) Run Qwen3 with injection for each run +# ------------------------------------------------------------------------------ +for RUN_NUM in $(seq 1 $NUM_RUNS); do + echo "" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + echo "Starting Qwen3 Injection Run #$RUN_NUM of $NUM_RUNS" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + # Define output paths for this run + RUN_SUFFIX="run${RUN_NUM}" + INJECTION_OUT_DIR="${OUTPUT_DIR}/qwen3-injection-${DATASET_NAME}-${RUN_SUFFIX}" + mkdir -p "$INJECTION_OUT_DIR" + INJECTION_OUT="${INJECTION_OUT_DIR}/generated_answer.json" + + # Activate Qwen3 environment + echo "Activating qwen3-vllm environment..." | tee -a "$LOG_PATH" + conda activate qwen3-vllm + + # Set PYTHONPATH + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + # Run Qwen3 with injection + echo "Running Qwen3 with Qwen-VL injection (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + echo " Dataset : $DATASET_PATH" | tee -a "$LOG_PATH" + echo " Injection : $QWEN_VL_OUT" | tee -a "$LOG_PATH" + echo " Output : $INJECTION_OUT" | tee -a "$LOG_PATH" + + python "${PROJECT_ROOT}/src/qwen3_with_injection.py" \ + --dataset_path "$DATASET_PATH" \ + --injection_path "$QWEN_VL_OUT" \ + --output_path "$INJECTION_OUT" \ + --server_url "$QWEN3_SERVER_URL" \ + --workers $QWEN3_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM: Injection results saved to $INJECTION_OUT" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + # -------------------------------------------------------------------------- + # 4) Evaluate injection results + # -------------------------------------------------------------------------- + echo "Evaluating injection results (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + + INJECTION_EXP_NAME=$(basename "$INJECTION_OUT_DIR") + + # Switch to evaluation environment + conda activate evaluation + + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + echo " Evaluating Qwen3 injection results..." | tee -a "$LOG_PATH" + python "${PROJECT_ROOT}/evaluation/evaluate_with_sampling.py" \ + --exp "$INJECTION_EXP_NAME" \ + --dataset "$DATASET_PATH" \ + --generated "$INJECTION_OUT" \ + --num_workers $EVAL_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM: Evaluation complete" | tee -a "$LOG_PATH" + echo " Results in ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM Complete!" | tee -a "$LOG_PATH" + echo "Results:" | tee -a "$LOG_PATH" + echo " - Injection results: $INJECTION_OUT" | tee -a "$LOG_PATH" + echo " - Evaluation: ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +done + +# ------------------------------------------------------------------------------ +# 5) Runtime Summary +# ------------------------------------------------------------------------------ +ENDTIME=$(date +%s) +RUNTIME=$((ENDTIME - STARTTIME)) +echo "=======================================" | tee -a "$LOG_PATH" +echo "Total runtime: $RUNTIME seconds ($(($RUNTIME / 60)) minutes)" | tee -a "$LOG_PATH" +echo "=======================================" | tee -a "$LOG_PATH" + +# ------------------------------------------------------------------------------ +# 6) Summary of all runs +# ------------------------------------------------------------------------------ +echo "========================================" | tee -a "$LOG_PATH" +echo "All $NUM_RUNS Runs Completed Successfully" | tee -a "$LOG_PATH" +echo "========================================" | tee -a "$LOG_PATH" + +echo "Summary of result locations:" | tee -a "$LOG_PATH" +echo " - Qwen-VL observations: $QWEN_VL_OUT" | tee -a "$LOG_PATH" + +for RUN_NUM in $(seq 1 $NUM_RUNS); do + RUN_SUFFIX="run${RUN_NUM}" + INJECTION_OUT_DIR="${OUTPUT_DIR}/qwen3-injection-${DATASET_NAME}-${RUN_SUFFIX}" + INJECTION_OUT="${INJECTION_OUT_DIR}/generated_answer.json" + INJECTION_EXP_NAME=$(basename "$INJECTION_OUT_DIR") + + echo "Run #$RUN_NUM:" | tee -a "$LOG_PATH" + echo " - Injection results: $INJECTION_OUT" | tee -a "$LOG_PATH" + echo " - Evaluation: ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +done + +# ------------------------------------------------------------------------------ +# 7) Stop servers if requested +# ------------------------------------------------------------------------------ +if [ "$STOP_SERVER" = true ]; then + echo "Stopping servers..." | tee -a "$LOG_PATH" + "${PROJECT_ROOT}/src/qwen_utils/stop_qwen_vl_server.sh" || true + "${PROJECT_ROOT}/src/qwen3_utils/stop_qwen3_server.sh" || true + echo "Servers stopped" | tee -a "$LOG_PATH" +else + echo "" | tee -a "$LOG_PATH" + echo "NOTE: Both the Qwen-VL and Qwen3 servers are still running." | tee -a "$LOG_PATH" + echo "When you are done with all evaluations, stop them using:" | tee -a "$LOG_PATH" + echo " ${PROJECT_ROOT}/src/qwen_utils/stop_qwen_vl_server.sh" | tee -a "$LOG_PATH" + echo " ${PROJECT_ROOT}/src/qwen3_utils/stop_qwen3_server.sh" | tee -a "$LOG_PATH" + echo "" | tee -a "$LOG_PATH" +fi + +echo "Log saved to: ${LOG_PATH}" | tee -a "$LOG_PATH" +echo "Workflow complete!" | tee -a "$LOG_PATH" diff --git a/scripts/run_qwen_inference.sh b/scripts/run_qwen_inference.sh new file mode 100755 index 0000000..b0113c4 --- /dev/null +++ b/scripts/run_qwen_inference.sh @@ -0,0 +1,191 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_qwen_inference.sh +# +# Script to run Qwen2.5-VL inference on a dataset. +# This script: +# 1. Checks if Qwen2.5-VL server is running, and starts it if not +# 2. Runs the Qwen2.5-VL inference script on the specified dataset +# 3. Creates output directories and saves results +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Defaults (can be overridden with command line arguments) +DATASET_PATH="${PROJECT_ROOT}/dataset/processed/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +OUTPUT_NAME="qwen_vl_results.json" + +# Server configuration +QWEN_PORT=5003 +QWEN_PID_FILE="/tmp/qwen_vl_server_${QWEN_PORT}.pid" +QWEN_SERVER_URL="http://localhost:${QWEN_PORT}" + +# Create required directories +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${PROJECT_ROOT}/logs" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --output) + OUTPUT_NAME="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --workers) + WORKERS="$2" + shift 2 + ;; + --max-samples) + MAX_SAMPLES="$2" + shift 2 + ;; + --help) + echo "Usage: $0 [--dataset PATH] [--output FILENAME] [--output-dir DIR] [--workers N] [--max-samples N]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output FILENAME Name of the output JSON file (default: qwen_vl_results.json)" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --workers N Number of parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo "" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Full paths for output +OUTPUT_PATH="${OUTPUT_DIR}/${OUTPUT_NAME}" +LOG_PATH="${PROJECT_ROOT}/logs/qwen_inference_$(date +%Y-%m-%d-%H-%M-%S).log" + +# Optional parameters for the inference script +OPTIONAL_PARAMS="" +if [[ -v WORKERS ]]; then + OPTIONAL_PARAMS="${OPTIONAL_PARAMS} --workers ${WORKERS}" +fi + +if [[ -v MAX_SAMPLES ]]; then + OPTIONAL_PARAMS="${OPTIONAL_PARAMS} --max_samples ${MAX_SAMPLES}" +fi + +echo "========================================" +echo "Configuration:" +echo "DATASET_PATH = ${DATASET_PATH}" +echo "OUTPUT_PATH = ${OUTPUT_PATH}" +echo "LOG_PATH = ${LOG_PATH}" +echo "SERVER_URL = ${QWEN_SERVER_URL}" +echo "OPTIONAL_PARAMS = ${OPTIONAL_PARAMS}" +echo "========================================" +echo "" + +# ── Check if Qwen2.5-VL server is running ──────────────────────────────────── +check_server() { + echo "Checking if Qwen2.5-VL server is running..." + + # Check if PID file exists + if [ -f "$QWEN_PID_FILE" ]; then + PID=$(cat "$QWEN_PID_FILE") + echo "Found PID file with PID: $PID" + + # Check if process is actually running + if kill -0 "$PID" 2>/dev/null; then + echo "Qwen2.5-VL server is running with PID: $PID" + return 0 + else + echo "PID file exists but process is not running" + rm -f "$QWEN_PID_FILE" + fi + fi + + # Check if port is in use (server might be running without PID file) + if nc -z localhost "$QWEN_PORT" 2>/dev/null; then + echo "Port $QWEN_PORT is in use, assuming Qwen2.5-VL server is running" + return 0 + fi + + echo "Qwen2.5-VL server is not running" + return 1 +} + +# ── Start Qwen2.5-VL server if needed ──────────────────────────────────────── +if check_server; then + echo "Using existing Qwen2.5-VL server" +else + echo "Starting Qwen2.5-VL server..." + QWEN_UTILS_DIR="${PROJECT_ROOT}/src/qwen_utils" + + if [ ! -x "${QWEN_UTILS_DIR}/start_qwen_vl_server.sh" ]; then + echo "Error: Qwen2.5-VL server start script not found or not executable: ${QWEN_UTILS_DIR}/start_qwen_vl_server.sh" + exit 1 + fi + + # Start the server + echo "Running: ${QWEN_UTILS_DIR}/start_qwen_vl_server.sh" + "${QWEN_UTILS_DIR}/start_qwen_vl_server.sh" + + # Check if server started successfully + if ! check_server; then + echo "Error: Failed to start Qwen2.5-VL server" + exit 1 + fi +fi + +# ── Run Qwen2.5-VL inference script ───────────────────────────────────────── +echo "Running Qwen2.5-VL inference..." +echo "Dataset: ${DATASET_PATH}" +echo "Output: ${OUTPUT_PATH}" +echo "Log: ${LOG_PATH}" + +# Ensure inference script is executable +INFERENCE_SCRIPT="${PROJECT_ROOT}/src/qwen_inference.py" +if [ ! -x "$INFERENCE_SCRIPT" ]; then + chmod +x "$INFERENCE_SCRIPT" +fi + +# Run the inference script +echo "Command: python $INFERENCE_SCRIPT --dataset_path $DATASET_PATH --output_path $OUTPUT_PATH --server_url $QWEN_SERVER_URL $OPTIONAL_PARAMS" +python "$INFERENCE_SCRIPT" \ + --dataset_path "$DATASET_PATH" \ + --output_path "$OUTPUT_PATH" \ + --server_url "$QWEN_SERVER_URL" \ + $OPTIONAL_PARAMS | tee -a "$LOG_PATH" + +# Check if inference completed successfully +if [ $? -eq 0 ]; then + echo "========================================" + echo "Qwen2.5-VL inference completed successfully!" + echo "Results saved to: ${OUTPUT_PATH}" + echo "Log saved to: ${LOG_PATH}" + echo "========================================" +else + echo "========================================" + echo "Qwen2.5-VL inference failed with an error" + echo "Check the log for details: ${LOG_PATH}" + echo "========================================" + exit 1 +fi diff --git a/scripts/run_r1_injection_workflow.sh b/scripts/run_r1_injection_workflow.sh new file mode 100755 index 0000000..1697a3a --- /dev/null +++ b/scripts/run_r1_injection_workflow.sh @@ -0,0 +1,222 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_r1_injection_workflow.sh +# +# A script to run DeepSeek-R1 with Qwen-VL injection pipeline for time +# series analysis. Identical pattern to Qwen3 — DeepSeek-R1-Distill-Qwen-32B +# uses the same tokenizer, API, and continue_final_message injection. +# +# This pipeline: +# 1. Ensures Qwen-VL server is running and generates VL observations +# 2. Ensures DeepSeek-R1 server is running +# 3. Runs R1 with Qwen-VL injection +# 4. Evaluates the results +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +DATASET_PATH="${PROJECT_ROOT}/dataset/dataset_a_with_mcq2.json" +OUTPUT_DIR="${PROJECT_ROOT}/evaluation/results" +NUM_RUNS=1 +MAX_SAMPLES=200 +R1_WORKERS=10 +QWEN_VL_WORKERS=4 +EVAL_WORKERS=4 + +# Server configuration +QWEN_VL_PORT=5003 +QWEN_VL_SERVER_URL="http://localhost:${QWEN_VL_PORT}" +R1_PORT=5002 +R1_SERVER_URL="http://localhost:${R1_PORT}" + +STOP_SERVER=false + +mkdir -p "${OUTPUT_DIR}" + +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --dataset PATH Path to the dataset JSON file" + echo " --output-dir DIR Directory to save results (default: ../evaluation/results)" + echo " --num-runs N Number of runs to perform (default: 1)" + echo " --r1-workers N Number of R1 API parallel workers (default: 10)" + echo " --qwen-vl-workers N Number of Qwen-VL parallel workers (default: 4)" + echo " --eval-workers N Number of evaluation parallel workers (default: 4)" + echo " --max-samples N Maximum number of samples to process (default: 200)" + echo " --qwen-vl-port N Port for Qwen-VL server (default: 5003)" + echo " --r1-port N Port for DeepSeek-R1 server (default: 5002)" + echo " --stop-server Stop servers when done" + echo "" + exit 0 +} + +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) DATASET_PATH="$2"; shift 2 ;; + --output-dir) OUTPUT_DIR="$2"; shift 2 ;; + --num-runs) NUM_RUNS="$2"; shift 2 ;; + --r1-workers) R1_WORKERS="$2"; shift 2 ;; + --qwen-vl-workers) QWEN_VL_WORKERS="$2"; shift 2 ;; + --eval-workers) EVAL_WORKERS="$2"; shift 2 ;; + --max-samples) MAX_SAMPLES="$2"; shift 2 ;; + --qwen-vl-port) QWEN_VL_PORT="$2"; QWEN_VL_SERVER_URL="http://localhost:${QWEN_VL_PORT}"; shift 2 ;; + --r1-port) R1_PORT="$2"; R1_SERVER_URL="http://localhost:${R1_PORT}"; shift 2 ;; + --stop-server) STOP_SERVER=true; shift ;; + --help) show_help ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +DATASET_NAME=$(basename "$DATASET_PATH" .json) + +mkdir -p "${PROJECT_ROOT}/logs" +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +LOG_PATH="${PROJECT_ROOT}/logs/r1_injection_workflow_${TIMESTAMP}.log" + +echo "========================================" | tee -a "$LOG_PATH" +echo "Configuration:" | tee -a "$LOG_PATH" +echo "DATASET_PATH = ${DATASET_PATH}" | tee -a "$LOG_PATH" +echo "OUTPUT_DIR = ${OUTPUT_DIR}" | tee -a "$LOG_PATH" +echo "NUM_RUNS = ${NUM_RUNS}" | tee -a "$LOG_PATH" +echo "MAX_SAMPLES = ${MAX_SAMPLES}" | tee -a "$LOG_PATH" +echo "R1_WORKERS = ${R1_WORKERS}" | tee -a "$LOG_PATH" +echo "QWEN_VL_WORKERS = ${QWEN_VL_WORKERS}" | tee -a "$LOG_PATH" +echo "EVAL_WORKERS = ${EVAL_WORKERS}" | tee -a "$LOG_PATH" +echo "QWEN_VL_SERVER_URL = ${QWEN_VL_SERVER_URL}" | tee -a "$LOG_PATH" +echo "R1_SERVER_URL = ${R1_SERVER_URL}" | tee -a "$LOG_PATH" +echo "========================================" | tee -a "$LOG_PATH" +echo "" | tee -a "$LOG_PATH" + +# ── Initialize Conda ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh." | tee -a "$LOG_PATH" + exit 1 +fi + +# 1) Verify servers +STARTTIME=$(date +%s) + +echo "Checking if Qwen-VL server is running..." | tee -a "$LOG_PATH" +if nc -z localhost $QWEN_VL_PORT 2>/dev/null; then + echo "Qwen-VL server running on port $QWEN_VL_PORT" | tee -a "$LOG_PATH" +else + echo "ERROR: Qwen-VL server is not running on port $QWEN_VL_PORT" | tee -a "$LOG_PATH" + echo "Please start it first with: src/qwen_utils/start_qwen_vl_server.sh" | tee -a "$LOG_PATH" + exit 1 +fi + +echo "Checking if DeepSeek-R1 server is running..." | tee -a "$LOG_PATH" +if nc -z localhost $R1_PORT 2>/dev/null; then + echo "DeepSeek-R1 server running on port $R1_PORT" | tee -a "$LOG_PATH" +else + echo "ERROR: DeepSeek-R1 server is not running on port $R1_PORT" | tee -a "$LOG_PATH" + echo "Please start it first with: src/r1_utils/start_r1_server.sh" | tee -a "$LOG_PATH" + exit 1 +fi + +# 2) Generate Qwen-VL observations (shared across all runs) +QWEN_VL_OUT_DIR="${OUTPUT_DIR}/qwen-vl-thinking-${DATASET_NAME}" +mkdir -p "$QWEN_VL_OUT_DIR" +QWEN_VL_OUT="${QWEN_VL_OUT_DIR}/generated_answer.json" + +if [ -f "$QWEN_VL_OUT" ]; then + echo "Found existing Qwen-VL observations at $QWEN_VL_OUT" | tee -a "$LOG_PATH" + echo "Reusing existing observations for injection" | tee -a "$LOG_PATH" +else + echo "Generating Qwen-VL observations..." | tee -a "$LOG_PATH" + conda activate evaluation + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + python "${PROJECT_ROOT}/src/qwen_inference.py" \ + --dataset_path "$DATASET_PATH" \ + --output_path "$QWEN_VL_OUT" \ + --server_url "$QWEN_VL_SERVER_URL" \ + --workers $QWEN_VL_WORKERS \ + --max_samples $MAX_SAMPLES | tee -a "$LOG_PATH" + + echo "Qwen-VL observations saved to $QWEN_VL_OUT" | tee -a "$LOG_PATH" +fi + +# 3) Run R1 with injection for each run +for RUN_NUM in $(seq 1 $NUM_RUNS); do + echo "" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + echo "Starting DeepSeek-R1 Injection Run #$RUN_NUM of $NUM_RUNS" | tee -a "$LOG_PATH" + echo "========================================" | tee -a "$LOG_PATH" + + RUN_SUFFIX="run${RUN_NUM}" + INJECTION_OUT_DIR="${OUTPUT_DIR}/r1-injection-${DATASET_NAME}-${RUN_SUFFIX}" + mkdir -p "$INJECTION_OUT_DIR" + INJECTION_OUT="${INJECTION_OUT_DIR}/generated_answer.json" + + conda activate qwen3-vllm + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + echo "Running DeepSeek-R1 with Qwen-VL injection (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + python "${PROJECT_ROOT}/src/qwen3_with_injection.py" \ + --dataset_path "$DATASET_PATH" \ + --injection_path "$QWEN_VL_OUT" \ + --output_path "$INJECTION_OUT" \ + --server_url "$R1_SERVER_URL" \ + --model_name "r1" \ + --workers $R1_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM: Injection results saved to $INJECTION_OUT" | tee -a "$LOG_PATH" + + # 4) Evaluate + echo "Evaluating injection results (Run #$RUN_NUM)" | tee -a "$LOG_PATH" + INJECTION_EXP_NAME=$(basename "$INJECTION_OUT_DIR") + conda activate evaluation + export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + + python "${PROJECT_ROOT}/evaluation/evaluate_with_sampling.py" \ + --exp "$INJECTION_EXP_NAME" \ + --dataset "$DATASET_PATH" \ + --generated "$INJECTION_OUT" \ + --num_workers $EVAL_WORKERS | tee -a "$LOG_PATH" + + echo "Run #$RUN_NUM Complete!" | tee -a "$LOG_PATH" + echo " - Injection results: $INJECTION_OUT" | tee -a "$LOG_PATH" + echo " - Evaluation: ${PROJECT_ROOT}/evaluation/exp/$INJECTION_EXP_NAME/" | tee -a "$LOG_PATH" +done + +# 5) Summary +ENDTIME=$(date +%s) +RUNTIME=$((ENDTIME - STARTTIME)) +echo "=======================================" | tee -a "$LOG_PATH" +echo "Total runtime: $RUNTIME seconds ($(($RUNTIME / 60)) minutes)" | tee -a "$LOG_PATH" +echo "All $NUM_RUNS Runs Completed Successfully" | tee -a "$LOG_PATH" +echo "=======================================" | tee -a "$LOG_PATH" + +# 6) Stop servers if requested +if [ "$STOP_SERVER" = true ]; then + echo "Stopping servers..." | tee -a "$LOG_PATH" + "${PROJECT_ROOT}/src/qwen_utils/stop_qwen_vl_server.sh" || true + "${PROJECT_ROOT}/src/r1_utils/stop_r1_server.sh" || true + echo "Servers stopped" | tee -a "$LOG_PATH" +else + echo "" | tee -a "$LOG_PATH" + echo "NOTE: Servers are still running. Stop them with:" | tee -a "$LOG_PATH" + echo " ${PROJECT_ROOT}/src/qwen_utils/stop_qwen_vl_server.sh" | tee -a "$LOG_PATH" + echo " ${PROJECT_ROOT}/src/r1_utils/stop_r1_server.sh" | tee -a "$LOG_PATH" +fi + +echo "Log saved to: ${LOG_PATH}" | tee -a "$LOG_PATH" +echo "Workflow complete!" | tee -a "$LOG_PATH" diff --git a/scripts/run_stochastic_refinement.sh b/scripts/run_stochastic_refinement.sh new file mode 100755 index 0000000..1b312b4 --- /dev/null +++ b/scripts/run_stochastic_refinement.sh @@ -0,0 +1,120 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_stochastic_refinement.sh +# +# Stage 2: Stochastic refinement of synthetic data generators. +# Simplifies Stage 1 models into sampling-based generators that produce +# diverse synthetic time series data. +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root +SYNTHETIC_DIR="${PROJECT_ROOT}/dataset/synthetic" + +echo "========================================" +echo "Stage 2: Stochastic Refinement" +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "SYNTHETIC_DIR = $SYNTHETIC_DIR" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +DATASET_PATH="${SYNTHETIC_DIR}/sample_data/qa_benchmark_base_train.json" +RESULTS_DIR="${SYNTHETIC_DIR}/results/iterative_results" +OUTPUT_DIR="${SYNTHETIC_DIR}/results/stochastic_results" +NUM_CLAUDE_CALLS=3 +MAX_WORKERS="" +REGION="us-west-2" + +# ── Parse command line arguments ───────────────────────────────────────────── +SAMPLE_IDS=() +while [[ $# -gt 0 ]]; do + case $1 in + --dataset) + DATASET_PATH="$2" + shift 2 + ;; + --results_dir) + RESULTS_DIR="$2" + shift 2 + ;; + --output_dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --num_claude_calls) + NUM_CLAUDE_CALLS="$2" + shift 2 + ;; + --max_workers) + MAX_WORKERS="$2" + shift 2 + ;; + --region) + REGION="$2" + shift 2 + ;; + --sample_ids) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + SAMPLE_IDS+=("$1") + shift + done + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +echo "Dataset: $DATASET_PATH" +echo "Stage 1 results: $RESULTS_DIR" +echo "Output: $OUTPUT_DIR" +echo "Claude calls: $NUM_CLAUDE_CALLS" +echo "Region: $REGION" +if [ ${#SAMPLE_IDS[@]} -gt 0 ]; then + echo "Sample IDs: ${SAMPLE_IDS[*]}" +else + echo "Sample IDs: (all available)" +fi +echo "Start time: $(date)" +echo "" + +# ── Create output directory ────────────────────────────────────────────────── +mkdir -p "${OUTPUT_DIR}" + +# ── Build Python command ───────────────────────────────────────────────────── +CMD=( + python "${SYNTHETIC_DIR}/stochastic_ts_generation.py" + --dataset_path="${DATASET_PATH}" + --results_dir="${RESULTS_DIR}" + --output_dir="${OUTPUT_DIR}" + --num_claude_calls="${NUM_CLAUDE_CALLS}" + --region="${REGION}" +) + +if [ -n "${MAX_WORKERS}" ]; then + CMD+=(--max_workers="${MAX_WORKERS}") +fi + +if [ ${#SAMPLE_IDS[@]} -gt 0 ]; then + CMD+=(--sample_ids "${SAMPLE_IDS[@]}") +fi + +# ── Run ────────────────────────────────────────────────────────────────────── +"${CMD[@]}" + +echo "" +echo "========================================" +echo "Stage 2 Complete" +echo "Results: ${OUTPUT_DIR}" +echo "" +echo "NEXT STEP: Review the results in ${OUTPUT_DIR} and keep the best" +echo "stochastic function for each sample before running Stage 3." +echo "End time: $(date)" +echo "========================================" diff --git a/scripts/run_synthetic_benchmark.sh b/scripts/run_synthetic_benchmark.sh new file mode 100755 index 0000000..ba94932 --- /dev/null +++ b/scripts/run_synthetic_benchmark.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# run_synthetic_benchmark.sh +# +# Stages 3-4: Generate synthetic benchmark dataset. +# Orchestrates the full pipeline from synthetic data generation to filtered +# MCQ benchmark creation. +# +# Usage: ./scripts/run_synthetic_benchmark.sh [SAMPLES_PER_SOURCE] +# SAMPLES_PER_SOURCE defaults to 100 if not specified. +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/scripts +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root +SYNTHETIC_DIR="${PROJECT_ROOT}/dataset/synthetic" + +echo "========================================" +echo "Stages 3-4: Synthetic Benchmark Generation" +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "SYNTHETIC_DIR = $SYNTHETIC_DIR" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +STOCHASTIC_RESULTS_DIR="${SYNTHETIC_DIR}/results/stochastic_results" +TRAINING_DATA_DIR="${SYNTHETIC_DIR}/results/synthetic_training_data" +DATASET_PATH="${SYNTHETIC_DIR}/sample_data/qa_benchmark_base_train.json" +SAMPLES_PER_SOURCE=${1:-100} +REGION="${REGION:-us-west-2}" + +echo "Samples per source: $SAMPLES_PER_SOURCE" +echo "Stochastic results dir: $STOCHASTIC_RESULTS_DIR" +echo "Dataset path: $DATASET_PATH" +echo "Region: $REGION" +echo "Output dir: $TRAINING_DATA_DIR" +echo "Start time: $(date)" +echo "" + +# ── Create output directory ────────────────────────────────────────────────── +mkdir -p "${TRAINING_DATA_DIR}" + +# ── Step 1: Generate synthetic dataset ─────────────────────────────────────── +echo "Step 1: Generating synthetic dataset..." +python "${SYNTHETIC_DIR}/generate_synthetic_dataset.py" \ + --stochastic_results_dir="${STOCHASTIC_RESULTS_DIR}" \ + --dataset_path="${DATASET_PATH}" \ + --output_dir="${TRAINING_DATA_DIR}" \ + --samples_per_source="${SAMPLES_PER_SOURCE}" + +if [ ! -f "${TRAINING_DATA_DIR}/data_ts.json" ]; then + echo "Error: Failed to create data_ts.json" + exit 1 +fi +echo "Successfully created data_ts.json" +echo "" + +# ── Step 2: Generate QA benchmark (uses LLM for answer diversification) ────── +echo "Step 2: Generating QA benchmark (with LLM diversification)..." +python "${SYNTHETIC_DIR}/generate_qa_benchmark.py" \ + --data_ts_path="${TRAINING_DATA_DIR}/data_ts.json" \ + --dataset_path="${DATASET_PATH}" \ + --output_path="${TRAINING_DATA_DIR}/qa_synthetic_base.json" \ + --region="${REGION}" + +if [ ! -f "${TRAINING_DATA_DIR}/qa_synthetic_base.json" ]; then + echo "Error: Failed to create qa_synthetic_base.json" + exit 1 +fi +echo "Successfully created qa_synthetic_base.json" +echo "" + +# ── Step 3: Generate MCQ benchmark ─────────────────────────────────────────── +echo "Step 3: Generating MCQ benchmark..." +python "${SYNTHETIC_DIR}/generate_mcq_benchmark.py" \ + --qa_benchmark_path="${TRAINING_DATA_DIR}/qa_synthetic_base.json" \ + --output_path="${TRAINING_DATA_DIR}/rme_synthetic_easy_unfiltered.json" + +if [ ! -f "${TRAINING_DATA_DIR}/rme_synthetic_easy_unfiltered.json" ]; then + echo "Error: Failed to create rme_synthetic_easy_unfiltered.json" + exit 1 +fi +echo "Successfully created rme_synthetic_easy_unfiltered.json" +echo "" + +# ── Step 4: Filter MCQ benchmark ──────────────────────────────────────────── +echo "Step 4: Filtering MCQ benchmark (keeping MCQ_obs and MCQ_cause)..." +python "${SYNTHETIC_DIR}/filter_mcq_benchmark.py" \ + --input_file="${TRAINING_DATA_DIR}/rme_synthetic_easy_unfiltered.json" \ + --output_file="${TRAINING_DATA_DIR}/rme_synthetic_easy.json" \ + --keep_ability_types "MCQ_obs" "MCQ_cause" + +if [ ! -f "${TRAINING_DATA_DIR}/rme_synthetic_easy.json" ]; then + echo "Error: Failed to create filtered rme_synthetic_easy.json" + exit 1 +fi +echo "Successfully created rme_synthetic_easy.json" +echo "" + +# ── Print dataset statistics ───────────────────────────────────────────────── +echo "========================================" +echo "Synthetic Benchmark Dataset Statistics" +echo "========================================" +echo "From dataset_summary.json:" +cat "${TRAINING_DATA_DIR}/dataset_summary.json" +echo "" +echo "QA benchmark count: $(python -c "import json; print(len(json.load(open('${TRAINING_DATA_DIR}/qa_synthetic_base.json'))))")" +echo "MCQ benchmark count: $(python -c "import json; print(len(json.load(open('${TRAINING_DATA_DIR}/rme_synthetic_easy.json'))))")" +echo "" +echo "========================================" +echo "Stages 3-4 Complete" +echo "Results: ${TRAINING_DATA_DIR}" +echo "End time: $(date)" +echo "========================================" diff --git a/src/chatts_inference.py b/src/chatts_inference.py new file mode 100644 index 0000000..db4b538 --- /dev/null +++ b/src/chatts_inference.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +ChatTS server-based inference benchmark script for time series datasets. + +This script: +1. Loads a time series dataset +2. Connects to a running ChatTS server (start with start_chatts_server.sh) +3. Runs inference with the ChatTS server using the OpenAI-compatible API +4. Saves the results to a JSON file + +Usage: + python chatts_inference_server.py --dataset_path /path/to/dataset.json --output_path /path/to/output.json +""" + +import os +import sys +import json +import time +import argparse +import logging +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor +import threading +from queue import Queue + +try: + from openai import OpenAI +except ImportError: + raise ImportError("OpenAI Python client not installed. Please install with 'pip install openai'") + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def parse_args(): + p = argparse.ArgumentParser(description="Evaluate ChatTS on a time-series QA dataset using a server") + p.add_argument("--server_url", default="http://localhost:5000", + help="URL of the ChatTS server") + p.add_argument("--dataset_path", "-d", required=True, + help="Path to the JSON evaluation set") + p.add_argument("--output_path", "-o", required=True, + help="Where to write generated answers JSON") + p.add_argument("--max_tokens", type=int, default=600, + help="Maximum tokens to generate") + p.add_argument("--max_samples", type=int, default=200, + help="Maximum number of samples to process") + p.add_argument("--seed", type=int, default=42, + help="Random seed for sampling") + p.add_argument("--checkpoint_interval", type=int, default=10, + help="Interval for saving checkpoints") + p.add_argument("--workers", type=int, default=4, + help="Number of parallel workers for processing samples") + p.add_argument("--timeout", type=int, default=120, + help="Timeout in seconds for API calls") + p.add_argument("--retry_delay", type=int, default=5, + help="Delay between retries in seconds") + p.add_argument("--max_retries", type=int, default=3, + help="Maximum number of retry attempts") + return p.parse_args() + +class ChatTSClient: + """Client for communicating with ChatTS server using OpenAI API.""" + + def __init__(self, server_url="http://localhost:5000", debug_mode=False): + """Initialize the ChatTS client.""" + self.server_url = server_url + self.debug_mode = debug_mode + self.client = OpenAI(base_url=f"{server_url}/v1", api_key="dummy-key") + + if debug_mode: + logger.setLevel(logging.DEBUG) + logger.info(f"ChatTSClient initialized in DEBUG mode with server URL: {server_url}") + + def check_server_health(self): + """Check if the server is healthy.""" + import requests + try: + logger.info(f"Checking health of ChatTS server at {self.server_url}...") + # Try to get the models list as a health check + response = requests.get(f"{self.server_url}/v1/models", timeout=10) + if response.status_code == 200: + logger.info(f"ChatTS server is healthy") + return True + else: + logger.warning(f"ChatTS server health check failed: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error checking server health: {type(e).__name__}: {e}") + return False + + def query_chatts_with_timeseries( + self, + timeseries, + question, + max_tokens=600, + temperature=0.01, + timeout=120, + retry_delay=5, + max_retries=3, + ): + """ + Query ChatTS with time series data. + + Args: + timeseries: Time series data array + question: Question with markers + max_tokens: Maximum tokens to generate + temperature: Temperature for sampling + timeout: Request timeout in seconds + retry_delay: Delay between retries in seconds + max_retries: Maximum number of retry attempts + + Returns: + Model response as string + """ + # Format the question with chat template + # formatted_question = f"<|im_start|>system\nYou are a helpful assistant.\n<|im_end|><|im_start|>user\n{question}\n<|im_end|><|im_start|>assistant\n" + + # logger.debug(f"Formatted question with template: {formatted_question[:100]}...") + formatted_question = question + + # Create messages array with timeseries data + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": formatted_question} + ] + [{"timeseries": ts} for ts in timeseries] + } + ] + + # Make the request with retries + for attempt in range(max_retries): + try: + logger.info(f"Sending query to ChatTS server (attempt {attempt+1}/{max_retries})") + start_time = time.time() + + response = self.client.chat.completions.create( + model="chatts", + messages=messages, + max_tokens=max_tokens, + temperature=temperature + ) + + end_time = time.time() + elapsed = end_time - start_time + + # Get usage information if available + usage = getattr(response, 'usage', None) + if usage: + logger.info( + f"Query successful in {elapsed:.2f}s: " + f"prompt_tokens={usage.prompt_tokens}, " + f"completion_tokens={usage.completion_tokens}, " + f"total_tokens={usage.total_tokens}" + ) + else: + logger.info(f"Query successful, inference time: {elapsed:.2f}s") + + return response.choices[0].message.content + + except Exception as e: + logger.error(f"Query failed (attempt {attempt+1}/{max_retries}): {type(e).__name__}: {e}") + time.sleep(retry_delay) # Wait before retry + + # Increase retry delay for exponential backoff + retry_delay = min(retry_delay * 2, 60) # Cap at 60 seconds + + logger.critical(f"Failed to get response from ChatTS server after {max_retries} attempts") + raise RuntimeError("Failed to get response from ChatTS server after multiple attempts") + +def prepare_question_with_ts_placeholders(question, cols): + """ + Prepares a question with placeholders for each time series column. + + Args: + question: Original question + cols: List of column names + + Returns: + Modified question with placeholders + """ + # Check if the question already has placeholders + if "" in question: + return question + + # Add placeholders for each column + if cols and len(cols) > 0: + prefix = f"There are {len(cols)} time series collected: " + placeholder_text = ", ".join([f"{col}:" for col in cols]) + return f"{prefix}{placeholder_text}. Please analyze time series features and answer the following question:\n\n{question}" + else: + return question + +def process_sample(args, client, sample, idx): + """ + Process a single sample with the ChatTS client. + + Args: + args: Command line arguments + client: ChatTSClient instance + sample: Data sample + idx: Sample index + + Returns: + Result dictionary + """ + try: + # Extract data + cols = sample.get("cols", []) + ts_data = sample["timeseries"] + question = sample["question"] + + # Prepare question with placeholders if needed + question_with_ts = prepare_question_with_ts_placeholders(question, cols) + + # Query the ChatTS server + answer = client.query_chatts_with_timeseries( + timeseries=ts_data, + question=question_with_ts, + max_tokens=args.max_tokens, + timeout=args.timeout, + retry_delay=args.retry_delay, + max_retries=args.max_retries + ) + + # Return successful result + return { + "idx": idx, + "question": question, + "response": answer, + "success": True + } + + except Exception as e: + logger.error(f"Error processing sample {idx}: {str(e)}") + # Get question from sample, handling potential key errors safely + question = "" + try: + question = sample.get("question", "") + except Exception: + pass + + return { + "idx": idx, + "question": question, + "response": f"ERROR: {str(e)}", + "success": False + } + +def main(): + args = parse_args() + + # Initialize lock for thread-safe operations + results_lock = threading.Lock() + + # Initialize ChatTS client pool + clients = [ChatTSClient(server_url=args.server_url) for _ in range(args.workers)] + + # Check server health with first client + if not clients[0].check_server_health(): + logger.error("ChatTS server is not healthy. Please make sure it is running.") + logger.error("Run: ./start_chatts_server.sh to start the server.") + sys.exit(1) + else: + logger.info(f"Using {args.workers} workers for parallel processing") + + # Load evaluation set + logger.info(f"Loading dataset from {args.dataset_path}") + with open(args.dataset_path, "r") as f: + full_dataset = json.load(f) + + # Sample if needed + total_entries = len(full_dataset) + if total_entries > args.max_samples: + logger.info(f"Sampling {args.max_samples} entries from {total_entries} total") + import random + random.seed(args.seed) + indices = random.sample(range(total_entries), args.max_samples) + dataset = [full_dataset[i] for i in indices] + + # Save metadata about the sampling + metadata_file = args.output_path.replace(".json", "_sampling_metadata.json") + with open(metadata_file, "w") as f: + json.dump({ + "original_size": total_entries, + "sampled_size": args.max_samples, + "seed": args.seed, + "sampled_indices": indices + }, f, indent=2) + else: + dataset = full_dataset + + # Create output directory + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Load existing results if any + results = [] + if os.path.exists(args.output_path): + logger.info(f"Loading existing results from {args.output_path}") + with open(args.output_path, "r") as f: + results = json.load(f) + processed_indices = {r["idx"] for r in results} + logger.info(f"Resuming from {len(processed_indices)} already processed entries") + else: + processed_indices = set() + + # Track results and progress in thread-safe way + results_dict = {r["idx"]: r for r in results} + total_processed = len(results) + progress_queue = Queue() + + # Create a dictionary to hold futures for each worker + if args.workers > 1: + logger.info(f"Starting parallel processing with {args.workers} workers") + + # Function to process results and update progress + def process_results(): + nonlocal total_processed + + with tqdm(total=len(dataset), desc="Evaluating ChatTS", initial=len(results)) as pbar: + while True: + # Get the next completed task from the queue + idx, result = progress_queue.get() + + if idx == -1: # Sentinel value to exit + break + + # Update results with thread safety + with results_lock: + results_dict[idx] = result + total_processed += 1 + + # Checkpoint at specified intervals + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + # Update progress bar + pbar.update(1) + progress_queue.task_done() + + # Start progress tracking thread + progress_thread = threading.Thread(target=process_results) + progress_thread.start() + + # Process samples in parallel + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + client_idx = 0 + + # Submit tasks to the thread pool + for idx, sample in enumerate(dataset): + # Skip if already processed + if idx in processed_indices: + continue + + # Submit the task to the thread pool + future = executor.submit( + process_sample, + args, + clients[client_idx % len(clients)], # Round-robin client assignment + sample, + idx + ) + futures[future] = idx + client_idx += 1 + + # Wait for tasks to complete and collect results + for future in futures: + idx = futures[future] + try: + result = future.result() + progress_queue.put((idx, result)) + except Exception as e: + logger.error(f"Worker error on sample {idx}: {str(e)}") + progress_queue.put((idx, { + "idx": idx, + "question": dataset[idx].get("question", ""), + "response": f"WORKER ERROR: {str(e)}", + "success": False + })) + + # Signal progress thread to exit and wait for it to finish + progress_queue.put((-1, None)) + progress_thread.join() + + else: + # Sequential processing + logger.info("Using sequential processing (workers=1)") + for idx, sample in enumerate(tqdm(dataset, desc="Evaluating ChatTS")): + # Skip if already processed + if idx in processed_indices: + continue + + # Process the sample + result = process_sample(args, clients[0], sample, idx) + results_dict[idx] = result + + # Checkpoint at specified intervals + if len(results_dict) % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {len(results_dict)} results") + + # Final save with sorted results + final_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(final_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(final_results)} answers to {args.output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/chatts_injection.py b/src/chatts_injection.py new file mode 100644 index 0000000..1cdad40 --- /dev/null +++ b/src/chatts_injection.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +""" +ChatTS Injection Script + +This script uses a running ChatTS server to extract injection time series observations: +1. Connects to a running ChatTS server (start with start_chatts_server.sh) +2. Loads a dataset with questions and time series data +3. Submits each question to ChatTS to generate injection observations +4. Saves the observations to a JSON file for later use with Claude + +Usage: + python chatts_injection.py --dataset_path /path/to/dataset.json --output_path /path/to/observations.json +""" + +import os +import sys +import json +import time +import argparse +import logging +import threading +from queue import Queue +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor + +try: + from openai import OpenAI +except ImportError: + raise ImportError("OpenAI Python client not installed. Please install with 'pip install openai'") + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def parse_args(): + p = argparse.ArgumentParser(description="Use ChatTS server to extract injection time series observations") + p.add_argument("--server_url", default="http://localhost:5000", + help="URL of the ChatTS server") + p.add_argument("--dataset_path", "-d", required=True, + help="Path to the dataset JSON") + p.add_argument("--output_path", "-o", required=True, + help="Where to write injection observations JSON") + p.add_argument("--max_tokens", type=int, default=3072, + help="Maximum tokens to generate") + p.add_argument("--workers", type=int, default=4, + help="Number of parallel workers for processing samples") + p.add_argument("--checkpoint_interval", type=int, default=10, + help="Interval for saving checkpoints") + p.add_argument("--timeout", type=int, default=180, + help="Timeout in seconds for API calls") + p.add_argument("--retry_delay", type=int, default=5, + help="Delay between retries in seconds") + p.add_argument("--max_retries", type=int, default=3, + help="Maximum number of retry attempts") + return p.parse_args() + +# For injection case, we keep the full question + +def build_injection_prompt(question: str, ts_cols=None) -> str: + """Build a prompt for the ChatTS injection.""" + # Keep the full question for injection case + + # Check if the question already has placeholders + has_ts_placeholders = '' in question + + # If no placeholders and we have column info, add them + ts_placeholder_text = "" + if not has_ts_placeholders and ts_cols and len(ts_cols) > 0: + num_series = len(ts_cols) + ts_placeholder_text = f"There are {num_series} time series collected: " + for i, col_name in enumerate(ts_cols): + ts_placeholder_text += f"{col_name}:" + if i < len(ts_cols) - 1: + ts_placeholder_text += ", " + ts_placeholder_text += ". " + + prompt_body = ( + "<|im_start|>system\n" + "You are a helpful time series analysis assistant.\n" + "<|im_end|>" + "<|im_start|>user\n" + "You are analyzing a time series to extract key quantitative observations that help answer the question.\n\n" + f"{ts_placeholder_text}{question}\n\n" + "Provide detailed, objective numerical observations by following these guidelines:\n" + "1. Make numbered, precise observations about the quantitative aspects of the time series.\n" + "2. Be specific about values, positions, and magnitudes when describing features.\n" + "3. Begin each observation with \"Observation 1:\", \"Observation 2:\", etc.\n" + "\n" + "Start your response with:\n" + "\"To answer this question, I need to carefully analyze the time series. " + "Here are my observations: Observation 1... Observation 2...\"" + "<|im_end|>" + "<|im_start|>assistant\n" + ) + return prompt_body + +class ChatTSClient: + """Client for communicating with ChatTS server using OpenAI API.""" + + def __init__(self, server_url="http://localhost:5000", debug_mode=False): + """Initialize the ChatTS client.""" + self.server_url = server_url + self.debug_mode = debug_mode + self.client = OpenAI(base_url=f"{server_url}/v1", api_key="dummy-key") + + if debug_mode: + logger.setLevel(logging.DEBUG) + logger.info(f"ChatTSClient initialized in DEBUG mode with server URL: {server_url}") + + def check_server_health(self): + """Check if the server is healthy.""" + import requests + try: + logger.info(f"Checking health of ChatTS server at {self.server_url}...") + # Try to get the models list as a health check + response = requests.get(f"{self.server_url}/v1/models", timeout=10) + if response.status_code == 200: + logger.info(f"ChatTS server is healthy") + return True + else: + logger.warning(f"ChatTS server health check failed: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error checking server health: {type(e).__name__}: {e}") + return False + + def generate_observations( + self, + timeseries, + question, + ts_cols=None, + max_tokens=1024, + temperature=0.01, + timeout=120, + retry_delay=5, + max_retries=3, + ): + """ + Submit question to ChatTS for generating observations. + + Args: + timeseries: Time series data + question: Original question + ts_cols: Column names for the time series + max_tokens: Maximum tokens to generate + temperature: Temperature for sampling + timeout: Request timeout in seconds + retry_delay: Delay between retries in seconds + max_retries: Maximum number of retry attempts + + Returns: + Generated observations as string + """ + # Build the prompt for the injection + prompt = build_injection_prompt(question, ts_cols) + + # Create messages array with timeseries data + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt} + ] + [{"timeseries": ts} for ts in timeseries] + } + ] + + # Make the request with retries + for attempt in range(max_retries): + try: + logger.info(f"Sending observation request to ChatTS server (attempt {attempt+1}/{max_retries})") + start_time = time.time() + + response = self.client.chat.completions.create( + model="chatts", + messages=messages, + max_tokens=max_tokens, + temperature=temperature + ) + + end_time = time.time() + elapsed = end_time - start_time + + # Get usage information if available + usage = getattr(response, 'usage', None) + if usage: + logger.info( + f"Observation generation successful in {elapsed:.2f}s: " + f"prompt_tokens={usage.prompt_tokens}, " + f"completion_tokens={usage.completion_tokens}, " + f"total_tokens={usage.total_tokens}" + ) + else: + logger.info(f"Observation generation successful, inference time: {elapsed:.2f}s") + + return response.choices[0].message.content + + except Exception as e: + logger.error(f"Observation request failed (attempt {attempt+1}/{max_retries}): {type(e).__name__}: {e}") + time.sleep(retry_delay) # Wait before retry + + # Increase retry delay for exponential backoff + retry_delay = min(retry_delay * 2, 60) # Cap at 60 seconds + + logger.critical(f"Failed to get response from ChatTS server after {max_retries} attempts") + raise RuntimeError("Failed to get response from ChatTS server after multiple attempts") + +def process_sample(args, client, dataset, idx): + """ + Process a single sample with the ChatTS client. + + Args: + args: Command line arguments + client: ChatTSClient instance + dataset: Full dataset + idx: Sample index + + Returns: + Result dictionary with observations + """ + try: + # Get the sample + sample = dataset[idx] + question = sample.get("question", "") + + if not question: + logger.warning(f"No question found for sample {idx}, skipping") + return None + + # Get column names and timeseries data from the sample + ts_cols = sample.get("cols", []) + timeseries = sample["timeseries"] + + # For better logging + if isinstance(timeseries, list): + ts_shape = f"List with {len(timeseries)} elements" + if len(timeseries) > 0 and isinstance(timeseries[0], list): + ts_shape += f", first element length: {len(timeseries[0])}" + else: + ts_shape = "Not a list" + logging.info(f"Sample {idx}: Timeseries shape: {ts_shape}") + + # Get the observations from the server + observations = client.generate_observations( + timeseries=timeseries, + question=question, + ts_cols=ts_cols, + max_tokens=args.max_tokens, + timeout=args.timeout, + retry_delay=args.retry_delay, + max_retries=args.max_retries + ) + + # Return the result + return { + "idx": idx, + "question": question, + "observations": observations, + "ability_types": sample.get("ability_types", []), # Preserve any metadata from sample + "attributes": sample.get("attributes", {}) # Preserve any metadata from sample + } + + except Exception as e: + logger.error(f"Error processing sample {idx}: {str(e)}") + return { + "idx": idx, + "question": sample.get("question", "") if sample else "", + "observations": f"ERROR: {str(e)}", + "ability_types": sample.get("ability_types", []) if sample else [], + "attributes": sample.get("attributes", {}) if sample else {} + } + +def main(): + args = parse_args() + + # Initialize lock for thread-safe operations + results_lock = threading.Lock() + + # Initialize ChatTS client pool + clients = [ChatTSClient(server_url=args.server_url) for _ in range(args.workers)] + + # Check server health with first client + if not clients[0].check_server_health(): + logger.error("ChatTS server is not healthy. Please make sure it is running.") + logger.error("Run: ./start_chatts_server.sh to start the server.") + sys.exit(1) + else: + logger.info(f"Using {args.workers} workers for parallel processing") + + # 1) Load data + logger.info(f"Loading dataset from {args.dataset_path}") + with open(args.dataset_path, "r") as f: + dataset = json.load(f) + + # Create output directory + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Load existing results if any + results = [] + if os.path.exists(args.output_path): + logger.info(f"Loading existing results from {args.output_path}") + with open(args.output_path, "r") as f: + results = json.load(f) + processed_indices = {r["idx"] for r in results} + logger.info(f"Resuming from {len(processed_indices)} already processed entries") + else: + processed_indices = set() + + # Track results and progress in thread-safe way + results_dict = {r["idx"]: r for r in results} + total_processed = len(results) + progress_queue = Queue() + + # Create a list of samples to process + samples_to_process = [idx for idx in range(len(dataset)) if idx not in processed_indices] + logger.info(f"Found {len(samples_to_process)} samples to process out of {len(dataset)} total") + + if not samples_to_process: + logger.info("No new samples to process") + return + + if args.workers > 1: + logger.info(f"Starting parallel processing with {args.workers} workers") + + # Function to process results and update progress + def process_results(): + nonlocal total_processed + + with tqdm(total=len(samples_to_process), desc="Generating observations", initial=0) as pbar: + while True: + # Get the next completed task from the queue + idx, result = progress_queue.get() + + if idx == -1: # Sentinel value to exit + break + + # Update results with thread safety + with results_lock: + if result is not None: # Skip None results (errors) + results_dict[idx] = result + total_processed += 1 + + # Checkpoint at specified intervals + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + # Update progress bar + pbar.update(1) + progress_queue.task_done() + + # Start progress tracking thread + progress_thread = threading.Thread(target=process_results) + progress_thread.start() + + # Process samples in parallel + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + client_idx = 0 + + # Submit tasks to the thread pool + for idx in samples_to_process: + # Submit the task to the thread pool + future = executor.submit( + process_sample, + args, + clients[client_idx % len(clients)], # Round-robin client assignment + dataset, + idx + ) + futures[future] = idx + client_idx += 1 + + # Wait for tasks to complete and collect results + for future in futures: + idx = futures[future] + try: + result = future.result() + progress_queue.put((idx, result)) + except Exception as e: + logger.error(f"Worker error on sample {idx}: {str(e)}") + # Create an error result + error_result = { + "idx": idx, + "question": dataset[idx].get("question", "") if idx < len(dataset) else "", + "observations": f"WORKER ERROR: {str(e)}", + "ability_types": dataset[idx].get("ability_types", []) if idx < len(dataset) else [], + "attributes": dataset[idx].get("attributes", {}) if idx < len(dataset) else {} + } + progress_queue.put((idx, error_result)) + + # Signal progress thread to exit and wait for it to finish + progress_queue.put((-1, None)) + progress_thread.join() + + else: + # Sequential processing + logger.info("Using sequential processing (workers=1)") + client = clients[0] # Use the first client + + for idx in tqdm(samples_to_process, desc="Generating observations"): + # Process the sample + result = process_sample(args, client, dataset, idx) + + # Update results + if result: + results_dict[idx] = result + total_processed += 1 + + # Checkpoint at specified intervals + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + # Final save + final_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(final_results, outf, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(final_results)} injection observations to {args.output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/chatts_utils/chatts_server.py b/src/chatts_utils/chatts_server.py new file mode 100755 index 0000000..1550710 --- /dev/null +++ b/src/chatts_utils/chatts_server.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +""" +ChatTS Server + +This script runs a vLLM server with OpenAI-compatible API for ChatTS inference. +It leverages the timeseries branch of vLLM for serving ChatTS models. +""" + +import os +import sys +import time +import signal +import argparse +import json +import subprocess +from pathlib import Path + +# Set environment variable for vLLM to allow insecure serialization +os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + +# Parse command line arguments +parser = argparse.ArgumentParser(description="ChatTS Server") +parser.add_argument("--model_path", type=str, required=True, help="Path to ChatTS model") +parser.add_argument("--chatts_path", type=str, required=True, help="Path to ChatTS directory") +parser.add_argument("--port", type=int, default=5000, help="Port to run server on") +parser.add_argument("--device", type=str, default="0", help="GPU device ID") +parser.add_argument("--context_length", type=int, default=6000, help="Max context length") +parser.add_argument("--pid_file", type=str, default="/tmp/chatts_server.pid", help="File to store server PID") +parser.add_argument("--log_file", type=str, default=None, help="File to log server output") +parser.add_argument("--initial_wait", type=int, default=120, help="Initial wait time in seconds") + +args = parser.parse_args() + +# Set up GPU +os.environ["CUDA_VISIBLE_DEVICES"] = args.device +print(f"Using CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") +print(f"Using VLLM_ALLOW_INSECURE_SERIALIZATION={os.environ.get('VLLM_ALLOW_INSECURE_SERIALIZATION', '0')}") + +# Add ChatTS to Python path +sys.path.insert(0, args.chatts_path) + +# Check if vLLM is available +try: + import vllm + print(f"vLLM package found. Using vLLM for ChatTS server.") + # Try to run a simple vLLM command to test if it's properly installed + subprocess.run(["vllm", "--version"], capture_output=True, check=False) + print("vLLM CLI tool is available.") +except ImportError: + print("Error: vLLM is not installed. Please install the timeseries branch from https://github.com/xiez22/vllm") + sys.exit(1) +except subprocess.CalledProcessError: + print("Warning: vLLM CLI tool not found or not working properly. Continuing anyway...") +except FileNotFoundError: + print("Warning: vLLM CLI tool not found in PATH. Continuing anyway...") + +# Create log file directory if needed +if args.log_file: + log_dir = os.path.dirname(args.log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_file = open(args.log_file, 'w') +else: + log_file = None + +# Write PID to file for cleanup +with open(args.pid_file, "w") as f: + f.write(str(os.getpid())) +print(f"Server PID {os.getpid()} written to {args.pid_file}") + +# Graceful shutdown handler +def signal_handler(sig, frame): + print(f"Received signal {sig}, shutting down...") + if server_process and server_process.poll() is None: + server_process.terminate() + server_process.wait(timeout=10) + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(0) + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def start_vllm_server(): + """Start the vLLM server with OpenAI-compatible API""" + + # Ensure environment variables are passed to the subprocess + env = os.environ.copy() + env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + # Determine data-parallel-size based on device parameter + try: + # If there are commas in the device parameter, split it + if ',' in args.device: + devices = [d.strip() for d in args.device.split(',') if d.strip()] + data_parallel_size = len(devices) + else: + # Single device + devices = [args.device.strip()] + data_parallel_size = 1 + + # Validate that we have at least one device + if not devices or data_parallel_size < 1: + print(f"Warning: Invalid device parameter '{args.device}'. Using data-parallel-size=1") + data_parallel_size = 1 + except Exception as e: + print(f"Error parsing device parameter: {e}. Using data-parallel-size=1") + devices = ["0"] + data_parallel_size = 1 + + print(f"Device parameter: {args.device}, detected {data_parallel_size} GPUs") + + cmd = [ + "vllm", "serve", args.model_path, + "--served-model-name", "chatts", + "--trust-remote-code", + "--hf-overrides", '{"model_type":"chatts"}', + "--max-model-len", str(args.context_length), + "--gpu-memory-utilization", "0.95", + "--limit-mm-per-prompt", f"timeseries=50", + "--allowed-local-media-path", os.path.abspath(os.getcwd()), + "--host", "0.0.0.0", + "--port", str(args.port), + "--uvicorn-log-level", "debug", + "--data-parallel-size", str(data_parallel_size) + ] + + print(f"Starting vLLM server with command: {' '.join(cmd)}") + print(f"Environment: VLLM_ALLOW_INSECURE_SERIALIZATION={env['VLLM_ALLOW_INSECURE_SERIALIZATION']}") + print(f"Data Parallel Configuration: {data_parallel_size} GPUs ({args.device})") + + # Start server process + process = subprocess.Popen( + cmd, + env=env, + stdout=log_file, + stderr=log_file if log_file else subprocess.STDOUT + ) + + return process + +def check_server_health(max_retries=60, retry_interval=5): + """Check if the server is healthy by polling the health endpoint""" + import requests + from requests.exceptions import ConnectionError + + # First, wait for the initial loading period + initial_wait = args.initial_wait # Default is 120 seconds + print(f"Waiting {initial_wait} seconds for initial model loading...") + time.sleep(initial_wait) + + print(f"Checking if server is ready at http://localhost:{args.port}/v1/models...") + + for i in range(max_retries): + try: + response = requests.get(f"http://localhost:{args.port}/v1/models", timeout=10) + if response.status_code == 200: + print("Server is ready!") + return True + except ConnectionError: + pass + except requests.exceptions.Timeout: + print("Request timed out. Server might be busy loading the model.") + + print(f"Server not ready yet, retrying in {retry_interval} seconds... ({i+1}/{max_retries})") + time.sleep(retry_interval) + + print("Server failed to start within the expected time") + return False + +if __name__ == "__main__": + # Start the vLLM server + server_process = start_vllm_server() + + # Check server health + if not check_server_health(): + print("Failed to start server, exiting") + if server_process and server_process.poll() is None: + server_process.terminate() + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(1) + + # Keep the script running until the server exits + try: + while server_process.poll() is None: + time.sleep(1) + except KeyboardInterrupt: + signal_handler(signal.SIGINT, None) + + # Server process exited + exit_code = server_process.returncode + print(f"Server process exited with code {exit_code}") + + # Clean up + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(exit_code) \ No newline at end of file diff --git a/src/chatts_utils/start_chatts_server.sh b/src/chatts_utils/start_chatts_server.sh new file mode 100755 index 0000000..d9d795d --- /dev/null +++ b/src/chatts_utils/start_chatts_server.sh @@ -0,0 +1,177 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# start_chatts_server.sh +# +# Script to start a ChatTS server for faster inference. +# This script: +# 1. Initializes the environment +# 2. Starts the ChatTS server using vLLM +# 3. Checks that the server is running and operational +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/evaluation +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# ChatTS model paths +CHATTS_MODEL_PATH="" # Path to ChatTS model checkpoint +CHATTS_PATH="" # Path to ChatTS directory + +# Server configuration +CHATTS_PORT=5000 +CHATTS_PID_FILE="/tmp/chatts_server_${CHATTS_PORT}.pid" +export CHATTS_SERVER_PORT="${CHATTS_PORT}" + +# Device configuration +CHATTS_DEVICE="4,5,6,7" # Use 4 GPUs for ChatTS + +# Create log directory +LOG_DIR="${PROJECT_ROOT}/logs" +mkdir -p "$LOG_DIR" + +# ChatTS log files +CHATTS_LOG="${LOG_DIR}/chatts_server.$(date +%Y-%m-%d-%H-%M-%S).log" +CHATTS_CONSOLE_LOG="${LOG_DIR}/chatts_console.$(date +%Y-%m-%d-%H-%M-%S).log" + +# ── Initialize Conda in this shell ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh. Do you need to run 'conda init'?" + exit 1 +fi +# ─────────────────────────────────────────────────────────────────────────────── + +# ===== Start ChatTS Server ===== +echo "Starting ChatTS server with chatts-vllm environment..." + +# Activate environment for ChatTS +eval "$(conda shell.bash hook)" +conda activate chatts-vllm + +# Set environment variables +export VLLM_ALLOW_INSECURE_SERIALIZATION=1 +echo "VLLM_ALLOW_INSECURE_SERIALIZATION=$VLLM_ALLOW_INSECURE_SERIALIZATION" + +# Don't set CUDA_VISIBLE_DEVICES for ChatTS server, use explicit device selection instead +echo "Using explicit device selection for ChatTS server: cuda:${CHATTS_DEVICE}" + +# Check if a server is already running on the port +if nc -z localhost $CHATTS_PORT 2>/dev/null; then + echo "Warning: Port $CHATTS_PORT is already in use!" + echo "Another server might be running. Check with: lsof -i :$CHATTS_PORT" + + # Ask if we should continue or abort + read -p "Do you want to continue anyway? [y/N] " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Aborting server startup." + exit 1 + fi +fi + +# Clear existing PID file if it exists +if [ -f "$CHATTS_PID_FILE" ]; then + echo "Removing existing PID file: $CHATTS_PID_FILE" + rm -f "$CHATTS_PID_FILE" +fi + +# Copy chatts_server.py to local directory if it doesn't exist +CHATTS_SERVER_SCRIPT="${SCRIPT_DIR}/chatts_server.py" +if [ ! -f "$CHATTS_SERVER_SCRIPT" ]; then + echo "ERROR: Missing ChatTS server script: $CHATTS_SERVER_SCRIPT" + echo "Please make sure src/chatts_utils/chatts_server.py exists in this repository." + exit 1 +fi + +# Start ChatTS server +echo "Starting ChatTS server with log at ${CHATTS_LOG}" +"$CHATTS_SERVER_SCRIPT" \ + --model_path "${CHATTS_MODEL_PATH}" \ + --chatts_path "${CHATTS_PATH}" \ + --port "${CHATTS_PORT}" \ + --device "${CHATTS_DEVICE}" \ + --context_length 5000 \ + --pid_file "${CHATTS_PID_FILE}" \ + --log_file "${CHATTS_LOG}" \ + --initial_wait 180 \ + > "${CHATTS_CONSOLE_LOG}" 2>&1 & + +CHATTS_SERVER_PID=$! +echo "Started ChatTS server process with PID $CHATTS_SERVER_PID" + +# Wait briefly to make sure the process starts +sleep 10 + +# Check if the PID file was created +if [ -f "$CHATTS_PID_FILE" ]; then + FILE_PID=$(cat $CHATTS_PID_FILE) + echo "ChatTS server PID file created with PID ${FILE_PID}" +else + echo "ChatTS server PID file not created yet, writing our tracked PID" + echo $CHATTS_SERVER_PID > "$CHATTS_PID_FILE" +fi + +# Check if the server process is still running +if kill -0 $CHATTS_SERVER_PID 2>/dev/null; then + echo "ChatTS server process is running" +else + echo "Error: ChatTS server process exited unexpectedly" + echo "Check the logs:" + echo "Console log: $CHATTS_CONSOLE_LOG" + echo "Server log: $CHATTS_LOG" + exit 1 +fi + +# Wait for server initialization - a fixed time instead of relying on the server script's check +echo "Waiting for ChatTS server to initialize (240 seconds)..." +echo "You can monitor the logs with:" +echo "tail -f ${CHATTS_CONSOLE_LOG}" +echo "tail -f ${CHATTS_LOG}" +sleep 240 # 4 minute initial wait + +# ===== Test Server Connectivity ===== +echo "Testing ChatTS server connectivity..." +python -c " +from openai import OpenAI +client = OpenAI(base_url='http://localhost:${CHATTS_PORT}/v1', api_key='dummy-key') +try: + response = client.models.list() + print(f'ChatTS models available: {response}') + print('ChatTS server is operational!') + exit(0) +except Exception as e: + print(f'Error testing ChatTS server: {e}') + exit(1) +" +CHATTS_TEST_EXIT_CODE=$? + +if [ $CHATTS_TEST_EXIT_CODE -ne 0 ]; then + echo "Error: ChatTS server test failed." + echo "Check the logs:" + echo "Console log: $CHATTS_CONSOLE_LOG" + echo "Server log: $CHATTS_LOG" + exit 1 +else + echo "ChatTS server test passed successfully!" + echo "The server is running on port $CHATTS_PORT" + echo "To stop the server later, run: $SCRIPT_DIR/stop_chatts_server.sh" +fi + +echo "" +echo "========================================" +echo "ChatTS Server is ready for inference!" +echo "Server URL: http://localhost:$CHATTS_PORT" +echo "========================================" diff --git a/src/chatts_utils/stop_chatts_server.sh b/src/chatts_utils/stop_chatts_server.sh new file mode 100755 index 0000000..7f7344f --- /dev/null +++ b/src/chatts_utils/stop_chatts_server.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# stop_chatts_server.sh +# +# Script to stop a running ChatTS server. +# This script: +# 1. Finds the PID file for the ChatTS server +# 2. Sends a SIGTERM signal to gracefully shut down the server +# ============================================================================== + +# ChatTS server PID file location +CHATTS_PORT=5000 +CHATTS_PID_FILE="/tmp/chatts_server_${CHATTS_PORT}.pid" + +echo "Stopping ChatTS server..." + +# Check if PID file exists +if [ ! -f "$CHATTS_PID_FILE" ]; then + echo "No PID file found at $CHATTS_PID_FILE" + + # Check if there's a process listening on the ChatTS port + if nc -z localhost $CHATTS_PORT 2>/dev/null; then + echo "Warning: Port $CHATTS_PORT is in use but no PID file exists." + echo "Finding processes using port $CHATTS_PORT:" + + # Find and display processes using the port + if command -v lsof &> /dev/null; then + PROCS=$(lsof -i :$CHATTS_PORT -t) + if [ -n "$PROCS" ]; then + echo "Found processes: $PROCS" + + # Ask if we should kill these processes + read -p "Do you want to kill these processes? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Killing processes using port $CHATTS_PORT" + for pid in $PROCS; do + echo "Killing process $pid" + kill -9 $pid + done + else + echo "No processes were killed. Please manually stop the server." + fi + else + echo "No processes found using lsof." + fi + else + echo "lsof command not available, cannot find processes by port." + fi + else + echo "No process is listening on port $CHATTS_PORT" + fi + + exit 0 +fi + +# Read PID from file +PID=$(cat $CHATTS_PID_FILE) +echo "Found ChatTS server with PID: $PID" + +# Check if the process exists +if kill -0 $PID 2>/dev/null; then + echo "Sending SIGTERM to PID $PID" + kill -15 $PID + + # Wait for the process to terminate + echo "Waiting for server to shut down..." + for i in {1..30}; do + if ! kill -0 $PID 2>/dev/null; then + echo "Server shut down successfully." + break + fi + sleep 1 + done + + # If process still exists, force kill + if kill -0 $PID 2>/dev/null; then + echo "Server did not shut down gracefully, sending SIGKILL..." + kill -9 $PID + sleep 2 + fi +else + echo "Process with PID $PID does not exist or is not accessible." +fi + +# Remove PID file +if [ -f "$CHATTS_PID_FILE" ]; then + echo "Removing PID file: $CHATTS_PID_FILE" + rm -f "$CHATTS_PID_FILE" +fi + +# Final check +if nc -z localhost $CHATTS_PORT 2>/dev/null; then + echo "Warning: Port $CHATTS_PORT is still in use after stopping the server." + echo "You might need to manually kill the remaining processes." +else + echo "Port $CHATTS_PORT is now free." +fi + +echo "ChatTS server stop script completed." \ No newline at end of file diff --git a/src/claude_thinking_inference.py b/src/claude_thinking_inference.py new file mode 100644 index 0000000..241de2e --- /dev/null +++ b/src/claude_thinking_inference.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +A two-step approach for Claude inference with time series data: +1. First generates and saves all figures sequentially +2. Then processes the figures with Claude in parallel + +This script: +1. Loads a multimodal dataset (timeseries + question). +2. Generate all figures sequentially to avoid thread-safety issues +3. Send each example to Claude in parallel with thinking mode enabled +4. Parse out "thought" (the chain-of-thought) and the answer. +5. Generate HTML reports for easy result inspection. + +Usage: + # Basic usage with images + python claude_thinking_inference.py --dataset_path path/to/dataset.json --output_path path/to/output.json + + # Text-only mode (no images) + python claude_thinking_inference.py --dataset_path path/to/dataset.json --output_path path/to/output.json --text_only + +""" + +import os +import json +import argparse +import random +import logging +import numpy as np +import boto3 +import base64 +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# Import utility functions +import sys +import os + +# Add the parent directory to the path so we can import claude_utils modules +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Now we can import from claude_utils +from claude_utils.ts_visualization import generate_image_from_timeseries +from claude_utils.claude_inference import ( + MODEL_ID, + invoke_claude, + parse_response, +) + +# Default configuration +WORKERS = 2 +FIG_DIR = "figures" # will be created under output's parent directory +CHECKPOINT_INTERVAL = 50 +MAX_SAMPLES = 450 # Default maximum number of samples to process + +def prepare_timeseries_data(ts): + """ + Prepare timeseries data for visualization. + + Args: + ts: Raw timeseries data + + Returns: + Processed timeseries data suitable for visualization + """ + # Ensure ts is a list of lists (for multiple series) + if not isinstance(ts, list): + ts = [ts] # Wrap single series + elif len(ts) > 0 and not isinstance(ts[0], list): + ts = [ts] # Wrap flat list into nested list + + # Check if we got empty data + if not ts or len(ts) == 0: + ts = [[0, 1, 2, 3, 4]] # Default dummy data + logging.warning(f"Entry has empty time series data, using dummy data") + + return ts + +def generate_all_images(data, to_process, fig_dir): + """ + Generate all images sequentially and return image paths. + + Args: + data: Dataset containing timeseries data + to_process: List of indices to process + fig_dir: Directory to save figures + + Returns: + Dictionary mapping indices to image paths + """ + image_paths = {} + + print(f"Generating {len(to_process)} figures sequentially...") + for idx in tqdm(to_process, desc="Generating figures"): + sample = data[idx] + ts = sample["timeseries"] + cols = sample.get("cols", []) + + # Prepare timeseries data + ts = prepare_timeseries_data(ts) + + # Generate and save image + path = os.path.join(fig_dir, f"{idx}.jpg") + try: + # Always save the image + _ = generate_image_from_timeseries( + case_idx=idx, + timeseries=ts, + cols=cols, + fig_dir=fig_dir, + save_image=True + ) + + # Check if the image was created successfully + if os.path.exists(path): + file_size = os.path.getsize(path) + if file_size > 0: + image_paths[idx] = path + else: + logging.warning(f"Empty image file generated for idx={idx}") + else: + logging.warning(f"Image file not created for idx={idx}") + except Exception as e: + logging.error(f"Error generating image for idx={idx}: {e}") + + print(f"Successfully generated {len(image_paths)} figures out of {len(to_process)} requested") + return image_paths + +def get_image_base64(image_path): + """ + Load an image from disk and convert to base64. + + Args: + image_path: Path to the image file + + Returns: + Base64 encoded string of the image + """ + with open(image_path, "rb") as f: + img_b64 = base64.b64encode(f.read()).decode("utf-8") + return img_b64 + +def process_sample_with_existing_image(idx, sample, client, image_path, text_only=False): + """ + Process a sample using a pre-generated image file. + + Args: + idx: Sample index + sample: Sample data containing question + client: Boto3 client for Bedrock + image_path: Path to the pre-generated image + text_only: If True, only use text input (no image) + + Returns: + Dict containing results (idx, question, thought, response, etc.) + """ + # Extract question + question = sample.get("question", "") + + # Read the image if not in text-only mode + img_b64 = None + if not text_only and image_path: + try: + img_b64 = get_image_base64(image_path) + except Exception as e: + logging.error(f"Error reading image for idx={idx}: {e}") + # Continue with text-only if image loading fails + text_only = True + + # Prepare message for Claude + if text_only: + # Text-only mode - just the question + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": question} + ] + }] + else: + # With image + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": question}, + {"type": "image", "source": { + "type": "base64", "media_type": "image/jpeg", "data": img_b64 + }} + ] + }] + + # Invoke Claude + resp = invoke_claude(client, messages) + thought, answer, ok = parse_response(resp) + + # Return results in a format compatible with HTML report generator + return { + "idx": idx, + "question": question, + "analysis": thought, # Use Claude's thought process as analysis for report + "thought": thought, + "response": answer, + "success": ok, + "image_path": image_path + } + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description="Run Claude inference with time series visualization (two-step approach)" + ) + parser.add_argument( + "--dataset_path", "-d", + required=True, + help="Path to the input dataset JSON" + ) + parser.add_argument( + "--output_path", "-o", + required=True, + help="Path to write the output JSON results" + ) + parser.add_argument( + "--workers", "-w", + type=int, + default=WORKERS, + help=f"Number of parallel workers (default: {WORKERS})" + ) + parser.add_argument( + "--checkpoint_interval", "-c", + type=int, + default=CHECKPOINT_INTERVAL, + help=f"Interval for saving checkpoints (default: {CHECKPOINT_INTERVAL})" + ) + parser.add_argument( + "--max_samples", "-m", + type=int, + default=MAX_SAMPLES, + help=f"Maximum number of samples to process (default: {MAX_SAMPLES})" + ) + parser.add_argument( + "--seed", "-s", + type=int, + default=42, + help="Random seed for sampling (default: 42)" + ) + parser.add_argument( + "--text_only", "-t", + action="store_true", + help="Run inference with text-only mode (no images)" + ) + args = parser.parse_args() + + # Create directories + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Set up experiment-specific figure directory + fig_dir = os.path.join(os.path.dirname(args.output_path), FIG_DIR) + os.makedirs(fig_dir, exist_ok=True) + print(f"Using figure directory: {fig_dir}") + + # Print working directory for debugging + print(f"Current working directory: {os.getcwd()}") + + # Load dataset + logging.info(f"Loading dataset from {args.dataset_path}") + with open(args.dataset_path, "r") as f: + full_dataset = json.load(f) + + # Sample if needed + total_entries = len(full_dataset) + if total_entries > args.max_samples: + logging.warning(f"Dataset has {total_entries} entries, which exceeds the maximum of {args.max_samples}.") + logging.warning(f"Randomly sampling {args.max_samples} entries with seed {args.seed}.") + + # Set random seed for reproducibility + random.seed(args.seed) + + # Sample entries + sampled_indices = random.sample(range(total_entries), args.max_samples) + data = [full_dataset[i] for i in sampled_indices] + + # Create a metadata file with the sampled indices + metadata_file = args.output_path.replace('.json', '_sampling_metadata.json') + with open(metadata_file, 'w') as f: + json.dump({ + "original_size": total_entries, + "sampled_size": args.max_samples, + "seed": args.seed, + "sampled_indices": sampled_indices + }, f, indent=2) + + logging.info(f"Sampling metadata saved to {metadata_file}") + else: + data = full_dataset + sampled_indices = list(range(len(data))) # All indices if no sampling + logging.info(f"Processing all {total_entries} entries (no sampling needed)") + + # Check for existing results to support resuming + existing_map = {} + if os.path.exists(args.output_path): + print(f"Loading existing results from {args.output_path}") + with open(args.output_path, "r") as f: + existing = json.load(f) + existing_map = {r["idx"]: r for r in existing} + print(f"Resuming from {len(existing_map)} / {len(data)} already done") + else: + print("Starting fresh run") + + # Determine which indices still need processing + to_process = [i for i in range(len(data)) if i not in existing_map] + + # STEP 1: Generate all images sequentially + if not args.text_only: + image_paths = generate_all_images(data, to_process, fig_dir) + image_count = len(image_paths) + print(f"Generated {image_count} images out of {len(to_process)} total samples") + else: + # In text-only mode, we don't need to generate images + image_paths = {idx: None for idx in to_process} + print("Skipping image generation in text-only mode") + + # Initialize Bedrock client + client = boto3.client("bedrock-runtime", region_name="us-west-2") + + # STEP 2: Process samples with Claude in parallel + print(f"Processing {len(to_process)} samples") + + if args.workers > 1: + # Parallel processing + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + + for idx in to_process: + # Get the image path for this index (or None if not available) + image_path = image_paths.get(idx, None) + + # Skip if we don't have an image in non-text-only mode + if not args.text_only and image_path is None: + logging.warning(f"Skipping idx={idx} - no image available") + continue + + # Submit the task + futures[executor.submit( + process_sample_with_existing_image, + idx, + data[idx], + client, + image_path, + args.text_only, + )] = idx + + # Process futures as they complete + for count, fut in enumerate(tqdm(futures, desc="Inference"), start=1): + idx = futures[fut] + try: + res = fut.result() + existing_map[idx] = res + except Exception as e: + print(f"[ERROR idx={idx}] {e}") + + # Checkpoint at specified intervals + if count % args.checkpoint_interval == 0 or count == len(futures): + # Save JSON checkpoint + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, + indent=2, + ensure_ascii=False + ) + else: + # Sequential processing + for count, idx in enumerate(tqdm(to_process, desc="Inference"), start=1): + # Get the image path for this index (or None if not available) + image_path = image_paths.get(idx, None) + + # Skip if we don't have an image in non-text-only mode + if not args.text_only and image_path is None: + logging.warning(f"Skipping idx={idx} - no image available") + continue + + try: + res = process_sample_with_existing_image( + idx, data[idx], client, image_path, + args.text_only, + ) + + existing_map[idx] = res + except Exception as e: + print(f"[ERROR idx={idx}] {e}") + + # Checkpoint at specified intervals + if count % args.checkpoint_interval == 0 or count == len(to_process): + # Save JSON checkpoint + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, + indent=2, + ensure_ascii=False + ) + + # Final write + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, + indent=2, + ensure_ascii=False + ) + + print(f"Done: {len(existing_map)}/{len(data)} entries written to {args.output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/claude_thinking_with_injection.py b/src/claude_thinking_with_injection.py new file mode 100644 index 0000000..312cd62 --- /dev/null +++ b/src/claude_thinking_with_injection.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +""" +A script to use ChatTS-generated injection observations with Claude to produce answers. + +For each example it will: + 1. Load the original timeseries + question from the dataset. + 2. Load the injection observations from ChatTS output JSON. + 3. Generate the plot and encode to base64. + 4. Build a multimodal prompt that injects: + - the question + - the injection observations from ChatTS + - a final instruction to produce a "Final Answer" + 5. Call Claude with thinking mode ON. + 6. Parse out the thought and the Final Answer. + 7. Save idx, question, injection observations, thought, answer, success flag. +""" + +import os +import re +import json +import base64 +import argparse +import numpy as np +# We import matplotlib.pyplot in generate_image_from_timeseries +import boto3 +from tqdm import tqdm +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Pool +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception +from botocore.exceptions import ClientError +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# ─── CONFIG ──────────────────────────────────────────────────────────────────── +MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +MAX_TOKENS = 4096 +THINKING_BUDGET = 2048 +WORKERS = 2 +FIG_DIR = "figures" +# ──────────────────────────────────────────────────────────────────────────────── + +default_system = ( + "You are a time‐series expert. \n" + "Answer **only** with a JSON object that has exactly one key, \"Final Answer\",\n" + "whose value is the answer string. \n" +) + +def is_throttling(exc): + return ( + isinstance(exc, ClientError) and + exc.response.get("Error", {}).get("Code") == "ThrottlingException" + ) + +@retry( + retry=retry_if_exception(is_throttling), + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=2, max=10) +) +def invoke_claude(client, messages): + payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "temperature": 1.0, + "thinking": {"type": "enabled", "budget_tokens": THINKING_BUDGET}, + "system": default_system, + "messages": messages + } + resp = client.invoke_model(body=json.dumps(payload), modelId=MODEL_ID) + return json.loads(resp['body'].read()) + +def parse_response(resp_body): + thought_chunks, text_chunks = [], [] + for chunk in resp_body.get("content", []): + if chunk.get("type") == "thinking": + thought_chunks.append(chunk.get("thinking","").strip()) + elif chunk.get("type") == "text": + text_chunks.append(chunk.get("text","").strip()) + thought = "\n".join(thought_chunks) + raw = "".join(text_chunks) + clean = re.sub(r'```(?:json)?', '', raw).strip() + start = clean.find('{') + end = clean.rfind('}') + if start!=-1 and end>start: + json_str = clean[start:end+1] + else: + json_str = clean + try: + obj = json.loads(json_str) + answer = obj.get("Final Answer","") + success = "Final Answer" in obj + except Exception: + answer = json_str + success = False + return thought, answer, success + +def generate_image_from_timeseries(idx, ts, cols): + """ + Generate an image from timeseries data. + Uses the same styling as the ts_visualization utility. + """ + import sys + import os + + # Add the parent directory to the path to import utils + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from utils.ts_visualization import generate_image_from_timeseries as gen_img + + # Ensure directory exists + os.makedirs(FIG_DIR, exist_ok=True) + path = os.path.join(FIG_DIR, f"{idx}.jpg") + + # Convert numpy array to list if needed + if isinstance(ts, np.ndarray): + # Handle different dimensions + if len(ts.shape) == 1: + # Single series + ts_list = [ts.tolist()] + else: + # Multiple series + ts_list = [series.tolist() for series in ts] + else: + ts_list = ts + + # Ensure we have column names + if not cols or len(cols) != len(ts_list): + cols = [f"Series {i+1}" for i in range(len(ts_list))] + + # Call the utility function with save_image=True and get the base64 string + img_b64 = gen_img(idx, ts_list, cols, FIG_DIR, save_image=True) + + return path + +def build_prompt_with_injection(question: str, observations: str) -> str: + """ + Build a prompt that includes the question and injection observations from ChatTS, + then asks Claude to provide a final answer based on these observations. + """ + # Split off the "Now, based on ..." part if it exists + if "Now," in question: + q_part, rest = question.split("Now,", 1) + question_part = q_part.strip() + answer_format = "Now," + rest.strip() + else: + question_part = question.strip() + answer_format = "" + + # Build the new user prompt (instructional proxy for early injection) + body = ( + f"{question_part}\n\n" + f"{observations.strip()}\n\n" + "Wait, let me summarize and reflect on the previous observations from the time series, " + "and then continue my reasoning process to derive the final answer...\n\n" + "Please continue your thinking process from the observations above and provide your answer to the question" + + (f" following these instructions:\n\n{answer_format}\n" if answer_format else ".\n") + ) + return body + +def generate_and_save_image(idx, sample): + """ + Generate an image for a given sample and save it to disk. + + Args: + idx: Sample index + sample: Data sample with timeseries and columns + + Returns: + Path to the generated image file or None if there was an error + """ + try: + # Extract timeseries and column data + ts = sample["timeseries"] + cols = sample["cols"] + + # Log structure information for debugging + if isinstance(ts, list): + if ts and isinstance(ts[0], list): + lengths = [len(series) for series in ts] + print(f"Sample {idx}: Time series is list of lists with lengths {lengths}") + else: + print(f"Sample {idx}: Time series is single list with length {len(ts) if ts else 0}") + elif isinstance(ts, np.ndarray): + print(f"Sample {idx}: Time series is numpy array with shape {ts.shape}") + else: + print(f"Sample {idx}: Time series has unknown type {type(ts)}") + + # Make sure cols list is properly sized + if not cols or len(cols) != len(ts if isinstance(ts, list) else []): + if isinstance(ts, list): + cols = [f"Series {i+1}" for i in range(len(ts))] + else: + cols = ["Series 1"] + + # Generate and save the image + img_path = generate_image_from_timeseries(idx, ts, cols) + return img_path + except Exception as e: + print(f"[ERROR in generate_and_save_image idx={idx}] {e}") + return None + +def process_with_claude(idx, injection_entry, img_path, client): + """ + Process a sample using Claude with a pre-generated image and injection observations. + + Args: + idx: Sample index + injection_entry: Input entry with question and injection observations + img_path: Path to the pre-generated image + client: Boto3 client for Bedrock + + Returns: + Dict with results + """ + # Read the pre-generated image file + with open(img_path, "rb") as f: + img_data = f.read() + img_b64 = base64.b64encode(img_data).decode("utf8") + + # Build the prompt + prompt_text = build_prompt_with_injection( + injection_entry["question"], injection_entry["observations"] + ) + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": prompt_text}, + {"type": "image", "source": { + "type": "base64", "media_type": "image/jpeg", "data": img_b64 + }} + ] + }] + + # Invoke Claude + resp = invoke_claude(client, messages) + thought, answer, ok = parse_response(resp) + + return { + "idx": idx, + "question": injection_entry["question"], + "observations": injection_entry["observations"], + "thought": thought, + "response": answer, + "success": ok, + "ability_types": injection_entry.get("ability_types", []), # Preserve metadata + "attributes": injection_entry.get("attributes", {}) # Preserve metadata + } + +def main(): + p = argparse.ArgumentParser(description="Generate answers with injection observations from ChatTS") + p.add_argument("--dataset_path", "-d", required=True, + help="Path to the original dataset JSON") + p.add_argument("--injection_path", "-p", required=True, + help="Path to the injection observations JSON from ChatTS") + p.add_argument("--output_path", "-o", required=True, + help="Where to write final results JSON") + p.add_argument("--workers", "-w", type=int, default=WORKERS, + help=f"Number of parallel workers (default: {WORKERS})") + p.add_argument("--image_workers", "-iw", type=int, default=WORKERS, + help=f"Number of workers for image generation (default: {WORKERS})") + args = p.parse_args() + + # Load data & injection observations + print(f"Loading dataset from {args.dataset_path}") + data = json.load(open(args.dataset_path)) + print(f"Loading injection observations from {args.injection_path}") + injections = json.load(open(args.injection_path)) + injection_map = {h["idx"]: h for h in injections} + + # Prepare output dir + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Ensure figures directory exists + fig_dir = os.path.join(os.path.dirname(args.output_path), FIG_DIR) + os.makedirs(fig_dir, exist_ok=True) + + # Load existing results if any + if os.path.exists(args.output_path): + existing = json.load(open(args.output_path)) + existing_map = {r["idx"]: r for r in existing} + print(f"Resuming from {len(existing_map)} / {len(injection_map)} already done") + else: + existing_map = {} + print("Starting fresh run") + + # Determine which indices still need processing + to_process = [i for i in sorted(injection_map) if i not in existing_map] + + if not to_process: + print("No samples to process. All done!") + return + + # STEP 1: Generate all images in parallel + print(f"\nSTEP 1: Generating {len(to_process)} images in parallel with {args.image_workers} workers") + image_paths = {} + + if args.image_workers > 1: + # Use multiprocessing Pool instead of ThreadPoolExecutor + # Create arguments for pool.map as list of tuples + args_list = [(i, data[i]) for i in to_process if i < len(data)] + + # Create a multiprocessing pool + with Pool(processes=args.image_workers) as pool: + # Process each item and collect results + for i, result in enumerate(tqdm(pool.starmap(generate_and_save_image, args_list), + total=len(args_list), desc="Generating images")): + idx = to_process[i] + if result: # Check if image generation was successful + image_paths[idx] = result + else: + print(f"[ERROR generating image idx={idx}] Failed to generate image") + else: + # Sequential image generation + for idx in tqdm(to_process, desc="Generating images"): + if idx < len(data): # Make sure we have the sample in the dataset + try: + img_path = generate_and_save_image(idx, data[idx]) + if img_path: + image_paths[idx] = img_path + except Exception as e: + print(f"[ERROR generating image idx={idx}] {e}") + + # Report image generation results + print(f"Successfully generated {len(image_paths)}/{len(to_process)} images") + + # Filter to_process to only include samples with successful image generation + to_process_filtered = [i for i in to_process if i in image_paths] + if len(to_process_filtered) < len(to_process): + print(f"WARNING: {len(to_process) - len(to_process_filtered)} samples will be skipped due to image generation failures") + + # STEP 2: Process with Claude in parallel + print(f"\nSTEP 2: Processing {len(to_process_filtered)} samples with Claude using {args.workers} workers") + client = boto3.client("bedrock-runtime", region_name="us-west-2") + + # Process samples with Claude + if args.workers > 1: + with ThreadPoolExecutor(max_workers=args.workers) as exe: + futures = { + exe.submit(process_with_claude, i, injection_map[i], image_paths[i], client): i + for i in to_process_filtered + } + for count, fut in enumerate(tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing with Claude"), start=1): + idx = futures[fut] + try: + res = fut.result() + existing_map[idx] = res + except Exception as e: + print(f"[ERROR processing with Claude idx={idx}] {e}") + # checkpoint every 5 + if count % 5 == 0: + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, indent=2, ensure_ascii=False + ) + else: + for count, idx in enumerate(tqdm(to_process_filtered, desc="Processing with Claude"), start=1): + try: + res = process_with_claude(idx, injection_map[idx], image_paths[idx], client) + existing_map[idx] = res + except Exception as e: + print(f"[ERROR processing with Claude idx={idx}] {e}") + if count % 5 == 0: + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, indent=2, ensure_ascii=False + ) + + # final write + with open(args.output_path, "w") as outf: + json.dump( + [existing_map[k] for k in sorted(existing_map)], + outf, indent=2, ensure_ascii=False + ) + + print(f"Done: {len(existing_map)}/{len(injection_map)} entries written to {args.output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/claude_utils/__init__.py b/src/claude_utils/__init__.py new file mode 100644 index 0000000..26b7295 --- /dev/null +++ b/src/claude_utils/__init__.py @@ -0,0 +1,11 @@ +""" +Utility functions for time series evaluation. +""" + +from .ts_visualization import generate_image_from_timeseries +from .api_utils import ask_via_llama_fac_api + +__all__ = [ + 'generate_image_from_timeseries', + 'ask_via_llama_fac_api' +] \ No newline at end of file diff --git a/src/claude_utils/api_utils.py b/src/claude_utils/api_utils.py new file mode 100644 index 0000000..ffd78a2 --- /dev/null +++ b/src/claude_utils/api_utils.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +""" +Utility functions for interacting with the LLaMA Factory API. +""" + +def ask_via_llama_fac_api(client, model, img_b64, question): + """ + Call the local llama-factory API with one image + grouped-QA text. + + Args: + client: OpenAI client object + model: Model name to use for inference + img_b64: Base64 encoded image string + question: Question to ask the model + + Returns: + Model response as a string + """ + data_uri = f"data:image/jpeg;base64,{img_b64}" + + # Build the multimodal message + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + { + "type": "image_url", + "image_url": {"url": data_uri} + } + ] + } + ] + + # Send with chat.completions + resp = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096 + ) + + return resp.choices[0].message.content \ No newline at end of file diff --git a/src/claude_utils/claude_inference.py b/src/claude_utils/claude_inference.py new file mode 100644 index 0000000..6b48735 --- /dev/null +++ b/src/claude_utils/claude_inference.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +Utility functions for Claude inference with time series data. +This module provides common functions for working with Claude API, +including response parsing, retry logic, and prompt configuration. +""" + +import re +import json +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception +from botocore.exceptions import ClientError + +# ─── CONFIGURATION ──────────────────────────────────────────────────────────────── +MODEL_ID = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" +MAX_TOKENS = 4096 +THINKING_BUDGET = 2048 + +# System prompt that requests JSON output with a single "Final Answer" key +DEFAULT_SYSTEM_PROMPT = ( + "You are a time‐series expert. \n" + "Answer **only** with a JSON object that has exactly one key, 'Final Answer',\n" + "whose value is the answer string. \n" +) + +def is_throttling(exc): + """ + Check if the exception is due to throttling. + + Args: + exc: Exception to check + + Returns: + bool: True if the exception is a throttling exception + """ + return ( + isinstance(exc, ClientError) and + exc.response.get("Error", {}).get("Code") == "ThrottlingException" + ) + +@retry( + retry=retry_if_exception(is_throttling), + stop=stop_after_attempt(20), + wait=wait_exponential(multiplier=1, min=2, max=10) +) +def invoke_claude(client, messages, model_id=MODEL_ID, temperature=1.0, system=None): + """ + Invoke Claude with retry logic for throttling. + + Args: + client: Boto3 Bedrock client + messages: List of message dictionaries + model_id: Claude model ID to use + temperature: Sampling temperature (1.0 for thinking) + system: Optional system prompt (uses default if None) + + Returns: + dict: Claude API response + """ + payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "temperature": temperature, + "thinking": {"type": "enabled", "budget_tokens": THINKING_BUDGET}, + "system": system or DEFAULT_SYSTEM_PROMPT, + "messages": messages + } + resp = client.invoke_model(body=json.dumps(payload), modelId=model_id) + return json.loads(resp['body'].read()) + +def parse_response(resp_body): + """ + Parse Claude's response to extract thought and answer. + + Args: + resp_body: Claude API response body + + Returns: + tuple: (thought, answer, success) + """ + thought_chunks, text_chunks = [], [] + + for chunk in resp_body.get("content", []): + if chunk.get("type") == "thinking": + thought_chunks.append(chunk.get("thinking", "").strip()) + elif chunk.get("type") == "text": + text_chunks.append(chunk.get("text", "").strip()) + + thought = "\n".join(thought_chunks) + raw = "".join(text_chunks) + + # Remove markdown code fences + clean = re.sub(r'```(?:json)?', '', raw).strip() + + # Find first '{' and last '}' to extract JSON + start = clean.find('{') + end = clean.rfind('}') + + if start != -1 and end != -1 and end > start: + json_str = clean[start:end+1] + else: + json_str = clean + + try: + obj = json.loads(json_str) + answer = obj.get('Final Answer', "") + success = 'Final Answer' in obj + except Exception: + answer = json_str + success = False + + return thought, answer, success \ No newline at end of file diff --git a/src/claude_utils/ts_visualization.py b/src/claude_utils/ts_visualization.py new file mode 100644 index 0000000..9e563a4 --- /dev/null +++ b/src/claude_utils/ts_visualization.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +Utility functions for time series visualization. +""" + +import os +import base64 +import numpy as np +import matplotlib.pyplot as plt + +def generate_image_from_timeseries(case_idx, timeseries, cols, fig_dir, save_image=False): + """ + Plot each channel in its own subplot, save as JPG, return base64 string. + + Args: + case_idx: Unique identifier for the figure + timeseries: Time series data to visualize (list or numpy array) + cols: Column names for the time series + fig_dir: Directory to save the figure + save_image: Whether to keep the saved image file (default: False) + + Returns: + Base64 encoded string of the image + """ + # Create directory if it doesn't exist + os.makedirs(fig_dir, exist_ok=True) + + # Ensure consistent naming scheme + path = os.path.join(fig_dir, f"{case_idx}.jpg") + + # Convert numpy array to list if needed + if isinstance(timeseries, np.ndarray): + # Handle different dimensions + if len(timeseries.shape) == 1: + # Single series + timeseries_list = [timeseries.tolist()] + else: + # Multiple series + timeseries_list = [series.tolist() for series in timeseries] + else: + timeseries_list = timeseries + + # Ensure we have column names + if not cols or len(cols) != len(timeseries_list): + cols = [f"Series {i+1}" for i in range(len(timeseries_list))] + + # Handle the case where we have multiple subplots + n = len(timeseries_list) + if n > 1: + figsize = (6, 2 * n) + # Create subplots with the determined figure size + fig, axes = plt.subplots(n, 1, figsize=figsize) + + for ax, series, title in zip(axes, timeseries_list, cols): + ax.plot(series, linewidth=2) + ax.set_title(title, fontsize=10, fontweight='bold') + else: + fig, ax = plt.subplots(figsize=(6, 2)) + ax.plot(timeseries_list[0], linewidth=2) + ax.set_title(cols[0], fontsize=10, fontweight='bold') + + # Save the figure with consistent settings + plt.tight_layout() + plt.savefig(path, format='jpg', dpi=100) + plt.close(fig) + + # Read the image and convert to base64 + with open(path, "rb") as f: + img_b64 = base64.b64encode(f.read()).decode("utf-8") + + # Delete the image file if not saving images + if not save_image: + try: + os.remove(path) + except OSError: + # Ignore errors if file cannot be deleted + pass + + return img_b64 \ No newline at end of file diff --git a/src/qwen3_utils/qwen3_server.py b/src/qwen3_utils/qwen3_server.py new file mode 100755 index 0000000..251a99e --- /dev/null +++ b/src/qwen3_utils/qwen3_server.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +""" +Qwen3 Server + +This script runs a vLLM server with OpenAI-compatible API for Qwen3 inference. +Qwen3 is a text-only reasoning model served via vLLM, used for thinking-based +time series analysis with injection support (continue_final_message). +""" + +import os +import sys +import time +import signal +import argparse +import subprocess +from pathlib import Path + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Qwen3 Server") +parser.add_argument("--model_path", type=str, required=True, help="Path to Qwen3 model") +parser.add_argument("--port", type=int, default=5001, help="Port to run server on") +parser.add_argument("--device", type=str, default="0,1,2,3", help="GPU device IDs") +parser.add_argument("--data_parallel_size", type=int, default=2, help="Data parallel size") +parser.add_argument("--tensor_parallel_size", type=int, default=2, help="Tensor parallel size") +parser.add_argument("--context_length", type=int, default=32768, help="Max context length") +parser.add_argument("--pid_file", type=str, default="/tmp/qwen3_server.pid", help="File to store server PID") +parser.add_argument("--log_file", type=str, default=None, help="File to log server output") +parser.add_argument("--initial_wait", type=int, default=120, help="Initial wait time in seconds") +parser.add_argument("--chat_template", type=str, default=None, help="Path to custom chat template file") + +args = parser.parse_args() + +# Print all args for debugging +print("Arguments received:") +for arg in vars(args): + print(f" {arg}: {getattr(args, arg)}") + +# Set up GPU +os.environ["CUDA_VISIBLE_DEVICES"] = args.device +print(f"Using CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") + +# Using V1 without multiprocessing +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_USE_V1"] = "1" +print("Enabled vLLM v1 engine with environment variables:") +print(f" VLLM_ENABLE_V1_MULTIPROCESSING={os.environ.get('VLLM_ENABLE_V1_MULTIPROCESSING')}") +print(f" VLLM_USE_V1={os.environ.get('VLLM_USE_V1')}") + +# Check if vLLM is available +try: + import vllm + print(f"vLLM package found. Using vLLM for Qwen3 server.") + subprocess.run(["vllm", "--version"], capture_output=True, check=False) + print("vLLM CLI tool is available.") +except ImportError: + print("Error: vLLM is not installed. Please install vllm.") + sys.exit(1) +except subprocess.CalledProcessError: + print("Warning: vLLM CLI tool not found or not working properly. Continuing anyway...") +except FileNotFoundError: + print("Warning: vLLM CLI tool not found in PATH. Continuing anyway...") + +# Create log file directory if needed +if args.log_file: + log_dir = os.path.dirname(args.log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_file = open(args.log_file, 'w') +else: + log_file = None + +# Write PID to file for cleanup +with open(args.pid_file, "w") as f: + f.write(str(os.getpid())) +print(f"Server PID {os.getpid()} written to {args.pid_file}") + +# Graceful shutdown handler +def signal_handler(sig, frame): + print(f"Received signal {sig}, shutting down...") + if server_process and server_process.poll() is None: + server_process.terminate() + server_process.wait(timeout=10) + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(0) + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def start_vllm_server(): + """Start the vLLM server with OpenAI-compatible API""" + + env = os.environ.copy() + + cmd = [ + "vllm", "serve", args.model_path, + "--served-model-name", "qwen3", + "--trust-remote-code", + "--max-model-len", str(args.context_length), + "--gpu-memory-utilization", "0.95", + # Uncomment for RoPE scaling when context exceeds 32k (e.g., TSEvol benchmark): + # "--rope-scaling", '{"rope_type":"yarn","factor":2.0,"original_max_position_embeddings":32768}', + "--host", "0.0.0.0", + "--port", str(args.port), + "--uvicorn-log-level", "debug", + "--data-parallel-size", str(args.data_parallel_size), + "--tensor-parallel-size", str(args.tensor_parallel_size), + ] + + # Add chat template if specified + if args.chat_template: + chat_template_path = os.path.abspath(args.chat_template) + if os.path.exists(chat_template_path): + cmd.extend(["--chat-template", chat_template_path]) + print(f"Using custom chat template from: {chat_template_path}") + else: + print(f"Warning: Specified chat template file '{chat_template_path}' not found. Using default template.") + + print(f"Starting vLLM server with command: {' '.join(cmd)}") + print(f"Data Parallel Size: {args.data_parallel_size}, Tensor Parallel Size: {args.tensor_parallel_size}") + print(f"GPU Configuration: {args.device}") + + process = subprocess.Popen( + cmd, + env=env, + stdout=log_file, + stderr=log_file if log_file else subprocess.STDOUT + ) + + return process + +def check_server_health(max_retries=60, retry_interval=5): + """Check if the server is healthy by polling the health endpoint""" + import requests + from requests.exceptions import ConnectionError + + initial_wait = args.initial_wait + print(f"Waiting {initial_wait} seconds for initial model loading...") + time.sleep(initial_wait) + + print(f"Checking if server is ready at http://localhost:{args.port}/v1/models...") + + for i in range(max_retries): + try: + response = requests.get(f"http://localhost:{args.port}/v1/models", timeout=10) + if response.status_code == 200: + print("Server is ready!") + return True + except ConnectionError: + pass + except requests.exceptions.Timeout: + print("Request timed out. Server might be busy loading the model.") + + print(f"Server not ready yet, retrying in {retry_interval} seconds... ({i+1}/{max_retries})") + time.sleep(retry_interval) + + print("Server failed to start within the expected time") + return False + +if __name__ == "__main__": + server_process = start_vllm_server() + + if not check_server_health(): + print("Failed to start server, exiting") + if server_process and server_process.poll() is None: + server_process.terminate() + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(1) + + try: + while server_process.poll() is None: + time.sleep(1) + except KeyboardInterrupt: + signal_handler(signal.SIGINT, None) + + exit_code = server_process.returncode + print(f"Server process exited with code {exit_code}") + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(exit_code) diff --git a/src/qwen3_utils/simple_chat_template.jinja b/src/qwen3_utils/simple_chat_template.jinja new file mode 100644 index 0000000..e717960 --- /dev/null +++ b/src/qwen3_utils/simple_chat_template.jinja @@ -0,0 +1,22 @@ +{%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} +{%- endif %} + +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + + {%- if message.role == "user" %} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {# Just pass through the assistant message without modifying it #} + {{- '<|im_start|>assistant\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/src/qwen3_utils/start_qwen3_server.sh b/src/qwen3_utils/start_qwen3_server.sh new file mode 100755 index 0000000..182861f --- /dev/null +++ b/src/qwen3_utils/start_qwen3_server.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# start_qwen3_server.sh +# +# Script to start a Qwen3 server for text-only reasoning inference. +# This script: +# 1. Initializes the environment +# 2. Starts the Qwen3 server using vLLM with a custom chat template +# 3. Checks that the server is running and operational +# +# The custom chat template (simple_chat_template.jinja) passes through +# assistant messages without modifying tags, which is required for +# the continue_final_message injection pattern. +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/src/qwen3_utils +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Qwen3-32B model path (GRLM) +QWEN3_MODEL_PATH="" # Path to Qwen3-32B checkpoint + +# Server configuration +QWEN3_PORT=5001 +QWEN3_PID_FILE="/tmp/qwen3_server_${QWEN3_PORT}.pid" +export QWEN3_SERVER_PORT="${QWEN3_PORT}" + +# Device configuration +QWEN3_DEVICE="0,1,2,3" +QWEN3_DATA_PARALLEL_SIZE=1 +QWEN3_TENSOR_PARALLEL_SIZE=4 + +# Custom chat template (required for continue_final_message injection) +QWEN3_CHAT_TEMPLATE="${SCRIPT_DIR}/simple_chat_template.jinja" +echo "Using chat template: ${QWEN3_CHAT_TEMPLATE}" +if [ -f "${QWEN3_CHAT_TEMPLATE}" ]; then + echo "Chat template file exists" +else + echo "WARNING: Chat template file does not exist!" +fi + +# Create log directory +LOG_DIR="${PROJECT_ROOT}/logs" +mkdir -p "$LOG_DIR" + +# Qwen3 log files +QWEN3_LOG="${LOG_DIR}/qwen3_server.$(date +%Y-%m-%d-%H-%M-%S).log" +QWEN3_CONSOLE_LOG="${LOG_DIR}/qwen3_console.$(date +%Y-%m-%d-%H-%M-%S).log" + +# ── Initialize Conda in this shell ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh. Do you need to run 'conda init'?" + exit 1 +fi +# ─────────────────────────────────────────────────────────────────────────────── + +# ===== Start Qwen3 Server ===== +echo "Starting Qwen3 server with qwen3-vllm environment..." + +# Activate environment for Qwen3 +eval "$(conda shell.bash hook)" +conda activate qwen3-vllm + +# Check if a server is already running on the port +if nc -z localhost $QWEN3_PORT 2>/dev/null; then + echo "Warning: Port $QWEN3_PORT is already in use!" + echo "Another server might be running. Check with: lsof -i :$QWEN3_PORT" + + # Ask if we should continue or abort + read -p "Do you want to continue anyway? [y/N] " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Aborting server startup." + exit 1 + fi +fi + +# Clear existing PID file if it exists +if [ -f "$QWEN3_PID_FILE" ]; then + echo "Removing existing PID file: $QWEN3_PID_FILE" + rm -f "$QWEN3_PID_FILE" +fi + +# Make server script executable +QWEN3_SERVER_SCRIPT="${SCRIPT_DIR}/qwen3_server.py" +chmod +x "$QWEN3_SERVER_SCRIPT" + +# Start Qwen3 server +echo "Starting Qwen3 server with log at ${QWEN3_LOG}" +"$QWEN3_SERVER_SCRIPT" \ + --model_path "${QWEN3_MODEL_PATH}" \ + --port "${QWEN3_PORT}" \ + --device "${QWEN3_DEVICE}" \ + --data_parallel_size "${QWEN3_DATA_PARALLEL_SIZE}" \ + --tensor_parallel_size "${QWEN3_TENSOR_PARALLEL_SIZE}" \ + --pid_file "${QWEN3_PID_FILE}" \ + --log_file "${QWEN3_LOG}" \ + --chat_template "${QWEN3_CHAT_TEMPLATE}" \ + --initial_wait 180 \ + > "${QWEN3_CONSOLE_LOG}" 2>&1 & + +QWEN3_SERVER_PID=$! +echo "Started Qwen3 server process with PID $QWEN3_SERVER_PID" + +# Wait briefly to make sure the process starts +sleep 10 + +# Check if the PID file was created +if [ -f "$QWEN3_PID_FILE" ]; then + FILE_PID=$(cat $QWEN3_PID_FILE) + echo "Qwen3 server PID file created with PID ${FILE_PID}" +else + echo "Qwen3 server PID file not created yet, writing our tracked PID" + echo $QWEN3_SERVER_PID > "$QWEN3_PID_FILE" +fi + +# Check if the server process is still running +if kill -0 $QWEN3_SERVER_PID 2>/dev/null; then + echo "Qwen3 server process is running" +else + echo "Error: Qwen3 server process exited unexpectedly" + echo "Check the logs:" + echo "Console log: $QWEN3_CONSOLE_LOG" + echo "Server log: $QWEN3_LOG" + exit 1 +fi + +# Wait for server initialization +echo "Waiting for Qwen3 server to initialize (240 seconds)..." +echo "You can monitor the logs with:" +echo "tail -f ${QWEN3_CONSOLE_LOG}" +echo "tail -f ${QWEN3_LOG}" +sleep 240 # 4 minute initial wait + +# ===== Test Server Connectivity ===== +echo "Testing Qwen3 server connectivity..." +python -c " +from openai import OpenAI +client = OpenAI(base_url='http://localhost:${QWEN3_PORT}/v1', api_key='dummy-key') +try: + response = client.models.list() + print(f'Qwen3 models available: {response}') + print('Qwen3 server is operational!') + exit(0) +except Exception as e: + print(f'Error testing Qwen3 server: {e}') + exit(1) +" +QWEN3_TEST_EXIT_CODE=$? + +if [ $QWEN3_TEST_EXIT_CODE -ne 0 ]; then + echo "Error: Qwen3 server test failed." + echo "Check the logs:" + echo "Console log: $QWEN3_CONSOLE_LOG" + echo "Server log: $QWEN3_LOG" + exit 1 +else + echo "Qwen3 server test passed successfully!" + echo "The server is running on port $QWEN3_PORT" + echo "To stop the server later, run: $SCRIPT_DIR/stop_qwen3_server.sh" +fi + +echo "" +echo "========================================" +echo "Qwen3 Server is ready for inference!" +echo "Server URL: http://localhost:$QWEN3_PORT" +echo "========================================" diff --git a/src/qwen3_utils/stop_qwen3_server.sh b/src/qwen3_utils/stop_qwen3_server.sh new file mode 100755 index 0000000..1e8f08e --- /dev/null +++ b/src/qwen3_utils/stop_qwen3_server.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# stop_qwen3_server.sh +# +# Script to stop a running Qwen3 server. +# This script: +# 1. Finds the PID file for the Qwen3 server +# 2. Sends a SIGTERM signal to gracefully shut down the server +# ============================================================================== + +# Qwen3 server PID file location +QWEN3_PORT=5001 +QWEN3_PID_FILE="/tmp/qwen3_server_${QWEN3_PORT}.pid" + +echo "Stopping Qwen3 server..." + +# Check if PID file exists +if [ ! -f "$QWEN3_PID_FILE" ]; then + echo "No PID file found at $QWEN3_PID_FILE" + + # Check if there's a process listening on the Qwen3 port + if nc -z localhost $QWEN3_PORT 2>/dev/null; then + echo "Warning: Port $QWEN3_PORT is in use but no PID file exists." + echo "Finding processes using port $QWEN3_PORT:" + + # Find and display processes using the port + if command -v lsof &> /dev/null; then + PROCS=$(lsof -i :$QWEN3_PORT -t) + if [ -n "$PROCS" ]; then + echo "Found processes: $PROCS" + + # Ask if we should kill these processes + read -p "Do you want to kill these processes? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Killing processes using port $QWEN3_PORT" + for pid in $PROCS; do + echo "Killing process $pid" + kill -9 $pid + done + else + echo "No processes were killed. Please manually stop the server." + fi + else + echo "No processes found using lsof." + fi + else + echo "lsof command not available, cannot find processes by port." + fi + else + echo "No process is listening on port $QWEN3_PORT" + fi + + exit 0 +fi + +# Read PID from file +PID=$(cat $QWEN3_PID_FILE) +echo "Found Qwen3 server with PID: $PID" + +# Check if the process exists +if kill -0 $PID 2>/dev/null; then + echo "Sending SIGTERM to PID $PID" + kill -15 $PID + + # Wait for the process to terminate + echo "Waiting for server to shut down..." + for i in {1..30}; do + if ! kill -0 $PID 2>/dev/null; then + echo "Server shut down successfully." + break + fi + sleep 1 + done + + # If process still exists, force kill + if kill -0 $PID 2>/dev/null; then + echo "Server did not shut down gracefully, sending SIGKILL..." + kill -9 $PID + sleep 2 + fi +else + echo "Process with PID $PID does not exist or is not accessible." +fi + +# Remove PID file +if [ -f "$QWEN3_PID_FILE" ]; then + echo "Removing PID file: $QWEN3_PID_FILE" + rm -f "$QWEN3_PID_FILE" +fi + +# Final check +if nc -z localhost $QWEN3_PORT 2>/dev/null; then + echo "Warning: Port $QWEN3_PORT is still in use after stopping the server." + echo "You might need to manually kill the remaining processes." +else + echo "Port $QWEN3_PORT is now free." +fi + +echo "Qwen3 server stop script completed." diff --git a/src/qwen3_with_injection.py b/src/qwen3_with_injection.py new file mode 100644 index 0000000..26a8189 --- /dev/null +++ b/src/qwen3_with_injection.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +""" +Text-Only GRLM with Qwen-VL Injection Script + +This script injects Qwen-VL observations (thoughts and answers) into a text-only +GRLM's (Qwen3-32B or DeepSeek-R1-Distill-Qwen-32B) thinking process for enhanced +time series reasoning. + +For each example it will: + 1. Load the original timeseries + question from the dataset. + 2. Load the initial thoughts and answers from Qwen-VL output JSON. + 3. Format the timeseries data as JSON text (Qwen3 is text-only). + 4. Build a prompt that: + - User: contains the question and time series data + - Assistant: begins with the Qwen-VL thoughts AND answer as part of the thinking + 5. Call Qwen3 with continue_final_message=true + 6. Parse out the complete answer. + 7. Save idx, question, initial thoughts, full thought, answer, success flag. + +Usage: + python qwen3_with_injection.py \\ + --dataset_path /path/to/dataset.json \\ + --injection_path /path/to/qwen_vl_output.json \\ + --output_path /path/to/output.json +""" + +import os +import sys +import json +import time +import re +import argparse +import logging +import threading +from queue import Queue +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor + +try: + from openai import OpenAI +except ImportError: + raise ImportError("OpenAI Python client not installed. Please install with 'pip install openai'") + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def parse_args(): + p = argparse.ArgumentParser(description="Qwen3 with Qwen-VL injection for time series analysis") + p.add_argument("--server_url", default="http://localhost:5001", + help="URL of the GRLM server (default: Qwen3 on 5001, use 5002 for R1)") + p.add_argument("--model_name", default="qwen3", + help="Model name for the GRLM server (qwen3 or r1)") + p.add_argument("--dataset_path", "-d", required=True, + help="Path to the JSON evaluation set") + p.add_argument("--injection_path", "-p", required=True, + help="Path to the Qwen-VL output JSON (from qwen_inference.py)") + p.add_argument("--output_path", "-o", required=True, + help="Where to write generated answers JSON") + p.add_argument("--max_tokens", type=int, default=6144, + help="Maximum tokens to generate") + p.add_argument("--workers", type=int, default=4, + help="Number of parallel workers for processing samples") + p.add_argument("--checkpoint_interval", type=int, default=10, + help="Interval for saving checkpoints") + p.add_argument("--timeout", type=int, default=120, + help="Timeout in seconds for API calls") + p.add_argument("--retry_delay", type=int, default=5, + help="Delay between retries in seconds") + p.add_argument("--max_retries", type=int, default=3, + help="Maximum number of retry attempts") + return p.parse_args() + + +class GRLMClient: + """Client for communicating with a text-only GRLM server (Qwen3 or DeepSeek-R1) using OpenAI API.""" + + def __init__(self, server_url="http://localhost:5001", model_name="qwen3", debug_mode=False): + """Initialize the GRLM client.""" + self.server_url = server_url + self.model_name = model_name + self.debug_mode = debug_mode + self.client = OpenAI(base_url=f"{server_url}/v1", api_key="dummy-key") + + if debug_mode: + logger.setLevel(logging.DEBUG) + logger.info(f"GRLMClient initialized in DEBUG mode with server URL: {server_url}") + + def check_server_health(self): + """Check if the server is healthy.""" + import requests + try: + logger.info(f"Checking health of {self.model_name} server at {self.server_url}...") + response = requests.get(f"{self.server_url}/v1/models", timeout=10) + if response.status_code == 200: + logger.info(f"{self.model_name} server is healthy") + return True + else: + logger.warning(f"{self.model_name} server health check failed: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error checking server health: {type(e).__name__}: {e}") + return False + + def query_with_injection( + self, + user_prompt, + assistant_start, + system_message=None, + max_tokens=6144, + temperature=0.6, + timeout=120, + retry_delay=5, + max_retries=3, + ): + """ + Query Qwen3 with an injected assistant start using continue_final_message. + + The assistant_start contains Qwen-VL's thoughts and answer, which Qwen3 + continues from. This requires a custom chat template that passes through + the assistant message without modifying tags. + + Args: + user_prompt: Text prompt for the user message + assistant_start: Initial text for the assistant response (Qwen-VL injection) + system_message: Optional system message + max_tokens: Maximum tokens to generate + temperature: Temperature for sampling + timeout: Request timeout in seconds + retry_delay: Delay between retries in seconds + max_retries: Maximum number of retry attempts + + Returns: + Model response continuation as string + """ + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + + messages.append({"role": "user", "content": user_prompt}) + messages.append({"role": "assistant", "content": assistant_start}) + + for attempt in range(max_retries): + try: + logger.info(f"Sending query to Qwen3 server (attempt {attempt+1}/{max_retries})") + start_time = time.time() + + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=0.95, + extra_body={ + "add_generation_prompt": False, + "continue_final_message": True, + "top_k": 20 + } + ) + + end_time = time.time() + elapsed = end_time - start_time + + usage = getattr(response, 'usage', None) + if usage: + logger.info( + f"Query successful in {elapsed:.2f}s: " + f"prompt_tokens={usage.prompt_tokens}, " + f"completion_tokens={usage.completion_tokens}, " + f"total_tokens={usage.total_tokens}" + ) + else: + logger.info(f"Query successful, inference time: {elapsed:.2f}s") + + return response.choices[0].message.content + + except Exception as e: + logger.error(f"Query failed (attempt {attempt+1}/{max_retries}): {type(e).__name__}: {e}") + time.sleep(retry_delay) + + retry_delay = min(retry_delay * 2, 60) + + logger.critical(f"Failed to get response from Qwen3 server after {max_retries} attempts") + raise RuntimeError("Failed to get response from Qwen3 server after multiple attempts") + + +def prepare_timeseries_data(ts): + """ + Prepare timeseries data for text formatting. + + Args: + ts: Raw timeseries data + + Returns: + Processed timeseries data suitable for formatting + """ + if not isinstance(ts, list): + ts = [ts] + elif len(ts) > 0 and not isinstance(ts[0], list): + ts = [ts] + + if not ts or len(ts) == 0: + ts = [[0, 1, 2, 3, 4]] + logger.warning("Entry has empty time series data, using dummy data") + + return ts + + +def format_timeseries_as_json(timeseries, cols=None): + """ + Format timeseries data as a JSON string for inclusion in prompts. + + Args: + timeseries: The time series data as a list of lists + cols: Optional column names for the time series + + Returns: + Formatted JSON string representation of the time series + """ + ts_data = prepare_timeseries_data(timeseries) + + if not cols or len(cols) != len(ts_data): + if cols and len(cols) != len(ts_data): + logger.error(f"Column count ({len(cols)}) doesn't match time series count ({len(ts_data)}). Using default column names.") + cols = [f"Series {i+1}" for i in range(len(ts_data))] + + max_length = max(len(series) for series in ts_data) + timestamps = list(range(1, max_length + 1)) + + ts_json = "{\n" + ts_json += f' "timestamps": {timestamps},\n' + + for i, (col, series) in enumerate(zip(cols, ts_data)): + series_values = [round(float(v), 2) for v in series] + ts_json += f' "{col}": {series_values}' + if i < len(ts_data) - 1: + ts_json += ",\n" + else: + ts_json += "\n" + ts_json += "}" + + return ts_json + + +def build_prompt_with_injection(question, initial_thought, qwen_answer, timeseries=None, cols=None): + """ + Build prompts that include the question as user prompt and both Qwen-VL + thoughts and answer as the start of the assistant's response. + + Args: + question: The original question text + initial_thought: Initial thoughts from Qwen-VL + qwen_answer: Answer from Qwen-VL + timeseries: Time series data (optional) + cols: Column names for time series (optional) + + Returns: + Tuple of (user_prompt, assistant_start_content) + """ + # Split off the "Now, based on ..." part if it exists + if "Now," in question: + q_part, rest = question.split("Now,", 1) + question_part = q_part.strip() + answer_format = "Now," + rest.strip() + else: + question_part = question.strip() + answer_format = "" + + # Format time series data if provided + ts_text = "" + if timeseries is not None and cols is not None: + ts_json = format_timeseries_as_json(timeseries, cols) + ts_text = f"Here is the time series data in JSON format:\n{ts_json}\n\n" + + # Build the user prompt - question and time series data + user_prompt = f"{ts_text}{question_part}" + + if answer_format: + user_prompt += f"\n\n{answer_format}" + + # Build the assistant's starting content with TSLM thinking trace only + # (assistant prefill for early injection) + assistant_start = ( + f"\n{initial_thought.strip()}\n\n" + f"Wait, let me summarize and reflect on the previous observations from the time series, " + f"and then continue my reasoning process to derive the final answer." + ) + + return user_prompt, assistant_start + + +def parse_response(response_content, assistant_start): + """ + Parse Qwen3 API response continuation. + + The response_content is the continuation after the assistant_start. + We look for to separate additional thinking from the final answer. + + Args: + response_content: The continuation text from Qwen3 + assistant_start: The initial assistant response text (for full thought reconstruction) + + Returns: + Tuple of (full_thought, answer, success) + """ + try: + parts = response_content.split('') + + if len(parts) > 1: + # Found - everything after is the answer + answer_part = parts[-1].strip() + + # Construct full thought: initial injection + model's continued thinking + clean_assistant_start = assistant_start.replace('', '', 1).strip() + full_thought = clean_assistant_start + "\n\n" + "\n\n".join(parts[:-1]) + else: + # No found - consider everything as thinking + full_thought = assistant_start.replace('', '', 1).strip() + "\n\n" + response_content + answer_part = "" + + clean_content = answer_part + + # Try to parse as JSON with "Final Answer" key + try: + raw_obj = json.loads(clean_content) + if isinstance(raw_obj, dict) and "Final Answer" in raw_obj: + return full_thought, raw_obj["Final Answer"], True + except json.JSONDecodeError: + pass + + # Try to extract JSON from code blocks + json_pattern = re.compile(r'```(?:json)?\s*({.*?})\s*```', re.DOTALL) + json_matches = json_pattern.findall(clean_content) + + if json_matches: + try: + obj = json.loads(json_matches[0]) + answer = obj.get("Final Answer", "") + success = "Final Answer" in obj + except Exception: + answer = clean_content + success = False + else: + # Check for inline JSON object + json_pattern_no_blocks = re.compile(r'{\s*"Final Answer"\s*:\s*".*?"}', re.DOTALL) + matches = json_pattern_no_blocks.findall(clean_content) + + if matches: + try: + obj = json.loads(matches[0]) + answer = obj.get("Final Answer", "") + success = "Final Answer" in obj + except Exception: + answer = clean_content + success = False + else: + answer = clean_content + success = True if clean_content.strip() else False + + return full_thought, answer, success + + except Exception as e: + logger.error(f"Error parsing response: {str(e)}") + return "", str(e), False + + +def process_sample(args, client, sample, idx, injection_entry): + """ + Process a single sample with the Qwen3 client using Qwen-VL injection. + + Args: + args: Command line arguments + client: GRLMClient instance + sample: Data sample with timeseries and columns + idx: Sample index + injection_entry: Qwen-VL output entry with thought and response + + Returns: + Result dictionary + """ + try: + ts = sample["timeseries"] + cols = sample.get("cols", []) + + # Prepare timeseries data + ts = prepare_timeseries_data(ts) + + # Ensure cols list is properly sized + if not cols or len(cols) != len(ts): + cols = [f"Series {i+1}" for i in range(len(ts))] + + # Build prompts with Qwen-VL injection + user_prompt, assistant_start = build_prompt_with_injection( + injection_entry["question"], + injection_entry["thought"], + injection_entry["response"], + ts, + cols + ) + + # System message + system_message = ( + "You are a time-series expert. \n" + "Answer **only** with a JSON object that has exactly one key, \"Final Answer\",\n" + "whose value is the answer string. \n" + ) + + # Query Qwen3 with injection + response_content = client.query_with_injection( + user_prompt, + assistant_start, + system_message, + max_tokens=args.max_tokens, + timeout=args.timeout, + retry_delay=args.retry_delay, + max_retries=args.max_retries + ) + + # Parse the response + full_thought, answer, ok = parse_response(response_content, assistant_start) + + return { + "idx": idx, + "question": injection_entry["question"], + "initial_thought": injection_entry["thought"], + "qwen_vl_answer": injection_entry["response"], + "thought": full_thought, + "response": answer, + "success": ok, + "ability_types": sample.get("ability_types", []), + "attributes": sample.get("attributes", {}) + } + + except Exception as e: + logger.error(f"Error processing sample {idx}: {str(e)}") + return { + "idx": idx, + "question": injection_entry.get("question", ""), + "initial_thought": injection_entry.get("thought", ""), + "qwen_vl_answer": injection_entry.get("response", ""), + "thought": "", + "response": f"ERROR: {str(e)}", + "success": False, + "ability_types": sample.get("ability_types", []), + "attributes": sample.get("attributes", {}) + } + + +def main(): + args = parse_args() + + # Initialize lock for thread-safe operations + results_lock = threading.Lock() + + # Initialize GRLM client pool (works for both Qwen3 and DeepSeek-R1) + clients = [GRLMClient(server_url=args.server_url, model_name=args.model_name) for _ in range(args.workers)] + + # Check server health with first client + if not clients[0].check_server_health(): + logger.error(f"{args.model_name} server is not healthy. Please make sure it is running.") + if args.model_name == "r1": + logger.error("Run: src/r1_utils/start_r1_server.sh to start the server.") + else: + logger.error("Run: src/qwen3_utils/start_qwen3_server.sh to start the server.") + sys.exit(1) + else: + logger.info(f"Using {args.workers} workers for parallel processing with {args.model_name}") + + # Load dataset + logger.info(f"Loading dataset from {args.dataset_path}") + with open(args.dataset_path, "r") as f: + dataset = json.load(f) + + # Load Qwen-VL injection outputs + logger.info(f"Loading Qwen-VL injection outputs from {args.injection_path}") + with open(args.injection_path, "r") as f: + injection_outputs = json.load(f) + injection_map = {h["idx"]: h for h in injection_outputs} + + # Create output directory + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Load existing results if any (resume support) + results = [] + if os.path.exists(args.output_path): + logger.info(f"Loading existing results from {args.output_path}") + with open(args.output_path, "r") as f: + results = json.load(f) + processed_indices = {r["idx"] for r in results} + logger.info(f"Resuming from {len(processed_indices)} already processed entries") + else: + processed_indices = set() + + # Track results + results_dict = {r["idx"]: r for r in results} + total_processed = len(results) + progress_queue = Queue() + + # Determine which indices still need processing + to_process = [i for i in sorted(injection_map) if i not in processed_indices] + + if not to_process: + logger.info("No samples to process. All done!") + return + + logger.info(f"Processing {len(to_process)} samples with {args.model_name} injection") + + if args.workers > 1: + logger.info(f"Starting parallel processing with {args.workers} workers") + + def process_results(): + nonlocal total_processed + + with tqdm(total=len(to_process), desc=f"{args.model_name} injection", initial=0) as pbar: + while True: + idx, result = progress_queue.get() + + if idx == -1: # Sentinel value to exit + break + + with results_lock: + if result is not None: + results_dict[idx] = result + total_processed += 1 + + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + pbar.update(1) + progress_queue.task_done() + + # Start progress tracking thread + progress_thread = threading.Thread(target=process_results) + progress_thread.start() + + # Process samples in parallel + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + client_idx = 0 + + for idx in to_process: + if idx >= len(dataset): + logger.warning(f"Skipping idx={idx} - index out of range for dataset") + continue + + future = executor.submit( + process_sample, + args, + clients[client_idx % len(clients)], + dataset[idx], + idx, + injection_map[idx] + ) + futures[future] = idx + client_idx += 1 + + for future in futures: + idx = futures[future] + try: + result = future.result() + progress_queue.put((idx, result)) + except Exception as e: + logger.error(f"Worker error on sample {idx}: {str(e)}") + progress_queue.put((idx, { + "idx": idx, + "question": injection_map[idx].get("question", ""), + "initial_thought": injection_map[idx].get("thought", ""), + "qwen_vl_answer": injection_map[idx].get("response", ""), + "thought": "", + "response": f"WORKER ERROR: {str(e)}", + "success": False, + "ability_types": dataset[idx].get("ability_types", []) if idx < len(dataset) else [], + "attributes": dataset[idx].get("attributes", {}) if idx < len(dataset) else {} + })) + + # Signal progress thread to exit + progress_queue.put((-1, None)) + progress_thread.join() + + else: + # Sequential processing + logger.info("Using sequential processing (workers=1)") + for idx in tqdm(to_process, desc=f"{args.model_name} injection"): + if idx >= len(dataset): + logger.warning(f"Skipping idx={idx} - index out of range for dataset") + continue + + result = process_sample(args, clients[0], dataset[idx], idx, injection_map[idx]) + results_dict[idx] = result + total_processed += 1 + + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + # Final save + final_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(final_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(final_results)} answers to {args.output_path}") + + +if __name__ == "__main__": + main() diff --git a/src/qwen_inference.py b/src/qwen_inference.py new file mode 100755 index 0000000..5e98910 --- /dev/null +++ b/src/qwen_inference.py @@ -0,0 +1,559 @@ +#!/usr/bin/env python3 +""" +Qwen2.5-VL server-based inference script for time series datasets. + +This script: +1. Loads a time series dataset +2. Generates all time series figures sequentially +3. Connects to a running Qwen2.5-VL server (start with start_qwen_vl_server.sh) +4. Processes figures with Qwen2.5-VL in parallel with thinking mode +5. Saves the results to a JSON file + +Usage: + python qwen_inference.py --dataset_path /path/to/dataset.json --output_path /path/to/output.json +""" + +import os +import sys +import json +import time +import re +import argparse +import logging +import base64 +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor +import threading +from queue import Queue + +try: + from openai import OpenAI +except ImportError: + raise ImportError("OpenAI Python client not installed. Please install with 'pip install openai'") + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Add the parent directory to the path so we can import claude_utils modules +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from claude_utils.ts_visualization import generate_image_from_timeseries + +def parse_args(): + p = argparse.ArgumentParser(description="Evaluate Qwen2.5-VL on a time-series QA dataset using a server") + p.add_argument("--server_url", default="http://localhost:5003", + help="URL of the Qwen2.5-VL server") + p.add_argument("--dataset_path", "-d", required=True, + help="Path to the JSON evaluation set") + p.add_argument("--output_path", "-o", required=True, + help="Where to write generated answers JSON") + p.add_argument("--max_tokens", type=int, default=1024, + help="Maximum tokens to generate") + p.add_argument("--max_samples", type=int, default=200, + help="Maximum number of samples to process") + p.add_argument("--seed", type=int, default=42, + help="Random seed for sampling") + p.add_argument("--checkpoint_interval", type=int, default=10, + help="Interval for saving checkpoints") + p.add_argument("--workers", type=int, default=4, + help="Number of parallel workers for processing samples") + p.add_argument("--timeout", type=int, default=120, + help="Timeout in seconds for API calls") + p.add_argument("--retry_delay", type=int, default=5, + help="Delay between retries in seconds") + p.add_argument("--max_retries", type=int, default=3, + help="Maximum number of retry attempts") + return p.parse_args() + +class QwenVLClient: + """Client for communicating with Qwen2.5-VL server using OpenAI API.""" + + def __init__(self, server_url="http://localhost:5003", debug_mode=False): + """Initialize the Qwen2.5-VL client.""" + self.server_url = server_url + self.debug_mode = debug_mode + self.client = OpenAI(base_url=f"{server_url}/v1", api_key="dummy-key") + + if debug_mode: + logger.setLevel(logging.DEBUG) + logger.info(f"QwenVLClient initialized in DEBUG mode with server URL: {server_url}") + + def check_server_health(self): + """Check if the server is healthy.""" + import requests + try: + logger.info(f"Checking health of Qwen2.5-VL server at {self.server_url}...") + response = requests.get(f"{self.server_url}/v1/models", timeout=10) + if response.status_code == 200: + logger.info(f"Qwen2.5-VL server is healthy") + return True + else: + logger.warning(f"Qwen2.5-VL server health check failed: {response.status_code}") + return False + except Exception as e: + logger.error(f"Error checking server health: {type(e).__name__}: {e}") + return False + + def query_qwen_with_image( + self, + image_b64, + question, + max_tokens=1024, + temperature=0, + timeout=120, + retry_delay=5, + max_retries=3, + ): + """ + Query Qwen2.5-VL with an image and question. + + Args: + image_b64: Base64 encoded image string + question: Question text + max_tokens: Maximum tokens to generate + temperature: Temperature for sampling + timeout: Request timeout in seconds + retry_delay: Delay between retries in seconds + max_retries: Maximum number of retry attempts + + Returns: + Model response as string + """ + # Create standard system message with thinking instructions + system_message = "You are a helpful assistant that analyzes time series data." + thinking_instruction = "First output the thinking process in tags and then output the final answer in tags" + + # Build image content + image_content = [] + if image_b64: + image_content = [ + {"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{image_b64}" + }} + ] + + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": [ + {"type": "text", "text": f"{thinking_instruction}\n\n{question}"}, + *image_content + ]} + ] + + # Make the request with retries + for attempt in range(max_retries): + try: + logger.info(f"Sending query to Qwen2.5-VL server (attempt {attempt+1}/{max_retries})") + start_time = time.time() + + response = self.client.chat.completions.create( + model="qwen_vl", + messages=messages, + max_tokens=max_tokens, + temperature=temperature + ) + + end_time = time.time() + elapsed = end_time - start_time + + # Get usage information if available + usage = getattr(response, 'usage', None) + if usage: + logger.info( + f"Query successful in {elapsed:.2f}s: " + f"prompt_tokens={usage.prompt_tokens}, " + f"completion_tokens={usage.completion_tokens}, " + f"total_tokens={usage.total_tokens}" + ) + else: + logger.info(f"Query successful, inference time: {elapsed:.2f}s") + + return response.choices[0].message.content + + except Exception as e: + logger.error(f"Query failed (attempt {attempt+1}/{max_retries}): {type(e).__name__}: {e}") + time.sleep(retry_delay) # Wait before retry + + # Increase retry delay for exponential backoff + retry_delay = min(retry_delay * 2, 60) # Cap at 60 seconds + + logger.critical(f"Failed to get response from Qwen2.5-VL server after {max_retries} attempts") + raise RuntimeError("Failed to get response from Qwen2.5-VL server after multiple attempts") + +def prepare_timeseries_data(ts): + """ + Prepare timeseries data for visualization. + + Args: + ts: Raw timeseries data + + Returns: + Processed timeseries data suitable for visualization + """ + # Ensure ts is a list of lists (for multiple series) + if not isinstance(ts, list): + ts = [ts] # Wrap single series + elif len(ts) > 0 and not isinstance(ts[0], list): + ts = [ts] # Wrap flat list into nested list + + # Check if we got empty data + if not ts or len(ts) == 0: + ts = [[0, 1, 2, 3, 4]] # Default dummy data + logger.warning("Entry has empty time series data, using dummy data") + + return ts + +def generate_all_images(data, to_process, fig_dir): + """ + Generate all images sequentially and return image paths. + + Args: + data: Dataset containing timeseries data + to_process: List of indices to process + fig_dir: Directory to save figures + + Returns: + Dictionary mapping indices to image paths + """ + image_paths = {} + + print(f"Generating {len(to_process)} figures sequentially...") + for idx in tqdm(to_process, desc="Generating figures"): + sample = data[idx] + ts = sample["timeseries"] + cols = sample.get("cols", []) + + # Prepare timeseries data + ts = prepare_timeseries_data(ts) + + # Generate and save image + path = os.path.join(fig_dir, f"{idx}.jpg") + try: + # Always save the image + _ = generate_image_from_timeseries( + case_idx=idx, + timeseries=ts, + cols=cols, + fig_dir=fig_dir, + save_image=True + ) + + # Check if the image was created successfully + if os.path.exists(path): + file_size = os.path.getsize(path) + if file_size > 0: + image_paths[idx] = path + else: + logger.warning(f"Empty image file generated for idx={idx}") + else: + logger.warning(f"Image file not created for idx={idx}") + except Exception as e: + logger.error(f"Error generating image for idx={idx}: {e}") + + print(f"Successfully generated {len(image_paths)} figures out of {len(to_process)} requested") + return image_paths + +def get_image_base64(image_path): + """ + Load an image from disk and convert to base64. + + Args: + image_path: Path to the image file + + Returns: + Base64 encoded string of the image + """ + try: + with open(image_path, "rb") as f: + img_b64 = base64.b64encode(f.read()).decode("utf-8") + return img_b64 + except Exception as e: + logger.error(f"Error reading image file {image_path}: {e}") + return None + +def parse_response(response): + """ + Parse response to extract thinking and answer parts. + + Args: + response: The raw response string from the model + + Returns: + Tuple of (thought, answer, success_flag) + """ + # Default values + thought = "" + answer = response + success = True + + # Extract thinking section + think_pattern = r"(.*?)" + think_match = re.search(think_pattern, response, re.DOTALL) + if think_match: + thought = think_match.group(1).strip() + + # Extract answer section + answer_pattern = r"(.*?)" + answer_match = re.search(answer_pattern, response, re.DOTALL) + if answer_match: + answer = answer_match.group(1).strip() + + # If we didn't find both sections, consider this a partial success + # but still return the full text as the answer + if not (think_match and answer_match): + logger.warning("Could not parse response into thinking and answer parts") + success = False + if not answer: + answer = response # Fallback to the full response + + return thought, answer, success + +def process_sample(args, client, sample, idx, image_path): + """ + Process a single sample with the Qwen2.5-VL client. + + Args: + args: Command line arguments + client: QwenVLClient instance + sample: Data sample + idx: Sample index + image_path: Path to the pre-generated image + + Returns: + Result dictionary + """ + try: + # Extract question + question = sample.get("question", "") + + # Load image as base64 + img_b64 = get_image_base64(image_path) if image_path else None + + # Query the Qwen2.5-VL server + raw_answer = client.query_qwen_with_image( + image_b64=img_b64, + question=question, + max_tokens=args.max_tokens, + timeout=args.timeout, + retry_delay=args.retry_delay, + max_retries=args.max_retries + ) + + # Parse the response to extract thinking and answer parts + thought, answer, parse_ok = parse_response(raw_answer) + + # Return successful result + return { + "idx": idx, + "question": question, + "thought": thought, + "response": answer, + "raw_response": raw_answer, + "success": True + } + + except Exception as e: + logger.error(f"Error processing sample {idx}: {str(e)}") + question = "" + try: + question = sample.get("question", "") + except Exception: + pass + + return { + "idx": idx, + "question": question, + "thought": "", + "response": f"ERROR: {str(e)}", + "raw_response": f"ERROR: {str(e)}", + "success": False + } + +def main(): + args = parse_args() + + # Initialize lock for thread-safe operations + results_lock = threading.Lock() + + # Initialize Qwen2.5-VL client pool + clients = [QwenVLClient(server_url=args.server_url) for _ in range(args.workers)] + + # Check server health with first client + if not clients[0].check_server_health(): + logger.error("Qwen2.5-VL server is not healthy. Please make sure it is running.") + logger.error("Run: src/qwen_utils/start_qwen_vl_server.sh to start the server.") + sys.exit(1) + else: + logger.info(f"Using {args.workers} workers for parallel processing") + + # Load evaluation set + logger.info(f"Loading dataset from {args.dataset_path}") + with open(args.dataset_path, "r") as f: + full_dataset = json.load(f) + + # Sample if needed + total_entries = len(full_dataset) + if total_entries > args.max_samples: + logger.info(f"Sampling {args.max_samples} entries from {total_entries} total") + import random + random.seed(args.seed) + indices = random.sample(range(total_entries), args.max_samples) + dataset = [full_dataset[i] for i in indices] + + # Save metadata about the sampling + metadata_file = args.output_path.replace(".json", "_sampling_metadata.json") + with open(metadata_file, "w") as f: + json.dump({ + "original_size": total_entries, + "sampled_size": args.max_samples, + "seed": args.seed, + "sampled_indices": indices + }, f, indent=2) + else: + dataset = full_dataset + + # Create output directory + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Set up experiment-specific figure directory + fig_dir = os.path.join(os.path.dirname(args.output_path), "figures") + os.makedirs(fig_dir, exist_ok=True) + logger.info(f"Using figure directory: {fig_dir}") + + # Load existing results if any (resume support) + results = [] + if os.path.exists(args.output_path): + logger.info(f"Loading existing results from {args.output_path}") + with open(args.output_path, "r") as f: + results = json.load(f) + processed_indices = {r["idx"] for r in results} + logger.info(f"Resuming from {len(processed_indices)} already processed entries") + else: + processed_indices = set() + + # Track results + results_dict = {r["idx"]: r for r in results} + total_processed = len(results) + progress_queue = Queue() + + # Determine which indices still need processing + to_process = [i for i in range(len(dataset)) if i not in processed_indices] + + # STEP 1: Generate all images sequentially + image_paths = generate_all_images(dataset, to_process, fig_dir) + image_count = len(image_paths) + logger.info(f"Generated {image_count} images out of {len(to_process)} total samples") + + # STEP 2: Process samples with Qwen2.5-VL + if args.workers > 1: + logger.info(f"Starting parallel processing with {args.workers} workers") + + # Function to process results and update progress + def process_results(): + nonlocal total_processed + + with tqdm(total=len(dataset), desc="Evaluating Qwen2.5-VL", initial=len(results)) as pbar: + while True: + idx, result = progress_queue.get() + + if idx == -1: # Sentinel value to exit + break + + with results_lock: + results_dict[idx] = result + total_processed += 1 + + # Checkpoint at specified intervals + if total_processed % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {total_processed} results") + + pbar.update(1) + progress_queue.task_done() + + # Start progress tracking thread + progress_thread = threading.Thread(target=process_results) + progress_thread.start() + + # Process samples in parallel + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + client_idx = 0 + + for idx in to_process: + # Get the image path for this index + image_path = image_paths.get(idx, None) + + # Skip if we don't have an image + if image_path is None: + logger.warning(f"Skipping idx={idx} - no image available") + continue + + future = executor.submit( + process_sample, + args, + clients[client_idx % len(clients)], + dataset[idx], + idx, + image_path + ) + futures[future] = idx + client_idx += 1 + + for future in futures: + idx = futures[future] + try: + result = future.result() + progress_queue.put((idx, result)) + except Exception as e: + logger.error(f"Worker error on sample {idx}: {str(e)}") + progress_queue.put((idx, { + "idx": idx, + "question": dataset[idx].get("question", ""), + "thought": "", + "response": f"WORKER ERROR: {str(e)}", + "raw_response": f"WORKER ERROR: {str(e)}", + "success": False + })) + + # Signal progress thread to exit and wait for it to finish + progress_queue.put((-1, None)) + progress_thread.join() + + else: + # Sequential processing + logger.info("Using sequential processing (workers=1)") + for idx in tqdm(to_process, desc="Evaluating Qwen2.5-VL"): + # Skip if already processed + if idx in processed_indices: + continue + + # Get the image path for this index + image_path = image_paths.get(idx, None) + + # Skip if we don't have an image + if image_path is None: + logger.warning(f"Skipping idx={idx} - no image available") + continue + + result = process_sample(args, clients[0], dataset[idx], idx, image_path) + results_dict[idx] = result + + # Checkpoint at specified intervals + if len(results_dict) % args.checkpoint_interval == 0: + checkpoint_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(checkpoint_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Checkpoint saved with {len(results_dict)} results") + + # Final save with sorted results + final_results = [results_dict[k] for k in sorted(results_dict.keys())] + with open(args.output_path, "w") as outf: + json.dump(final_results, outf, indent=2, ensure_ascii=False) + logger.info(f"Saved {len(final_results)} answers to {args.output_path}") + +if __name__ == "__main__": + main() diff --git a/src/qwen_utils/qwen_vl_server.py b/src/qwen_utils/qwen_vl_server.py new file mode 100755 index 0000000..9d2bfd2 --- /dev/null +++ b/src/qwen_utils/qwen_vl_server.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Qwen2.5-VL Server + +This script runs a vLLM server with OpenAI-compatible API for Qwen2.5-VL inference. +It leverages vLLM for serving Qwen2.5-VL models with multimodal capabilities. +""" + +import os +import sys +import time +import signal +import argparse +import json +import subprocess +from pathlib import Path + +# Parse command line arguments +parser = argparse.ArgumentParser(description="Qwen2.5-VL Server") +parser.add_argument("--model_path", type=str, required=True, help="Path to Qwen2.5-VL model") +parser.add_argument("--port", type=int, default=5003, help="Port to run server on") +parser.add_argument("--device", type=str, default="0,1,2,3", help="GPU device IDs") +parser.add_argument("--data_parallel_size", type=int, default=2, help="Data parallel size") +parser.add_argument("--tensor_parallel_size", type=int, default=2, help="Tensor parallel size") +parser.add_argument("--context_length", type=int, default=32768, help="Max context length") +parser.add_argument("--pid_file", type=str, default="/tmp/qwen_vl_server.pid", help="File to store server PID") +parser.add_argument("--log_file", type=str, default=None, help="File to log server output") +parser.add_argument("--initial_wait", type=int, default=120, help="Initial wait time in seconds") + +args = parser.parse_args() + +# Print all args for debugging +print("Arguments received:") +for arg in vars(args): + print(f" {arg}: {getattr(args, arg)}") + +# Set up GPU +os.environ["CUDA_VISIBLE_DEVICES"] = args.device +print(f"Using CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") + +# Using V1 without multiprocessing +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_USE_V1"] = "1" +print("Enabled vLLM v1 engine with environment variables:") +print(f" VLLM_ENABLE_V1_MULTIPROCESSING={os.environ.get('VLLM_ENABLE_V1_MULTIPROCESSING')}") +print(f" VLLM_USE_V1={os.environ.get('VLLM_USE_V1')}") + +# Check if vLLM is available +try: + import vllm + print(f"vLLM package found. Using vLLM for Qwen2.5-VL server.") + # Try to run a simple vLLM command to test if it's properly installed + subprocess.run(["vllm", "--version"], capture_output=True, check=False) + print("vLLM CLI tool is available.") +except ImportError: + print("Error: vLLM is not installed. Please install vllm.") + sys.exit(1) +except subprocess.CalledProcessError: + print("Warning: vLLM CLI tool not found or not working properly. Continuing anyway...") +except FileNotFoundError: + print("Warning: vLLM CLI tool not found in PATH. Continuing anyway...") + +# Create log file directory if needed +if args.log_file: + log_dir = os.path.dirname(args.log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_file = open(args.log_file, 'w') +else: + log_file = None + +# Write PID to file for cleanup +with open(args.pid_file, "w") as f: + f.write(str(os.getpid())) +print(f"Server PID {os.getpid()} written to {args.pid_file}") + +# Graceful shutdown handler +def signal_handler(sig, frame): + print(f"Received signal {sig}, shutting down...") + if server_process and server_process.poll() is None: + server_process.terminate() + server_process.wait(timeout=10) + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(0) + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def start_vllm_server(): + """Start the vLLM server with OpenAI-compatible API""" + + # Ensure environment variables are passed to the subprocess + env = os.environ.copy() + + cmd = [ + "vllm", "serve", args.model_path, + "--served-model-name", "qwen_vl", + "--trust-remote-code", + "--max-model-len", str(args.context_length), + "--gpu-memory-utilization", "0.95", + "--host", "0.0.0.0", + "--port", str(args.port), + "--uvicorn-log-level", "debug", + "--data-parallel-size", str(args.data_parallel_size), + "--tensor-parallel-size", str(args.tensor_parallel_size), + ] + + print(f"Starting vLLM server with command: {' '.join(cmd)}") + print(f"Data Parallel Size: {args.data_parallel_size}, Tensor Parallel Size: {args.tensor_parallel_size}") + print(f"GPU Configuration: {args.device}") + + # Start server process + process = subprocess.Popen( + cmd, + env=env, + stdout=log_file, + stderr=log_file if log_file else subprocess.STDOUT + ) + + return process + +def check_server_health(max_retries=60, retry_interval=5): + """Check if the server is healthy by polling the health endpoint""" + import requests + from requests.exceptions import ConnectionError + + # First, wait for the initial loading period + initial_wait = args.initial_wait # Default is 120 seconds + print(f"Waiting {initial_wait} seconds for initial model loading...") + time.sleep(initial_wait) + + print(f"Checking if server is ready at http://localhost:{args.port}/v1/models...") + + for i in range(max_retries): + try: + response = requests.get(f"http://localhost:{args.port}/v1/models", timeout=10) + if response.status_code == 200: + print("Server is ready!") + return True + except ConnectionError: + pass + except requests.exceptions.Timeout: + print("Request timed out. Server might be busy loading the model.") + + print(f"Server not ready yet, retrying in {retry_interval} seconds... ({i+1}/{max_retries})") + time.sleep(retry_interval) + + print("Server failed to start within the expected time") + return False + +if __name__ == "__main__": + # Start the vLLM server + server_process = start_vllm_server() + + # Check server health + if not check_server_health(): + print("Failed to start server, exiting") + if server_process and server_process.poll() is None: + server_process.terminate() + + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(1) + + # Keep the script running until the server exits + try: + while server_process.poll() is None: + time.sleep(1) + except KeyboardInterrupt: + signal_handler(signal.SIGINT, None) + + # Server process exited + exit_code = server_process.returncode + print(f"Server process exited with code {exit_code}") + + # Clean up + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + + if log_file: + log_file.close() + + sys.exit(exit_code) diff --git a/src/qwen_utils/start_qwen_vl_server.sh b/src/qwen_utils/start_qwen_vl_server.sh new file mode 100755 index 0000000..5908109 --- /dev/null +++ b/src/qwen_utils/start_qwen_vl_server.sh @@ -0,0 +1,167 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# start_qwen_vl_server.sh +# +# Script to start a Qwen2.5-VL server for multimodal inference. +# This script: +# 1. Initializes the environment +# 2. Starts the Qwen2.5-VL server using vLLM +# 3. Checks that the server is running and operational +# ============================================================================== + +# ── Script path handling ─────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # e.g., …/src/qwen_utils +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" # project root + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" +echo "" + +# ── Configuration ──────────────────────────────────────────────────────────── +# Qwen2.5-VL-3B-Instruct model path (TSLM) +QWEN_VL_MODEL_PATH="" # Path to Qwen2.5-VL-3B-Instruct checkpoint (or SFT/RL fine-tuned variant) + +# Server configuration +QWEN_VL_PORT=5003 +QWEN_VL_PID_FILE="/tmp/qwen_vl_server_${QWEN_VL_PORT}.pid" +export QWEN_VL_SERVER_PORT="${QWEN_VL_PORT}" + +# Device configuration +QWEN_VL_DEVICE="4,5,6,7" +QWEN_VL_DATA_PARALLEL_SIZE=4 +QWEN_VL_TENSOR_PARALLEL_SIZE=1 + +# Create log directory +LOG_DIR="${PROJECT_ROOT}/logs" +mkdir -p "$LOG_DIR" + +# Qwen2.5-VL log files +QWEN_VL_LOG="${LOG_DIR}/qwen_vl_server.$(date +%Y-%m-%d-%H-%M-%S).log" +QWEN_VL_CONSOLE_LOG="${LOG_DIR}/qwen_vl_console.$(date +%Y-%m-%d-%H-%M-%S).log" + +# ── Initialize Conda in this shell ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh. Do you need to run 'conda init'?" + exit 1 +fi +# ─────────────────────────────────────────────────────────────────────────────── + +# ===== Start Qwen2.5-VL Server ===== +echo "Starting Qwen2.5-VL server with qwen3-vllm environment..." + +# Activate environment for Qwen2.5-VL +eval "$(conda shell.bash hook)" +conda activate qwen3-vllm + +# Check if a server is already running on the port +if nc -z localhost $QWEN_VL_PORT 2>/dev/null; then + echo "Warning: Port $QWEN_VL_PORT is already in use!" + echo "Another server might be running. Check with: lsof -i :$QWEN_VL_PORT" + + # Ask if we should continue or abort + read -p "Do you want to continue anyway? [y/N] " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Aborting server startup." + exit 1 + fi +fi + +# Clear existing PID file if it exists +if [ -f "$QWEN_VL_PID_FILE" ]; then + echo "Removing existing PID file: $QWEN_VL_PID_FILE" + rm -f "$QWEN_VL_PID_FILE" +fi + +# Make server script executable +QWEN_VL_SERVER_SCRIPT="${SCRIPT_DIR}/qwen_vl_server.py" +chmod +x "$QWEN_VL_SERVER_SCRIPT" + +# Start Qwen2.5-VL server +echo "Starting Qwen2.5-VL server with log at ${QWEN_VL_LOG}" +"$QWEN_VL_SERVER_SCRIPT" \ + --model_path "${QWEN_VL_MODEL_PATH}" \ + --port "${QWEN_VL_PORT}" \ + --device "${QWEN_VL_DEVICE}" \ + --data_parallel_size "${QWEN_VL_DATA_PARALLEL_SIZE}" \ + --tensor_parallel_size "${QWEN_VL_TENSOR_PARALLEL_SIZE}" \ + --pid_file "${QWEN_VL_PID_FILE}" \ + --log_file "${QWEN_VL_LOG}" \ + --initial_wait 180 \ + > "${QWEN_VL_CONSOLE_LOG}" 2>&1 & + +QWEN_VL_SERVER_PID=$! +echo "Started Qwen2.5-VL server process with PID $QWEN_VL_SERVER_PID" + +# Wait briefly to make sure the process starts +sleep 10 + +# Check if the PID file was created +if [ -f "$QWEN_VL_PID_FILE" ]; then + FILE_PID=$(cat $QWEN_VL_PID_FILE) + echo "Qwen2.5-VL server PID file created with PID ${FILE_PID}" +else + echo "Qwen2.5-VL server PID file not created yet, writing our tracked PID" + echo $QWEN_VL_SERVER_PID > "$QWEN_VL_PID_FILE" +fi + +# Check if the server process is still running +if kill -0 $QWEN_VL_SERVER_PID 2>/dev/null; then + echo "Qwen2.5-VL server process is running" +else + echo "Error: Qwen2.5-VL server process exited unexpectedly" + echo "Check the logs:" + echo "Console log: $QWEN_VL_CONSOLE_LOG" + echo "Server log: $QWEN_VL_LOG" + exit 1 +fi + +# Wait for server initialization - a fixed time instead of relying on the server script's check +echo "Waiting for Qwen2.5-VL server to initialize (240 seconds)..." +echo "You can monitor the logs with:" +echo "tail -f ${QWEN_VL_CONSOLE_LOG}" +echo "tail -f ${QWEN_VL_LOG}" +sleep 240 # 4 minute initial wait + +# ===== Test Server Connectivity ===== +echo "Testing Qwen2.5-VL server connectivity..." +python -c " +from openai import OpenAI +client = OpenAI(base_url='http://localhost:${QWEN_VL_PORT}/v1', api_key='dummy-key') +try: + response = client.models.list() + print(f'Qwen2.5-VL models available: {response}') + print('Qwen2.5-VL server is operational!') + exit(0) +except Exception as e: + print(f'Error testing Qwen2.5-VL server: {e}') + exit(1) +" +QWEN_VL_TEST_EXIT_CODE=$? + +if [ $QWEN_VL_TEST_EXIT_CODE -ne 0 ]; then + echo "Error: Qwen2.5-VL server test failed." + echo "Check the logs:" + echo "Console log: $QWEN_VL_CONSOLE_LOG" + echo "Server log: $QWEN_VL_LOG" + exit 1 +else + echo "Qwen2.5-VL server test passed successfully!" + echo "The server is running on port $QWEN_VL_PORT" + echo "To stop the server later, run: $SCRIPT_DIR/stop_qwen_vl_server.sh" +fi + +echo "" +echo "========================================" +echo "Qwen2.5-VL Server is ready for inference!" +echo "Server URL: http://localhost:$QWEN_VL_PORT" +echo "========================================" diff --git a/src/qwen_utils/stop_qwen_vl_server.sh b/src/qwen_utils/stop_qwen_vl_server.sh new file mode 100755 index 0000000..402132a --- /dev/null +++ b/src/qwen_utils/stop_qwen_vl_server.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# stop_qwen_vl_server.sh +# +# Script to stop a running Qwen2.5-VL server. +# This script: +# 1. Finds the PID file for the Qwen2.5-VL server +# 2. Sends a SIGTERM signal to gracefully shut down the server +# ============================================================================== + +# Qwen2.5-VL server PID file location +QWEN_VL_PORT=5003 +QWEN_VL_PID_FILE="/tmp/qwen_vl_server_${QWEN_VL_PORT}.pid" + +echo "Stopping Qwen2.5-VL server..." + +# Check if PID file exists +if [ ! -f "$QWEN_VL_PID_FILE" ]; then + echo "No PID file found at $QWEN_VL_PID_FILE" + + # Check if there's a process listening on the Qwen2.5-VL port + if nc -z localhost $QWEN_VL_PORT 2>/dev/null; then + echo "Warning: Port $QWEN_VL_PORT is in use but no PID file exists." + echo "Finding processes using port $QWEN_VL_PORT:" + + # Find and display processes using the port + if command -v lsof &> /dev/null; then + PROCS=$(lsof -i :$QWEN_VL_PORT -t) + if [ -n "$PROCS" ]; then + echo "Found processes: $PROCS" + + # Ask if we should kill these processes + read -p "Do you want to kill these processes? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Killing processes using port $QWEN_VL_PORT" + for pid in $PROCS; do + echo "Killing process $pid" + kill -9 $pid + done + else + echo "No processes were killed. Please manually stop the server." + fi + else + echo "No processes found using lsof." + fi + else + echo "lsof command not available, cannot find processes by port." + fi + else + echo "No process is listening on port $QWEN_VL_PORT" + fi + + exit 0 +fi + +# Read PID from file +PID=$(cat $QWEN_VL_PID_FILE) +echo "Found Qwen2.5-VL server with PID: $PID" + +# Check if the process exists +if kill -0 $PID 2>/dev/null; then + echo "Sending SIGTERM to PID $PID" + kill -15 $PID + + # Wait for the process to terminate + echo "Waiting for server to shut down..." + for i in {1..30}; do + if ! kill -0 $PID 2>/dev/null; then + echo "Server shut down successfully." + break + fi + sleep 1 + done + + # If process still exists, force kill + if kill -0 $PID 2>/dev/null; then + echo "Server did not shut down gracefully, sending SIGKILL..." + kill -9 $PID + sleep 2 + fi +else + echo "Process with PID $PID does not exist or is not accessible." +fi + +# Remove PID file +if [ -f "$QWEN_VL_PID_FILE" ]; then + echo "Removing PID file: $QWEN_VL_PID_FILE" + rm -f "$QWEN_VL_PID_FILE" +fi + +# Final check +if nc -z localhost $QWEN_VL_PORT 2>/dev/null; then + echo "Warning: Port $QWEN_VL_PORT is still in use after stopping the server." + echo "You might need to manually kill the remaining processes." +else + echo "Port $QWEN_VL_PORT is now free." +fi + +echo "Qwen2.5-VL server stop script completed." diff --git a/src/r1_utils/r1_server.py b/src/r1_utils/r1_server.py new file mode 100755 index 0000000..9db6f5f --- /dev/null +++ b/src/r1_utils/r1_server.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +""" +DeepSeek-R1 Server + +This script runs a vLLM server with OpenAI-compatible API for DeepSeek-R1 inference. +DeepSeek-R1-Distill-Qwen-32B is a text-only reasoning model served via vLLM, +used as a GRLM for knowledge injection (same API pattern as Qwen3). +""" + +import os +import sys +import time +import signal +import argparse +import subprocess +from pathlib import Path + +parser = argparse.ArgumentParser(description="DeepSeek-R1 Server") +parser.add_argument("--model_path", type=str, required=True, help="Path to DeepSeek-R1 model") +parser.add_argument("--port", type=int, default=5002, help="Port to run server on") +parser.add_argument("--device", type=str, default="0,1,2,3", help="GPU device IDs") +parser.add_argument("--data_parallel_size", type=int, default=1, help="Data parallel size") +parser.add_argument("--tensor_parallel_size", type=int, default=4, help="Tensor parallel size") +parser.add_argument("--context_length", type=int, default=56320, help="Max context length") +parser.add_argument("--pid_file", type=str, default="/tmp/r1_server.pid", help="File to store server PID") +parser.add_argument("--log_file", type=str, default=None, help="File to log server output") +parser.add_argument("--initial_wait", type=int, default=120, help="Initial wait time in seconds") +parser.add_argument("--chat_template", type=str, default=None, help="Path to custom chat template file") + +args = parser.parse_args() + +print("Arguments received:") +for arg in vars(args): + print(f" {arg}: {getattr(args, arg)}") + +os.environ["CUDA_VISIBLE_DEVICES"] = args.device +print(f"Using CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") + +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_USE_V1"] = "1" + +try: + import vllm + print("vLLM package found.") + subprocess.run(["vllm", "--version"], capture_output=True, check=False) +except ImportError: + print("Error: vLLM is not installed.") + sys.exit(1) +except (subprocess.CalledProcessError, FileNotFoundError): + print("Warning: vLLM CLI tool issue. Continuing anyway...") + +if args.log_file: + log_dir = os.path.dirname(args.log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_file = open(args.log_file, 'w') +else: + log_file = None + +with open(args.pid_file, "w") as f: + f.write(str(os.getpid())) +print(f"Server PID {os.getpid()} written to {args.pid_file}") + +def signal_handler(sig, frame): + print(f"Received signal {sig}, shutting down...") + if server_process and server_process.poll() is None: + server_process.terminate() + server_process.wait(timeout=10) + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + if log_file: + log_file.close() + sys.exit(0) + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +def start_vllm_server(): + env = os.environ.copy() + cmd = [ + "vllm", "serve", args.model_path, + "--served-model-name", "r1", + "--trust-remote-code", + "--max-model-len", str(args.context_length), + "--gpu-memory-utilization", "0.95", + # Uncomment for RoPE scaling when context exceeds 32k (e.g., TSEvol benchmark): + # "--rope-scaling", '{"rope_type":"yarn","factor":2.0,"original_max_position_embeddings":32768}', + "--host", "0.0.0.0", + "--port", str(args.port), + "--uvicorn-log-level", "debug", + "--data-parallel-size", str(args.data_parallel_size), + "--tensor-parallel-size", str(args.tensor_parallel_size), + ] + if args.chat_template: + chat_template_path = os.path.abspath(args.chat_template) + if os.path.exists(chat_template_path): + cmd.extend(["--chat-template", chat_template_path]) + print(f"Using custom chat template from: {chat_template_path}") + else: + print(f"Warning: Chat template '{chat_template_path}' not found.") + + print(f"Starting vLLM server with command: {' '.join(cmd)}") + return subprocess.Popen(cmd, env=env, stdout=log_file, stderr=log_file if log_file else subprocess.STDOUT) + +def check_server_health(max_retries=60, retry_interval=5): + import requests + from requests.exceptions import ConnectionError + print(f"Waiting {args.initial_wait} seconds for initial model loading...") + time.sleep(args.initial_wait) + print(f"Checking if server is ready at http://localhost:{args.port}/v1/models...") + for i in range(max_retries): + try: + response = requests.get(f"http://localhost:{args.port}/v1/models", timeout=10) + if response.status_code == 200: + print("Server is ready!") + return True + except (ConnectionError, requests.exceptions.Timeout): + pass + print(f"Server not ready yet, retrying... ({i+1}/{max_retries})") + time.sleep(retry_interval) + print("Server failed to start within the expected time") + return False + +if __name__ == "__main__": + server_process = start_vllm_server() + if not check_server_health(): + print("Failed to start server, exiting") + if server_process and server_process.poll() is None: + server_process.terminate() + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + if log_file: + log_file.close() + sys.exit(1) + + try: + while server_process.poll() is None: + time.sleep(1) + except KeyboardInterrupt: + signal_handler(signal.SIGINT, None) + + exit_code = server_process.returncode + print(f"Server process exited with code {exit_code}") + if os.path.exists(args.pid_file): + os.remove(args.pid_file) + if log_file: + log_file.close() + sys.exit(exit_code) diff --git a/src/r1_utils/simple_chat_template.jinja b/src/r1_utils/simple_chat_template.jinja new file mode 100644 index 0000000..e717960 --- /dev/null +++ b/src/r1_utils/simple_chat_template.jinja @@ -0,0 +1,22 @@ +{%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} +{%- endif %} + +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + + {%- if message.role == "user" %} + {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {# Just pass through the assistant message without modifying it #} + {{- '<|im_start|>assistant\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/src/r1_utils/start_r1_server.sh b/src/r1_utils/start_r1_server.sh new file mode 100755 index 0000000..d04e209 --- /dev/null +++ b/src/r1_utils/start_r1_server.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# start_r1_server.sh +# +# Script to start a DeepSeek-R1-Distill-Qwen-32B server for text-only reasoning. +# Uses the same chat template and continue_final_message injection as Qwen3. +# ============================================================================== + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +echo "========================================" +echo "SCRIPT_DIR = $SCRIPT_DIR" +echo "PROJECT_ROOT = $PROJECT_ROOT" +echo "========================================" + +# ── Configuration ──────────────────────────────────────────────────────────── +# DeepSeek-R1-Distill-Qwen-32B model path (GRLM) +R1_MODEL_PATH="" # Path to DeepSeek-R1-Distill-Qwen-32B checkpoint + +R1_PORT=5002 +R1_PID_FILE="/tmp/r1_server_${R1_PORT}.pid" +export R1_SERVER_PORT="${R1_PORT}" + +R1_DEVICE="0,1,2,3" +R1_DATA_PARALLEL_SIZE=1 +R1_TENSOR_PARALLEL_SIZE=4 + +R1_CHAT_TEMPLATE="${SCRIPT_DIR}/simple_chat_template.jinja" +echo "Using chat template: ${R1_CHAT_TEMPLATE}" +[ -f "${R1_CHAT_TEMPLATE}" ] && echo "Chat template file exists" || echo "WARNING: Chat template file does not exist!" + +LOG_DIR="${PROJECT_ROOT}/logs" +mkdir -p "$LOG_DIR" +R1_LOG="${LOG_DIR}/r1_server.$(date +%Y-%m-%d-%H-%M-%S).log" +R1_CONSOLE_LOG="${LOG_DIR}/r1_console.$(date +%Y-%m-%d-%H-%M-%S).log" + +# ── Initialize Conda ──────────── +export MKL_INTERFACE_LAYER=${MKL_INTERFACE_LAYER:-LP64} +if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +elif [ -f "$(conda info --base)/etc/profile.d/conda.sh" ]; then + source "$(conda info --base)/etc/profile.d/conda.sh" +else + echo "ERROR: Cannot find conda.sh." + exit 1 +fi + +echo "Starting DeepSeek-R1 server with qwen3-vllm environment..." +eval "$(conda shell.bash hook)" +conda activate qwen3-vllm + +if nc -z localhost $R1_PORT 2>/dev/null; then + echo "Warning: Port $R1_PORT is already in use!" + read -p "Continue anyway? [y/N] " -n 1 -r; echo + [[ ! $REPLY =~ ^[Yy]$ ]] && exit 1 +fi + +[ -f "$R1_PID_FILE" ] && rm -f "$R1_PID_FILE" + +R1_SERVER_SCRIPT="${SCRIPT_DIR}/r1_server.py" +chmod +x "$R1_SERVER_SCRIPT" + +echo "Starting DeepSeek-R1 server with log at ${R1_LOG}" +"$R1_SERVER_SCRIPT" \ + --model_path "${R1_MODEL_PATH}" \ + --port "${R1_PORT}" \ + --device "${R1_DEVICE}" \ + --data_parallel_size "${R1_DATA_PARALLEL_SIZE}" \ + --tensor_parallel_size "${R1_TENSOR_PARALLEL_SIZE}" \ + --pid_file "${R1_PID_FILE}" \ + --log_file "${R1_LOG}" \ + --chat_template "${R1_CHAT_TEMPLATE}" \ + --context_length 56320 \ + --initial_wait 180 \ + > "${R1_CONSOLE_LOG}" 2>&1 & + +R1_SERVER_PID=$! +echo "Started DeepSeek-R1 server process with PID $R1_SERVER_PID" + +sleep 10 + +if [ -f "$R1_PID_FILE" ]; then + echo "PID file created with PID $(cat $R1_PID_FILE)" +else + echo $R1_SERVER_PID > "$R1_PID_FILE" +fi + +kill -0 $R1_SERVER_PID 2>/dev/null || { echo "Error: Server exited unexpectedly. Check $R1_CONSOLE_LOG"; exit 1; } + +echo "Waiting for DeepSeek-R1 server to initialize (240 seconds)..." +echo "Monitor: tail -f ${R1_CONSOLE_LOG}" +sleep 240 + +echo "Testing DeepSeek-R1 server connectivity..." +python -c " +from openai import OpenAI +client = OpenAI(base_url='http://localhost:${R1_PORT}/v1', api_key='dummy-key') +try: + response = client.models.list() + print(f'DeepSeek-R1 models available: {response}') + print('DeepSeek-R1 server is operational!') + exit(0) +except Exception as e: + print(f'Error testing DeepSeek-R1 server: {e}') + exit(1) +" +[ $? -ne 0 ] && { echo "Error: Server test failed. Check $R1_CONSOLE_LOG"; exit 1; } + +echo "" +echo "========================================" +echo "DeepSeek-R1 Server is ready for inference!" +echo "Server URL: http://localhost:$R1_PORT" +echo "To stop: $SCRIPT_DIR/stop_r1_server.sh" +echo "========================================" diff --git a/src/r1_utils/stop_r1_server.sh b/src/r1_utils/stop_r1_server.sh new file mode 100755 index 0000000..85a90a3 --- /dev/null +++ b/src/r1_utils/stop_r1_server.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================================== +# stop_r1_server.sh — Stop a running DeepSeek-R1 server. +# ============================================================================== + +R1_PORT=5002 +R1_PID_FILE="/tmp/r1_server_${R1_PORT}.pid" + +echo "Stopping DeepSeek-R1 server..." + +if [ ! -f "$R1_PID_FILE" ]; then + echo "No PID file found at $R1_PID_FILE" + if nc -z localhost $R1_PORT 2>/dev/null; then + echo "Warning: Port $R1_PORT is in use but no PID file exists." + if command -v lsof &> /dev/null; then + PROCS=$(lsof -i :$R1_PORT -t) + if [ -n "$PROCS" ]; then + echo "Found processes: $PROCS" + read -p "Kill these processes? [y/N] " -n 1 -r; echo + [[ $REPLY =~ ^[Yy]$ ]] && for pid in $PROCS; do kill -9 $pid; done + fi + fi + else + echo "No process is listening on port $R1_PORT" + fi + exit 0 +fi + +PID=$(cat $R1_PID_FILE) +echo "Found DeepSeek-R1 server with PID: $PID" + +if kill -0 $PID 2>/dev/null; then + echo "Sending SIGTERM to PID $PID" + kill -15 $PID + for i in {1..30}; do + kill -0 $PID 2>/dev/null || { echo "Server shut down successfully."; break; } + sleep 1 + done + kill -0 $PID 2>/dev/null && { echo "Force killing..."; kill -9 $PID; sleep 2; } +else + echo "Process with PID $PID does not exist." +fi + +[ -f "$R1_PID_FILE" ] && rm -f "$R1_PID_FILE" && echo "Removed PID file." + +nc -z localhost $R1_PORT 2>/dev/null && echo "Warning: Port $R1_PORT still in use." || echo "Port $R1_PORT is now free." +echo "DeepSeek-R1 server stop script completed."