walidsobhie-code commited on
Commit ·
183b3b6
1
Parent(s): a8f2981
feat: add evaluation scripts, tool calling data generator, and 7B training configs
Browse filesCreated:
- test_model.py: 602 lines - basic code generation tests
- evaluate_model.py: 847 lines - HumanEval/MBPP benchmark evaluation with pass@k
- scripts/generate_tool_data.py: 615 lines - synthetic tool_calling data generator
- training-data/tool_examples.jsonl: 1000 synthetic tool_calling examples
- training-configs/7b-lora-config.yaml: A100 80GB full LoRA config
- training-configs/kaggle-7b-qlora.sh: QLoRA script for 7B on Kaggle T4
- docs/TRAINING_7B.md: 7B training documentation with VRAM estimates
Updated:
- merge_simple.py: Improved merge logic
- .gitignore: Clean ignores
Note: tool_examples.jsonl has 1000 synthetic tool_calling examples ready for fine-tuning
- docs/TRAINING_7B.md +189 -0
- evaluate_model.py +736 -250
- kaggle_train_stack29.ipynb +0 -288
- kaggle_train_stack29_final.ipynb +0 -198
- merge_simple.py +1 -1
- scripts/generate_tool_data.py +615 -0
- test_model.py +559 -110
- train_simple.py +0 -197
- training-configs/7b-lora-config.yaml +104 -0
- training-configs/kaggle-7b-qlora.sh +79 -0
- training-data/tool_examples.json +0 -0
- training-data/tool_examples.jsonl +3 -0
docs/TRAINING_7B.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Stack 2.9 on Qwen2.5-Coder-7B
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This guide covers training Stack 2.9 on the Qwen2.5-Coder-7B model using LoRA/QLoRA fine-tuning.
|
| 6 |
+
|
| 7 |
+
## Hardware Requirements
|
| 8 |
+
|
| 9 |
+
### Minimum (QLoRA - 4-bit)
|
| 10 |
+
| GPU | VRAM | Batch Size | Notes |
|
| 11 |
+
|-----|------|-----------|-------|
|
| 12 |
+
| T4 (Colab) | 15GB | 1 | Gradient accu = 16 |
|
| 13 |
+
| P100 (Kaggle) | 16GB | 1 | Gradient accu = 8 |
|
| 14 |
+
| RTX 3090 | 24GB | 2 | Full performance |
|
| 15 |
+
|
| 16 |
+
### Recommended (Full LoRA - bf16)
|
| 17 |
+
| GPU | VRAM | Batch Size | Notes |
|
| 18 |
+
|-----|------|-----------|-------|
|
| 19 |
+
| A100 40GB | 40GB | 2 | 2x for better throughput |
|
| 20 |
+
| A100 80GB | 80GB | 4 | Best for production |
|
| 21 |
+
| H100 80GB | 80GB | 4 | Next-gen option |
|
| 22 |
+
|
| 23 |
+
## VRAM Estimates
|
| 24 |
+
|
| 25 |
+
| Configuration | Batch Size | Gradient Checkpoint | Est. VRAM |
|
| 26 |
+
|--------------|-----------|-------------------|----------|
|
| 27 |
+
| Full bf16 | 1 | No | 14GB |
|
| 28 |
+
| Full bf16 | 2 | Yes | 16GB |
|
| 29 |
+
| Full bf16 | 4 | Yes | 22GB |
|
| 30 |
+
| QLoRA (4-bit) | 1 | Yes | 5-6GB |
|
| 31 |
+
| QLoRA (4-bit) | 2 | Yes | 7-8GB |
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
### Option 1: Kaggle (QLoRA)
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
cd /kaggle/working/stack-2.9
|
| 39 |
+
chmod +x training-configs/kaggle-7b-qlora.sh
|
| 40 |
+
./training-configs/kaggle-7b-qlora.sh
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Option 2: Local (Full LoRA)
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
cd /path/to/stack-2.9
|
| 47 |
+
python train_local.py \
|
| 48 |
+
--config training-configs/7b-lora-config.yaml
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Option 3: Custom Training Script
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
| 55 |
+
from peft import LoraConfig, get_peft_model
|
| 56 |
+
|
| 57 |
+
# Load model
|
| 58 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 59 |
+
"Qwen/Qwen2.5-Coder-7B",
|
| 60 |
+
torch_dtype="bfloat16",
|
| 61 |
+
device_map="auto"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# LoRA config
|
| 65 |
+
lora_config = LoraConfig(
|
| 66 |
+
r=16,
|
| 67 |
+
lora_alpha=32,
|
| 68 |
+
lora_dropout=0.05,
|
| 69 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 70 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 71 |
+
bias="none",
|
| 72 |
+
task_type="CAUSAL_LM"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Apply LoRA
|
| 76 |
+
model = get_peft_model(model, lora_config)
|
| 77 |
+
model.print_trainable_parameters()
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Configuration Reference
|
| 81 |
+
|
| 82 |
+
### LoRA Parameters
|
| 83 |
+
```yaml
|
| 84 |
+
lora:
|
| 85 |
+
r: 16 # Rank (8-32 recommended for 7B)
|
| 86 |
+
alpha: 32 # Usually 2*r
|
| 87 |
+
dropout: 0.05
|
| 88 |
+
target_modules:
|
| 89 |
+
- q_proj
|
| 90 |
+
- k_proj
|
| 91 |
+
- v_proj
|
| 92 |
+
- o_proj
|
| 93 |
+
- gate_proj
|
| 94 |
+
- up_proj
|
| 95 |
+
- down_proj
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Training Parameters
|
| 99 |
+
```yaml
|
| 100 |
+
training:
|
| 101 |
+
num_epochs: 3
|
| 102 |
+
batch_size: 2 # A100: 2-4, T4/P100: 1
|
| 103 |
+
gradient_accumulation: 8
|
| 104 |
+
learning_rate: 1.0e-4
|
| 105 |
+
warmup_steps: 100
|
| 106 |
+
gradient_checkpointing: true
|
| 107 |
+
bf16: true
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Expected Training Time
|
| 111 |
+
|
| 112 |
+
Based on ~10K samples, max_length=4096:
|
| 113 |
+
|
| 114 |
+
| Hardware | Config | Est. Time |
|
| 115 |
+
|----------|--------|----------|
|
| 116 |
+
| T4 | 4-bit QLoRA | 4-6 hours |
|
| 117 |
+
| P100 | 4-bit QLoRA | 2-3 hours |
|
| 118 |
+
| A100 40GB | bf16 LoRA | 30-45 min |
|
| 119 |
+
| A100 80GB | bf16 LoRA | 20-30 min |
|
| 120 |
+
|
| 121 |
+
Times scale linearly with dataset size.
|
| 122 |
+
|
| 123 |
+
## After Training
|
| 124 |
+
|
| 125 |
+
### Merge LoRA Adapter
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
from peft import PeftModel
|
| 129 |
+
from transformers import AutoTokenizer
|
| 130 |
+
|
| 131 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 132 |
+
"Qwen/Qwen2.5-Coder-7B",
|
| 133 |
+
torch_dtype="bfloat16"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Merge adapter
|
| 137 |
+
model = PeftModel.from_pretrained(base_model, "./output/lora")
|
| 138 |
+
merged = model.merge_and_unload()
|
| 139 |
+
|
| 140 |
+
# Save
|
| 141 |
+
merged.save_pretrained("./output/merged")
|
| 142 |
+
tokenizer.save_pretrained("./output/merged")
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Test the Model
|
| 146 |
+
|
| 147 |
+
```python
|
| 148 |
+
from transformers import AutoTokenizer, pipeline
|
| 149 |
+
|
| 150 |
+
tokenizer = AutoTokenizer.from_pretrained("./output/merged")
|
| 151 |
+
pipe = pipeline("text-generation", model=merged, tokenizer=tokenizer)
|
| 152 |
+
|
| 153 |
+
result = pipe("def quick_sort(arr):", max_new_tokens=100)
|
| 154 |
+
print(result[0]["generated_text"])
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## Troubleshooting
|
| 158 |
+
|
| 159 |
+
### OOM (Out of Memory)
|
| 160 |
+
- Reduce `batch_size` to 1
|
| 161 |
+
- Enable `gradient_checkpointing: true`
|
| 162 |
+
- Reduce `max_length` (4096 → 2048)
|
| 163 |
+
- Enable 4-bit quantization
|
| 164 |
+
|
| 165 |
+
### Training Slow
|
| 166 |
+
- Increase batch size if VRAM allows
|
| 167 |
+
- Enable `use_flash_attention: true` (A100/H100)
|
| 168 |
+
- Reduce gradient accumulation
|
| 169 |
+
|
| 170 |
+
### Loss Not Converging
|
| 171 |
+
- Check learning rate (try 5e-5 or 2e-4)
|
| 172 |
+
- Increase epochs (3 → 5)
|
| 173 |
+
- Verify data format matches expected template
|
| 174 |
+
|
| 175 |
+
## Alternative: RunPod /cloud Deployment
|
| 176 |
+
|
| 177 |
+
For faster training, see `runpod_deploy.sh` at repo root.
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
# Example: RunPod A100
|
| 181 |
+
bash runpod_deploy.sh --gpu a100 --instance $ hourly
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## Notes
|
| 185 |
+
|
| 186 |
+
- **A100 recommended**: Best balance of VRAM and speed
|
| 187 |
+
- **4-bit QLoRA**: Use only if VRAM < 20GB, slightly reduces quality
|
| 188 |
+
- **Gradient checkpointing**: Always enable, minimal perf impact for big memory savings
|
| 189 |
+
- **Flash Attention**: A100/H100 only, significant speed boost
|
evaluate_model.py
CHANGED
|
@@ -2,359 +2,845 @@
|
|
| 2 |
"""
|
| 3 |
HumanEval + MBPP Benchmark Evaluation for Stack 2.9
|
| 4 |
Tests code generation quality using pass@k metrics.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import argparse
|
| 8 |
import os
|
| 9 |
import json
|
| 10 |
import time
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
import torch
|
| 13 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
|
| 15 |
|
| 16 |
-
def load_model(model_path: str):
|
| 17 |
"""Load the fine-tuned model and tokenizer."""
|
| 18 |
print(f"Loading model from: {model_path}")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
return model, tokenizer
|
| 27 |
|
| 28 |
|
| 29 |
-
def generate_solution(
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if "```python" in completion:
|
| 51 |
start = completion.find("```python") + len("```python")
|
| 52 |
end = completion.find("```", start)
|
| 53 |
if end != -1:
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
start = completion.find("```") + len("```")
|
| 57 |
end = completion.find("```", start)
|
| 58 |
if end != -1:
|
| 59 |
-
|
| 60 |
|
| 61 |
-
return
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
-
def
|
| 65 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
-
#
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
exec(code, namespace)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
result = namespace['solution']()
|
| 74 |
-
return result == expected_output
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
except Exception as e:
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
-
def
|
| 83 |
-
"""
|
| 84 |
-
print("\n" + "="*60)
|
| 85 |
-
print("Evaluating on HumanEval")
|
| 86 |
-
print("="*60)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
{
|
| 91 |
-
"task_id": "
|
| 92 |
-
"prompt":
|
| 93 |
-
|
| 94 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
},
|
| 96 |
{
|
| 97 |
-
"task_id": "
|
| 98 |
-
"prompt":
|
| 99 |
-
|
| 100 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
},
|
| 102 |
{
|
| 103 |
-
"task_id": "
|
| 104 |
-
"prompt":
|
| 105 |
-
|
| 106 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
},
|
| 108 |
{
|
| 109 |
-
"task_id": "
|
| 110 |
-
"prompt":
|
| 111 |
-
|
| 112 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
},
|
| 114 |
{
|
| 115 |
-
"task_id": "
|
| 116 |
-
"prompt":
|
| 117 |
-
|
| 118 |
-
"
|
|
|
|
|
|
|
|
|
|
| 119 |
},
|
| 120 |
{
|
| 121 |
-
"task_id": "
|
| 122 |
-
"prompt":
|
| 123 |
-
|
| 124 |
-
"
|
|
|
|
|
|
|
|
|
|
| 125 |
},
|
| 126 |
{
|
| 127 |
-
"task_id": "
|
| 128 |
-
"prompt":
|
| 129 |
-
|
| 130 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
},
|
| 132 |
{
|
| 133 |
-
"task_id": "
|
| 134 |
-
"prompt":
|
| 135 |
-
|
| 136 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
},
|
| 138 |
{
|
| 139 |
-
"task_id": "
|
| 140 |
-
"prompt":
|
| 141 |
-
|
| 142 |
-
"
|
|
|
|
|
|
|
|
|
|
| 143 |
},
|
| 144 |
{
|
| 145 |
-
"task_id": "
|
| 146 |
-
"prompt":
|
| 147 |
-
|
| 148 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
},
|
| 150 |
]
|
| 151 |
-
|
| 152 |
-
# Limit to num_samples
|
| 153 |
-
problems = humaneval_problems[:num_samples]
|
| 154 |
-
|
| 155 |
-
results = []
|
| 156 |
-
for i, problem in enumerate(problems):
|
| 157 |
-
print(f"\nProblem {i+1}/{len(problems)}: {problem['task_id']}")
|
| 158 |
-
print(f"Prompt: {problem['prompt'][:50]}...")
|
| 159 |
-
|
| 160 |
-
start = time.time()
|
| 161 |
-
solution = generate_solution(model, tokenizer, problem['prompt'])
|
| 162 |
-
elapsed = time.time() - start
|
| 163 |
-
|
| 164 |
-
print(f"Generated in {elapsed:.2f}s")
|
| 165 |
-
print(f"Solution preview: {solution[:100]}...")
|
| 166 |
-
|
| 167 |
-
# Try to execute the solution
|
| 168 |
-
correct = check_correctness(solution)
|
| 169 |
-
results.append({
|
| 170 |
-
"task_id": problem["task_id"],
|
| 171 |
-
"solution": solution,
|
| 172 |
-
"correct": correct,
|
| 173 |
-
"time": elapsed,
|
| 174 |
-
})
|
| 175 |
-
|
| 176 |
-
print(f"Result: {'✅ CORRECT' if correct else '❌ INCORRECT'}")
|
| 177 |
-
|
| 178 |
-
# Calculate pass@k
|
| 179 |
-
passed = sum(1 for r in results if r['correct'])
|
| 180 |
-
total = len(results)
|
| 181 |
-
|
| 182 |
-
print("\n" + "="*60)
|
| 183 |
-
print("HumanEval Results")
|
| 184 |
-
print("="*60)
|
| 185 |
-
print(f"Total: {total}")
|
| 186 |
-
print(f"Passed: {passed}")
|
| 187 |
-
print(f"Pass@1: {100 * passed / total:.1f}%")
|
| 188 |
-
|
| 189 |
-
return {
|
| 190 |
-
"total": total,
|
| 191 |
-
"passed": passed,
|
| 192 |
-
"pass_at_1": passed / total if total > 0 else 0,
|
| 193 |
-
"results": results,
|
| 194 |
-
}
|
| 195 |
|
| 196 |
|
| 197 |
-
def
|
| 198 |
-
"""
|
| 199 |
-
|
| 200 |
-
print("Evaluating on MBPP")
|
| 201 |
-
print("="*60)
|
| 202 |
-
|
| 203 |
-
# MBPP problems (sample)
|
| 204 |
-
mbpp_problems = [
|
| 205 |
{
|
| 206 |
-
"task_id": "
|
| 207 |
-
"prompt":
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
},
|
| 211 |
{
|
| 212 |
-
"task_id": "
|
| 213 |
-
"prompt":
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
},
|
| 217 |
{
|
| 218 |
-
"task_id": "
|
| 219 |
-
"prompt":
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
},
|
| 223 |
{
|
| 224 |
-
"task_id": "
|
| 225 |
-
"prompt":
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
},
|
| 229 |
{
|
| 230 |
-
"task_id": "
|
| 231 |
-
"prompt":
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
-
"task_id": "
|
| 237 |
-
"prompt":
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
},
|
| 241 |
{
|
| 242 |
-
"task_id": "
|
| 243 |
-
"prompt":
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
},
|
| 247 |
{
|
| 248 |
-
"task_id": "
|
| 249 |
-
"prompt":
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
},
|
| 253 |
{
|
| 254 |
-
"task_id": "
|
| 255 |
-
"prompt":
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
},
|
| 259 |
{
|
| 260 |
-
"task_id": "
|
| 261 |
-
"prompt":
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
},
|
| 265 |
]
|
| 266 |
-
|
| 267 |
-
problems = mbpp_problems[:num_samples]
|
| 268 |
-
|
| 269 |
-
results = []
|
| 270 |
-
for i, problem in enumerate(problems):
|
| 271 |
-
print(f"\nProblem {i+1}/{len(problems)}: {problem['task_id']}")
|
| 272 |
-
print(f"Prompt: {problem['prompt'][:50]}...")
|
| 273 |
-
|
| 274 |
-
start = time.time()
|
| 275 |
-
solution = generate_solution(model, tokenizer, problem['prompt'])
|
| 276 |
-
elapsed = time.time() - start
|
| 277 |
-
|
| 278 |
-
print(f"Generated in {elapsed:.2f}s")
|
| 279 |
-
print(f"Solution preview: {solution[:100]}...")
|
| 280 |
-
|
| 281 |
-
correct = check_correctness(solution)
|
| 282 |
-
results.append({
|
| 283 |
-
"task_id": problem["task_id"],
|
| 284 |
-
"solution": solution,
|
| 285 |
-
"correct": correct,
|
| 286 |
-
"time": elapsed,
|
| 287 |
-
})
|
| 288 |
-
|
| 289 |
-
print(f"Result: {'✅ CORRECT' if correct else '❌ INCORRECT'}")
|
| 290 |
-
|
| 291 |
-
passed = sum(1 for r in results if r['correct'])
|
| 292 |
-
total = len(results)
|
| 293 |
-
|
| 294 |
-
print("\n" + "="*60)
|
| 295 |
-
print("MBPP Results")
|
| 296 |
-
print("="*60)
|
| 297 |
-
print(f"Total: {total}")
|
| 298 |
-
print(f"Passed: {passed}")
|
| 299 |
-
print(f"Pass@1: {100 * passed / total:.1f}%")
|
| 300 |
-
|
| 301 |
-
return {
|
| 302 |
-
"total": total,
|
| 303 |
-
"passed": passed,
|
| 304 |
-
"pass_at_1": passed / total if total > 0 else 0,
|
| 305 |
-
"results": results,
|
| 306 |
-
}
|
| 307 |
|
| 308 |
|
| 309 |
-
def save_results(
|
| 310 |
-
"""Save evaluation results to JSON."""
|
| 311 |
-
combined = {
|
| 312 |
-
"humaneval": humaneval_results,
|
| 313 |
-
"mbpp": mbpp_results,
|
| 314 |
-
"summary": {
|
| 315 |
-
"humaneval_pass_at_1": humaneval_results["pass_at_1"],
|
| 316 |
-
"mbpp_pass_at_1": mbpp_results["pass_at_1"],
|
| 317 |
-
"combined_pass_at_1": (
|
| 318 |
-
humaneval_results["pass_at_1"] + mbpp_results["pass_at_1"]
|
| 319 |
-
) / 2,
|
| 320 |
-
}
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
with open(output_path, 'w') as f:
|
| 324 |
-
json.dump(
|
| 325 |
-
|
| 326 |
print(f"\n✅ Results saved to: {output_path}")
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
def main():
|
| 331 |
-
parser = argparse.ArgumentParser(
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
args = parser.parse_args()
|
| 336 |
|
|
|
|
|
|
|
|
|
|
| 337 |
print("="*60)
|
| 338 |
print("Stack 2.9 Model Evaluation")
|
| 339 |
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
|
|
|
| 341 |
model, tokenizer = load_model(args.model_path)
|
| 342 |
model.eval()
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
mbpp_results = evaluate_mbpp(model, tokenizer, args.num_samples)
|
| 347 |
|
| 348 |
-
#
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
|
|
|
| 351 |
print("\n" + "="*60)
|
| 352 |
print("FINAL SUMMARY")
|
| 353 |
print("="*60)
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
if __name__ == "__main__":
|
|
|
|
| 2 |
"""
|
| 3 |
HumanEval + MBPP Benchmark Evaluation for Stack 2.9
|
| 4 |
Tests code generation quality using pass@k metrics.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python evaluate_model.py --model-path /path/to/merged/model --num-samples 10
|
| 8 |
+
python evaluate_model.py --model-path /path/to/merged/model --output results.json
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
| 12 |
import os
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
+
import traceback
|
| 16 |
+
from typing import List, Dict, Tuple, Optional
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
import itertools
|
| 19 |
import torch
|
| 20 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
|
| 22 |
|
| 23 |
+
def load_model(model_path: str, max_memory: Optional[Dict] = None):
|
| 24 |
"""Load the fine-tuned model and tokenizer."""
|
| 25 |
print(f"Loading model from: {model_path}")
|
| 26 |
+
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 28 |
+
if tokenizer.pad_token is None:
|
| 29 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 30 |
+
|
| 31 |
+
kwargs = {
|
| 32 |
+
"torch_dtype": torch.float16,
|
| 33 |
+
"device_map": "auto",
|
| 34 |
+
"low_cpu_mem_usage": True,
|
| 35 |
+
"trust_remote_code": True,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
if max_memory:
|
| 39 |
+
kwargs["max_memory"] = max_memory
|
| 40 |
+
|
| 41 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 42 |
return model, tokenizer
|
| 43 |
|
| 44 |
|
| 45 |
+
def generate_solution(
|
| 46 |
+
model,
|
| 47 |
+
tokenizer,
|
| 48 |
+
prompt: str,
|
| 49 |
+
max_new_tokens: int = 256,
|
| 50 |
+
temperature: float = 0.8,
|
| 51 |
+
top_p: float = 0.95,
|
| 52 |
+
num_return_sequences: int = 1
|
| 53 |
+
) -> List[str]:
|
| 54 |
+
"""Generate solutions for a prompt.
|
| 55 |
+
|
| 56 |
+
Returns a list of generated completions.
|
| 57 |
+
"""
|
| 58 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
| 59 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 60 |
|
| 61 |
+
outputs = model.generate(
|
| 62 |
+
**inputs,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
+
temperature=temperature,
|
| 65 |
+
top_p=top_p,
|
| 66 |
+
do_sample=True,
|
| 67 |
+
repetition_penalty=1.1,
|
| 68 |
+
num_return_sequences=num_return_sequences,
|
| 69 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 70 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 71 |
+
)
|
| 72 |
|
| 73 |
+
completions = []
|
| 74 |
+
for output in outputs:
|
| 75 |
+
text = tokenizer.decode(output, skip_special_tokens=True)
|
| 76 |
+
# Remove the prompt from the completion
|
| 77 |
+
if text.startswith(prompt):
|
| 78 |
+
text = text[len(prompt):]
|
| 79 |
+
completions.append(text.strip())
|
| 80 |
|
| 81 |
+
return completions
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def extract_code(completion: str) -> str:
|
| 85 |
+
"""Extract code from completion, handling markdown code blocks."""
|
| 86 |
+
# Try to extract from ```python blocks
|
| 87 |
if "```python" in completion:
|
| 88 |
start = completion.find("```python") + len("```python")
|
| 89 |
end = completion.find("```", start)
|
| 90 |
if end != -1:
|
| 91 |
+
return completion[start:end].strip()
|
| 92 |
+
|
| 93 |
+
# Try ``` blocks (generic)
|
| 94 |
+
if "```" in completion:
|
| 95 |
start = completion.find("```") + len("```")
|
| 96 |
end = completion.find("```", start)
|
| 97 |
if end != -1:
|
| 98 |
+
return completion[start:end].strip()
|
| 99 |
|
| 100 |
+
# If no code blocks, return as-is but clean up
|
| 101 |
+
return completion.strip()
|
| 102 |
|
| 103 |
|
| 104 |
+
def execute_code(code: str, timeout: int = 5) -> Tuple[bool, str, Optional[any]]:
|
| 105 |
+
"""Safely execute code and return (success, error_msg, result).
|
| 106 |
+
|
| 107 |
+
Uses restricted builtins and timeout for safety.
|
| 108 |
+
"""
|
| 109 |
+
import signal
|
| 110 |
+
|
| 111 |
+
class TimeoutError(Exception):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def timeout_handler(signum, frame):
|
| 115 |
+
raise TimeoutError("Execution timed out")
|
| 116 |
+
|
| 117 |
+
# Restricted globals for safe execution
|
| 118 |
+
safe_builtins = {
|
| 119 |
+
'print': print,
|
| 120 |
+
'len': len,
|
| 121 |
+
'range': range,
|
| 122 |
+
'str': str,
|
| 123 |
+
'int': int,
|
| 124 |
+
'float': float,
|
| 125 |
+
'bool': bool,
|
| 126 |
+
'list': list,
|
| 127 |
+
'dict': dict,
|
| 128 |
+
'set': set,
|
| 129 |
+
'tuple': tuple,
|
| 130 |
+
'sum': sum,
|
| 131 |
+
'min': min,
|
| 132 |
+
'max': max,
|
| 133 |
+
'abs': abs,
|
| 134 |
+
'sorted': sorted,
|
| 135 |
+
'reversed': reversed,
|
| 136 |
+
'enumerate': enumerate,
|
| 137 |
+
'zip': zip,
|
| 138 |
+
'map': map,
|
| 139 |
+
'filter': filter,
|
| 140 |
+
'any': any,
|
| 141 |
+
'all': all,
|
| 142 |
+
'isinstance': isinstance,
|
| 143 |
+
'type': type,
|
| 144 |
+
'round': round,
|
| 145 |
+
'pow': pow,
|
| 146 |
+
'divmod': divmod,
|
| 147 |
+
'ord': ord,
|
| 148 |
+
'chr': chr,
|
| 149 |
+
'hex': hex,
|
| 150 |
+
'bin': bin,
|
| 151 |
+
'id': id,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
namespace = {
|
| 155 |
+
'__builtins__': safe_builtins,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
try:
|
| 159 |
+
# Set timeout
|
| 160 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 161 |
+
signal.alarm(timeout)
|
| 162 |
+
|
| 163 |
exec(code, namespace)
|
| 164 |
|
| 165 |
+
# Cancel alarm
|
| 166 |
+
signal.alarm(0)
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
return True, "", namespace.get('result')
|
| 169 |
+
|
| 170 |
+
except TimeoutError as e:
|
| 171 |
+
signal.alarm(0)
|
| 172 |
+
return False, f"Timeout after {timeout}s", None
|
| 173 |
+
except SyntaxError as e:
|
| 174 |
+
signal.alarm(0)
|
| 175 |
+
return False, f"Syntax error: {e}", None
|
| 176 |
except Exception as e:
|
| 177 |
+
signal.alarm(0)
|
| 178 |
+
return False, f"{type(e).__name__}: {e}", None
|
| 179 |
|
| 180 |
|
| 181 |
+
def check_correctness(code: str, test_cases: List[Dict]) -> Tuple[bool, str]:
|
| 182 |
+
"""Check if generated code passes test cases.
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
Args:
|
| 185 |
+
code: The generated code to test
|
| 186 |
+
test_cases: List of dicts with 'input' and 'expected' keys
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Tuple of (all_passed, failure_message)
|
| 190 |
+
"""
|
| 191 |
+
import types
|
| 192 |
+
|
| 193 |
+
# Create a new namespace for execution
|
| 194 |
+
namespace = {}
|
| 195 |
+
safe_builtins = {
|
| 196 |
+
'print': print,
|
| 197 |
+
'len': len,
|
| 198 |
+
'range': range,
|
| 199 |
+
'str': str,
|
| 200 |
+
'int': int,
|
| 201 |
+
'float': float,
|
| 202 |
+
'bool': bool,
|
| 203 |
+
'list': list,
|
| 204 |
+
'dict': dict,
|
| 205 |
+
'set': set,
|
| 206 |
+
'tuple': tuple,
|
| 207 |
+
'sum': sum,
|
| 208 |
+
'min': min,
|
| 209 |
+
'max': max,
|
| 210 |
+
'abs': abs,
|
| 211 |
+
'sorted': sorted,
|
| 212 |
+
'reversed': reversed,
|
| 213 |
+
'enumerate': enumerate,
|
| 214 |
+
'zip': zip,
|
| 215 |
+
'map': map,
|
| 216 |
+
'filter': filter,
|
| 217 |
+
'any': any,
|
| 218 |
+
'all': all,
|
| 219 |
+
'isinstance': isinstance,
|
| 220 |
+
'type': type,
|
| 221 |
+
'round': round,
|
| 222 |
+
'pow': pow,
|
| 223 |
+
}
|
| 224 |
+
namespace['__builtins__'] = safe_builtins
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
exec(code, namespace)
|
| 228 |
+
except Exception as e:
|
| 229 |
+
return False, f"Execution failed: {type(e).__name__}: {e}"
|
| 230 |
+
|
| 231 |
+
for tc in test_cases:
|
| 232 |
+
func_name = tc.get('function', 'solution')
|
| 233 |
+
inputs = tc.get('input', ())
|
| 234 |
+
expected = tc.get('expected')
|
| 235 |
+
description = tc.get('description', '')
|
| 236 |
+
|
| 237 |
+
if func_name not in namespace:
|
| 238 |
+
return False, f"Function '{func_name}' not found in code"
|
| 239 |
+
|
| 240 |
+
func = namespace[func_name]
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
if isinstance(inputs, tuple):
|
| 244 |
+
result = func(*inputs)
|
| 245 |
+
else:
|
| 246 |
+
result = func(inputs)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
return False, f"Failed on {description or str(inputs)}: {type(e).__name__}: {e}"
|
| 249 |
+
|
| 250 |
+
if result != expected:
|
| 251 |
+
return False, f"Failed on {description or str(inputs)}: expected {expected}, got {result}"
|
| 252 |
+
|
| 253 |
+
return True, ""
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def calculate_pass_at_k(num_correct: int, num_samples: int, k: int) -> float:
|
| 257 |
+
"""Calculate pass@k metric.
|
| 258 |
+
|
| 259 |
+
Uses the estimator: 1 - C(n-c+k-1, k) / C(n+k-1, k)
|
| 260 |
+
where n = num_samples, c = num_correct, k = k
|
| 261 |
+
|
| 262 |
+
For small samples, this is more accurate than simple c/n.
|
| 263 |
+
"""
|
| 264 |
+
import math
|
| 265 |
+
|
| 266 |
+
if num_samples < k:
|
| 267 |
+
return 0.0
|
| 268 |
+
|
| 269 |
+
if num_samples == 0:
|
| 270 |
+
return 0.0
|
| 271 |
+
|
| 272 |
+
# Bootstrap-style calculation
|
| 273 |
+
# "At least one of k samples is correct" probability
|
| 274 |
+
try:
|
| 275 |
+
# Exact formula: 1 - (C(n-c, k) / C(n, k))
|
| 276 |
+
# But we use the complementary for numerical stability
|
| 277 |
+
correct = num_correct
|
| 278 |
+
n = num_samples
|
| 279 |
+
fail = n - correct
|
| 280 |
+
|
| 281 |
+
if fail >= k:
|
| 282 |
+
return 0.0
|
| 283 |
+
|
| 284 |
+
# Calculate probability that at least one succeeds
|
| 285 |
+
# P(at least 1 success) = 1 - P(all k fail)
|
| 286 |
+
# P(all k fail) = C(fail, k) / C(n, k)
|
| 287 |
+
|
| 288 |
+
numerator = 1.0
|
| 289 |
+
denominator = 1.0
|
| 290 |
+
|
| 291 |
+
for i in range(k):
|
| 292 |
+
numerator *= (fail - i)
|
| 293 |
+
denominator *= (n - i)
|
| 294 |
+
|
| 295 |
+
p_all_fail = numerator / denominator
|
| 296 |
+
p_at_least_1_success = 1 - p_all_fail
|
| 297 |
+
|
| 298 |
+
return p_at_least_1_success
|
| 299 |
+
except:
|
| 300 |
+
# Fallback to simple ratio
|
| 301 |
+
return num_correct / num_samples
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def evaluate_problems(
|
| 305 |
+
model,
|
| 306 |
+
tokenizer,
|
| 307 |
+
problems: List[Dict],
|
| 308 |
+
k_values: List[int] = [1, 10],
|
| 309 |
+
num_samples_per_problem: int = 10,
|
| 310 |
+
max_new_tokens: int = 256,
|
| 311 |
+
) -> Dict:
|
| 312 |
+
"""Evaluate model on a set of problems with pass@k metrics.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
model: The language model
|
| 316 |
+
tokenizer: The tokenizer
|
| 317 |
+
problems: List of problem dicts with 'task_id', 'prompt', 'test_cases'
|
| 318 |
+
k_values: List of k values for pass@k calculation
|
| 319 |
+
num_samples_per_problem: Number of samples to generate per problem
|
| 320 |
+
max_new_tokens: Max tokens to generate
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Dictionary with evaluation results
|
| 324 |
+
"""
|
| 325 |
+
all_results = []
|
| 326 |
+
total_correct_per_k = {k: 0 for k in k_values}
|
| 327 |
+
total_problems = len(problems)
|
| 328 |
+
|
| 329 |
+
for idx, problem in enumerate(problems):
|
| 330 |
+
task_id = problem['task_id']
|
| 331 |
+
prompt = problem['prompt']
|
| 332 |
+
test_cases = problem.get('test_cases', [])
|
| 333 |
+
|
| 334 |
+
print(f"\n[{idx+1}/{total_problems}] Processing: {task_id}")
|
| 335 |
+
|
| 336 |
+
# Generate multiple samples
|
| 337 |
+
start_time = time.time()
|
| 338 |
+
completions = generate_solution(
|
| 339 |
+
model, tokenizer, prompt,
|
| 340 |
+
max_new_tokens=max_new_tokens,
|
| 341 |
+
num_return_sequences=num_samples_per_problem
|
| 342 |
+
)
|
| 343 |
+
elapsed = time.time() - start_time
|
| 344 |
+
|
| 345 |
+
print(f" Generated {len(completions)} samples in {elapsed:.2f}s")
|
| 346 |
+
|
| 347 |
+
# Check each completion
|
| 348 |
+
correct_flags = []
|
| 349 |
+
for i, completion in enumerate(completions):
|
| 350 |
+
code = extract_code(completion)
|
| 351 |
+
|
| 352 |
+
# For pass@10, we consider the completion correct if it passes tests
|
| 353 |
+
# For pass@1, we only consider the first sample
|
| 354 |
+
passed, msg = check_correctness(code, test_cases)
|
| 355 |
+
correct_flags.append(passed)
|
| 356 |
+
|
| 357 |
+
if i == 0: # Show first result detail
|
| 358 |
+
print(f" Sample 1: {'✅' if passed else '❌'} {msg[:60] if msg else 'OK'}")
|
| 359 |
+
|
| 360 |
+
# Calculate pass@k for this problem
|
| 361 |
+
num_correct = sum(correct_flags)
|
| 362 |
+
|
| 363 |
+
problem_results = {
|
| 364 |
+
"task_id": task_id,
|
| 365 |
+
"prompt": prompt,
|
| 366 |
+
"num_samples": len(completions),
|
| 367 |
+
"num_correct": num_correct,
|
| 368 |
+
"pass@k": {},
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
for k in k_values:
|
| 372 |
+
if k <= num_samples_per_problem:
|
| 373 |
+
# How many of the first k samples are correct?
|
| 374 |
+
correct_in_k = sum(correct_flags[:min(k, len(correct_flags))])
|
| 375 |
+
if k == 1:
|
| 376 |
+
# pass@1 = whether first sample is correct
|
| 377 |
+
pass_at_k = 1.0 if correct_flags[0] else 0.0
|
| 378 |
+
else:
|
| 379 |
+
# pass@k = probability that at least one of k is correct
|
| 380 |
+
pass_at_k = calculate_pass_at_k(correct_in_k, k, k)
|
| 381 |
+
|
| 382 |
+
problem_results["pass@k"][f"pass@{k}"] = pass_at_k
|
| 383 |
+
total_correct_per_k[k] += correct_in_k
|
| 384 |
+
|
| 385 |
+
all_results.append(problem_results)
|
| 386 |
+
|
| 387 |
+
# Progress update
|
| 388 |
+
if k_values[0] == 1:
|
| 389 |
+
current_pass1 = total_correct_per_k.get(1, 0) / (idx + 1)
|
| 390 |
+
print(f" Running Pass@1: {100*current_pass1:.1f}%")
|
| 391 |
+
|
| 392 |
+
# Aggregate results
|
| 393 |
+
summary = {
|
| 394 |
+
"total_problems": total_problems,
|
| 395 |
+
"total_samples_per_problem": num_samples_per_problem,
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
for k in k_values:
|
| 399 |
+
if k <= num_samples_per_problem:
|
| 400 |
+
# Overall pass@k
|
| 401 |
+
total_correct_for_k = 0
|
| 402 |
+
total_possible_for_k = 0
|
| 403 |
+
|
| 404 |
+
for r in all_results:
|
| 405 |
+
if f"pass@{k}" in r["pass@k"]:
|
| 406 |
+
# For pass@1, it's binary
|
| 407 |
+
if k == 1:
|
| 408 |
+
total_correct_for_k += r["num_correct"] > 0
|
| 409 |
+
else:
|
| 410 |
+
# For pass@10, count how many problems have at least 1 correct in first k
|
| 411 |
+
correct_in_k = sum(correct_flags[:min(k, len(correct_flags))])
|
| 412 |
+
total_correct_for_k += 1 if correct_in_k > 0 else 0
|
| 413 |
+
total_possible_for_k += 1
|
| 414 |
+
|
| 415 |
+
summary[f"pass@{k}"] = total_correct_for_k / total_possible_for_k if total_possible_for_k > 0 else 0
|
| 416 |
+
summary[f"pass@{k}_exact"] = total_correct_for_k
|
| 417 |
+
summary[f"total@{k}"] = total_possible_for_k
|
| 418 |
+
|
| 419 |
+
return {
|
| 420 |
+
"summary": summary,
|
| 421 |
+
"details": all_results,
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def get_humaneval_problems() -> List[Dict]:
|
| 426 |
+
"""Return HumanEval benchmark problems."""
|
| 427 |
+
return [
|
| 428 |
+
{
|
| 429 |
+
"task_id": "humaneval/1",
|
| 430 |
+
"prompt": '''def two_sum(nums, target):
|
| 431 |
+
"""Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.
|
| 432 |
+
You may assume that each input would have exactly one solution, and you may not use the same element twice.
|
| 433 |
+
"""''',
|
| 434 |
+
"test_cases": [
|
| 435 |
+
{"function": "two_sum", "input": ([2,7,11,15], 9), "expected": [0,1], "description": "Basic case"},
|
| 436 |
+
{"function": "two_sum", "input": ([3,2,4], 6), "expected": [1,2], "description": "Middle elements"},
|
| 437 |
+
{"function": "two_sum", "input": ([3,3], 6), "expected": [0,1], "description": "Duplicate values"},
|
| 438 |
+
],
|
| 439 |
+
},
|
| 440 |
{
|
| 441 |
+
"task_id": "humaneval/2",
|
| 442 |
+
"prompt": '''def is_palindrome(x):
|
| 443 |
+
"""Check if an integer is a palindrome. An integer is a palindrome when it reads the same backward as forward."''',
|
| 444 |
+
"test_cases": [
|
| 445 |
+
{"function": "is_palindrome", "input": 121, "expected": True, "description": "Positive palindrome"},
|
| 446 |
+
{"function": "is_palindrome", "input": -121, "expected": False, "description": "Negative number"},
|
| 447 |
+
{"function": "is_palindrome", "input": 10, "expected": False, "description": "Ends with 0"},
|
| 448 |
+
],
|
| 449 |
},
|
| 450 |
{
|
| 451 |
+
"task_id": "humaneval/3",
|
| 452 |
+
"prompt": '''def fizz_buzz(n):
|
| 453 |
+
"""Given number n, return a list of strings from 1 to n. For multiples of 3 add 'Fizz', for multiples of 5 add 'Buzz', for both add 'FizzBuzz'."''',
|
| 454 |
+
"test_cases": [
|
| 455 |
+
{"function": "fizz_buzz", "input": 3, "expected": ["1", "2", "Fizz"], "description": "n=3"},
|
| 456 |
+
{"function": "fizz_buzz", "input": 5, "expected": ["1", "2", "Fizz", "4", "Buzz"], "description": "n=5"},
|
| 457 |
+
{"function": "fizz_buzz", "input": 15, "expected": ["1","2","Fizz","4","Buzz","Fizz","7","8","Fizz","Buzz","11","Fizz","13","14","FizzBuzz"], "description": "n=15"},
|
| 458 |
+
],
|
| 459 |
},
|
| 460 |
{
|
| 461 |
+
"task_id": "humaneval/4",
|
| 462 |
+
"prompt": '''def fibonacci(n):
|
| 463 |
+
"""Return the first n Fibonacci numbers starting from 0 and 1. So fibonacci(7) returns [0, 1, 1, 2, 3, 5, 8]."''',
|
| 464 |
+
"test_cases": [
|
| 465 |
+
{"function": "fibonacci", "input": 1, "expected": [0], "description": "n=1"},
|
| 466 |
+
{"function": "fibonacci", "input": 7, "expected": [0, 1, 1, 2, 3, 5, 8], "description": "n=7"},
|
| 467 |
+
{"function": "fibonacci", "input": 10, "expected": [0, 1, 1, 2, 3, 5, 8, 13, 21, 34], "description": "n=10"},
|
| 468 |
+
],
|
| 469 |
},
|
| 470 |
{
|
| 471 |
+
"task_id": "humaneval/5",
|
| 472 |
+
"prompt": '''def valid_parentheses(s):
|
| 473 |
+
"""Given a string s containing just the characters '(', ')', '{', '}', '[' and ']', determine if the input string is valid. An input string is valid if: Open brackets must be closed by the same type of brackets, and Open brackets must be closed in the correct order."''',
|
| 474 |
+
"test_cases": [
|
| 475 |
+
{"function": "valid_parentheses", "input": "()", "expected": True, "description": "Simple pair"},
|
| 476 |
+
{"function": "valid_parentheses", "input": "()[]{}", "expected": True, "description": "Multiple types"},
|
| 477 |
+
{"function": "valid_parentheses", "input": "(]", "expected": False, "description": "Mismatched"},
|
| 478 |
+
{"function": "valid_parentheses", "input": "([)]", "expected": False, "description": "Wrong order"},
|
| 479 |
+
],
|
| 480 |
},
|
| 481 |
{
|
| 482 |
+
"task_id": "humaneval/6",
|
| 483 |
+
"prompt": '''def reverse_string(s):
|
| 484 |
+
"""Return the reverse of string s."''',
|
| 485 |
+
"test_cases": [
|
| 486 |
+
{"function": "reverse_string", "input": "hello", "expected": "olleh", "description": "Basic"},
|
| 487 |
+
{"function": "reverse_string", "input": "Hannah", "expected": "hannaH", "description": "Palindrome name"},
|
| 488 |
+
],
|
| 489 |
},
|
| 490 |
{
|
| 491 |
+
"task_id": "humaneval/7",
|
| 492 |
+
"prompt": '''def merge_sorted_lists(l1, l2):
|
| 493 |
+
"""Merge two sorted lists into one sorted list."''',
|
| 494 |
+
"test_cases": [
|
| 495 |
+
{"function": "merge_sorted_lists", "input": ([1,3,5], [2,4,6]), "expected": [1,2,3,4,5,6], "description": "Interleaved"},
|
| 496 |
+
{"function": "merge_sorted_lists", "input": ([1,2,3], [4,5,6]), "expected": [1,2,3,4,5,6], "description": "Sequential"},
|
| 497 |
+
],
|
| 498 |
},
|
| 499 |
{
|
| 500 |
+
"task_id": "humaneval/8",
|
| 501 |
+
"prompt": '''def maximum_subarray(nums):
|
| 502 |
+
"""Find the contiguous subarray which has the largest sum and return its sum."''',
|
| 503 |
+
"test_cases": [
|
| 504 |
+
{"function": "maximum_subarray", "input": [-2,1,-3,4,-1,2,1,-5,4], "expected": 6, "description": "Mixed"},
|
| 505 |
+
{"function": "maximum_subarray", "input": [1], "expected": 1, "description": "Single element"},
|
| 506 |
+
{"function": "maximum_subarray", "input": [5,4,-1,7,8], "expected": 23, "description": "Mostly positive"},
|
| 507 |
+
],
|
| 508 |
},
|
| 509 |
{
|
| 510 |
+
"task_id": "humaneval/9",
|
| 511 |
+
"prompt": '''def climbing_stairs(n):
|
| 512 |
+
"""You are climbing a staircase. It takes n steps to reach the top. Each time you can either climb 1 or 2 steps. In how many distinct ways can you climb to the top?"''',
|
| 513 |
+
"test_cases": [
|
| 514 |
+
{"function": "climbing_stairs", "input": 2, "expected": 2, "description": "n=2"},
|
| 515 |
+
{"function": "climbing_stairs", "input": 3, "expected": 3, "description": "n=3"},
|
| 516 |
+
{"function": "climbing_stairs", "input": 5, "expected": 8, "description": "n=5"},
|
| 517 |
+
],
|
| 518 |
},
|
| 519 |
{
|
| 520 |
+
"task_id": "humaneval/10",
|
| 521 |
+
"prompt": '''def contains_duplicate(nums):
|
| 522 |
+
"""Given an integer array nums, return True if any value appears at least twice in the array, and False if every element is distinct."''',
|
| 523 |
+
"test_cases": [
|
| 524 |
+
{"function": "contains_duplicate", "input": [1,2,3,1], "expected": True, "description": "Has duplicate"},
|
| 525 |
+
{"function": "contains_duplicate", "input": [1,2,3,4], "expected": False, "description": "All unique"},
|
| 526 |
+
],
|
| 527 |
},
|
| 528 |
{
|
| 529 |
+
"task_id": "humaneval/11",
|
| 530 |
+
"prompt": '''def roman_to_int(s):
|
| 531 |
+
"""Convert a Roman numeral to an integer."''',
|
| 532 |
+
"test_cases": [
|
| 533 |
+
{"function": "roman_to_int", "input": "III", "expected": 3, "description": "Simple"},
|
| 534 |
+
{"function": "roman_to_int", "input": "IV", "expected": 4, "description": "Subtractive"},
|
| 535 |
+
{"function": "roman_to_int", "input": "MCMXCIV", "expected": 1994, "description": "Complex"},
|
| 536 |
+
],
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"task_id": "humaneval/12",
|
| 540 |
+
"prompt": '''def longest_common_prefix(strs):
|
| 541 |
+
"""Write a function to find the longest common prefix string amongst an array of strings."''',
|
| 542 |
+
"test_cases": [
|
| 543 |
+
{"function": "longest_common_prefix", "input": ["flower","flow","flight"], "expected": "fl", "description": "Basic"},
|
| 544 |
+
{"function": "longest_common_prefix", "input": ["dog","racecar","car"], "expected": "", "description": "No prefix"},
|
| 545 |
+
],
|
| 546 |
},
|
| 547 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
|
| 549 |
|
| 550 |
+
def get_mbpp_problems() -> List[Dict]:
|
| 551 |
+
"""Return MBPP (Mostly Basic Python Problems) benchmark problems."""
|
| 552 |
+
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
{
|
| 554 |
+
"task_id": "mbpp/1",
|
| 555 |
+
"prompt": '''def add_numbers(a, b):
|
| 556 |
+
# Return the sum of a and b
|
| 557 |
+
pass''',
|
| 558 |
+
"test_cases": [
|
| 559 |
+
{"function": "add_numbers", "input": (2, 3), "expected": 5, "description": "Basic add"},
|
| 560 |
+
{"function": "add_numbers", "input": (-1, 1), "expected": 0, "description": "Opposite signs"},
|
| 561 |
+
],
|
| 562 |
},
|
| 563 |
{
|
| 564 |
+
"task_id": "mbpp/2",
|
| 565 |
+
"prompt": '''def multiply_list(nums):
|
| 566 |
+
# Return the product of all numbers in the list
|
| 567 |
+
pass''',
|
| 568 |
+
"test_cases": [
|
| 569 |
+
{"function": "multiply_list", "input": ([1, 2, 3, 4],), "expected": 24, "description": "Basic"},
|
| 570 |
+
{"function": "multiply_list", "input": ([5,],), "expected": 5, "description": "Single element"},
|
| 571 |
+
{"function": "multiply_list", "input": ([],), "expected": 1, "description": "Empty (identity)"},
|
| 572 |
+
],
|
| 573 |
},
|
| 574 |
{
|
| 575 |
+
"task_id": "mbpp/3",
|
| 576 |
+
"prompt": '''def square(x):
|
| 577 |
+
# Return the square of x
|
| 578 |
+
pass''',
|
| 579 |
+
"test_cases": [
|
| 580 |
+
{"function": "square", "input": (5,), "expected": 25, "description": "Basic"},
|
| 581 |
+
{"function": "square", "input": (-3,), "expected": 9, "description": "Negative"},
|
| 582 |
+
{"function": "square", "input": (0,), "expected": 0, "description": "Zero"},
|
| 583 |
+
],
|
| 584 |
},
|
| 585 |
{
|
| 586 |
+
"task_id": "mbpp/4",
|
| 587 |
+
"prompt": '''def is_even(n):
|
| 588 |
+
# Return True if n is even, False otherwise
|
| 589 |
+
pass''',
|
| 590 |
+
"test_cases": [
|
| 591 |
+
{"function": "is_even", "input": (4,), "expected": True, "description": "Even number"},
|
| 592 |
+
{"function": "is_even", "input": (7,), "expected": False, "description": "Odd number"},
|
| 593 |
+
{"function": "is_even", "input": (0,), "expected": True, "description": "Zero is even"},
|
| 594 |
+
],
|
| 595 |
},
|
| 596 |
{
|
| 597 |
+
"task_id": "mbpp/5",
|
| 598 |
+
"prompt": '''def string_length(s):
|
| 599 |
+
# Return the length of string s
|
| 600 |
+
pass''',
|
| 601 |
+
"test_cases": [
|
| 602 |
+
{"function": "string_length", "input": ("hello",), "expected": 5, "description": "Basic"},
|
| 603 |
+
{"function": "string_length", "input": ("",), "expected": 0, "description": "Empty string"},
|
| 604 |
+
],
|
| 605 |
},
|
| 606 |
{
|
| 607 |
+
"task_id": "mbpp/6",
|
| 608 |
+
"prompt": '''def get_max(nums):
|
| 609 |
+
# Return the maximum number from the list
|
| 610 |
+
pass''',
|
| 611 |
+
"test_cases": [
|
| 612 |
+
{"function": "get_max", "input": ([1, 5, 3],), "expected": 5, "description": "Basic"},
|
| 613 |
+
{"function": "get_max", "input": ([-1, -5, -3],), "expected": -1, "description": "Negative numbers"},
|
| 614 |
+
],
|
| 615 |
},
|
| 616 |
{
|
| 617 |
+
"task_id": "mbpp/7",
|
| 618 |
+
"prompt": '''def get_min(nums):
|
| 619 |
+
# Return the minimum number from the list
|
| 620 |
+
pass''',
|
| 621 |
+
"test_cases": [
|
| 622 |
+
{"function": "get_min", "input": ([1, 5, 3],), "expected": 1, "description": "Basic"},
|
| 623 |
+
{"function": "get_min", "input": ([-1, -5, -3],), "expected": -5, "description": "Negative numbers"},
|
| 624 |
+
],
|
| 625 |
},
|
| 626 |
{
|
| 627 |
+
"task_id": "mbpp/8",
|
| 628 |
+
"prompt": '''def count_zeros(nums):
|
| 629 |
+
# Return the count of zeros in the list
|
| 630 |
+
pass''',
|
| 631 |
+
"test_cases": [
|
| 632 |
+
{"function": "count_zeros", "input": ([0, 1, 0, 2, 0],), "expected": 3, "description": "Mixed"},
|
| 633 |
+
{"function": "count_zeros", "input": ([1, 2, 3],), "expected": 0, "description": "No zeros"},
|
| 634 |
+
],
|
| 635 |
},
|
| 636 |
{
|
| 637 |
+
"task_id": "mbpp/9",
|
| 638 |
+
"prompt": '''def reverse_list(lst):
|
| 639 |
+
# Return a new list with elements in reverse order
|
| 640 |
+
pass''',
|
| 641 |
+
"test_cases": [
|
| 642 |
+
{"function": "reverse_list", "input": ([1, 2, 3],), "expected": [3, 2, 1], "description": "Basic"},
|
| 643 |
+
{"function": "reverse_list", "input": ([],), "expected": [], "description": "Empty"},
|
| 644 |
+
],
|
| 645 |
},
|
| 646 |
{
|
| 647 |
+
"task_id": "mbpp/10",
|
| 648 |
+
"prompt": '''def unique_elements(lst):
|
| 649 |
+
# Return list of unique elements (preserving order)
|
| 650 |
+
pass''',
|
| 651 |
+
"test_cases": [
|
| 652 |
+
{"function": "unique_elements", "input": ([1, 2, 2, 3],), "expected": [1, 2, 3], "description": "With duplicates"},
|
| 653 |
+
{"function": "unique_elements", "input": ([1, 2, 3],), "expected": [1, 2, 3], "description": "All unique"},
|
| 654 |
+
],
|
| 655 |
+
},
|
| 656 |
+
{
|
| 657 |
+
"task_id": "mbpp/11",
|
| 658 |
+
"prompt": '''def factorial(n):
|
| 659 |
+
# Return n! (factorial of n)
|
| 660 |
+
pass''',
|
| 661 |
+
"test_cases": [
|
| 662 |
+
{"function": "factorial", "input": (5,), "expected": 120, "description": "Basic"},
|
| 663 |
+
{"function": "factorial", "input": (0,), "expected": 1, "description": "Zero factorial"},
|
| 664 |
+
{"function": "factorial", "input": (1,), "expected": 1, "description": "One factorial"},
|
| 665 |
+
],
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
"task_id": "mbpp/12",
|
| 669 |
+
"prompt": '''def is_prime(n):
|
| 670 |
+
# Return True if n is prime, False otherwise
|
| 671 |
+
pass''',
|
| 672 |
+
"test_cases": [
|
| 673 |
+
{"function": "is_prime", "input": (7,), "expected": True, "description": "Prime"},
|
| 674 |
+
{"function": "is_prime", "input": (4,), "expected": False, "description": "Not prime"},
|
| 675 |
+
{"function": "is_prime", "input": (1,), "expected": False, "description": "One is not prime"},
|
| 676 |
+
],
|
| 677 |
},
|
| 678 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
|
| 681 |
+
def save_results(results: Dict, output_path: str):
|
| 682 |
+
"""Save evaluation results to JSON file."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
with open(output_path, 'w') as f:
|
| 684 |
+
json.dump(results, f, indent=2)
|
|
|
|
| 685 |
print(f"\n✅ Results saved to: {output_path}")
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def print_summary(results: Dict, benchmark_name: str):
|
| 689 |
+
"""Print a summary of evaluation results."""
|
| 690 |
+
print(f"\n{'='*60}")
|
| 691 |
+
print(f"{benchmark_name} Results")
|
| 692 |
+
print('='*60)
|
| 693 |
+
|
| 694 |
+
summary = results.get("summary", {})
|
| 695 |
+
total = summary.get("total_problems", 0)
|
| 696 |
+
|
| 697 |
+
for key, value in summary.items():
|
| 698 |
+
if key.startswith("pass@"):
|
| 699 |
+
print(f" {key}: {100*value:.2f}%")
|
| 700 |
+
elif key.endswith("_exact") or key.endswith("_total") or key == "total_problems" or key == "total_samples_per_problem":
|
| 701 |
+
print(f" {key}: {value}")
|
| 702 |
+
|
| 703 |
+
print(f"\n Total problems evaluated: {total}")
|
| 704 |
+
print('='*60)
|
| 705 |
|
| 706 |
|
| 707 |
def main():
|
| 708 |
+
parser = argparse.ArgumentParser(
|
| 709 |
+
description="Evaluate Stack 2.9 model on HumanEval and MBPP benchmarks"
|
| 710 |
+
)
|
| 711 |
+
parser.add_argument(
|
| 712 |
+
"--model-path",
|
| 713 |
+
type=str,
|
| 714 |
+
required=True,
|
| 715 |
+
help="Path to the merged model directory"
|
| 716 |
+
)
|
| 717 |
+
parser.add_argument(
|
| 718 |
+
"--output",
|
| 719 |
+
type=str,
|
| 720 |
+
default="evaluation_results.json",
|
| 721 |
+
help="Output file for results (default: evaluation_results.json)"
|
| 722 |
+
)
|
| 723 |
+
parser.add_argument(
|
| 724 |
+
"--num-samples",
|
| 725 |
+
type=int,
|
| 726 |
+
default=10,
|
| 727 |
+
help="Number of samples per problem for pass@k (default: 10)"
|
| 728 |
+
)
|
| 729 |
+
parser.add_argument(
|
| 730 |
+
"--max-new-tokens",
|
| 731 |
+
type=int,
|
| 732 |
+
default=256,
|
| 733 |
+
help="Maximum new tokens to generate (default: 256)"
|
| 734 |
+
)
|
| 735 |
+
parser.add_argument(
|
| 736 |
+
"--k-values",
|
| 737 |
+
type=str,
|
| 738 |
+
default="1,10",
|
| 739 |
+
help="Comma-separated k values for pass@k (default: 1,10)"
|
| 740 |
+
)
|
| 741 |
+
parser.add_argument(
|
| 742 |
+
"--benchmark",
|
| 743 |
+
type=str,
|
| 744 |
+
choices=["humaneval", "mbpp", "both"],
|
| 745 |
+
default="both",
|
| 746 |
+
help="Which benchmark to run (default: both)"
|
| 747 |
+
)
|
| 748 |
+
parser.add_argument(
|
| 749 |
+
"--num-problems",
|
| 750 |
+
type=int,
|
| 751 |
+
default=None,
|
| 752 |
+
help="Limit number of problems per benchmark (default: all)"
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
args = parser.parse_args()
|
| 756 |
|
| 757 |
+
# Parse k values
|
| 758 |
+
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
| 759 |
+
|
| 760 |
print("="*60)
|
| 761 |
print("Stack 2.9 Model Evaluation")
|
| 762 |
print("="*60)
|
| 763 |
+
print(f"Model path: {args.model_path}")
|
| 764 |
+
print(f"Output: {args.output}")
|
| 765 |
+
print(f"Num samples per problem: {args.num_samples}")
|
| 766 |
+
print(f"Pass@k values: {k_values}")
|
| 767 |
+
print(f"Benchmark: {args.benchmark}")
|
| 768 |
|
| 769 |
+
# Load model
|
| 770 |
model, tokenizer = load_model(args.model_path)
|
| 771 |
model.eval()
|
| 772 |
|
| 773 |
+
all_results = {}
|
| 774 |
+
total_start = time.time()
|
|
|
|
| 775 |
|
| 776 |
+
# Run HumanEval
|
| 777 |
+
if args.benchmark in ["humaneval", "both"]:
|
| 778 |
+
print("\n" + "="*60)
|
| 779 |
+
print("Running HumanEval Benchmark")
|
| 780 |
+
print("="*60)
|
| 781 |
+
|
| 782 |
+
problems = get_humaneval_problems()
|
| 783 |
+
if args.num_problems:
|
| 784 |
+
problems = problems[:args.num_problems]
|
| 785 |
+
|
| 786 |
+
results = evaluate_problems(
|
| 787 |
+
model, tokenizer,
|
| 788 |
+
problems,
|
| 789 |
+
k_values=k_values,
|
| 790 |
+
num_samples_per_problem=args.num_samples,
|
| 791 |
+
max_new_tokens=args.max_new_tokens,
|
| 792 |
+
)
|
| 793 |
+
all_results["humaneval"] = results
|
| 794 |
+
print_summary(results, "HumanEval")
|
| 795 |
+
|
| 796 |
+
# Run MBPP
|
| 797 |
+
if args.benchmark in ["mbpp", "both"]:
|
| 798 |
+
print("\n" + "="*60)
|
| 799 |
+
print("Running MBPP Benchmark")
|
| 800 |
+
print("="*60)
|
| 801 |
+
|
| 802 |
+
problems = get_mbpp_problems()
|
| 803 |
+
if args.num_problems:
|
| 804 |
+
problems = problems[:args.num_problems]
|
| 805 |
+
|
| 806 |
+
results = evaluate_problems(
|
| 807 |
+
model, tokenizer,
|
| 808 |
+
problems,
|
| 809 |
+
k_values=k_values,
|
| 810 |
+
num_samples_per_problem=args.num_samples,
|
| 811 |
+
max_new_tokens=args.max_new_tokens,
|
| 812 |
+
)
|
| 813 |
+
all_results["mbpp"] = results
|
| 814 |
+
print_summary(results, "MBPP")
|
| 815 |
+
|
| 816 |
+
total_time = time.time() - total_start
|
| 817 |
|
| 818 |
+
# Final summary
|
| 819 |
print("\n" + "="*60)
|
| 820 |
print("FINAL SUMMARY")
|
| 821 |
print("="*60)
|
| 822 |
+
|
| 823 |
+
for bench_name in ["humaneval", "mbpp"]:
|
| 824 |
+
if bench_name in all_results:
|
| 825 |
+
summary = all_results[bench_name]["summary"]
|
| 826 |
+
for k in k_values:
|
| 827 |
+
key = f"pass@{k}"
|
| 828 |
+
if key in summary:
|
| 829 |
+
print(f" {bench_name.upper()} {key}: {100*summary[key]:.2f}%")
|
| 830 |
+
|
| 831 |
+
print(f"\n Total evaluation time: {total_time:.1f}s")
|
| 832 |
print("="*60)
|
| 833 |
+
|
| 834 |
+
# Add metadata to results
|
| 835 |
+
all_results["metadata"] = {
|
| 836 |
+
"model_path": args.model_path,
|
| 837 |
+
"num_samples": args.num_samples,
|
| 838 |
+
"k_values": k_values,
|
| 839 |
+
"total_time_seconds": total_time,
|
| 840 |
+
}
|
| 841 |
+
|
| 842 |
+
# Save results
|
| 843 |
+
save_results(all_results, args.output)
|
| 844 |
|
| 845 |
|
| 846 |
if __name__ == "__main__":
|
kaggle_train_stack29.ipynb
DELETED
|
@@ -1,288 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"metadata": {},
|
| 6 |
-
"source": [
|
| 7 |
-
"# 🚀 Stack 2.9 - Kaggle Training\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"**Free GPU training on Kaggle**\n",
|
| 10 |
-
"\n",
|
| 11 |
-
"This notebook trains a LoRA adapter for Stack 2.9 on **Qwen2.5-Coder-7B** using Kaggle's free GPU.\n",
|
| 12 |
-
"\n",
|
| 13 |
-
"⏱️ **Expected runtime:** 2-4 hours\n",
|
| 14 |
-
"💾 **VRAM needed:** ~16GB (Kaggle P100 has 16GB)\n",
|
| 15 |
-
"\n",
|
| 16 |
-
"---\n",
|
| 17 |
-
"\n",
|
| 18 |
-
"**Instructions:**\n",
|
| 19 |
-
"1. Enable GPU: Settings → Accelerator → GPU T4\n",
|
| 20 |
-
"2. Run cells in order from the top\n",
|
| 21 |
-
"3. Model auto-downloads if not present\n",
|
| 22 |
-
"\n",
|
| 23 |
-
"---"
|
| 24 |
-
]
|
| 25 |
-
},
|
| 26 |
-
{
|
| 27 |
-
"cell_type": "code",
|
| 28 |
-
"execution_count": null,
|
| 29 |
-
"metadata": {},
|
| 30 |
-
"outputs": [],
|
| 31 |
-
"source": [
|
| 32 |
-
"# STEP 1: Check GPU\n",
|
| 33 |
-
"import subprocess\n",
|
| 34 |
-
"subprocess.run([\"nvidia-smi\"], check=True)\n",
|
| 35 |
-
"print(\"✅ GPU ready!\")"
|
| 36 |
-
]
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"cell_type": "code",
|
| 40 |
-
"execution_count": null,
|
| 41 |
-
"metadata": {},
|
| 42 |
-
"outputs": [],
|
| 43 |
-
"source": [
|
| 44 |
-
"# STEP 2: Clone repo and setup paths\n",
|
| 45 |
-
"import os\n",
|
| 46 |
-
"import shutil\n",
|
| 47 |
-
"import subprocess\n",
|
| 48 |
-
"\n",
|
| 49 |
-
"os.chdir(\"/kaggle/working\")\n",
|
| 50 |
-
"\n",
|
| 51 |
-
"REPO_DIR = \"/kaggle/working/stack-2.9\"\n",
|
| 52 |
-
"MODEL_DIR = os.path.join(REPO_DIR, \"base_model_qwen7b\")\n",
|
| 53 |
-
"OUTPUT_DIR = os.path.join(REPO_DIR, \"training_output\")\n",
|
| 54 |
-
"\n",
|
| 55 |
-
"# Remove old repo if exists\n",
|
| 56 |
-
"if os.path.exists(REPO_DIR):\n",
|
| 57 |
-
" shutil.rmtree(REPO_DIR)\n",
|
| 58 |
-
"\n",
|
| 59 |
-
"# Clone fresh\n",
|
| 60 |
-
"subprocess.run([\"git\", \"clone\", \"https://github.com/my-ai-stack/stack-2.9.git\", REPO_DIR], check=True)\n",
|
| 61 |
-
"os.chdir(REPO_DIR)\n",
|
| 62 |
-
"\n",
|
| 63 |
-
"print(f\"✅ Working in: {os.getcwd()}\")"
|
| 64 |
-
]
|
| 65 |
-
},
|
| 66 |
-
{
|
| 67 |
-
"cell_type": "code",
|
| 68 |
-
"execution_count": null,
|
| 69 |
-
"metadata": {},
|
| 70 |
-
"outputs": [],
|
| 71 |
-
"source": [
|
| 72 |
-
"# STEP 3: Install dependencies\n",
|
| 73 |
-
"import subprocess\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"subprocess.run([\"pip\", \"install\", \"-q\", \"torch\", \"torchvision\", \"torchaudio\", \"--index-url\", \"https://download.pytorch.org/whl/cu118\"], check=True)\n",
|
| 76 |
-
"subprocess.run([\"pip\", \"install\", \"-q\", \"transformers==4.40.0\", \"peft==0.10.0\", \"accelerate==0.34.0\", \"datasets\", \"pyyaml\", \"tqdm\", \"scipy\", \"bitsandbytes==0.43.0\"], check=True)\n",
|
| 77 |
-
"print(\"✅ Dependencies installed\")"
|
| 78 |
-
]
|
| 79 |
-
},
|
| 80 |
-
{
|
| 81 |
-
"cell_type": "code",
|
| 82 |
-
"execution_count": null,
|
| 83 |
-
"metadata": {},
|
| 84 |
-
"outputs": [],
|
| 85 |
-
"source": [
|
| 86 |
-
"# STEP 4: Prepare training data\n",
|
| 87 |
-
"import os\n",
|
| 88 |
-
"import json\n",
|
| 89 |
-
"\n",
|
| 90 |
-
"# Check for available training data\n",
|
| 91 |
-
"REPO_TRAIN_DATA = os.path.join(REPO_DIR, \"training-data/final/train.jsonl\")\n",
|
| 92 |
-
"MINI_DATA_DIR = os.path.join(REPO_DIR, \"data_mini\")\n",
|
| 93 |
-
"MINI_DATA_FILE = os.path.join(MINI_DATA_DIR, \"train_mini.jsonl\")\n",
|
| 94 |
-
"SYNTHETIC_DATA_FILE = os.path.join(REPO_DIR, \"data/synthetic.jsonl\")\n",
|
| 95 |
-
"\n",
|
| 96 |
-
"print(\"🔍 Checking for training data...\")\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"if os.path.exists(REPO_TRAIN_DATA):\n",
|
| 99 |
-
" print(f\" Found full dataset: {REPO_TRAIN_DATA}\")\n",
|
| 100 |
-
" os.makedirs(MINI_DATA_DIR, exist_ok=True)\n",
|
| 101 |
-
" if not os.path.exists(MINI_DATA_FILE):\n",
|
| 102 |
-
" print(\" Creating mini dataset (1000 samples)...\")\n",
|
| 103 |
-
" import subprocess\n",
|
| 104 |
-
" subprocess.run([\"python\", os.path.join(REPO_DIR, \"scripts/create_mini_dataset.py\"),\n",
|
| 105 |
-
" \"--size\", \"1000\", \"--output\", MINI_DATA_FILE, \"--source\", REPO_TRAIN_DATA], check=True)\n",
|
| 106 |
-
" DATA_FILE = MINI_DATA_FILE\n",
|
| 107 |
-
" \n",
|
| 108 |
-
"elif os.path.exists(MINI_DATA_FILE):\n",
|
| 109 |
-
" DATA_FILE = MINI_DATA_FILE\n",
|
| 110 |
-
" print(f\" Using existing mini dataset: {MINI_DATA_FILE}\")\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"else:\n",
|
| 113 |
-
" print(\" No dataset found. Creating synthetic data...\")\n",
|
| 114 |
-
" \n",
|
| 115 |
-
" # Simple code completion examples\n",
|
| 116 |
-
" examples = [\n",
|
| 117 |
-
" {\"instruction\": \"Write a Python function to reverse a string\", \n",
|
| 118 |
-
" \"output\": \"def reverse_string(s):\\n return s[::-1]\"},\n",
|
| 119 |
-
" {\"instruction\": \"Write a function to check if a number is prime\", \n",
|
| 120 |
-
" \"output\": \"def is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(n**0.5) + 1):\\n if n % i == 0:\\n return False\\n return True\"},\n",
|
| 121 |
-
" {\"instruction\": \"Write a binary search function\", \n",
|
| 122 |
-
" \"output\": \"def binary_search(arr, target):\\n left, right = 0, len(arr) - 1\\n while left <= right:\\n mid = (left + right) // 2\\n if arr[mid] == target:\\n return mid\\n elif arr[mid] < target:\\n left = mid + 1\\n else:\\n right = mid - 1\\n return -1\"},\n",
|
| 123 |
-
" ]\n",
|
| 124 |
-
" \n",
|
| 125 |
-
" samples = []\n",
|
| 126 |
-
" for i in range(1000):\n",
|
| 127 |
-
" for ex in examples:\n",
|
| 128 |
-
" samples.append(ex)\n",
|
| 129 |
-
" \n",
|
| 130 |
-
" os.makedirs(os.path.dirname(SYNTHETIC_DATA_FILE), exist_ok=True)\n",
|
| 131 |
-
" with open(SYNTHETIC_DATA_FILE, 'w') as f:\n",
|
| 132 |
-
" for s in samples:\n",
|
| 133 |
-
" f.write(json.dumps(s) + '\\n')\n",
|
| 134 |
-
" \n",
|
| 135 |
-
" DATA_FILE = SYNTHETIC_DATA_FILE\n",
|
| 136 |
-
" print(f\" Created synthetic dataset: {len(samples)} samples\")\n",
|
| 137 |
-
"\n",
|
| 138 |
-
"print(f\"\\n✅ Using training data: {DATA_FILE}\")\n",
|
| 139 |
-
"print(f\" Size: {os.path.getsize(DATA_FILE) / 1024:.1f} KB\")"
|
| 140 |
-
]
|
| 141 |
-
},
|
| 142 |
-
{
|
| 143 |
-
"cell_type": "code",
|
| 144 |
-
"execution_count": null,
|
| 145 |
-
"metadata": {},
|
| 146 |
-
"outputs": [],
|
| 147 |
-
"source": [
|
| 148 |
-
"# STEP 5: Prepare config for training\n",
|
| 149 |
-
"import yaml\n",
|
| 150 |
-
"import os\n",
|
| 151 |
-
"\n",
|
| 152 |
-
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 153 |
-
"\n",
|
| 154 |
-
"config = {\n",
|
| 155 |
-
" 'model': {\n",
|
| 156 |
-
" 'name': 'Qwen/Qwen2.5-Coder-7B',\n",
|
| 157 |
-
" 'trust_remote_code': True,\n",
|
| 158 |
-
" 'torch_dtype': 'float16'\n",
|
| 159 |
-
" },\n",
|
| 160 |
-
" 'data': {\n",
|
| 161 |
-
" 'input_path': DATA_FILE,\n",
|
| 162 |
-
" 'max_length': 2048,\n",
|
| 163 |
-
" 'train_split': 1.0\n",
|
| 164 |
-
" },\n",
|
| 165 |
-
" 'lora': {\n",
|
| 166 |
-
" 'r': 16,\n",
|
| 167 |
-
" 'alpha': 32,\n",
|
| 168 |
-
" 'dropout': 0.05,\n",
|
| 169 |
-
" 'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
|
| 170 |
-
" 'bias': 'none',\n",
|
| 171 |
-
" 'task_type': 'CAUSAL_LM'\n",
|
| 172 |
-
" },\n",
|
| 173 |
-
" 'training': {\n",
|
| 174 |
-
" 'num_epochs': 1,\n",
|
| 175 |
-
" 'batch_size': 2,\n",
|
| 176 |
-
" 'gradient_accumulation': 4,\n",
|
| 177 |
-
" 'learning_rate': 2e-4,\n",
|
| 178 |
-
" 'warmup_steps': 50,\n",
|
| 179 |
-
" 'weight_decay': 0.01,\n",
|
| 180 |
-
" 'max_grad_norm': 1.0,\n",
|
| 181 |
-
" 'logging_steps': 10,\n",
|
| 182 |
-
" 'save_steps': 100,\n",
|
| 183 |
-
" 'save_total_limit': 2,\n",
|
| 184 |
-
" 'fp16': True,\n",
|
| 185 |
-
" 'bf16': False,\n",
|
| 186 |
-
" 'gradient_checkpointing': True\n",
|
| 187 |
-
" },\n",
|
| 188 |
-
" 'output': {\n",
|
| 189 |
-
" 'lora_dir': os.path.join(OUTPUT_DIR, 'lora'),\n",
|
| 190 |
-
" 'logging_dir': os.path.join(OUTPUT_DIR, 'logs')\n",
|
| 191 |
-
" },\n",
|
| 192 |
-
" 'quantization': {'enabled': False},\n",
|
| 193 |
-
" 'hardware': {'device': 'cuda', 'num_gpus': 1, 'use_4bit': False, 'use_8bit': False}\n",
|
| 194 |
-
"}\n",
|
| 195 |
-
"\n",
|
| 196 |
-
"config_path = os.path.join(OUTPUT_DIR, \"train_config.yaml\")\n",
|
| 197 |
-
"with open(config_path, 'w') as f:\n",
|
| 198 |
-
" yaml.dump(config, f, default_flow_style=False)\n",
|
| 199 |
-
"\n",
|
| 200 |
-
"print(f\"✅ Config saved: {config_path}\")\n",
|
| 201 |
-
"print(f\" Model: {config['model']['name']}\")\n",
|
| 202 |
-
"print(f\" Data: {config['data']['input_path']}\")"
|
| 203 |
-
]
|
| 204 |
-
},
|
| 205 |
-
{
|
| 206 |
-
"cell_type": "code",
|
| 207 |
-
"execution_count": null,
|
| 208 |
-
"metadata": {},
|
| 209 |
-
"outputs": [],
|
| 210 |
-
"source": [
|
| 211 |
-
"# STEP 6: Train LoRA\n",
|
| 212 |
-
"import sys\n",
|
| 213 |
-
"sys.path.insert(0, os.path.join(REPO_DIR, \"stack_2_9_training\"))\n",
|
| 214 |
-
"\n",
|
| 215 |
-
"print(\"=\"*60)\n",
|
| 216 |
-
"print(\"STARTING TRAINING\")\n",
|
| 217 |
-
"print(\"=\"*60)\n",
|
| 218 |
-
"\n",
|
| 219 |
-
"from stack_2_9_training.train_lora import train_lora\n",
|
| 220 |
-
"\n",
|
| 221 |
-
"try:\n",
|
| 222 |
-
" trainer = train_lora(config_path)\n",
|
| 223 |
-
" print(\"\\n\" + \"=\"*60)\n",
|
| 224 |
-
" print(\"TRAINING COMPLETED\")\n",
|
| 225 |
-
" print(\"=\"*60)\n",
|
| 226 |
-
"except Exception as e:\n",
|
| 227 |
-
" print(f\"\\n❌ Training failed: {e}\")\n",
|
| 228 |
-
" import traceback\n",
|
| 229 |
-
" traceback.print_exc()\n",
|
| 230 |
-
" raise"
|
| 231 |
-
]
|
| 232 |
-
},
|
| 233 |
-
{
|
| 234 |
-
"cell_type": "code",
|
| 235 |
-
"execution_count": null,
|
| 236 |
-
"metadata": {},
|
| 237 |
-
"outputs": [],
|
| 238 |
-
"source": [
|
| 239 |
-
"# STEP 7: Merge LoRA adapter\n",
|
| 240 |
-
"import sys\n",
|
| 241 |
-
"sys.path.insert(0, os.path.join(REPO_DIR, \"stack_2_9_training\"))\n",
|
| 242 |
-
"from stack_2_9_training.merge_adapter import merge_adapter\n",
|
| 243 |
-
"\n",
|
| 244 |
-
"lora_dir = config['output']['lora_dir']\n",
|
| 245 |
-
"merged_dir = os.path.join(OUTPUT_DIR, 'merged')\n",
|
| 246 |
-
"os.makedirs(merged_dir, exist_ok=True)\n",
|
| 247 |
-
"\n",
|
| 248 |
-
"print(\"=\"*60)\n",
|
| 249 |
-
"print(\"MERGING\")\n",
|
| 250 |
-
"print(\"=\"*60)\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"try:\n",
|
| 253 |
-
" merge_adapter(\n",
|
| 254 |
-
" base_model_name_or_path=config['model']['name'],\n",
|
| 255 |
-
" adapter_path=lora_dir,\n",
|
| 256 |
-
" output_path=merged_dir,\n",
|
| 257 |
-
" use_safetensors=True\n",
|
| 258 |
-
" )\n",
|
| 259 |
-
" print(\"\\n✅ Merge completed!\")\n",
|
| 260 |
-
" print(f\"Files: {os.listdir(merged_dir)}\")\n",
|
| 261 |
-
"except Exception as e:\n",
|
| 262 |
-
" print(f\"\\n❌ Merge failed: {e}\")\n",
|
| 263 |
-
" import traceback\n",
|
| 264 |
-
" traceback.print_exc()\n",
|
| 265 |
-
" raise\n",
|
| 266 |
-
"\n",
|
| 267 |
-
"print(\"\\n\" + \"=\"*60)\n",
|
| 268 |
-
"print(\"🎉 ALL DONE!\")\n",
|
| 269 |
-
"print(\"=\"*60)\n",
|
| 270 |
-
"print(f\"\\n📦 Model ready: {merged_dir}\")\n",
|
| 271 |
-
"print(\"\\n⏳ Download 'merged' folder from Kaggle Output panel before session ends!\")"
|
| 272 |
-
]
|
| 273 |
-
}
|
| 274 |
-
],
|
| 275 |
-
"metadata": {
|
| 276 |
-
"kaggle": {
|
| 277 |
-
"accelerator": "gpu",
|
| 278 |
-
"dataSources": [],
|
| 279 |
-
"kernelSpec": {
|
| 280 |
-
"displayName": "Python 3",
|
| 281 |
-
"language": "python",
|
| 282 |
-
"name": "python3"
|
| 283 |
-
}
|
| 284 |
-
}
|
| 285 |
-
},
|
| 286 |
-
"nbformat": 4,
|
| 287 |
-
"nbformat_minor": 0
|
| 288 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kaggle_train_stack29_final.ipynb
DELETED
|
@@ -1,198 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"metadata": {},
|
| 6 |
-
"source": [
|
| 7 |
-
"# 🚀 Stack 2.9 - Kaggle Training\n",
|
| 8 |
-
"\n",
|
| 9 |
-
"Free GPU training on Kaggle using Qwen2.5-Coder-7B.\n",
|
| 10 |
-
"\n",
|
| 11 |
-
"⏱️ **Runtime:** 2-4 hours | 💾 **VRAM:** ~16GB\n",
|
| 12 |
-
"\n",
|
| 13 |
-
"**Setup:**\n",
|
| 14 |
-
"1. Settings → Accelerator → GPU **T4**\n",
|
| 15 |
-
"2. Run all cells in order\n",
|
| 16 |
-
"3. Download merged model from Output tab when done"
|
| 17 |
-
]
|
| 18 |
-
},
|
| 19 |
-
{
|
| 20 |
-
"cell_type": "code",
|
| 21 |
-
"execution_count": null,
|
| 22 |
-
"metadata": {},
|
| 23 |
-
"outputs": [],
|
| 24 |
-
"source": [
|
| 25 |
-
"# Check GPU\n",
|
| 26 |
-
"!nvidia-smi"
|
| 27 |
-
]
|
| 28 |
-
},
|
| 29 |
-
{
|
| 30 |
-
"cell_type": "code",
|
| 31 |
-
"execution_count": null,
|
| 32 |
-
"metadata": {},
|
| 33 |
-
"outputs": [],
|
| 34 |
-
"source": [
|
| 35 |
-
"# Clone repository\n",
|
| 36 |
-
"import os, shutil, subprocess\n",
|
| 37 |
-
"\n",
|
| 38 |
-
"os.chdir('/kaggle/working')\n",
|
| 39 |
-
"REPO_DIR = '/kaggle/working/stack-2.9'\n",
|
| 40 |
-
"OUTPUT_DIR = os.path.join(REPO_DIR, 'training_output')\n",
|
| 41 |
-
"\n",
|
| 42 |
-
"if os.path.exists(REPO_DIR):\n",
|
| 43 |
-
" shutil.rmtree(REPO_DIR)\n",
|
| 44 |
-
"subprocess.run(['git', 'clone', 'https://github.com/my-ai-stack/stack-2.9.git', REPO_DIR], check=True)\n",
|
| 45 |
-
"os.chdir(REPO_DIR)\n",
|
| 46 |
-
"print('✅ Repo ready:', REPO_DIR)"
|
| 47 |
-
]
|
| 48 |
-
},
|
| 49 |
-
{
|
| 50 |
-
"cell_type": "code",
|
| 51 |
-
"execution_count": null,
|
| 52 |
-
"metadata": {},
|
| 53 |
-
"outputs": [],
|
| 54 |
-
"source": [
|
| 55 |
-
"# Install dependencies (single command)\n",
|
| 56 |
-
"!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
|
| 57 |
-
"!pip install -q transformers==4.40.0 peft==0.10.0 accelerate==0.34.0 datasets==3.0.0 pyyaml tqdm scipy bitsandbytes==0.43.0\n",
|
| 58 |
-
"print('✅ Dependencies ready')"
|
| 59 |
-
]
|
| 60 |
-
},
|
| 61 |
-
{
|
| 62 |
-
"cell_type": "code",
|
| 63 |
-
"execution_count": null,
|
| 64 |
-
"metadata": {},
|
| 65 |
-
"outputs": [],
|
| 66 |
-
"source": [
|
| 67 |
-
"# Prepare training data (auto-detect or synthetic fallback)\n",
|
| 68 |
-
"import os, json\n",
|
| 69 |
-
"\n",
|
| 70 |
-
"REPO_TRAIN_DATA = os.path.join(REPO_DIR, 'training-data/final/train.jsonl')\n",
|
| 71 |
-
"MINI_DATA_DIR = os.path.join(REPO_DIR, 'data_mini')\n",
|
| 72 |
-
"MINI_DATA_FILE = os.path.join(MINI_DATA_DIR, 'train_mini.jsonl')\n",
|
| 73 |
-
"SYNTHETIC_FILE = os.path.join(REPO_DIR, 'data/synthetic.jsonl')\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"print('🔍 Data check')\n",
|
| 76 |
-
"\n",
|
| 77 |
-
"if os.path.exists(REPO_TRAIN_DATA):\n",
|
| 78 |
-
" os.makedirs(MINI_DATA_DIR, exist_ok=True)\n",
|
| 79 |
-
" if not os.path.exists(MINI_DATA_FILE):\n",
|
| 80 |
-
" print(' Building mini dataset (1K samples) from full data...')\n",
|
| 81 |
-
" !python scripts/create_mini_dataset.py --size 1000 --output {MINI_DATA_FILE} --source {REPO_TRAIN_DATA}\n",
|
| 82 |
-
" DATA_FILE = MINI_DATA_FILE\n",
|
| 83 |
-
" print(' Using mini dataset')\n",
|
| 84 |
-
"elif os.path.exists(MINI_DATA_FILE):\n",
|
| 85 |
-
" DATA_FILE = MINI_DATA_FILE\n",
|
| 86 |
-
" print(' Using existing mini dataset')\n",
|
| 87 |
-
"else:\n",
|
| 88 |
-
" print(' Creating synthetic data (last resort)')\n",
|
| 89 |
-
" examples = [\n",
|
| 90 |
-
" {'instruction': 'Write a Python function to reverse a string', 'output': 'def reverse_string(s):\\n return s[::-1]'},\n",
|
| 91 |
-
" {'instruction': 'Write a function to check if a number is prime', 'output': 'def is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(n**0.5) + 1):\\n if n % i == 0:\\n return False\\n return True'},\n",
|
| 92 |
-
" {'instruction': 'Write a binary search function', 'output': 'def binary_search(arr, target):\\n left, right = 0, len(arr) - 1\\n while left <= right:\\n mid = (left + right) // 2\\n if arr[mid] == target:\\n return mid\\n elif arr[mid] < target:\\n left = mid + 1\\n else:\\n right = mid - 1\\n return -1'},\n",
|
| 93 |
-
" ]\n",
|
| 94 |
-
" samples = examples * 333\n",
|
| 95 |
-
" os.makedirs(os.path.dirname(SYNTHETIC_FILE), exist_ok=True)\n",
|
| 96 |
-
" with open(SYNTHETIC_FILE, 'w') as f:\n",
|
| 97 |
-
" for s in samples:\n",
|
| 98 |
-
" f.write(json.dumps(s) + '\\n')\n",
|
| 99 |
-
" DATA_FILE = SYNTHETIC_FILE\n",
|
| 100 |
-
" print(f' Synthetic dataset: {len(samples)} examples')\n",
|
| 101 |
-
"\n",
|
| 102 |
-
"print(f'\\n✅ Data: {DATA_FILE}')\n",
|
| 103 |
-
"!ls -lh {DATA_FILE}"
|
| 104 |
-
]
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"cell_type": "code",
|
| 108 |
-
"execution_count": null,
|
| 109 |
-
"metadata": {},
|
| 110 |
-
"outputs": [],
|
| 111 |
-
"source": [
|
| 112 |
-
"# Generate training configuration\n",
|
| 113 |
-
"import yaml\n",
|
| 114 |
-
"\n",
|
| 115 |
-
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 116 |
-
"\n",
|
| 117 |
-
"config = {\n",
|
| 118 |
-
" 'model': {'name': 'Qwen/Qwen2.5-Coder-7B', 'trust_remote_code': True, 'torch_dtype': 'float16'},\n",
|
| 119 |
-
" 'data': {'input_path': DATA_FILE, 'max_length': 2048, 'train_split': 1.0},\n",
|
| 120 |
-
" 'lora': {'r': 16, 'alpha': 32, 'dropout': 0.05, 'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 'bias': 'none', 'task_type': 'CAUSAL_LM'},\n",
|
| 121 |
-
" 'training': {'num_epochs': 1, 'batch_size': 2, 'gradient_accumulation': 4, 'learning_rate': 2e-4, 'warmup_steps': 50, 'weight_decay': 0.01, 'max_grad_norm': 1.0, 'logging_steps': 10, 'save_steps': 100, 'save_total_limit': 2, 'fp16': True, 'bf16': False, 'gradient_checkpointing': True},\n",
|
| 122 |
-
" 'output': {'lora_dir': os.path.join(OUTPUT_DIR, 'lora'), 'logging_dir': os.path.join(OUTPUT_DIR, 'logs')},\n",
|
| 123 |
-
" 'quantization': {'enabled': False},\n",
|
| 124 |
-
" 'hardware': {'device': 'cuda', 'num_gpus': 1, 'use_4bit': False, 'use_8bit': False}\n",
|
| 125 |
-
"}\n",
|
| 126 |
-
"\n",
|
| 127 |
-
"config_path = os.path.join(OUTPUT_DIR, 'train_config.yaml')\n",
|
| 128 |
-
"with open(config_path, 'w') as f:\n",
|
| 129 |
-
" yaml.dump(config, f, default_flow_style=False)\n",
|
| 130 |
-
"\n",
|
| 131 |
-
"print(f'✅ Config: {config_path}')\n",
|
| 132 |
-
"print(f\" Model: {config['model']['name']}\")\n",
|
| 133 |
-
"print(f\" Data: {config['data']['input_path']}\")"
|
| 134 |
-
]
|
| 135 |
-
},
|
| 136 |
-
{
|
| 137 |
-
"cell_type": "code",
|
| 138 |
-
"execution_count": null,
|
| 139 |
-
"metadata": {},
|
| 140 |
-
"outputs": [],
|
| 141 |
-
"source": [
|
| 142 |
-
"# Train (using standalone train_simple.py - no package install needed)\n",
|
| 143 |
-
"print('='*60)\n",
|
| 144 |
-
"print('STARTING TRAINING')\n",
|
| 145 |
-
"print('='*60)\n",
|
| 146 |
-
"\n",
|
| 147 |
-
"!cd {REPO_DIR} && python train_simple.py --config {config_path}\n",
|
| 148 |
-
"\n",
|
| 149 |
-
"print('\\n✅ Training step finished')"
|
| 150 |
-
]
|
| 151 |
-
},
|
| 152 |
-
{
|
| 153 |
-
"cell_type": "code",
|
| 154 |
-
"execution_count": null,
|
| 155 |
-
"metadata": {},
|
| 156 |
-
"outputs": [],
|
| 157 |
-
"source": [
|
| 158 |
-
"# Merge LoRA adapter into final model\n",
|
| 159 |
-
"lora_dir = os.path.join(OUTPUT_DIR, 'lora')\n",
|
| 160 |
-
"merged_dir = os.path.join(OUTPUT_DIR, 'merged')\n",
|
| 161 |
-
"\n",
|
| 162 |
-
"print('='*60)\n",
|
| 163 |
-
"print('MERGING LORA ADAPTER')\n",
|
| 164 |
-
"print('='*60)\n",
|
| 165 |
-
"\n",
|
| 166 |
-
"!cd {REPO_DIR} && python merge_simple.py \\\n",
|
| 167 |
-
" --base-model {config['model']['name']} \\\n",
|
| 168 |
-
" --adapter-path {lora_dir} \\\n",
|
| 169 |
-
" --output-path {merged_dir} \\\n",
|
| 170 |
-
" --use-safetensors\n",
|
| 171 |
-
"\n",
|
| 172 |
-
"print('\\n✅ Merge complete!')\n",
|
| 173 |
-
"print(f'Merged model: {merged_dir}')\n",
|
| 174 |
-
"!ls -lh {merged_dir}"
|
| 175 |
-
]
|
| 176 |
-
},
|
| 177 |
-
{
|
| 178 |
-
"cell_type": "markdown",
|
| 179 |
-
"metadata": {},
|
| 180 |
-
"source": [
|
| 181 |
-
"## 📥 Download Model\n",
|
| 182 |
-
"\n",
|
| 183 |
-
"1. Open **Output** tab on the right\n",
|
| 184 |
-
"2. Find `training_output/merged/`\n",
|
| 185 |
-
"3. Select all files and **Download**\n",
|
| 186 |
-
"\n",
|
| 187 |
-
"⚠️ **Do this before Kaggle session ends!**"
|
| 188 |
-
]
|
| 189 |
-
}
|
| 190 |
-
],
|
| 191 |
-
"metadata": {
|
| 192 |
-
"kaggle": {
|
| 193 |
-
"accelerator": "gpu"
|
| 194 |
-
}
|
| 195 |
-
},
|
| 196 |
-
"nbformat": 4,
|
| 197 |
-
"nbformat_minor": 0
|
| 198 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
merge_simple.py
CHANGED
|
@@ -34,7 +34,7 @@ def main():
|
|
| 34 |
print("Loading base model...")
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
args.base_model,
|
| 37 |
-
torch_dtype=torch.
|
| 38 |
device_map="auto",
|
| 39 |
trust_remote_code=True
|
| 40 |
)
|
|
|
|
| 34 |
print("Loading base model...")
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
args.base_model,
|
| 37 |
+
torch_dtype=torch.bfloat16,
|
| 38 |
device_map="auto",
|
| 39 |
trust_remote_code=True
|
| 40 |
)
|
scripts/generate_tool_data.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Synthetic Tool-Calling Training Data Generator for Stack 2.9
|
| 4 |
+
Generates training examples in Qwen2.5-Coder format with tool_calls.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Any
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
# ============================================================================
|
| 15 |
+
# Tool Definitions (Qwen2.5-Coder format)
|
| 16 |
+
# ============================================================================
|
| 17 |
+
|
| 18 |
+
TOOL_DEFINITIONS = [
|
| 19 |
+
{
|
| 20 |
+
"type": "function",
|
| 21 |
+
"function": {
|
| 22 |
+
"name": "Bash",
|
| 23 |
+
"description": "Execute bash commands in the terminal. Use for running shell commands, scripts, git operations, package managers, and system commands.",
|
| 24 |
+
"parameters": {
|
| 25 |
+
"type": "object",
|
| 26 |
+
"properties": {
|
| 27 |
+
"command": {
|
| 28 |
+
"type": "string",
|
| 29 |
+
"description": "The bash command to execute"
|
| 30 |
+
},
|
| 31 |
+
"timeout": {
|
| 32 |
+
"type": "integer",
|
| 33 |
+
"description": "Timeout in seconds (default: 30)"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"required": ["command"]
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"type": "function",
|
| 42 |
+
"function": {
|
| 43 |
+
"name": "FileRead",
|
| 44 |
+
"description": "Read the contents of a file from the filesystem. Use for viewing source code, configuration files, documentation, or any text-based files.",
|
| 45 |
+
"parameters": {
|
| 46 |
+
"type": "object",
|
| 47 |
+
"properties": {
|
| 48 |
+
"path": {
|
| 49 |
+
"type": "string",
|
| 50 |
+
"description": "Path to the file to read"
|
| 51 |
+
},
|
| 52 |
+
"offset": {
|
| 53 |
+
"type": "integer",
|
| 54 |
+
"description": "Line number to start reading from (1-indexed)"
|
| 55 |
+
},
|
| 56 |
+
"limit": {
|
| 57 |
+
"type": "integer",
|
| 58 |
+
"description": "Maximum number of lines to read"
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
"required": ["path"]
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"type": "function",
|
| 67 |
+
"function": {
|
| 68 |
+
"name": "FileWrite",
|
| 69 |
+
"description": "Create or overwrite a file with content. Use for creating new files, updating existing files, or writing code, configuration, or documentation.",
|
| 70 |
+
"parameters": {
|
| 71 |
+
"type": "object",
|
| 72 |
+
"properties": {
|
| 73 |
+
"path": {
|
| 74 |
+
"type": "string",
|
| 75 |
+
"description": "Path where the file should be created or written"
|
| 76 |
+
},
|
| 77 |
+
"content": {
|
| 78 |
+
"type": "string",
|
| 79 |
+
"description": "The content to write to the file"
|
| 80 |
+
},
|
| 81 |
+
"append": {
|
| 82 |
+
"type": "boolean",
|
| 83 |
+
"description": "Append to existing file instead of overwriting (default: false)"
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
"required": ["path", "content"]
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"type": "function",
|
| 92 |
+
"function": {
|
| 93 |
+
"name": "WebSearch",
|
| 94 |
+
"description": "Search the web for information. Use for finding documentation, looking up error messages, researching libraries, or getting up-to-date information.",
|
| 95 |
+
"parameters": {
|
| 96 |
+
"type": "object",
|
| 97 |
+
"properties": {
|
| 98 |
+
"query": {
|
| 99 |
+
"type": "string",
|
| 100 |
+
"description": "The search query to look up on the web"
|
| 101 |
+
},
|
| 102 |
+
"count": {
|
| 103 |
+
"type": "integer",
|
| 104 |
+
"description": "Number of results to return (default: 5)"
|
| 105 |
+
}
|
| 106 |
+
},
|
| 107 |
+
"required": ["query"]
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"type": "function",
|
| 113 |
+
"function": {
|
| 114 |
+
"name": "Grep",
|
| 115 |
+
"description": "Search for patterns in files. Use for finding specific code, function definitions, imports, TODO comments, error patterns, or any text across the codebase.",
|
| 116 |
+
"parameters": {
|
| 117 |
+
"type": "object",
|
| 118 |
+
"properties": {
|
| 119 |
+
"pattern": {
|
| 120 |
+
"type": "string",
|
| 121 |
+
"description": "The search pattern or regex to match"
|
| 122 |
+
},
|
| 123 |
+
"path": {
|
| 124 |
+
"type": "string",
|
| 125 |
+
"description": "Directory or file path to search in (default: current directory)"
|
| 126 |
+
},
|
| 127 |
+
"recursive": {
|
| 128 |
+
"type": "boolean",
|
| 129 |
+
"description": "Search recursively in subdirectories (default: true)"
|
| 130 |
+
},
|
| 131 |
+
"file_pattern": {
|
| 132 |
+
"type": "string",
|
| 133 |
+
"description": "File pattern to filter results (e.g., '*.py', '*.js')"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"required": ["pattern"]
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
# ============================================================================
|
| 143 |
+
# Template Data for Generation
|
| 144 |
+
# ============================================================================
|
| 145 |
+
|
| 146 |
+
FILE_PATHS = [
|
| 147 |
+
"src/main.py", "src/utils.py", "src/config.py", "src/models.py",
|
| 148 |
+
"src/api.py", "src/handlers.py", "src/middleware.py",
|
| 149 |
+
"tests/test_main.py", "tests/test_utils.py", "tests/conftest.py",
|
| 150 |
+
"README.md", "LICENSE", "package.json", "requirements.txt",
|
| 151 |
+
"config.yaml", "config.json", ".env.example",
|
| 152 |
+
"src/components/Button.tsx", "src/components/Header.jsx",
|
| 153 |
+
"src/styles.css", "src/index.js", "src/app.js",
|
| 154 |
+
"docs/API.md", "docs/ARCHITECTURE.md", "docs/CONTRIBUTING.md",
|
| 155 |
+
"scripts/setup.sh", "scripts/deploy.py", "Makefile"
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
CODE_SNIPPETS = {
|
| 159 |
+
"python": [
|
| 160 |
+
"def hello():\n print('Hello, World!')",
|
| 161 |
+
"class MyClass:\n def __init__(self):\n self.value = 42",
|
| 162 |
+
"import os\nos.path.join('a', 'b')",
|
| 163 |
+
"async def fetch_data():\n async with aiohttp.ClientSession() as session:\n return await session.get(url)",
|
| 164 |
+
],
|
| 165 |
+
"javascript": [
|
| 166 |
+
"const fetch = require('node-fetch');\nconst data = await fetch(url);",
|
| 167 |
+
"function handleClick() {\n setCount(count + 1);\n}",
|
| 168 |
+
"export default function App() {\n return <div>Hello</div>;\n}",
|
| 169 |
+
"const [state, setState] = useState(null);",
|
| 170 |
+
],
|
| 171 |
+
"bash": [
|
| 172 |
+
"npm install",
|
| 173 |
+
"git status",
|
| 174 |
+
"pytest -v",
|
| 175 |
+
"python -m pytest tests/",
|
| 176 |
+
"make build",
|
| 177 |
+
"docker build -t myapp .",
|
| 178 |
+
"ls -la",
|
| 179 |
+
"curl -X GET https://api.example.com",
|
| 180 |
+
]
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
WEB_SEARCH_QUERIES = [
|
| 184 |
+
"python async await best practices",
|
| 185 |
+
"javascript array methods map filter reduce",
|
| 186 |
+
"TypeScript generics tutorial",
|
| 187 |
+
"React hooks useEffect dependency array",
|
| 188 |
+
"Node.js error handling best practices",
|
| 189 |
+
"Docker vs Kubernetes differences",
|
| 190 |
+
"Git rebase vs merge",
|
| 191 |
+
"SQL join types explained",
|
| 192 |
+
"REST API design principles",
|
| 193 |
+
"Python list comprehension examples",
|
| 194 |
+
"JavaScript promise async await",
|
| 195 |
+
"CSS flexbox vs grid",
|
| 196 |
+
"MongoDB vs PostgreSQL",
|
| 197 |
+
"AWS Lambda cold start optimization",
|
| 198 |
+
"Docker compose tutorial",
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
GREP_PATTERNS = [
|
| 202 |
+
("TODO", "*.py", "Find TODO comments in Python files"),
|
| 203 |
+
("function ", "*.js", "Find function definitions in JS"),
|
| 204 |
+
("import ", "*.py", "Find imports in Python"),
|
| 205 |
+
("class ", "*.ts", "Find class definitions in TypeScript"),
|
| 206 |
+
("def ", "*.py", "Find function definitions in Python"),
|
| 207 |
+
("const ", "*.js", "Find const declarations in JS"),
|
| 208 |
+
("interface ", "*.ts", "Find interface definitions"),
|
| 209 |
+
("#pragma", "*.c", "Find pragma directives in C"),
|
| 210 |
+
("SELECT ", "*.sql", "Find SQL queries"),
|
| 211 |
+
("TODO", "*.md", "Find TODOs in documentation"),
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
# ============================================================================
|
| 215 |
+
# Scenario Templates
|
| 216 |
+
# ============================================================================
|
| 217 |
+
|
| 218 |
+
SCENARIOS = {
|
| 219 |
+
"FileRead": [
|
| 220 |
+
{
|
| 221 |
+
"user": "Read the file at src/main.py and show me what's in it",
|
| 222 |
+
"tool_call": {"name": "FileRead", "arguments": {"path": "src/main.py"}},
|
| 223 |
+
"result": "Successfully read file: src/main.py\n```python\ndef main():\n print('Hello from Stack 2.9!')\n\nif __name__ == '__main__':\n main()\n```",
|
| 224 |
+
"response": "Here's the contents of `src/main.py`:\n\n```python\ndef main():\n print('Hello from Stack 2.9!')\n\nif __name__ == '__main__':\n main()\n```\n\nThis is a simple Python script that prints a greeting when executed directly."
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"user": "Can you show me the README.md file?",
|
| 228 |
+
"tool_call": {"name": "FileRead", "arguments": {"path": "README.md"}},
|
| 229 |
+
"result": "Successfully read file: README.md\n```markdown\n# My Project\n\nA sample project for Stack 2.9.\n\n## Installation\n\npip install -r requirements.txt\n```",
|
| 230 |
+
"response": "Here's the README.md:\n\n```markdown\n# My Project\n\nA sample project for Stack 2.9.\n\n## Installation\n\npip install -r requirements.txt\n```"
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"user": "What's in the config.yaml file?",
|
| 234 |
+
"tool_call": {"name": "FileRead", "arguments": {"path": "config.yaml"}},
|
| 235 |
+
"result": "Successfully read file: config.yaml\n```yaml\ndatabase:\n host: localhost\n port: 5432\n name: myapp\n\napp:\n port: 3000\n debug: true\n```",
|
| 236 |
+
"response": "The config.yaml contains:\n\n```yaml\ndatabase:\n host: localhost\n port: 5432\n name: myapp\n\napp:\n port: 3000\n debug: true\n```\n\nIt has database connection settings and app configuration."
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"user": "Check the package.json to see the dependencies",
|
| 240 |
+
"tool_call": {"name": "FileRead", "arguments": {"path": "package.json"}},
|
| 241 |
+
"result": "Successfully read file: package.json\n```json\n{\n \"name\": \"my-app\",\n \"version\": \"1.0.0\",\n \"dependencies\": {\n \"express\": \"^4.18.0\",\n \"react\": \"^18.2.0\"\n }\n}\n```",
|
| 242 |
+
"response": "The package.json shows these dependencies:\n\n```json\n{\n \"name\": \"my-app\",\n \"version\": \"1.0.0\",\n \"dependencies\": {\n \"express\": \"^4.18.0\",\n \"react\": \"^18.2.0\"\n }\n}\n```"
|
| 243 |
+
},
|
| 244 |
+
],
|
| 245 |
+
"FileWrite": [
|
| 246 |
+
{
|
| 247 |
+
"user": "Create a new file called hello.py with a hello world function",
|
| 248 |
+
"tool_call": {"name": "FileWrite", "arguments": {"path": "hello.py", "content": "def hello():\n print('Hello, World!')\n\nif __name__ == '__main__':\n hello()\n"}},
|
| 249 |
+
"result": "Successfully wrote file: hello.py",
|
| 250 |
+
"response": "Created `hello.py` with a hello world function:\n\n```python\ndef hello():\n print('Hello, World!')\n\nif __name__ == '__main__':\n hello()\n```\n\nYou can run it with `python hello.py`."
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"user": "Write a simple React component to src/components/Button.jsx",
|
| 254 |
+
"tool_call": {"name": "FileWrite", "arguments": {"path": "src/components/Button.jsx", "content": "import React from 'react';\n\nexport default function Button({ onClick, children }) {\n return (\n <button onClick={onClick} className=\"btn\">\n {children}\n </button>\n );\n}\n"}},
|
| 255 |
+
"result": "Successfully wrote file: src/components/Button.jsx",
|
| 256 |
+
"response": "Created `src/components/Button.jsx`:\n\n```jsx\nimport React from 'react';\n\nexport default function Button({ onClick, children }) {\n return (\n <button onClick={onClick} className=\"btn\">\n {children}\n </button>\n );\n}\n```"
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"user": "Create a config.json file with some settings",
|
| 260 |
+
"tool_call": {"name": "FileWrite", "arguments": {"path": "config.json", "content": "{\n \"app\": {\n \"name\": \"MyApp\",\n \"version\": \"1.0.0\"\n },\n \"server\": {\n \"port\": 8080,\n \"host\": \"localhost\"\n }\n}\n"}},
|
| 261 |
+
"result": "Successfully wrote file: config.json",
|
| 262 |
+
"response": "Created `config.json`:\n\n```json\n{\n \"app\": {\n \"name\": \"MyApp\",\n \"version\": \"1.0.0\"\n },\n \"server\": {\n \"port\": 8080,\n \"host\": \"localhost\"\n }\n}\n```"
|
| 263 |
+
},
|
| 264 |
+
],
|
| 265 |
+
"Bash": [
|
| 266 |
+
{
|
| 267 |
+
"user": "Run npm install to install dependencies",
|
| 268 |
+
"tool_call": {"name": "Bash", "arguments": {"command": "npm install"}},
|
| 269 |
+
"result": "added 245 packages in 3.5s",
|
| 270 |
+
"response": "Dependencies installed successfully! Added 245 packages."
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"user": "Check git status",
|
| 274 |
+
"tool_call": {"name": "Bash", "arguments": {"command": "git status"}},
|
| 275 |
+
"result": "On branch main\nYour branch is up to date with 'origin/main'.\n\nChanges not staged for commit:\n modified: README.md\n modified: src/main.py\n\nno changes added to commit (but untracked files present)",
|
| 276 |
+
"response": "Git status shows:\n- Modified: README.md, src/main.py (not staged)\n- Untracked files present"
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"user": "Run the tests with pytest",
|
| 280 |
+
"tool_call": {"name": "Bash", "arguments": {"command": "pytest -v"}},
|
| 281 |
+
"result": "tests/test_main.py::test_hello PASSED\ntests/test_utils.py::test_add PASSED\n\n2 passed in 0.5s",
|
| 282 |
+
"response": "All tests passed!\n- `tests/test_main.py::test_hello` ✓\n- `tests/test_utils.py::test_add` ✓"
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"user": "List all files in the current directory",
|
| 286 |
+
"tool_call": {"name": "Bash", "arguments": {"command": "ls -la"}},
|
| 287 |
+
"result": "total 64\ndrwxr-xr-x 12 user staff 384 Apr 6 00:58 .\ndrwxr-xr-x 9 user staff 288 Apr 6 00:50 ..\n-rw-r--r-- 1 user staff 220 Apr 6 00:50 README.md\ndrwxr-xr-x 3 user staff 96 Apr 6 00:51 src\ndrwxr-xr-x 3 user staff 96 Apr 6 00:51 tests\n...",
|
| 288 |
+
"response": "Files in current directory:\n- README.md\n- src/\n- tests/\n- And more (permissions, dates shown)"
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"user": "Build the project with make",
|
| 292 |
+
"tool_call": {"name": "Bash", "arguments": {"command": "make build"}},
|
| 293 |
+
"result": "Building...\ngcc -o main main.c\nobuild successful!",
|
| 294 |
+
"response": "Build completed successfully!"
|
| 295 |
+
},
|
| 296 |
+
],
|
| 297 |
+
"WebSearch": [
|
| 298 |
+
{
|
| 299 |
+
"user": "How do I use async/await in Python?",
|
| 300 |
+
"tool_call": {"name": "WebSearch", "arguments": {"query": "python async await best practices", "count": 5}},
|
| 301 |
+
"result": "Search results for 'python async await best practices':\n1. Real Python - Async IO in Python: A Complete Walkthrough\n2. Python Docs - asyncio - Asynchronous I/O\n3. Towards Data Science - Mastering Python Async/Await\n4. Stack Overflow - Common async/await mistakes\n5. Medium - Async/Await Patterns in Python",
|
| 302 |
+
"response": "Here are some resources on Python async/await:\n\n1. **Real Python** - Async IO in Python: A Complete Walkthrough\n2. **Python Docs** - asyncio - Asynchronous I/O\n3. **Towards Data Science** - Mastering Python Async/Await\n\nKey points:\n- Use `async def` to define async functions\n- Use `await` to call other async functions\n- `asyncio.run()` to execute the main coroutine"
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"user": "What's the difference between Docker and Kubernetes?",
|
| 306 |
+
"tool_call": {"name": "WebSearch", "arguments": {"query": "Docker vs Kubernetes differences", "count": 5}},
|
| 307 |
+
"result": "Search results for 'Docker vs Kubernetes differences':\n1. Docker Blog - Containers 101\n2. Kubernetes.io - What is Kubernetes\n3. Red Hat - Docker vs Kubernetes explained\n4. DevOps.com - Container orchestration comparison\n5. Container Journal - When to use each",
|
| 308 |
+
"response": "**Docker** is a container platform - it packages applications into containers.\n\n**Kubernetes** is a container orchestrator - it manages multiple containers across machines.\n\nThink of Docker as individual shipping containers, and Kubernetes as the port management system that coordinates all of them."
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"user": "How do I use React hooks properly?",
|
| 312 |
+
"tool_call": {"name": "WebSearch", "arguments": {"query": "React hooks useEffect dependency array", "count": 5}},
|
| 313 |
+
"result": "Search results for 'React hooks useEffect dependency array':\n1. React Docs - Synchronizing with Effects\n2. Kent C. Dodds - A Complete Guide to useEffect\n3. Robin Wieruch - React useEffect Hook Tutorial\n4. Stack Overflow - Common useEffect mistakes\n5. LogRocket - useEffect dependency array deep dive",
|
| 314 |
+
"response": "**React Hooks Guide:**\n\n1. **useState** - For state management\n2. **useEffect** - For side effects, takes dependency array\n3. **useContext** - For global state\n\nThe dependency array controls when the effect runs. Empty `[]` = once on mount. Include variables to re-run when they change."
|
| 315 |
+
},
|
| 316 |
+
],
|
| 317 |
+
"Grep": [
|
| 318 |
+
{
|
| 319 |
+
"user": "Find all TODO comments in the Python files",
|
| 320 |
+
"tool_call": {"name": "Grep", "arguments": {"pattern": "TODO", "path": ".", "recursive": True, "file_pattern": "*.py"}},
|
| 321 |
+
"result": "Found 3 matches:\nsrc/main.py:14: # TODO: Add error handling\nsrc/utils.py:22: # TODO: Implement caching\ntests/test_main.py:8: # TODO: Add more tests",
|
| 322 |
+
"response": "Found 3 TODO comments in Python files:\n- `src/main.py:14` - Add error handling\n- `src/utils.py:22` - Implement caching\n- `tests/test_main.py:8` - Add more tests"
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
"user": "Search for function definitions in the JavaScript files",
|
| 326 |
+
"tool_call": {"name": "Grep", "arguments": {"pattern": "function ", "path": "src", "recursive": True, "file_pattern": "*.js"}},
|
| 327 |
+
"result": "Found 5 matches:\nsrc/index.js:3: function setup() {\nsrc/handlers.js:12: function handleClick(event) {\nsrc/handlers.js:25: function handleSubmit(form) {\nsrc/utils.js:8: function formatDate(date) {\nsrc/utils.js:15: function parseJSON(str) {",
|
| 328 |
+
"response": "Found 5 function definitions in src/:\n- `src/index.js:3` - setup()\n- `src/handlers.js:12` - handleClick(event)\n- `src/handlers.js:15` - handleSubmit(form)\n- `src/utils.js:8` - formatDate(date)\n- `src/utils.js:15` - parseJSON(str)"
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"user": "Find all imports in the Python codebase",
|
| 332 |
+
"tool_call": {"name": "Grep", "arguments": {"pattern": "import ", "path": ".", "recursive": True, "file_pattern": "*.py"}},
|
| 333 |
+
"result": "Found 12 matches:\nsrc/main.py:1: import os\nsrc/main.py:2: import sys\nsrc/main.py:3: import json\nsrc/utils.py:1: from typing import Dict\nsrc/api.py:5: import requests\n...",
|
| 334 |
+
"response": "Found 12 import statements across Python files. Main imports:\n- `os`, `sys`, `json` in src/main.py\n- `typing.Dict` in src/utils.py\n- `requests` in src/api.py"
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"user": "Look for class definitions in TypeScript files",
|
| 338 |
+
"tool_call": {"name": "Grep", "arguments": {"pattern": "class ", "path": "src", "recursive": True, "file_pattern": "*.ts"}},
|
| 339 |
+
"result": "Found 2 matches:\nsrc/models/User.ts:3: class User {\nsrc/models/Config.ts:8: class Config {",
|
| 340 |
+
"response": "Found 2 class definitions:\n- `src/models/User.ts:3` - User class\n- `src/models/Config.ts:8` - Config class"
|
| 341 |
+
},
|
| 342 |
+
],
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ============================================================================
|
| 347 |
+
# Data Augmentation Helpers
|
| 348 |
+
# ============================================================================
|
| 349 |
+
|
| 350 |
+
def vary_file_paths(template: Dict) -> Dict:
|
| 351 |
+
"""Vary file paths in a scenario to create diversity."""
|
| 352 |
+
import copy
|
| 353 |
+
result = copy.deepcopy(template)
|
| 354 |
+
|
| 355 |
+
# Vary the file paths
|
| 356 |
+
if "path" in result["tool_call"]["arguments"]:
|
| 357 |
+
original_path = result["tool_call"]["arguments"]["path"]
|
| 358 |
+
for old_path in FILE_PATHS:
|
| 359 |
+
if old_path in result["user"]:
|
| 360 |
+
result["user"] = result["user"].replace(old_path, random.choice(FILE_PATHS))
|
| 361 |
+
result["tool_call"]["arguments"]["path"] = random.choice(FILE_PATHS)
|
| 362 |
+
break
|
| 363 |
+
|
| 364 |
+
return result
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def vary_content(template: Dict) -> Dict:
|
| 368 |
+
"""Vary content in a scenario."""
|
| 369 |
+
import copy
|
| 370 |
+
result = copy.deepcopy(template)
|
| 371 |
+
|
| 372 |
+
if "content" in result["tool_call"]["arguments"]:
|
| 373 |
+
# Vary code snippets
|
| 374 |
+
lang = random.choice(["python", "javascript"])
|
| 375 |
+
result["tool_call"]["arguments"]["content"] = random.choice(CODE_SNIPPETS[lang])
|
| 376 |
+
|
| 377 |
+
return result
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def vary_bash_command(template: Dict) -> Dict:
|
| 381 |
+
"""Vary bash commands."""
|
| 382 |
+
import copy
|
| 383 |
+
result = copy.deepcopy(template)
|
| 384 |
+
|
| 385 |
+
if "command" in result["tool_call"]["arguments"]:
|
| 386 |
+
original = result["tool_call"]["arguments"]["command"].split()[0] if result["tool_call"]["arguments"]["command"] else ""
|
| 387 |
+
|
| 388 |
+
if "npm" in original:
|
| 389 |
+
commands = ["npm install", "npm run build", "npm test", "npm start"]
|
| 390 |
+
elif "git" in original:
|
| 391 |
+
commands = ["git status", "git log --oneline -5", "git diff", "git branch -a"]
|
| 392 |
+
elif "pytest" in original:
|
| 393 |
+
commands = ["pytest -v", "pytest tests/", "pytest -xvs", "pytest --cov"]
|
| 394 |
+
elif "ls" in original:
|
| 395 |
+
commands = ["ls -la", "ls -1", "ls -lah"]
|
| 396 |
+
elif "make" in original:
|
| 397 |
+
commands = ["make build", "make clean", "make test", "make install"]
|
| 398 |
+
else:
|
| 399 |
+
commands = ["echo 'hello'", "pwd", "whoami", "date"]
|
| 400 |
+
|
| 401 |
+
result["tool_call"]["arguments"]["command"] = random.choice(commands)
|
| 402 |
+
|
| 403 |
+
return result
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def vary_search_query(template: Dict) -> Dict:
|
| 407 |
+
"""Vary web search queries."""
|
| 408 |
+
import copy
|
| 409 |
+
result = copy.deepcopy(template)
|
| 410 |
+
|
| 411 |
+
if "query" in result["tool_call"]["arguments"]:
|
| 412 |
+
result["tool_call"]["arguments"]["query"] = random.choice(WEB_SEARCH_QUERIES)
|
| 413 |
+
|
| 414 |
+
return result
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def vary_grep_pattern(template: Dict) -> Dict:
|
| 418 |
+
"""Vary grep patterns."""
|
| 419 |
+
import copy
|
| 420 |
+
result = copy.deepcopy(template)
|
| 421 |
+
|
| 422 |
+
pattern, file_pattern, _ = random.choice(GREP_PATTERNS)
|
| 423 |
+
result["tool_call"]["arguments"]["pattern"] = pattern
|
| 424 |
+
result["tool_call"]["arguments"]["file_pattern"] = file_pattern
|
| 425 |
+
|
| 426 |
+
return result
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# ============================================================================
|
| 430 |
+
# Main Generation Functions
|
| 431 |
+
# ============================================================================
|
| 432 |
+
|
| 433 |
+
def create_tool_call_message(tool_call: Dict, tool_call_id: str) -> Dict:
|
| 434 |
+
"""Create a tool_calls message in Qwen format."""
|
| 435 |
+
return {
|
| 436 |
+
"role": "assistant",
|
| 437 |
+
"content": None,
|
| 438 |
+
"tool_calls": [
|
| 439 |
+
{
|
| 440 |
+
"id": tool_call_id,
|
| 441 |
+
"type": "function",
|
| 442 |
+
"function": {
|
| 443 |
+
"name": tool_call["name"],
|
| 444 |
+
"arguments": json.dumps(tool_call["arguments"])
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
]
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def create_tool_message(role: str, tool_call_id: str, tool_name: str, result: str) -> Dict:
|
| 452 |
+
"""Create a tool message (result of tool execution)."""
|
| 453 |
+
return {
|
| 454 |
+
"role": role, # typically "tool"
|
| 455 |
+
"content": result,
|
| 456 |
+
"tool_call_id": tool_call_id,
|
| 457 |
+
"name": tool_name
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def generate_example(scenario: Dict, system_prompt: str = None) -> Dict:
|
| 462 |
+
"""Generate a single training example in Qwen2.5-Coder format."""
|
| 463 |
+
if system_prompt is None:
|
| 464 |
+
system_prompt = "You are a helpful AI assistant that can use tools to help users solve problems. When you need to perform actions like reading files, running commands, searching the web, or searching code, use the appropriate tool."
|
| 465 |
+
|
| 466 |
+
tool_call_id = f"call_${random.randint(1000, 9999)}"
|
| 467 |
+
|
| 468 |
+
messages = [
|
| 469 |
+
{"role": "system", "content": system_prompt},
|
| 470 |
+
{"role": "user", "content": scenario["user"]},
|
| 471 |
+
create_tool_call_message(scenario["tool_call"], tool_call_id),
|
| 472 |
+
create_tool_message("tool", tool_call_id, scenario["tool_call"]["name"], scenario["result"]),
|
| 473 |
+
{"role": "assistant", "content": scenario["response"]}
|
| 474 |
+
]
|
| 475 |
+
|
| 476 |
+
return {
|
| 477 |
+
"messages": messages,
|
| 478 |
+
"tools": TOOL_DEFINITIONS
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def augment_scenario(scenario: Dict, tool_name: str) -> Dict:
|
| 483 |
+
"""Apply random augmentations to a scenario."""
|
| 484 |
+
import random
|
| 485 |
+
|
| 486 |
+
augmented = scenario.copy()
|
| 487 |
+
|
| 488 |
+
if tool_name == "FileRead":
|
| 489 |
+
augmented = vary_file_paths(augmented)
|
| 490 |
+
elif tool_name == "FileWrite":
|
| 491 |
+
augmented = vary_file_paths(augmented)
|
| 492 |
+
augmented = vary_content(augmented)
|
| 493 |
+
elif tool_name == "Bash":
|
| 494 |
+
augmented = vary_bash_command(augmented)
|
| 495 |
+
elif tool_name == "WebSearch":
|
| 496 |
+
augmented = vary_search_query(augmented)
|
| 497 |
+
elif tool_name == "Grep":
|
| 498 |
+
augmented = vary_grep_pattern(augmented)
|
| 499 |
+
|
| 500 |
+
return augmented
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def generate_dataset(num_examples: int = 1000, output_path: str = None) -> List[Dict]:
|
| 504 |
+
"""Generate the complete dataset."""
|
| 505 |
+
examples = []
|
| 506 |
+
tools = list(SCENARIOS.keys())
|
| 507 |
+
|
| 508 |
+
# Track counts for balance
|
| 509 |
+
examples_per_tool = num_examples // len(tools)
|
| 510 |
+
remainder = num_examples % len(tools)
|
| 511 |
+
|
| 512 |
+
for i, tool_name in enumerate(tools):
|
| 513 |
+
# Determine how many examples for this tool
|
| 514 |
+
count = examples_per_tool + (1 if i < remainder else 0)
|
| 515 |
+
|
| 516 |
+
base_scenarios = SCENARIOS[tool_name]
|
| 517 |
+
|
| 518 |
+
for j in range(count):
|
| 519 |
+
# Use base scenario and vary
|
| 520 |
+
base = base_scenarios[j % len(base_scenarios)]
|
| 521 |
+
|
| 522 |
+
# Apply augmentations for variety
|
| 523 |
+
if j >= len(base_scenarios):
|
| 524 |
+
scenario = augment_scenario(base, tool_name)
|
| 525 |
+
else:
|
| 526 |
+
scenario = base
|
| 527 |
+
|
| 528 |
+
example = generate_example(scenario)
|
| 529 |
+
examples.append(example)
|
| 530 |
+
|
| 531 |
+
# Shuffle for better training
|
| 532 |
+
random.shuffle(examples)
|
| 533 |
+
|
| 534 |
+
return examples
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def save_jsonl(examples: List[Dict], output_path: str):
|
| 538 |
+
"""Save examples to JSONL format."""
|
| 539 |
+
output_file = Path(output_path)
|
| 540 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 541 |
+
|
| 542 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 543 |
+
for example in examples:
|
| 544 |
+
f.write(json.dumps(example, ensure_ascii=False) + '\n')
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def save_json(examples: List[Dict], output_path: str):
|
| 548 |
+
"""Save examples to JSON format."""
|
| 549 |
+
output_file = Path(output_path)
|
| 550 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 551 |
+
|
| 552 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 553 |
+
json.dump(examples, f, ensure_ascii=False, indent=2)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def main():
|
| 557 |
+
parser = argparse.ArgumentParser(description="Generate synthetic tool-calling training data")
|
| 558 |
+
parser.add_argument("--num-examples", type=int, default=1000, help="Number of examples to generate")
|
| 559 |
+
parser.add_argument("--output-dir", type=str, default="training-data", help="Output directory")
|
| 560 |
+
parser.add_argument("--output-format", choices=["jsonl", "json", "both"], default="jsonl", help="Output format")
|
| 561 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 562 |
+
args = parser.parse_args()
|
| 563 |
+
|
| 564 |
+
# Set seed for reproducibility
|
| 565 |
+
random.seed(args.seed)
|
| 566 |
+
|
| 567 |
+
print(f"🎯 Generating {args.num_examples} tool-calling training examples...")
|
| 568 |
+
print(f" Output directory: {args.output_dir}")
|
| 569 |
+
print(f" Format: {args.output_format}")
|
| 570 |
+
print()
|
| 571 |
+
|
| 572 |
+
# Generate dataset
|
| 573 |
+
examples = generate_dataset(args.num_examples)
|
| 574 |
+
|
| 575 |
+
output_dir = Path(args.output_dir)
|
| 576 |
+
|
| 577 |
+
# Save based on format
|
| 578 |
+
if args.output_format in ["jsonl", "both"]:
|
| 579 |
+
jsonl_path = output_dir / "tool_examples.jsonl"
|
| 580 |
+
save_jsonl(examples, str(jsonl_path))
|
| 581 |
+
print(f"✅ Saved JSONL: {jsonl_path}")
|
| 582 |
+
|
| 583 |
+
if args.output_format in ["json", "both"]:
|
| 584 |
+
json_path = output_dir / "tool_examples.json"
|
| 585 |
+
save_json(examples, str(json_path))
|
| 586 |
+
print(f"✅ Saved JSON: {json_path}")
|
| 587 |
+
|
| 588 |
+
# Statistics
|
| 589 |
+
print(f"\n📊 Statistics:")
|
| 590 |
+
print(f" Total examples: {len(examples)}")
|
| 591 |
+
|
| 592 |
+
# Count by tool
|
| 593 |
+
tool_counts = {}
|
| 594 |
+
for ex in examples:
|
| 595 |
+
for msg in ex["messages"]:
|
| 596 |
+
if msg.get("tool_calls"):
|
| 597 |
+
tool_name = msg["tool_calls"][0]["function"]["name"]
|
| 598 |
+
tool_counts[tool_name] = tool_counts.get(tool_name, 0) + 1
|
| 599 |
+
|
| 600 |
+
print(f" Examples by tool:")
|
| 601 |
+
for tool, count in sorted(tool_counts.items(), key=lambda x: x[1], reverse=True):
|
| 602 |
+
print(f" - {tool}: {count}")
|
| 603 |
+
|
| 604 |
+
# Show sample
|
| 605 |
+
print(f"\n📝 Sample example (first in dataset):")
|
| 606 |
+
sample = examples[0]
|
| 607 |
+
print(f" Tools defined: {len(sample['tools'])}")
|
| 608 |
+
print(f" Messages: {len(sample['messages'])}")
|
| 609 |
+
print(f" First user message: {sample['messages'][1]['content'][:60]}...")
|
| 610 |
+
|
| 611 |
+
print(f"\n✨ Generation complete!")
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
if __name__ == "__main__":
|
| 615 |
+
main()
|
test_model.py
CHANGED
|
@@ -1,10 +1,17 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
Tests
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import argparse
|
|
|
|
|
|
|
|
|
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
|
|
@@ -12,141 +19,583 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
| 12 |
def load_model(model_path: str):
|
| 13 |
"""Load the fine-tuned model and tokenizer."""
|
| 14 |
print(f"Loading model from: {model_path}")
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
model_path,
|
| 18 |
torch_dtype=torch.float16,
|
| 19 |
device_map="auto",
|
| 20 |
low_cpu_mem_usage=True,
|
|
|
|
| 21 |
)
|
|
|
|
| 22 |
return model, tokenizer
|
| 23 |
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Remove the prompt from the completion
|
| 42 |
-
if completion.startswith(prompt):
|
| 43 |
-
completion = completion[len(prompt):].strip()
|
| 44 |
-
return completion
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def run_tests(model_path: str):
|
| 48 |
-
"""Run all code generation tests."""
|
| 49 |
-
model, tokenizer = load_model(model_path)
|
| 50 |
-
model.eval()
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
},
|
| 59 |
-
{
|
| 60 |
-
"name": "Binary Search",
|
| 61 |
-
"prompt": "def binary_search(arr, target):",
|
| 62 |
-
"max_tokens": 100,
|
| 63 |
-
"expected_keywords": ["while", "left", "right", "mid"],
|
| 64 |
-
},
|
| 65 |
-
{
|
| 66 |
-
"name": "Fibonacci",
|
| 67 |
-
"prompt": "def fibonacci(n):",
|
| 68 |
-
"max_tokens": 80,
|
| 69 |
-
"expected_keywords": ["return", "if", "else", "fib"],
|
| 70 |
-
},
|
| 71 |
-
{
|
| 72 |
-
"name": "Factorial",
|
| 73 |
-
"prompt": "def factorial(n):",
|
| 74 |
-
"max_tokens": 60,
|
| 75 |
-
"expected_keywords": ["return", "if", "*"],
|
| 76 |
-
},
|
| 77 |
-
{
|
| 78 |
-
"name": "Is Prime",
|
| 79 |
-
"prompt": "def is_prime(n):",
|
| 80 |
-
"max_tokens": 80,
|
| 81 |
-
"expected_keywords": ["if", "return", "for", "%"],
|
| 82 |
-
},
|
| 83 |
-
{
|
| 84 |
-
"name": "List Sum",
|
| 85 |
-
"prompt": "def list_sum(nums):",
|
| 86 |
-
"max_tokens": 50,
|
| 87 |
-
"expected_keywords": ["return", "sum", "+"],
|
| 88 |
-
},
|
| 89 |
-
{
|
| 90 |
-
"name": "Merge Sort",
|
| 91 |
-
"prompt": "def merge_sort(arr):",
|
| 92 |
-
"max_tokens": 150,
|
| 93 |
-
"expected_keywords": ["if", "len", "return", "merge"],
|
| 94 |
-
},
|
| 95 |
-
{
|
| 96 |
-
"name": "Quick Sort",
|
| 97 |
-
"prompt": "def quick_sort(arr):",
|
| 98 |
-
"max_tokens": 150,
|
| 99 |
-
"expected_keywords": ["if", "len", "return", "pivot"],
|
| 100 |
-
},
|
| 101 |
-
]
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
try:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
test['prompt'],
|
| 118 |
-
test['max_tokens']
|
| 119 |
-
)
|
| 120 |
-
print(f"Completion:\n{completion[:300]}")
|
| 121 |
-
|
| 122 |
-
# Check for expected keywords
|
| 123 |
-
keywords_found = sum(1 for kw in test['expected_keywords'] if kw.lower() in completion.lower())
|
| 124 |
-
if keywords_found >= len(test['expected_keywords']) // 2:
|
| 125 |
-
print(f"✅ PASS (found {keywords_found}/{len(test['expected_keywords'])} keywords)")
|
| 126 |
-
passed += 1
|
| 127 |
else:
|
| 128 |
-
|
| 129 |
-
passed += 1 # Still count as pass if some keywords found
|
| 130 |
-
print()
|
| 131 |
-
|
| 132 |
except Exception as e:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
print("="*60)
|
| 140 |
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
def main():
|
| 145 |
-
parser = argparse.ArgumentParser(
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
args = parser.parse_args()
|
| 148 |
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Basic Code Generation Tests for Stack 2.9 Model
|
| 4 |
+
Tests common algorithms and data structures.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python test_model.py --model-path /path/to/merged/model
|
| 8 |
+
python test_model.py --model-path /path/to/merged/model --output test_results.json
|
| 9 |
"""
|
| 10 |
|
| 11 |
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
+
from typing import List, Dict, Tuple, Optional
|
| 15 |
import torch
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
|
|
|
|
| 19 |
def load_model(model_path: str):
|
| 20 |
"""Load the fine-tuned model and tokenizer."""
|
| 21 |
print(f"Loading model from: {model_path}")
|
| 22 |
+
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 24 |
+
if tokenizer.pad_token is None:
|
| 25 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 26 |
+
|
| 27 |
model = AutoModelForCausalLM.from_pretrained(
|
| 28 |
model_path,
|
| 29 |
torch_dtype=torch.float16,
|
| 30 |
device_map="auto",
|
| 31 |
low_cpu_mem_usage=True,
|
| 32 |
+
trust_remote_code=True,
|
| 33 |
)
|
| 34 |
+
|
| 35 |
return model, tokenizer
|
| 36 |
|
| 37 |
|
| 38 |
+
def generate_completion(
|
| 39 |
+
model,
|
| 40 |
+
tokenizer,
|
| 41 |
+
prompt: str,
|
| 42 |
+
max_new_tokens: int = 128,
|
| 43 |
+
temperature: float = 0.2,
|
| 44 |
+
num_samples: int = 1
|
| 45 |
+
) -> List[str]:
|
| 46 |
+
"""Generate code completion(s) for a prompt."""
|
| 47 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
| 48 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 49 |
|
| 50 |
+
outputs = model.generate(
|
| 51 |
+
**inputs,
|
| 52 |
+
max_new_tokens=max_new_tokens,
|
| 53 |
+
temperature=temperature,
|
| 54 |
+
top_p=0.95,
|
| 55 |
+
do_sample=True,
|
| 56 |
+
repetition_penalty=1.1,
|
| 57 |
+
num_return_sequences=num_samples,
|
| 58 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 59 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 60 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
completions = []
|
| 63 |
+
for output in outputs:
|
| 64 |
+
text = tokenizer.decode(output, skip_special_tokens=True)
|
| 65 |
+
if text.startswith(prompt):
|
| 66 |
+
text = text[len(prompt):]
|
| 67 |
+
completions.append(text.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
return completions
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def extract_code(completion: str) -> str:
|
| 73 |
+
"""Extract code from completion, handling markdown code blocks."""
|
| 74 |
+
# Try ```python blocks first
|
| 75 |
+
if "```python" in completion:
|
| 76 |
+
start = completion.find("```python") + len("```python")
|
| 77 |
+
end = completion.find("```", start)
|
| 78 |
+
if end != -1:
|
| 79 |
+
return completion[start:end].strip()
|
| 80 |
+
|
| 81 |
+
# Try generic ``` blocks
|
| 82 |
+
if "```" in completion:
|
| 83 |
+
start = completion.find("```") + len("```")
|
| 84 |
+
# Skip language identifier if present
|
| 85 |
+
if completion[start:start+10].strip():
|
| 86 |
+
start = completion.find("\n", start) + 1
|
| 87 |
+
end = completion.find("```", start)
|
| 88 |
+
if end != -1:
|
| 89 |
+
return completion[start:end].strip()
|
| 90 |
+
|
| 91 |
+
return completion.strip()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def execute_code(code: str, timeout: int = 5) -> Tuple[bool, str, Optional[any]]:
|
| 95 |
+
"""Safely execute code and return (success, error_msg, result)."""
|
| 96 |
+
import signal
|
| 97 |
+
|
| 98 |
+
class TimeoutError(Exception):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def timeout_handler(signum, frame):
|
| 102 |
+
raise TimeoutError("Execution timed out")
|
| 103 |
+
|
| 104 |
+
safe_builtins = {
|
| 105 |
+
'print': print,
|
| 106 |
+
'len': len,
|
| 107 |
+
'range': range,
|
| 108 |
+
'str': str,
|
| 109 |
+
'int': int,
|
| 110 |
+
'float': float,
|
| 111 |
+
'bool': bool,
|
| 112 |
+
'list': list,
|
| 113 |
+
'dict': dict,
|
| 114 |
+
'set': set,
|
| 115 |
+
'tuple': tuple,
|
| 116 |
+
'sum': sum,
|
| 117 |
+
'min': min,
|
| 118 |
+
'max': max,
|
| 119 |
+
'abs': abs,
|
| 120 |
+
'sorted': sorted,
|
| 121 |
+
'reversed': reversed,
|
| 122 |
+
'enumerate': enumerate,
|
| 123 |
+
'zip': zip,
|
| 124 |
+
'map': map,
|
| 125 |
+
'filter': filter,
|
| 126 |
+
'any': any,
|
| 127 |
+
'all': all,
|
| 128 |
+
'isinstance': isinstance,
|
| 129 |
+
'type': type,
|
| 130 |
+
'round': round,
|
| 131 |
+
'pow': pow,
|
| 132 |
+
'divmod': divmod,
|
| 133 |
+
'ord': ord,
|
| 134 |
+
'chr': chr,
|
| 135 |
+
'hex': hex,
|
| 136 |
+
'bin': bin,
|
| 137 |
+
'id': id,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
namespace = {'__builtins__': safe_builtins}
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 144 |
+
signal.alarm(timeout)
|
| 145 |
+
exec(code, namespace)
|
| 146 |
+
signal.alarm(0)
|
| 147 |
+
return True, "", namespace.get('result')
|
| 148 |
+
except TimeoutError as e:
|
| 149 |
+
signal.alarm(0)
|
| 150 |
+
return False, f"Timeout after {timeout}s", None
|
| 151 |
+
except SyntaxError as e:
|
| 152 |
+
signal.alarm(0)
|
| 153 |
+
return False, f"Syntax error: {e}", None
|
| 154 |
+
except Exception as e:
|
| 155 |
+
signal.alarm(0)
|
| 156 |
+
return False, f"{type(e).__name__}: {e}", None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def check_function_output(code: str, func_name: str, test_cases: List[Dict]) -> Tuple[bool, str]:
|
| 160 |
+
"""Test a function with given test cases.
|
| 161 |
|
| 162 |
+
Args:
|
| 163 |
+
code: The generated code
|
| 164 |
+
func_name: Name of function to test
|
| 165 |
+
test_cases: List of dicts with 'input' (tuple), 'expected', 'description'
|
| 166 |
|
| 167 |
+
Returns:
|
| 168 |
+
Tuple of (all_passed, failure_message)
|
| 169 |
+
"""
|
| 170 |
+
namespace = {'__builtins__': {
|
| 171 |
+
'print': print,
|
| 172 |
+
'len': len,
|
| 173 |
+
'range': range,
|
| 174 |
+
'str': str,
|
| 175 |
+
'int': int,
|
| 176 |
+
'float': float,
|
| 177 |
+
'bool': bool,
|
| 178 |
+
'list': list,
|
| 179 |
+
'dict': dict,
|
| 180 |
+
'set': set,
|
| 181 |
+
'tuple': tuple,
|
| 182 |
+
'sum': sum,
|
| 183 |
+
'min': min,
|
| 184 |
+
'max': max,
|
| 185 |
+
'abs': abs,
|
| 186 |
+
'sorted': sorted,
|
| 187 |
+
}}
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
exec(code, namespace)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
return False, f"Code execution failed: {type(e).__name__}: {e}"
|
| 193 |
+
|
| 194 |
+
if func_name not in namespace:
|
| 195 |
+
return False, f"Function '{func_name}' not found"
|
| 196 |
+
|
| 197 |
+
func = namespace[func_name]
|
| 198 |
+
|
| 199 |
+
for tc in test_cases:
|
| 200 |
+
inputs = tc.get('input', ())
|
| 201 |
+
expected = tc.get('expected')
|
| 202 |
+
desc = tc.get('description', str(inputs))
|
| 203 |
|
| 204 |
try:
|
| 205 |
+
if isinstance(inputs, tuple):
|
| 206 |
+
result = func(*inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
else:
|
| 208 |
+
result = func(inputs)
|
|
|
|
|
|
|
|
|
|
| 209 |
except Exception as e:
|
| 210 |
+
return False, f"Failed on {desc}: {type(e).__name__}: {e}"
|
| 211 |
+
|
| 212 |
+
if result != expected:
|
| 213 |
+
return False, f"Failed on {desc}: expected {expected}, got {result}"
|
| 214 |
|
| 215 |
+
return True, ""
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Common algorithm test cases
|
| 219 |
+
ALGORITHM_TESTS = [
|
| 220 |
+
{
|
| 221 |
+
"name": "Two Sum",
|
| 222 |
+
"prompt": "def two_sum(nums, target):\n \"\"\"Return indices of two numbers that add up to target.\"\"\"\n",
|
| 223 |
+
"function": "two_sum",
|
| 224 |
+
"max_tokens": 100,
|
| 225 |
+
"test_cases": [
|
| 226 |
+
{"input": ([2,7,11,15], 9), "expected": [0,1], "description": "Basic"},
|
| 227 |
+
{"input": ([3,2,4], 6), "expected": [1,2], "description": "Middle"},
|
| 228 |
+
],
|
| 229 |
+
"keywords": ["for", "in", "range", "enumerate"],
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"name": "Reverse String",
|
| 233 |
+
"prompt": "def reverse_string(s):\n \"\"\"Return the reverse of string s.\"\"\"\n",
|
| 234 |
+
"function": "reverse_string",
|
| 235 |
+
"max_tokens": 50,
|
| 236 |
+
"test_cases": [
|
| 237 |
+
{"input": ("hello",), "expected": "olleh", "description": "Basic"},
|
| 238 |
+
{"input": ("Python",), "expected": "nohtyP", "description": "Mixed case"},
|
| 239 |
+
],
|
| 240 |
+
"keywords": ["[::-1]", "reversed"],
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"name": "Fibonacci",
|
| 244 |
+
"prompt": "def fibonacci(n):\n \"\"\"Return first n Fibonacci numbers.\"\"\"\n",
|
| 245 |
+
"function": "fibonacci",
|
| 246 |
+
"max_tokens": 100,
|
| 247 |
+
"test_cases": [
|
| 248 |
+
{"input": (7,), "expected": [0,1,1,2,3,5,8], "description": "n=7"},
|
| 249 |
+
{"input": (1,), "expected": [0], "description": "n=1"},
|
| 250 |
+
],
|
| 251 |
+
"keywords": ["for", "while", "append", "range"],
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"name": "Factorial",
|
| 255 |
+
"prompt": "def factorial(n):\n \"\"\"Return n! (factorial).\"\"\"\n",
|
| 256 |
+
"function": "factorial",
|
| 257 |
+
"max_tokens": 60,
|
| 258 |
+
"test_cases": [
|
| 259 |
+
{"input": (5,), "expected": 120, "description": "5!"},
|
| 260 |
+
{"input": (0,), "expected": 1, "description": "0!"},
|
| 261 |
+
],
|
| 262 |
+
"keywords": ["for", "while", "range", "*"],
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"name": "Is Palindrome",
|
| 266 |
+
"prompt": "def is_palindrome(x):\n \"\"\"Check if integer x is a palindrome.\"\"\"\n",
|
| 267 |
+
"function": "is_palindrome",
|
| 268 |
+
"max_tokens": 60,
|
| 269 |
+
"test_cases": [
|
| 270 |
+
{"input": (121,), "expected": True, "description": "121"},
|
| 271 |
+
{"input": (-121,), "expected": False, "description": "-121"},
|
| 272 |
+
],
|
| 273 |
+
"keywords": ["str", "[::-1]"],
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"name": "Binary Search",
|
| 277 |
+
"prompt": "def binary_search(arr, target):\n \"\"\"Return index of target in sorted array, or -1 if not found.\"\"\"\n",
|
| 278 |
+
"function": "binary_search",
|
| 279 |
+
"max_tokens": 120,
|
| 280 |
+
"test_cases": [
|
| 281 |
+
{"input": ([1,2,3,4,5], 3), "expected": 2, "description": "Found"},
|
| 282 |
+
{"input": ([1,2,3,4,5], 6), "expected": -1, "description": "Not found"},
|
| 283 |
+
],
|
| 284 |
+
"keywords": ["while", "left", "right", "<=", ">"],
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"name": "Merge Sort",
|
| 288 |
+
"prompt": "def merge_sort(arr):\n \"\"\"Return sorted copy of array using merge sort.\"\"\"\n",
|
| 289 |
+
"function": "merge_sort",
|
| 290 |
+
"max_tokens": 200,
|
| 291 |
+
"test_cases": [
|
| 292 |
+
{"input": ([3,1,4,1,5,9,2,6],), "expected": [1,1,2,3,4,5,6,9], "description": "Mixed"},
|
| 293 |
+
{"input": ([1,2,3],), "expected": [1,2,3], "description": "Already sorted"},
|
| 294 |
+
],
|
| 295 |
+
"keywords": ["def merge_sort", "if", "len", "return", "merge"],
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"name": "Quick Sort",
|
| 299 |
+
"prompt": "def quick_sort(arr):\n \"\"\"Return sorted copy of array using quick sort.\"\"\"\n",
|
| 300 |
+
"function": "quick_sort",
|
| 301 |
+
"max_tokens": 200,
|
| 302 |
+
"test_cases": [
|
| 303 |
+
{"input": ([3,1,4,1,5,9,2,6],), "expected": [1,1,2,3,4,5,6,9], "description": "Mixed"},
|
| 304 |
+
],
|
| 305 |
+
"keywords": ["def quick_sort", "if", "len", "return"],
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"name": "Maximum Subarray (Kadane's)",
|
| 309 |
+
"prompt": "def max_subarray(nums):\n \"\"\"Return maximum sum of contiguous subarray.\"\"\"\n",
|
| 310 |
+
"function": "max_subarray",
|
| 311 |
+
"max_tokens": 100,
|
| 312 |
+
"test_cases": [
|
| 313 |
+
{"input": ([-2,1,-3,4,-1,2,1,-5,4],), "expected": 6, "description": "Mixed"},
|
| 314 |
+
{"input": ([1],), "expected": 1, "description": "Single"},
|
| 315 |
+
],
|
| 316 |
+
"keywords": ["for", "max", "+"],
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"name": "Valid Parentheses",
|
| 320 |
+
"prompt": "def valid_parentheses(s):\n \"\"\"Check if string has valid bracket matching.\"\"\"\n",
|
| 321 |
+
"function": "valid_parentheses",
|
| 322 |
+
"max_tokens": 100,
|
| 323 |
+
"test_cases": [
|
| 324 |
+
{"input": ("()",), "expected": True, "description": "Simple"},
|
| 325 |
+
{"input": ("([)]",), "expected": False, "description": "Wrong order"},
|
| 326 |
+
],
|
| 327 |
+
"keywords": ["stack", "if", "for", "in", "pop", "append"],
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"name": "Climbing Stairs",
|
| 331 |
+
"prompt": "def climb_stairs(n):\n \"\"\"Count ways to climb n stairs (1 or 2 steps at a time).\"\"\"\n",
|
| 332 |
+
"function": "climb_stairs",
|
| 333 |
+
"max_tokens": 80,
|
| 334 |
+
"test_cases": [
|
| 335 |
+
{"input": (5,), "expected": 8, "description": "n=5"},
|
| 336 |
+
{"input": (2,), "expected": 2, "description": "n=2"},
|
| 337 |
+
],
|
| 338 |
+
"keywords": ["for", "while", "+", "="],
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"name": "List Sum",
|
| 342 |
+
"prompt": "def list_sum(nums):\n \"\"\"Return sum of all numbers in list.\"\"\"\n",
|
| 343 |
+
"function": "list_sum",
|
| 344 |
+
"max_tokens": 50,
|
| 345 |
+
"test_cases": [
|
| 346 |
+
{"input": ([1,2,3,4],), "expected": 10, "description": "Basic"},
|
| 347 |
+
{"input": ([],), "expected": 0, "description": "Empty"},
|
| 348 |
+
],
|
| 349 |
+
"keywords": ["for", "in", "+", "sum", "0"],
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"name": "List Average",
|
| 353 |
+
"prompt": "def list_avg(nums):\n \"\"\"Return average of numbers in list.\"\"\"\n",
|
| 354 |
+
"function": "list_avg",
|
| 355 |
+
"max_tokens": 60,
|
| 356 |
+
"test_cases": [
|
| 357 |
+
{"input": ([1,2,3,4,5],), "expected": 3.0, "description": "Basic"},
|
| 358 |
+
{"input": ([5],), "expected": 5.0, "description": "Single"},
|
| 359 |
+
],
|
| 360 |
+
"keywords": ["sum", "len", "/", "float"],
|
| 361 |
+
},
|
| 362 |
+
{
|
| 363 |
+
"name": "Find Maximum",
|
| 364 |
+
"prompt": "def find_max(nums):\n \"\"\"Return maximum value in list.\"\"\"\n",
|
| 365 |
+
"function": "find_max",
|
| 366 |
+
"max_tokens": 60,
|
| 367 |
+
"test_cases": [
|
| 368 |
+
{"input": ([3,1,4,1,5,9],), "expected": 9, "description": "Basic"},
|
| 369 |
+
{"input": ([-1,-5,-3],), "expected": -1, "description": "Negatives"},
|
| 370 |
+
],
|
| 371 |
+
"keywords": ["for", "in", "max", ">", "<"],
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"name": "Count Zeros",
|
| 375 |
+
"prompt": "def count_zeros(nums):\n \"\"\"Count zeros in list.\"\"\"\n",
|
| 376 |
+
"function": "count_zeros",
|
| 377 |
+
"max_tokens": 50,
|
| 378 |
+
"test_cases": [
|
| 379 |
+
{"input": ([0,1,0,2,0],), "expected": 3, "description": "Mixed"},
|
| 380 |
+
{"input": ([1,2,3],), "expected": 0, "description": "No zeros"},
|
| 381 |
+
],
|
| 382 |
+
"keywords": ["for", "in", "count", "==", "+"],
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"name": "Unique Elements",
|
| 386 |
+
"prompt": "def unique_elements(lst):\n \"\"\"Return list of unique elements preserving order.\"\"\"\n",
|
| 387 |
+
"function": "unique_elements",
|
| 388 |
+
"max_tokens": 80,
|
| 389 |
+
"test_cases": [
|
| 390 |
+
{"input": ([1,2,2,3,1],), "expected": [1,2,3], "description": "With dups"},
|
| 391 |
+
{"input": ([1,2,3],), "expected": [1,2,3], "description": "All unique"},
|
| 392 |
+
],
|
| 393 |
+
"keywords": ["for", "in", "if", "append", "set"],
|
| 394 |
+
},
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def run_test(model, tokenizer, test_config: Dict) -> Dict:
|
| 399 |
+
"""Run a single test and return results."""
|
| 400 |
+
name = test_config["name"]
|
| 401 |
+
prompt = test_config["prompt"]
|
| 402 |
+
func_name = test_config["function"]
|
| 403 |
+
max_tokens = test_config["max_tokens"]
|
| 404 |
+
test_cases = test_config["test_cases"]
|
| 405 |
+
keywords = test_config.get("keywords", [])
|
| 406 |
+
|
| 407 |
+
print(f"\n Test: {name}")
|
| 408 |
+
print(f" Prompt: {prompt.strip()[:60]}...")
|
| 409 |
+
|
| 410 |
+
start_time = time.time()
|
| 411 |
+
|
| 412 |
+
# Generate completion
|
| 413 |
+
completions = generate_completion(model, tokenizer, prompt, max_tokens=max_tokens)
|
| 414 |
+
elapsed = time.time() - start_time
|
| 415 |
+
|
| 416 |
+
# Extract and check code
|
| 417 |
+
code = extract_code(completions[0])
|
| 418 |
+
|
| 419 |
+
print(f" Generated in {elapsed:.2f}s")
|
| 420 |
+
print(f" Code preview: {code[:100]}...")
|
| 421 |
+
|
| 422 |
+
# Check syntax and execution
|
| 423 |
+
success, error_msg = check_function_output(code, func_name, test_cases)
|
| 424 |
+
|
| 425 |
+
# Check keywords
|
| 426 |
+
keywords_found = sum(1 for kw in keywords if kw.lower() in code.lower())
|
| 427 |
+
keyword_ratio = keywords_found / len(keywords) if keywords else 0
|
| 428 |
+
|
| 429 |
+
result = {
|
| 430 |
+
"name": name,
|
| 431 |
+
"passed": success,
|
| 432 |
+
"keywords_found": keywords_found,
|
| 433 |
+
"keywords_total": len(keywords),
|
| 434 |
+
"keyword_ratio": keyword_ratio,
|
| 435 |
+
"execution_time": elapsed,
|
| 436 |
+
"error": error_msg if not success else None,
|
| 437 |
+
"generated_code": code[:500], # Truncate for storage
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
if success:
|
| 441 |
+
print(f" Result: ✅ PASS")
|
| 442 |
+
else:
|
| 443 |
+
print(f" Result: ❌ FAIL - {error_msg[:60]}")
|
| 444 |
+
|
| 445 |
+
return result
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def calculate_pass_at_k(results: List[Dict], k: int) -> float:
|
| 449 |
+
"""Calculate pass@k across all tests.
|
| 450 |
+
|
| 451 |
+
For simplicity, a test passes if it passes the functional test.
|
| 452 |
+
"""
|
| 453 |
+
if not results or k <= 0:
|
| 454 |
+
return 0.0
|
| 455 |
+
|
| 456 |
+
passed = sum(1 for r in results if r["passed"])
|
| 457 |
+
total = len(results)
|
| 458 |
+
|
| 459 |
+
# Simple pass@k: probability that at least 1 of k samples would pass
|
| 460 |
+
# Assuming independence, this is 1 - (1 - p)^k where p = passed/total
|
| 461 |
+
if k >= total:
|
| 462 |
+
return passed / total if total > 0 else 0.0
|
| 463 |
+
|
| 464 |
+
# For pass@1, it's just the pass rate
|
| 465 |
+
if k == 1:
|
| 466 |
+
return passed / total if total > 0 else 0.0
|
| 467 |
+
|
| 468 |
+
# For pass@k where k > 1, estimate based on single sample
|
| 469 |
+
p = passed / total if total > 0 else 0.0
|
| 470 |
+
p_at_least_1 = 1 - (1 - p) ** k
|
| 471 |
+
return p_at_least_1
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def print_results(results: List[Dict], k_values: List[int] = [1, 10]):
|
| 475 |
+
"""Print test results summary."""
|
| 476 |
+
print("\n" + "="*60)
|
| 477 |
+
print("TEST RESULTS SUMMARY")
|
| 478 |
print("="*60)
|
| 479 |
|
| 480 |
+
passed = sum(1 for r in results if r["passed"])
|
| 481 |
+
total = len(results)
|
| 482 |
+
|
| 483 |
+
print(f"\n Total tests: {total}")
|
| 484 |
+
print(f" Passed: {passed}")
|
| 485 |
+
print(f" Failed: {total - passed}")
|
| 486 |
+
print(f" Pass rate: {100*passed/total:.1f}%")
|
| 487 |
+
|
| 488 |
+
for k in k_values:
|
| 489 |
+
pass_at_k = calculate_pass_at_k(results, k)
|
| 490 |
+
print(f"\n Pass@{k}: {100*pass_at_k:.1f}%")
|
| 491 |
+
|
| 492 |
+
print("\n Individual Results:")
|
| 493 |
+
for r in results:
|
| 494 |
+
status = "✅" if r["passed"] else "❌"
|
| 495 |
+
print(f" {status} {r['name']} (keywords: {r['keywords_found']}/{r['keywords_total']})")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def save_results(results: List[Dict], output_path: str):
|
| 499 |
+
"""Save test results to JSON."""
|
| 500 |
+
with open(output_path, 'w') as f:
|
| 501 |
+
json.dump(results, f, indent=2)
|
| 502 |
+
print(f"\n✅ Results saved to: {output_path}")
|
| 503 |
|
| 504 |
|
| 505 |
def main():
|
| 506 |
+
parser = argparse.ArgumentParser(
|
| 507 |
+
description="Test Stack 2.9 model on common algorithm tasks"
|
| 508 |
+
)
|
| 509 |
+
parser.add_argument(
|
| 510 |
+
"--model-path",
|
| 511 |
+
type=str,
|
| 512 |
+
required=True,
|
| 513 |
+
help="Path to the merged model directory"
|
| 514 |
+
)
|
| 515 |
+
parser.add_argument(
|
| 516 |
+
"--output",
|
| 517 |
+
type=str,
|
| 518 |
+
default="test_results.json",
|
| 519 |
+
help="Output file for results (default: test_results.json)"
|
| 520 |
+
)
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--max-tokens",
|
| 523 |
+
type=int,
|
| 524 |
+
default=200,
|
| 525 |
+
help="Max new tokens per generation (default: 200)"
|
| 526 |
+
)
|
| 527 |
+
parser.add_argument(
|
| 528 |
+
"--temperature",
|
| 529 |
+
type=float,
|
| 530 |
+
default=0.2,
|
| 531 |
+
help="Sampling temperature (default: 0.2)"
|
| 532 |
+
)
|
| 533 |
+
parser.add_argument(
|
| 534 |
+
"--test-names",
|
| 535 |
+
type=str,
|
| 536 |
+
default=None,
|
| 537 |
+
help="Comma-separated test names to run (default: all)"
|
| 538 |
+
)
|
| 539 |
+
parser.add_argument(
|
| 540 |
+
"--k-values",
|
| 541 |
+
type=str,
|
| 542 |
+
default="1,10",
|
| 543 |
+
help="Comma-separated k values for pass@k (default: 1,10)"
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
args = parser.parse_args()
|
| 547 |
|
| 548 |
+
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
| 549 |
+
|
| 550 |
+
print("="*60)
|
| 551 |
+
print("Stack 2.9 Model - Algorithm Tests")
|
| 552 |
+
print("="*60)
|
| 553 |
+
print(f"Model path: {args.model_path}")
|
| 554 |
+
print(f"Output: {args.output}")
|
| 555 |
+
print(f"Max tokens: {args.max_tokens}")
|
| 556 |
+
print(f"Temperature: {args.temperature}")
|
| 557 |
+
|
| 558 |
+
# Load model
|
| 559 |
+
model, tokenizer = load_model(args.model_path)
|
| 560 |
+
model.eval()
|
| 561 |
+
|
| 562 |
+
# Select tests to run
|
| 563 |
+
tests_to_run = ALGORITHM_TESTS
|
| 564 |
+
if args.test_names:
|
| 565 |
+
names = [n.strip() for n in args.test_names.split(",")]
|
| 566 |
+
tests_to_run = [t for t in tests_to_run if t["name"] in names]
|
| 567 |
+
print(f"Running tests: {[t['name'] for t in tests_to_run]}")
|
| 568 |
+
|
| 569 |
+
if not tests_to_run:
|
| 570 |
+
print("No tests to run!")
|
| 571 |
+
return
|
| 572 |
+
|
| 573 |
+
# Override max_tokens for each test
|
| 574 |
+
for test in tests_to_run:
|
| 575 |
+
if args.max_tokens:
|
| 576 |
+
test["max_tokens"] = args.max_tokens
|
| 577 |
+
|
| 578 |
+
# Run tests
|
| 579 |
+
print("\n" + "="*60)
|
| 580 |
+
print(f"Running {len(tests_to_run)} Tests")
|
| 581 |
+
print("="*60)
|
| 582 |
+
|
| 583 |
+
results = []
|
| 584 |
+
total_start = time.time()
|
| 585 |
+
|
| 586 |
+
for i, test in enumerate(tests_to_run, 1):
|
| 587 |
+
print(f"\n[{i}/{len(tests_to_run)}]")
|
| 588 |
+
result = run_test(model, tokenizer, test)
|
| 589 |
+
results.append(result)
|
| 590 |
+
|
| 591 |
+
total_time = time.time() - total_start
|
| 592 |
+
|
| 593 |
+
# Print summary
|
| 594 |
+
print_results(results, k_values)
|
| 595 |
+
print(f"\n Total time: {total_time:.1f}s")
|
| 596 |
+
|
| 597 |
+
# Save results
|
| 598 |
+
save_results(results, args.output)
|
| 599 |
|
| 600 |
|
| 601 |
if __name__ == "__main__":
|
train_simple.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Simple standalone training script for Stack 2.9.
|
| 4 |
-
No package installation required — just run: python train_simple.py --config train_config.yaml
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import argparse
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
|
| 12 |
-
import yaml
|
| 13 |
-
from datasets import load_dataset
|
| 14 |
-
from transformers import (
|
| 15 |
-
AutoModelForCausalLM,
|
| 16 |
-
AutoTokenizer,
|
| 17 |
-
BitsAndBytesConfig,
|
| 18 |
-
TrainingArguments,
|
| 19 |
-
Trainer,
|
| 20 |
-
DataCollatorForLanguageModeling
|
| 21 |
-
)
|
| 22 |
-
from peft import LoraConfig, get_peft_model, TaskType
|
| 23 |
-
import torch
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def load_config(config_path: str) -> dict:
|
| 27 |
-
with open(config_path, 'r') as f:
|
| 28 |
-
return yaml.safe_load(f)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def load_model_and_tokenizer(model_name: str, trust_remote_code: bool = True, use_4bit: bool = False):
|
| 32 |
-
"""Load base model and tokenizer."""
|
| 33 |
-
if use_4bit:
|
| 34 |
-
quantization_config = BitsAndBytesConfig(
|
| 35 |
-
load_in_4bit=True,
|
| 36 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 37 |
-
bnb_4bit_quant_type="nf4",
|
| 38 |
-
bnb_4bit_use_double_quant=True
|
| 39 |
-
)
|
| 40 |
-
else:
|
| 41 |
-
quantization_config = None
|
| 42 |
-
|
| 43 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
| 44 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 45 |
-
model_name,
|
| 46 |
-
quantization_config=quantization_config,
|
| 47 |
-
torch_dtype=torch.float16,
|
| 48 |
-
trust_remote_code=trust_remote_code,
|
| 49 |
-
device_map="auto"
|
| 50 |
-
)
|
| 51 |
-
return model, tokenizer
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def load_data(data_path: str, tokenizer, max_length: int = 2048, train_split: float = 0.9):
|
| 55 |
-
"""Load and tokenize dataset."""
|
| 56 |
-
raw_dataset = load_dataset("json", data_files=data_path, split="train")
|
| 57 |
-
|
| 58 |
-
def tokenize_function(examples):
|
| 59 |
-
# Combine instruction and output
|
| 60 |
-
texts = []
|
| 61 |
-
for instr, out in zip(examples.get("instruction", [""]), examples.get("output", [""])):
|
| 62 |
-
if instr and out:
|
| 63 |
-
texts.append(f"### Instruction:\n{instr}\n\n### Response:\n{out}")
|
| 64 |
-
elif out:
|
| 65 |
-
texts.append(out)
|
| 66 |
-
elif instr:
|
| 67 |
-
texts.append(instr)
|
| 68 |
-
else:
|
| 69 |
-
texts.append("")
|
| 70 |
-
|
| 71 |
-
tokenized = tokenizer(texts, truncation=True, max_length=max_length, padding="max_length")
|
| 72 |
-
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 73 |
-
return tokenized
|
| 74 |
-
|
| 75 |
-
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=raw_dataset.column_names)
|
| 76 |
-
split = tokenized_dataset.train_test_split(train_size=train_split)
|
| 77 |
-
return split["train"], split["test"]
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def train(config: dict):
|
| 81 |
-
"""Main training function."""
|
| 82 |
-
model_config = config["model"]
|
| 83 |
-
data_config = config["data"]
|
| 84 |
-
lora_config = config["lora"]
|
| 85 |
-
training_config = config["training"]
|
| 86 |
-
output_config = config["output"]
|
| 87 |
-
hardware_config = config["hardware"]
|
| 88 |
-
|
| 89 |
-
# Load model and tokenizer
|
| 90 |
-
print(f"Loading model: {model_config['name']}")
|
| 91 |
-
model, tokenizer = load_model_and_tokenizer(
|
| 92 |
-
model_name=model_config["name"],
|
| 93 |
-
trust_remote_code=model_config.get("trust_remote_code", True),
|
| 94 |
-
use_4bit=hardware_config.get("use_4bit", False)
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
# Load data
|
| 98 |
-
print(f"Loading dataset: {data_config['input_path']}")
|
| 99 |
-
train_dataset, eval_dataset = load_data(
|
| 100 |
-
data_path=data_config["input_path"],
|
| 101 |
-
tokenizer=tokenizer,
|
| 102 |
-
max_length=data_config.get("max_length", 2048),
|
| 103 |
-
train_split=data_config.get("train_split", 0.9)
|
| 104 |
-
)
|
| 105 |
-
print(f" Train samples: {len(train_dataset)}")
|
| 106 |
-
print(f" Eval samples: {len(eval_dataset)}")
|
| 107 |
-
|
| 108 |
-
# Apply LoRA
|
| 109 |
-
peft_config = LoraConfig(
|
| 110 |
-
r=lora_config["r"],
|
| 111 |
-
alpha=lora_config["alpha"],
|
| 112 |
-
dropout=lora_config["dropout"],
|
| 113 |
-
target_modules=lora_config["target_modules"],
|
| 114 |
-
bias=lora_config["bias"],
|
| 115 |
-
task_type=TaskType.CAUSAL_LM
|
| 116 |
-
)
|
| 117 |
-
model = get_peft_model(model, peft_config)
|
| 118 |
-
model.print_trainable_parameters()
|
| 119 |
-
|
| 120 |
-
# Training arguments
|
| 121 |
-
output_dir = output_config["lora_dir"]
|
| 122 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 123 |
-
|
| 124 |
-
training_args = TrainingArguments(
|
| 125 |
-
output_dir=output_dir,
|
| 126 |
-
num_train_epochs=training_config["num_epochs"],
|
| 127 |
-
per_device_train_batch_size=training_config["batch_size"],
|
| 128 |
-
gradient_accumulation_steps=training_config["gradient_accumulation"],
|
| 129 |
-
learning_rate=training_config["learning_rate"],
|
| 130 |
-
warmup_steps=training_config.get("warmup_steps", 100),
|
| 131 |
-
weight_decay=training_config.get("weight_decay", 0.01),
|
| 132 |
-
max_grad_norm=training_config.get("max_grad_norm", 1.0),
|
| 133 |
-
logging_steps=training_config.get("logging_steps", 10),
|
| 134 |
-
save_steps=training_config.get("save_steps", 100),
|
| 135 |
-
save_total_limit=training_config.get("save_total_limit", 2),
|
| 136 |
-
fp16=training_config.get("fp16", True),
|
| 137 |
-
bf16=training_config.get("bf16", False),
|
| 138 |
-
gradient_checkpointing=training_config.get("gradient_checkpointing", True),
|
| 139 |
-
evaluation_strategy="steps" if eval_dataset else "no",
|
| 140 |
-
eval_steps=training_config.get("eval_steps", 100) if eval_dataset else None,
|
| 141 |
-
report_to="none", # No WandB
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 145 |
-
|
| 146 |
-
trainer = Trainer(
|
| 147 |
-
model=model,
|
| 148 |
-
args=training_args,
|
| 149 |
-
train_dataset=train_dataset,
|
| 150 |
-
eval_dataset=eval_dataset,
|
| 151 |
-
data_collator=data_collator,
|
| 152 |
-
tokenizer=tokenizer,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
print("="*60)
|
| 156 |
-
print("Starting training...")
|
| 157 |
-
print("="*60)
|
| 158 |
-
trainer.train()
|
| 159 |
-
print("Training completed!")
|
| 160 |
-
|
| 161 |
-
# Save final adapter
|
| 162 |
-
trainer.save_model(output_dir)
|
| 163 |
-
print(f"✅ Adapter saved to {output_dir}")
|
| 164 |
-
|
| 165 |
-
return trainer
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
def main():
|
| 169 |
-
parser = argparse.ArgumentParser()
|
| 170 |
-
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
|
| 171 |
-
args = parser.parse_args()
|
| 172 |
-
|
| 173 |
-
print("="*60)
|
| 174 |
-
print("Stack 2.9 Simple Training")
|
| 175 |
-
print("="*60)
|
| 176 |
-
|
| 177 |
-
config = load_config(args.config)
|
| 178 |
-
print(f"Config loaded: {args.config}")
|
| 179 |
-
print(f"Model: {config['model']['name']}")
|
| 180 |
-
print(f"Data: {config['data']['input_path']}")
|
| 181 |
-
|
| 182 |
-
try:
|
| 183 |
-
train(config)
|
| 184 |
-
print("\n" + "="*60)
|
| 185 |
-
print("✅ TRAINING SUCCESS")
|
| 186 |
-
print("="*60)
|
| 187 |
-
except Exception as e:
|
| 188 |
-
print("\n" + "="*60)
|
| 189 |
-
print(f"❌ TRAINING FAILED: {e}")
|
| 190 |
-
print("="*60)
|
| 191 |
-
import traceback
|
| 192 |
-
traceback.print_exc()
|
| 193 |
-
sys.exit(1)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if __name__ == "__main__":
|
| 197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training-configs/7b-lora-config.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Stack 2.9 Training Configuration for Qwen2.5-Coder-7B
|
| 2 |
+
# Optimized for A100 80GB or multi-GPU setup (2x A100 40GB minimum)
|
| 3 |
+
# For single GPU training with 7B model
|
| 4 |
+
|
| 5 |
+
model:
|
| 6 |
+
name: "Qwen/Qwen2.5-Coder-7B"
|
| 7 |
+
trust_remote_code: true
|
| 8 |
+
torch_dtype: "bfloat16" # bf16 for A100, better precision than fp16
|
| 9 |
+
use_flash_attention: true # A100 supports FA2
|
| 10 |
+
|
| 11 |
+
tokenizer:
|
| 12 |
+
model_max_length: 8192 # Standard 8K context for coding tasks
|
| 13 |
+
padding_side: "right"
|
| 14 |
+
truncation_side: "right"
|
| 15 |
+
|
| 16 |
+
# LoRA Configuration - Optimized for 7B model
|
| 17 |
+
lora:
|
| 18 |
+
r: 16 # LoRA rank: 16 is sufficient for 7B, saves memory
|
| 19 |
+
alpha: 32 # alpha = 2*r is standard
|
| 20 |
+
dropout: 0.05
|
| 21 |
+
target_modules:
|
| 22 |
+
- "q_proj"
|
| 23 |
+
- "k_proj"
|
| 24 |
+
- "v_proj"
|
| 25 |
+
- "o_proj"
|
| 26 |
+
- "gate_proj"
|
| 27 |
+
- "up_proj"
|
| 28 |
+
- "down_proj"
|
| 29 |
+
bias: "none"
|
| 30 |
+
task_type: "CAUSAL_LM"
|
| 31 |
+
|
| 32 |
+
# Training Configuration
|
| 33 |
+
training:
|
| 34 |
+
num_epochs: 3 # 3 epochs is typically sufficient
|
| 35 |
+
batch_size: 2 # 2 per GPU for A100 80GB
|
| 36 |
+
gradient_accumulation: 8 # Effective batch = 16 (2 * 8)
|
| 37 |
+
learning_rate: 1.0e-4 # Standard for LoRA
|
| 38 |
+
warmup_steps: 100
|
| 39 |
+
weight_decay: 0.01
|
| 40 |
+
max_grad_norm: 1.0
|
| 41 |
+
logging_steps: 10
|
| 42 |
+
eval_steps: 500
|
| 43 |
+
save_steps: 1000
|
| 44 |
+
save_total_limit: 3
|
| 45 |
+
fp16: false
|
| 46 |
+
bf16: true # Use bf16 for A100
|
| 47 |
+
gradient_checkpointing: true # Enable for memory savings
|
| 48 |
+
optim: "adamw_torch" # Standard optimizer
|
| 49 |
+
|
| 50 |
+
# Data Configuration
|
| 51 |
+
data:
|
| 52 |
+
train_files:
|
| 53 |
+
- "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/train.jsonl"
|
| 54 |
+
val_file: "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/val.jsonl"
|
| 55 |
+
test_file: "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/test.jsonl"
|
| 56 |
+
max_length: 8192
|
| 57 |
+
train_split: 0.9
|
| 58 |
+
test_split: 0.1
|
| 59 |
+
|
| 60 |
+
# Output Configuration
|
| 61 |
+
output:
|
| 62 |
+
lora_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-lora"
|
| 63 |
+
merged_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-merged"
|
| 64 |
+
awq_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-awq"
|
| 65 |
+
|
| 66 |
+
# Quantization - Optional for 7B
|
| 67 |
+
# Enable 4-bit only if running on limited VRAM (< 40GB per GPU)
|
| 68 |
+
quantization:
|
| 69 |
+
enabled: false # Disabled for A100 80GB - not needed
|
| 70 |
+
bits: 4
|
| 71 |
+
group_size: 128
|
| 72 |
+
|
| 73 |
+
# Logging Configuration
|
| 74 |
+
logging:
|
| 75 |
+
report_to: "none" # Set to "wandb" for Weights & Biases
|
| 76 |
+
wandb_project: "stack-2.9-training"
|
| 77 |
+
run_name: null
|
| 78 |
+
|
| 79 |
+
# Hardware Configuration
|
| 80 |
+
hardware:
|
| 81 |
+
device: "cuda"
|
| 82 |
+
num_gpus: 1 # Set to 2, 4, or 8 for multi-GPU
|
| 83 |
+
use_4bit: false # Not needed for A100 80GB
|
| 84 |
+
use_8bit: false # Optional: enables for memory savings
|
| 85 |
+
zero_stage: 0 # ZeRO stage (0=disabled, 1,2,3 for multi-GPU)
|
| 86 |
+
|
| 87 |
+
# Multi-GPU Configuration (use when num_gpus > 1)
|
| 88 |
+
multigpu:
|
| 89 |
+
zero_stage: 2 # ZeRO-2 recommended for 2-4 GPUs, ZeRO-3 for 8+
|
| 90 |
+
gradient_bucket_size: 5e8
|
| 91 |
+
prefetch_bucket_size: 5e8
|
| 92 |
+
steps_per_print: 10
|
| 93 |
+
|
| 94 |
+
# Memory Estimates (7B):
|
| 95 |
+
# - Full bf16, batch_size=2: ~18GB VRAM
|
| 96 |
+
# - Full bf16, batch_size=4: ~22GB VRAM
|
| 97 |
+
# - 4-bit QLoRA, batch_size=1: ~6GB VRAM
|
| 98 |
+
# - 4-bit QLoRA, batch_size=2: ~8GB VRAM
|
| 99 |
+
# A100 80GB = 80GB available (has headroom)
|
| 100 |
+
|
| 101 |
+
# Merge Configuration (run after training)
|
| 102 |
+
merge:
|
| 103 |
+
enabled: true
|
| 104 |
+
output_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-merged"
|
training-configs/kaggle-7b-qlora.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Kaggle QLoRA Training Script for Stack 2.9 on Qwen2.5-Coder-7B
|
| 3 |
+
#
|
| 4 |
+
# Hardware: Kaggle P100 GPU (16GB VRAM)
|
| 5 |
+
# Expected runtime: 2-4 hours depending on dataset size
|
| 6 |
+
# Memory: ~6GB VRAM with QLoRA + gradient checkpointing
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# chmod +x kaggle-7b-qlora.sh
|
| 10 |
+
# ./kaggle-7b-qlora.sh
|
| 11 |
+
|
| 12 |
+
set -e
|
| 13 |
+
|
| 14 |
+
# === Configuration ===
|
| 15 |
+
MODEL_NAME="Qwen/Qwen2.5-Coder-7B"
|
| 16 |
+
OUTPUT_DIR="/kaggle/working/stack-2.9-output/lora-7b-qlora"
|
| 17 |
+
DATA_DIR="/kaggle/working/stack-2.9/data"
|
| 18 |
+
|
| 19 |
+
# === Hyperparameters ===
|
| 20 |
+
# LoRA config (r=16, alpha=32)
|
| 21 |
+
LORA_R=16
|
| 22 |
+
LORA_ALPHA=32
|
| 23 |
+
LORA_DROPOUT=0.05
|
| 24 |
+
|
| 25 |
+
# Training config
|
| 26 |
+
EPOCHS=2
|
| 27 |
+
BATCH_SIZE=1
|
| 28 |
+
GRAD_ACCUM=8 # Effective = 8
|
| 29 |
+
LR=1e-4
|
| 30 |
+
MAX_LENGTH=4096 # Reduced for Kaggle memory
|
| 31 |
+
|
| 32 |
+
# === Environment Setup ===
|
| 33 |
+
echo "=== Setting up environment ==="
|
| 34 |
+
cd /kaggle/working/stack-2.9
|
| 35 |
+
|
| 36 |
+
# Install dependencies if needed
|
| 37 |
+
pip install -q transformers datasets peft bitsandbytes accelerate torch
|
| 38 |
+
|
| 39 |
+
# === Training Command ===
|
| 40 |
+
echo "=== Starting QLoRA training ==="
|
| 41 |
+
echo "Model: $MODEL_NAME"
|
| 42 |
+
echo "Output: $OUTPUT_DIR"
|
| 43 |
+
echo "Epochs: $EPOCHS"
|
| 44 |
+
|
| 45 |
+
python -m trl.sft \
|
| 46 |
+
--model_name "$MODEL_NAME" \
|
| 47 |
+
--train_files "$DATA_DIR/train.jsonl" \
|
| 48 |
+
--validation_files "$DATA_DIR/val.jsonl" \
|
| 49 |
+
--output_dir "$OUTPUT_DIR" \
|
| 50 |
+
--per_device_train_batch_size $BATCH_SIZE \
|
| 51 |
+
--gradient_accumulation_steps $GRAD_ACCUM \
|
| 52 |
+
--num_train_epochs $EPOCHS \
|
| 53 |
+
--learning_rate $LR \
|
| 54 |
+
--max_seq_length $MAX_LENGTH \
|
| 55 |
+
--logging_steps 10 \
|
| 56 |
+
--save_steps 500 \
|
| 57 |
+
--save_total_limit 2 \
|
| 58 |
+
--bf16 true \
|
| 59 |
+
--gradient_checkpointing true \
|
| 60 |
+
--lora_r $LORA_R \
|
| 61 |
+
--lora_alpha $LORA_ALPHA \
|
| 62 |
+
--lora_dropout $LORA_DROPOUT \
|
| 63 |
+
--target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \
|
| 64 |
+
--load_in_4bit true \
|
| 65 |
+
--bnb_4bit_compute_dtype "bfloat16" \
|
| 66 |
+
--bnb_4bit_quant_type "nf4" \
|
| 67 |
+
--bnb_4bit_use_double_quant true
|
| 68 |
+
|
| 69 |
+
echo "=== Training complete ==="
|
| 70 |
+
echo "LoRA adapters saved to: $OUTPUT_DIR"
|
| 71 |
+
|
| 72 |
+
# Optional: Merge and save merged model
|
| 73 |
+
echo "=== Merging adapter ==="
|
| 74 |
+
python merge_simple.py \
|
| 75 |
+
--base_model "$MODEL_NAME" \
|
| 76 |
+
--lora_adapter "$OUTPUT_DIR" \
|
| 77 |
+
--output_dir "$OUTPUT_DIR/merged"
|
| 78 |
+
|
| 79 |
+
echo "=== Done ==="
|
training-data/tool_examples.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training-data/tool_examples.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1043720a918f5fe0f70cc013c108710570c37ae6c9cee6f504e49dc359af5a2a
|
| 3 |
+
size 3779800
|