walidsobhie-code commited on
Commit
183b3b6
·
1 Parent(s): a8f2981

feat: add evaluation scripts, tool calling data generator, and 7B training configs

Browse files

Created:
- test_model.py: 602 lines - basic code generation tests
- evaluate_model.py: 847 lines - HumanEval/MBPP benchmark evaluation with pass@k
- scripts/generate_tool_data.py: 615 lines - synthetic tool_calling data generator
- training-data/tool_examples.jsonl: 1000 synthetic tool_calling examples
- training-configs/7b-lora-config.yaml: A100 80GB full LoRA config
- training-configs/kaggle-7b-qlora.sh: QLoRA script for 7B on Kaggle T4
- docs/TRAINING_7B.md: 7B training documentation with VRAM estimates

Updated:
- merge_simple.py: Improved merge logic
- .gitignore: Clean ignores

Note: tool_examples.jsonl has 1000 synthetic tool_calling examples ready for fine-tuning

docs/TRAINING_7B.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Stack 2.9 on Qwen2.5-Coder-7B
2
+
3
+ ## Overview
4
+
5
+ This guide covers training Stack 2.9 on the Qwen2.5-Coder-7B model using LoRA/QLoRA fine-tuning.
6
+
7
+ ## Hardware Requirements
8
+
9
+ ### Minimum (QLoRA - 4-bit)
10
+ | GPU | VRAM | Batch Size | Notes |
11
+ |-----|------|-----------|-------|
12
+ | T4 (Colab) | 15GB | 1 | Gradient accu = 16 |
13
+ | P100 (Kaggle) | 16GB | 1 | Gradient accu = 8 |
14
+ | RTX 3090 | 24GB | 2 | Full performance |
15
+
16
+ ### Recommended (Full LoRA - bf16)
17
+ | GPU | VRAM | Batch Size | Notes |
18
+ |-----|------|-----------|-------|
19
+ | A100 40GB | 40GB | 2 | 2x for better throughput |
20
+ | A100 80GB | 80GB | 4 | Best for production |
21
+ | H100 80GB | 80GB | 4 | Next-gen option |
22
+
23
+ ## VRAM Estimates
24
+
25
+ | Configuration | Batch Size | Gradient Checkpoint | Est. VRAM |
26
+ |--------------|-----------|-------------------|----------|
27
+ | Full bf16 | 1 | No | 14GB |
28
+ | Full bf16 | 2 | Yes | 16GB |
29
+ | Full bf16 | 4 | Yes | 22GB |
30
+ | QLoRA (4-bit) | 1 | Yes | 5-6GB |
31
+ | QLoRA (4-bit) | 2 | Yes | 7-8GB |
32
+
33
+ ## Quick Start
34
+
35
+ ### Option 1: Kaggle (QLoRA)
36
+
37
+ ```bash
38
+ cd /kaggle/working/stack-2.9
39
+ chmod +x training-configs/kaggle-7b-qlora.sh
40
+ ./training-configs/kaggle-7b-qlora.sh
41
+ ```
42
+
43
+ ### Option 2: Local (Full LoRA)
44
+
45
+ ```bash
46
+ cd /path/to/stack-2.9
47
+ python train_local.py \
48
+ --config training-configs/7b-lora-config.yaml
49
+ ```
50
+
51
+ ### Option 3: Custom Training Script
52
+
53
+ ```python
54
+ from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
55
+ from peft import LoraConfig, get_peft_model
56
+
57
+ # Load model
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ "Qwen/Qwen2.5-Coder-7B",
60
+ torch_dtype="bfloat16",
61
+ device_map="auto"
62
+ )
63
+
64
+ # LoRA config
65
+ lora_config = LoraConfig(
66
+ r=16,
67
+ lora_alpha=32,
68
+ lora_dropout=0.05,
69
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
70
+ "gate_proj", "up_proj", "down_proj"],
71
+ bias="none",
72
+ task_type="CAUSAL_LM"
73
+ )
74
+
75
+ # Apply LoRA
76
+ model = get_peft_model(model, lora_config)
77
+ model.print_trainable_parameters()
78
+ ```
79
+
80
+ ## Configuration Reference
81
+
82
+ ### LoRA Parameters
83
+ ```yaml
84
+ lora:
85
+ r: 16 # Rank (8-32 recommended for 7B)
86
+ alpha: 32 # Usually 2*r
87
+ dropout: 0.05
88
+ target_modules:
89
+ - q_proj
90
+ - k_proj
91
+ - v_proj
92
+ - o_proj
93
+ - gate_proj
94
+ - up_proj
95
+ - down_proj
96
+ ```
97
+
98
+ ### Training Parameters
99
+ ```yaml
100
+ training:
101
+ num_epochs: 3
102
+ batch_size: 2 # A100: 2-4, T4/P100: 1
103
+ gradient_accumulation: 8
104
+ learning_rate: 1.0e-4
105
+ warmup_steps: 100
106
+ gradient_checkpointing: true
107
+ bf16: true
108
+ ```
109
+
110
+ ## Expected Training Time
111
+
112
+ Based on ~10K samples, max_length=4096:
113
+
114
+ | Hardware | Config | Est. Time |
115
+ |----------|--------|----------|
116
+ | T4 | 4-bit QLoRA | 4-6 hours |
117
+ | P100 | 4-bit QLoRA | 2-3 hours |
118
+ | A100 40GB | bf16 LoRA | 30-45 min |
119
+ | A100 80GB | bf16 LoRA | 20-30 min |
120
+
121
+ Times scale linearly with dataset size.
122
+
123
+ ## After Training
124
+
125
+ ### Merge LoRA Adapter
126
+
127
+ ```python
128
+ from peft import PeftModel
129
+ from transformers import AutoTokenizer
130
+
131
+ base_model = AutoModelForCausalLM.from_pretrained(
132
+ "Qwen/Qwen2.5-Coder-7B",
133
+ torch_dtype="bfloat16"
134
+ )
135
+
136
+ # Merge adapter
137
+ model = PeftModel.from_pretrained(base_model, "./output/lora")
138
+ merged = model.merge_and_unload()
139
+
140
+ # Save
141
+ merged.save_pretrained("./output/merged")
142
+ tokenizer.save_pretrained("./output/merged")
143
+ ```
144
+
145
+ ### Test the Model
146
+
147
+ ```python
148
+ from transformers import AutoTokenizer, pipeline
149
+
150
+ tokenizer = AutoTokenizer.from_pretrained("./output/merged")
151
+ pipe = pipeline("text-generation", model=merged, tokenizer=tokenizer)
152
+
153
+ result = pipe("def quick_sort(arr):", max_new_tokens=100)
154
+ print(result[0]["generated_text"])
155
+ ```
156
+
157
+ ## Troubleshooting
158
+
159
+ ### OOM (Out of Memory)
160
+ - Reduce `batch_size` to 1
161
+ - Enable `gradient_checkpointing: true`
162
+ - Reduce `max_length` (4096 → 2048)
163
+ - Enable 4-bit quantization
164
+
165
+ ### Training Slow
166
+ - Increase batch size if VRAM allows
167
+ - Enable `use_flash_attention: true` (A100/H100)
168
+ - Reduce gradient accumulation
169
+
170
+ ### Loss Not Converging
171
+ - Check learning rate (try 5e-5 or 2e-4)
172
+ - Increase epochs (3 → 5)
173
+ - Verify data format matches expected template
174
+
175
+ ## Alternative: RunPod /cloud Deployment
176
+
177
+ For faster training, see `runpod_deploy.sh` at repo root.
178
+
179
+ ```bash
180
+ # Example: RunPod A100
181
+ bash runpod_deploy.sh --gpu a100 --instance $ hourly
182
+ ```
183
+
184
+ ## Notes
185
+
186
+ - **A100 recommended**: Best balance of VRAM and speed
187
+ - **4-bit QLoRA**: Use only if VRAM < 20GB, slightly reduces quality
188
+ - **Gradient checkpointing**: Always enable, minimal perf impact for big memory savings
189
+ - **Flash Attention**: A100/H100 only, significant speed boost
evaluate_model.py CHANGED
@@ -2,359 +2,845 @@
2
  """
3
  HumanEval + MBPP Benchmark Evaluation for Stack 2.9
4
  Tests code generation quality using pass@k metrics.
 
 
 
 
5
  """
6
 
7
  import argparse
8
  import os
9
  import json
10
  import time
11
- from typing import List, Dict
 
 
 
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
15
 
16
- def load_model(model_path: str):
17
  """Load the fine-tuned model and tokenizer."""
18
  print(f"Loading model from: {model_path}")
19
- tokenizer = AutoTokenizer.from_pretrained(model_path)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_path,
22
- torch_dtype=torch.float16,
23
- device_map="auto",
24
- low_cpu_mem_usage=True,
25
- )
 
 
 
 
 
 
 
 
 
26
  return model, tokenizer
27
 
28
 
29
- def generate_solution(model, tokenizer, prompt: str, max_new_tokens: int = 256) -> str:
30
- """Generate a single solution for a problem."""
31
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
32
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
33
 
34
- with torch.no_grad():
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=max_new_tokens,
38
- temperature=0.8,
39
- top_p=0.95,
40
- do_sample=True,
41
- repetition_penalty=1.1,
42
- )
 
 
43
 
44
- completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- # Extract just the generated part
46
- if completion.startswith(prompt):
47
- completion = completion[len(prompt):].strip()
 
 
 
48
 
49
- # Try to extract just the code (between ```python and ``` if present)
 
 
 
 
 
50
  if "```python" in completion:
51
  start = completion.find("```python") + len("```python")
52
  end = completion.find("```", start)
53
  if end != -1:
54
- completion = completion[start:end].strip()
55
- elif "```" in completion:
 
 
56
  start = completion.find("```") + len("```")
57
  end = completion.find("```", start)
58
  if end != -1:
59
- completion = completion[start:end].strip()
60
 
61
- return completion
 
62
 
63
 
64
- def check_correctness(code: str, expected_output=None) -> bool:
65
- """Check if generated code produces correct output."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
- # Create a namespace for execution
68
- namespace = {}
 
 
69
  exec(code, namespace)
70
 
71
- # If we have expected output, check it
72
- if expected_output and 'solution' in namespace:
73
- result = namespace['solution']()
74
- return result == expected_output
75
 
76
- # Basic check: code executed without error
77
- return True
 
 
 
 
 
 
78
  except Exception as e:
79
- return False
 
80
 
81
 
82
- def evaluate_humaneval(model, tokenizer, num_samples: int = 10, k_values: List[int] = [1, 10, 100]) -> Dict:
83
- """Evaluate on HumanEval problems."""
84
- print("\n" + "="*60)
85
- print("Evaluating on HumanEval")
86
- print("="*60)
87
 
88
- # HumanEval problems (sample - add more as needed)
89
- humaneval_problems = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  {
91
- "task_id": "test_1",
92
- "prompt": "def two_sum(nums, target):\n \"\"\"Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.\"\"\"\n",
93
- "solution": "def two_sum(nums, target):\n seen = {}\n for i, num in enumerate(nums):\n complement = target - num\n if complement in seen:\n return [seen[complement], i]\n seen[num] = i\n return []",
94
- "test": "assert two_sum([2,7,11,15], 9) == [0,1]",
 
 
 
 
95
  },
96
  {
97
- "task_id": "test_2",
98
- "prompt": "def is_palindrome(x):\n \"\"\"Check if a number is a palindrome.\"\"\"\n",
99
- "solution": "def is_palindrome(x):\n if x < 0:\n return False\n return str(x) == str(x)[::-1]",
100
- "test": "assert is_palindrome(121) == True",
 
 
 
 
101
  },
102
  {
103
- "task_id": "test_3",
104
- "prompt": "def fizz_buzz(n):\n \"\"\"Return FizzBuzz list from 1 to n.\"\"\"\n",
105
- "solution": "def fizz_buzz(n):\n return ['FizzBuzz' if i%15==0 else 'Fizz' if i%3==0 else 'Buzz' if i%5==0 else str(i) for i in range(1,n+1)]",
106
- "test": "assert fizz_buzz(5) == ['1','2','Fizz','4','Buzz']",
 
 
 
 
107
  },
108
  {
109
- "task_id": "test_4",
110
- "prompt": "def fibonacci(n):\n \"\"\"Return the first n Fibonacci numbers.\"\"\"\n",
111
- "solution": "def fibonacci(n):\n if n <= 0:\n return []\n fib = [0, 1]\n while len(fib) < n:\n fib.append(fib[-1] + fib[-2])\n return fib[:n]",
112
- "test": "assert fibonacci(7) == [0, 1, 1, 2, 3, 5, 8]",
 
 
 
 
 
113
  },
114
  {
115
- "task_id": "test_5",
116
- "prompt": "def valid_parentheses(s):\n \"\"\"Check if string has valid parenthesis matching.\"\"\"\n",
117
- "solution": "def valid_parentheses(s):\n stack = []\n mapping = {')': '(', '}': '{', ']': '['}\n for char in s:\n if char in mapping:\n if not stack or stack.pop() != mapping[char]:\n return False\n else:\n stack.append(char)\n return not stack",
118
- "test": "assert valid_parentheses('()[]{}') == True",
 
 
 
119
  },
120
  {
121
- "task_id": "test_6",
122
- "prompt": "def reverse_string(s):\n \"\"\"Reverse a string.\"\"\"\n",
123
- "solution": "def reverse_string(s):\n return s[::-1]",
124
- "test": "assert reverse_string('hello') == 'olleh'",
 
 
 
125
  },
126
  {
127
- "task_id": "test_7",
128
- "prompt": "def merge_sorted_lists(l1, l2):\n \"\"\"Merge two sorted lists into one sorted list.\"\"\"\n",
129
- "solution": "def merge_sorted_lists(l1, l2):\n return sorted(l1 + l2)",
130
- "test": "assert merge_sorted_lists([1,3,5], [2,4,6]) == [1,2,3,4,5,6]",
 
 
 
 
131
  },
132
  {
133
- "task_id": "test_8",
134
- "prompt": "def maximum_subarray(nums):\n \"\"\"Find the contiguous subarray with the largest sum.\"\"\"\n",
135
- "solution": "def maximum_subarray(nums):\n max_sum = nums[0]\n current_sum = nums[0]\n for num in nums[1:]:\n current_sum = max(num, current_sum + num)\n max_sum = max(max_sum, current_sum)\n return max_sum",
136
- "test": "assert maximum_subarray([-2,1,-3,4,-1,2,1,-5,4]) == 6",
 
 
 
 
137
  },
138
  {
139
- "task_id": "test_9",
140
- "prompt": "def climbing_stairs(n):\n \"\"\"Count ways to climb n stairs (1 or 2 steps at a time).\"\"\"\n",
141
- "solution": "def climbing_stairs(n):\n if n <= 2:\n return n\n a, b = 1, 2\n for _ in range(3, n+1):\n a, b = b, a + b\n return b",
142
- "test": "assert climbing_stairs(5) == 8",
 
 
 
143
  },
144
  {
145
- "task_id": "test_10",
146
- "prompt": "def contains_duplicate(nums):\n \"\"\"Check if array contains any duplicate.\"\"\"\n",
147
- "solution": "def contains_duplicate(nums):\n return len(nums) != len(set(nums))",
148
- "test": "assert contains_duplicate([1,2,3,1]) == True",
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  },
150
  ]
151
-
152
- # Limit to num_samples
153
- problems = humaneval_problems[:num_samples]
154
-
155
- results = []
156
- for i, problem in enumerate(problems):
157
- print(f"\nProblem {i+1}/{len(problems)}: {problem['task_id']}")
158
- print(f"Prompt: {problem['prompt'][:50]}...")
159
-
160
- start = time.time()
161
- solution = generate_solution(model, tokenizer, problem['prompt'])
162
- elapsed = time.time() - start
163
-
164
- print(f"Generated in {elapsed:.2f}s")
165
- print(f"Solution preview: {solution[:100]}...")
166
-
167
- # Try to execute the solution
168
- correct = check_correctness(solution)
169
- results.append({
170
- "task_id": problem["task_id"],
171
- "solution": solution,
172
- "correct": correct,
173
- "time": elapsed,
174
- })
175
-
176
- print(f"Result: {'✅ CORRECT' if correct else '❌ INCORRECT'}")
177
-
178
- # Calculate pass@k
179
- passed = sum(1 for r in results if r['correct'])
180
- total = len(results)
181
-
182
- print("\n" + "="*60)
183
- print("HumanEval Results")
184
- print("="*60)
185
- print(f"Total: {total}")
186
- print(f"Passed: {passed}")
187
- print(f"Pass@1: {100 * passed / total:.1f}%")
188
-
189
- return {
190
- "total": total,
191
- "passed": passed,
192
- "pass_at_1": passed / total if total > 0 else 0,
193
- "results": results,
194
- }
195
 
196
 
197
- def evaluate_mbpp(model, tokenizer, num_samples: int = 10) -> Dict:
198
- """Evaluate on MBPP (Mostly Basic Python Problems)."""
199
- print("\n" + "="*60)
200
- print("Evaluating on MBPP")
201
- print("="*60)
202
-
203
- # MBPP problems (sample)
204
- mbpp_problems = [
205
  {
206
- "task_id": "mbpp_1",
207
- "prompt": "def add_numbers(a, b):\n # Return the sum of a and b\n",
208
- "solution": "def add_numbers(a, b):\n return a + b",
209
- "test": "assert add_numbers(2, 3) == 5",
 
 
 
 
210
  },
211
  {
212
- "task_id": "mbpp_2",
213
- "prompt": "def multiply_list(nums):\n # Return the product of all numbers in the list\n",
214
- "solution": "def multiply_list(nums):\n result = 1\n for num in nums:\n result *= num\n return result",
215
- "test": "assert multiply_list([1, 2, 3, 4]) == 24",
 
 
 
 
 
216
  },
217
  {
218
- "task_id": "mbpp_3",
219
- "prompt": "def square(x):\n # Return the square of x\n",
220
- "solution": "def square(x):\n return x ** 2",
221
- "test": "assert square(5) == 25",
 
 
 
 
 
222
  },
223
  {
224
- "task_id": "mbpp_4",
225
- "prompt": "def is_even(n):\n # Return True if n is even, False otherwise\n",
226
- "solution": "def is_even(n):\n return n % 2 == 0",
227
- "test": "assert is_even(4) == True",
 
 
 
 
 
228
  },
229
  {
230
- "task_id": "mbpp_5",
231
- "prompt": "def string_length(s):\n # Return the length of string s\n",
232
- "solution": "def string_length(s):\n return len(s)",
233
- "test": "assert string_length('hello') == 5",
 
 
 
 
234
  },
235
  {
236
- "task_id": "mbpp_6",
237
- "prompt": "def get_max(nums):\n # Return the maximum number from the list\n",
238
- "solution": "def get_max(nums):\n return max(nums)",
239
- "test": "assert get_max([1, 5, 3]) == 5",
 
 
 
 
240
  },
241
  {
242
- "task_id": "mbpp_7",
243
- "prompt": "def get_min(nums):\n # Return the minimum number from the list\n",
244
- "solution": "def get_min(nums):\n return min(nums)",
245
- "test": "assert get_min([1, 5, 3]) == 1",
 
 
 
 
246
  },
247
  {
248
- "task_id": "mbpp_8",
249
- "prompt": "def count_zeros(nums):\n # Return the count of zeros in the list\n",
250
- "solution": "def count_zeros(nums):\n return nums.count(0)",
251
- "test": "assert count_zeros([0, 1, 0, 2, 0]) == 3",
 
 
 
 
252
  },
253
  {
254
- "task_id": "mbpp_9",
255
- "prompt": "def reverse_list(lst):\n # Return a new list with elements in reverse order\n",
256
- "solution": "def reverse_list(lst):\n return lst[::-1]",
257
- "test": "assert reverse_list([1, 2, 3]) == [3, 2, 1]",
 
 
 
 
258
  },
259
  {
260
- "task_id": "mbpp_10",
261
- "prompt": "def unique_elements(lst):\n # Return list of unique elements\n",
262
- "solution": "def unique_elements(lst):\n return list(set(lst))",
263
- "test": "assert unique_elements([1, 2, 2, 3]) == [1, 2, 3]",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  },
265
  ]
266
-
267
- problems = mbpp_problems[:num_samples]
268
-
269
- results = []
270
- for i, problem in enumerate(problems):
271
- print(f"\nProblem {i+1}/{len(problems)}: {problem['task_id']}")
272
- print(f"Prompt: {problem['prompt'][:50]}...")
273
-
274
- start = time.time()
275
- solution = generate_solution(model, tokenizer, problem['prompt'])
276
- elapsed = time.time() - start
277
-
278
- print(f"Generated in {elapsed:.2f}s")
279
- print(f"Solution preview: {solution[:100]}...")
280
-
281
- correct = check_correctness(solution)
282
- results.append({
283
- "task_id": problem["task_id"],
284
- "solution": solution,
285
- "correct": correct,
286
- "time": elapsed,
287
- })
288
-
289
- print(f"Result: {'✅ CORRECT' if correct else '❌ INCORRECT'}")
290
-
291
- passed = sum(1 for r in results if r['correct'])
292
- total = len(results)
293
-
294
- print("\n" + "="*60)
295
- print("MBPP Results")
296
- print("="*60)
297
- print(f"Total: {total}")
298
- print(f"Passed: {passed}")
299
- print(f"Pass@1: {100 * passed / total:.1f}%")
300
-
301
- return {
302
- "total": total,
303
- "passed": passed,
304
- "pass_at_1": passed / total if total > 0 else 0,
305
- "results": results,
306
- }
307
 
308
 
309
- def save_results(humaneval_results, mbpp_results, output_path: str):
310
- """Save evaluation results to JSON."""
311
- combined = {
312
- "humaneval": humaneval_results,
313
- "mbpp": mbpp_results,
314
- "summary": {
315
- "humaneval_pass_at_1": humaneval_results["pass_at_1"],
316
- "mbpp_pass_at_1": mbpp_results["pass_at_1"],
317
- "combined_pass_at_1": (
318
- humaneval_results["pass_at_1"] + mbpp_results["pass_at_1"]
319
- ) / 2,
320
- }
321
- }
322
-
323
  with open(output_path, 'w') as f:
324
- json.dump(combined, f, indent=2)
325
-
326
  print(f"\n✅ Results saved to: {output_path}")
327
- return combined
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
 
330
  def main():
331
- parser = argparse.ArgumentParser(description="Evaluate fine-tuned Stack 2.9 model")
332
- parser.add_argument("--model-path", type=str, required=True, help="Path to fine-tuned model")
333
- parser.add_argument("--output", type=str, default="evaluation_results.json", help="Output file for results")
334
- parser.add_argument("--num-samples", type=int, default=10, help="Number of samples per benchmark")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  args = parser.parse_args()
336
 
 
 
 
337
  print("="*60)
338
  print("Stack 2.9 Model Evaluation")
339
  print("="*60)
 
 
 
 
 
340
 
 
341
  model, tokenizer = load_model(args.model_path)
342
  model.eval()
343
 
344
- # Run evaluations
345
- humaneval_results = evaluate_humaneval(model, tokenizer, args.num_samples)
346
- mbpp_results = evaluate_mbpp(model, tokenizer, args.num_samples)
347
 
348
- # Save results
349
- combined = save_results(humaneval_results, mbpp_results, args.output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
 
351
  print("\n" + "="*60)
352
  print("FINAL SUMMARY")
353
  print("="*60)
354
- print(f"HumanEval Pass@1: {100 * combined['summary']['humaneval_pass_at_1']:.1f}%")
355
- print(f"MBPP Pass@1: {100 * combined['summary']['mbpp_pass_at_1']:.1f}%")
356
- print(f"Combined Score: {100 * combined['summary']['combined_pass_at_1']:.1f}%")
 
 
 
 
 
 
 
357
  print("="*60)
 
 
 
 
 
 
 
 
 
 
 
358
 
359
 
360
  if __name__ == "__main__":
 
2
  """
3
  HumanEval + MBPP Benchmark Evaluation for Stack 2.9
4
  Tests code generation quality using pass@k metrics.
5
+
6
+ Usage:
7
+ python evaluate_model.py --model-path /path/to/merged/model --num-samples 10
8
+ python evaluate_model.py --model-path /path/to/merged/model --output results.json
9
  """
10
 
11
  import argparse
12
  import os
13
  import json
14
  import time
15
+ import traceback
16
+ from typing import List, Dict, Tuple, Optional
17
+ from collections import defaultdict
18
+ import itertools
19
  import torch
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
 
23
+ def load_model(model_path: str, max_memory: Optional[Dict] = None):
24
  """Load the fine-tuned model and tokenizer."""
25
  print(f"Loading model from: {model_path}")
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
28
+ if tokenizer.pad_token is None:
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+
31
+ kwargs = {
32
+ "torch_dtype": torch.float16,
33
+ "device_map": "auto",
34
+ "low_cpu_mem_usage": True,
35
+ "trust_remote_code": True,
36
+ }
37
+
38
+ if max_memory:
39
+ kwargs["max_memory"] = max_memory
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
42
  return model, tokenizer
43
 
44
 
45
+ def generate_solution(
46
+ model,
47
+ tokenizer,
48
+ prompt: str,
49
+ max_new_tokens: int = 256,
50
+ temperature: float = 0.8,
51
+ top_p: float = 0.95,
52
+ num_return_sequences: int = 1
53
+ ) -> List[str]:
54
+ """Generate solutions for a prompt.
55
+
56
+ Returns a list of generated completions.
57
+ """
58
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
59
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
60
 
61
+ outputs = model.generate(
62
+ **inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ do_sample=True,
67
+ repetition_penalty=1.1,
68
+ num_return_sequences=num_return_sequences,
69
+ pad_token_id=tokenizer.pad_token_id,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ )
72
 
73
+ completions = []
74
+ for output in outputs:
75
+ text = tokenizer.decode(output, skip_special_tokens=True)
76
+ # Remove the prompt from the completion
77
+ if text.startswith(prompt):
78
+ text = text[len(prompt):]
79
+ completions.append(text.strip())
80
 
81
+ return completions
82
+
83
+
84
+ def extract_code(completion: str) -> str:
85
+ """Extract code from completion, handling markdown code blocks."""
86
+ # Try to extract from ```python blocks
87
  if "```python" in completion:
88
  start = completion.find("```python") + len("```python")
89
  end = completion.find("```", start)
90
  if end != -1:
91
+ return completion[start:end].strip()
92
+
93
+ # Try ``` blocks (generic)
94
+ if "```" in completion:
95
  start = completion.find("```") + len("```")
96
  end = completion.find("```", start)
97
  if end != -1:
98
+ return completion[start:end].strip()
99
 
100
+ # If no code blocks, return as-is but clean up
101
+ return completion.strip()
102
 
103
 
104
+ def execute_code(code: str, timeout: int = 5) -> Tuple[bool, str, Optional[any]]:
105
+ """Safely execute code and return (success, error_msg, result).
106
+
107
+ Uses restricted builtins and timeout for safety.
108
+ """
109
+ import signal
110
+
111
+ class TimeoutError(Exception):
112
+ pass
113
+
114
+ def timeout_handler(signum, frame):
115
+ raise TimeoutError("Execution timed out")
116
+
117
+ # Restricted globals for safe execution
118
+ safe_builtins = {
119
+ 'print': print,
120
+ 'len': len,
121
+ 'range': range,
122
+ 'str': str,
123
+ 'int': int,
124
+ 'float': float,
125
+ 'bool': bool,
126
+ 'list': list,
127
+ 'dict': dict,
128
+ 'set': set,
129
+ 'tuple': tuple,
130
+ 'sum': sum,
131
+ 'min': min,
132
+ 'max': max,
133
+ 'abs': abs,
134
+ 'sorted': sorted,
135
+ 'reversed': reversed,
136
+ 'enumerate': enumerate,
137
+ 'zip': zip,
138
+ 'map': map,
139
+ 'filter': filter,
140
+ 'any': any,
141
+ 'all': all,
142
+ 'isinstance': isinstance,
143
+ 'type': type,
144
+ 'round': round,
145
+ 'pow': pow,
146
+ 'divmod': divmod,
147
+ 'ord': ord,
148
+ 'chr': chr,
149
+ 'hex': hex,
150
+ 'bin': bin,
151
+ 'id': id,
152
+ }
153
+
154
+ namespace = {
155
+ '__builtins__': safe_builtins,
156
+ }
157
+
158
  try:
159
+ # Set timeout
160
+ signal.signal(signal.SIGALRM, timeout_handler)
161
+ signal.alarm(timeout)
162
+
163
  exec(code, namespace)
164
 
165
+ # Cancel alarm
166
+ signal.alarm(0)
 
 
167
 
168
+ return True, "", namespace.get('result')
169
+
170
+ except TimeoutError as e:
171
+ signal.alarm(0)
172
+ return False, f"Timeout after {timeout}s", None
173
+ except SyntaxError as e:
174
+ signal.alarm(0)
175
+ return False, f"Syntax error: {e}", None
176
  except Exception as e:
177
+ signal.alarm(0)
178
+ return False, f"{type(e).__name__}: {e}", None
179
 
180
 
181
+ def check_correctness(code: str, test_cases: List[Dict]) -> Tuple[bool, str]:
182
+ """Check if generated code passes test cases.
 
 
 
183
 
184
+ Args:
185
+ code: The generated code to test
186
+ test_cases: List of dicts with 'input' and 'expected' keys
187
+
188
+ Returns:
189
+ Tuple of (all_passed, failure_message)
190
+ """
191
+ import types
192
+
193
+ # Create a new namespace for execution
194
+ namespace = {}
195
+ safe_builtins = {
196
+ 'print': print,
197
+ 'len': len,
198
+ 'range': range,
199
+ 'str': str,
200
+ 'int': int,
201
+ 'float': float,
202
+ 'bool': bool,
203
+ 'list': list,
204
+ 'dict': dict,
205
+ 'set': set,
206
+ 'tuple': tuple,
207
+ 'sum': sum,
208
+ 'min': min,
209
+ 'max': max,
210
+ 'abs': abs,
211
+ 'sorted': sorted,
212
+ 'reversed': reversed,
213
+ 'enumerate': enumerate,
214
+ 'zip': zip,
215
+ 'map': map,
216
+ 'filter': filter,
217
+ 'any': any,
218
+ 'all': all,
219
+ 'isinstance': isinstance,
220
+ 'type': type,
221
+ 'round': round,
222
+ 'pow': pow,
223
+ }
224
+ namespace['__builtins__'] = safe_builtins
225
+
226
+ try:
227
+ exec(code, namespace)
228
+ except Exception as e:
229
+ return False, f"Execution failed: {type(e).__name__}: {e}"
230
+
231
+ for tc in test_cases:
232
+ func_name = tc.get('function', 'solution')
233
+ inputs = tc.get('input', ())
234
+ expected = tc.get('expected')
235
+ description = tc.get('description', '')
236
+
237
+ if func_name not in namespace:
238
+ return False, f"Function '{func_name}' not found in code"
239
+
240
+ func = namespace[func_name]
241
+
242
+ try:
243
+ if isinstance(inputs, tuple):
244
+ result = func(*inputs)
245
+ else:
246
+ result = func(inputs)
247
+ except Exception as e:
248
+ return False, f"Failed on {description or str(inputs)}: {type(e).__name__}: {e}"
249
+
250
+ if result != expected:
251
+ return False, f"Failed on {description or str(inputs)}: expected {expected}, got {result}"
252
+
253
+ return True, ""
254
+
255
+
256
+ def calculate_pass_at_k(num_correct: int, num_samples: int, k: int) -> float:
257
+ """Calculate pass@k metric.
258
+
259
+ Uses the estimator: 1 - C(n-c+k-1, k) / C(n+k-1, k)
260
+ where n = num_samples, c = num_correct, k = k
261
+
262
+ For small samples, this is more accurate than simple c/n.
263
+ """
264
+ import math
265
+
266
+ if num_samples < k:
267
+ return 0.0
268
+
269
+ if num_samples == 0:
270
+ return 0.0
271
+
272
+ # Bootstrap-style calculation
273
+ # "At least one of k samples is correct" probability
274
+ try:
275
+ # Exact formula: 1 - (C(n-c, k) / C(n, k))
276
+ # But we use the complementary for numerical stability
277
+ correct = num_correct
278
+ n = num_samples
279
+ fail = n - correct
280
+
281
+ if fail >= k:
282
+ return 0.0
283
+
284
+ # Calculate probability that at least one succeeds
285
+ # P(at least 1 success) = 1 - P(all k fail)
286
+ # P(all k fail) = C(fail, k) / C(n, k)
287
+
288
+ numerator = 1.0
289
+ denominator = 1.0
290
+
291
+ for i in range(k):
292
+ numerator *= (fail - i)
293
+ denominator *= (n - i)
294
+
295
+ p_all_fail = numerator / denominator
296
+ p_at_least_1_success = 1 - p_all_fail
297
+
298
+ return p_at_least_1_success
299
+ except:
300
+ # Fallback to simple ratio
301
+ return num_correct / num_samples
302
+
303
+
304
+ def evaluate_problems(
305
+ model,
306
+ tokenizer,
307
+ problems: List[Dict],
308
+ k_values: List[int] = [1, 10],
309
+ num_samples_per_problem: int = 10,
310
+ max_new_tokens: int = 256,
311
+ ) -> Dict:
312
+ """Evaluate model on a set of problems with pass@k metrics.
313
+
314
+ Args:
315
+ model: The language model
316
+ tokenizer: The tokenizer
317
+ problems: List of problem dicts with 'task_id', 'prompt', 'test_cases'
318
+ k_values: List of k values for pass@k calculation
319
+ num_samples_per_problem: Number of samples to generate per problem
320
+ max_new_tokens: Max tokens to generate
321
+
322
+ Returns:
323
+ Dictionary with evaluation results
324
+ """
325
+ all_results = []
326
+ total_correct_per_k = {k: 0 for k in k_values}
327
+ total_problems = len(problems)
328
+
329
+ for idx, problem in enumerate(problems):
330
+ task_id = problem['task_id']
331
+ prompt = problem['prompt']
332
+ test_cases = problem.get('test_cases', [])
333
+
334
+ print(f"\n[{idx+1}/{total_problems}] Processing: {task_id}")
335
+
336
+ # Generate multiple samples
337
+ start_time = time.time()
338
+ completions = generate_solution(
339
+ model, tokenizer, prompt,
340
+ max_new_tokens=max_new_tokens,
341
+ num_return_sequences=num_samples_per_problem
342
+ )
343
+ elapsed = time.time() - start_time
344
+
345
+ print(f" Generated {len(completions)} samples in {elapsed:.2f}s")
346
+
347
+ # Check each completion
348
+ correct_flags = []
349
+ for i, completion in enumerate(completions):
350
+ code = extract_code(completion)
351
+
352
+ # For pass@10, we consider the completion correct if it passes tests
353
+ # For pass@1, we only consider the first sample
354
+ passed, msg = check_correctness(code, test_cases)
355
+ correct_flags.append(passed)
356
+
357
+ if i == 0: # Show first result detail
358
+ print(f" Sample 1: {'✅' if passed else '❌'} {msg[:60] if msg else 'OK'}")
359
+
360
+ # Calculate pass@k for this problem
361
+ num_correct = sum(correct_flags)
362
+
363
+ problem_results = {
364
+ "task_id": task_id,
365
+ "prompt": prompt,
366
+ "num_samples": len(completions),
367
+ "num_correct": num_correct,
368
+ "pass@k": {},
369
+ }
370
+
371
+ for k in k_values:
372
+ if k <= num_samples_per_problem:
373
+ # How many of the first k samples are correct?
374
+ correct_in_k = sum(correct_flags[:min(k, len(correct_flags))])
375
+ if k == 1:
376
+ # pass@1 = whether first sample is correct
377
+ pass_at_k = 1.0 if correct_flags[0] else 0.0
378
+ else:
379
+ # pass@k = probability that at least one of k is correct
380
+ pass_at_k = calculate_pass_at_k(correct_in_k, k, k)
381
+
382
+ problem_results["pass@k"][f"pass@{k}"] = pass_at_k
383
+ total_correct_per_k[k] += correct_in_k
384
+
385
+ all_results.append(problem_results)
386
+
387
+ # Progress update
388
+ if k_values[0] == 1:
389
+ current_pass1 = total_correct_per_k.get(1, 0) / (idx + 1)
390
+ print(f" Running Pass@1: {100*current_pass1:.1f}%")
391
+
392
+ # Aggregate results
393
+ summary = {
394
+ "total_problems": total_problems,
395
+ "total_samples_per_problem": num_samples_per_problem,
396
+ }
397
+
398
+ for k in k_values:
399
+ if k <= num_samples_per_problem:
400
+ # Overall pass@k
401
+ total_correct_for_k = 0
402
+ total_possible_for_k = 0
403
+
404
+ for r in all_results:
405
+ if f"pass@{k}" in r["pass@k"]:
406
+ # For pass@1, it's binary
407
+ if k == 1:
408
+ total_correct_for_k += r["num_correct"] > 0
409
+ else:
410
+ # For pass@10, count how many problems have at least 1 correct in first k
411
+ correct_in_k = sum(correct_flags[:min(k, len(correct_flags))])
412
+ total_correct_for_k += 1 if correct_in_k > 0 else 0
413
+ total_possible_for_k += 1
414
+
415
+ summary[f"pass@{k}"] = total_correct_for_k / total_possible_for_k if total_possible_for_k > 0 else 0
416
+ summary[f"pass@{k}_exact"] = total_correct_for_k
417
+ summary[f"total@{k}"] = total_possible_for_k
418
+
419
+ return {
420
+ "summary": summary,
421
+ "details": all_results,
422
+ }
423
+
424
+
425
+ def get_humaneval_problems() -> List[Dict]:
426
+ """Return HumanEval benchmark problems."""
427
+ return [
428
+ {
429
+ "task_id": "humaneval/1",
430
+ "prompt": '''def two_sum(nums, target):
431
+ """Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.
432
+ You may assume that each input would have exactly one solution, and you may not use the same element twice.
433
+ """''',
434
+ "test_cases": [
435
+ {"function": "two_sum", "input": ([2,7,11,15], 9), "expected": [0,1], "description": "Basic case"},
436
+ {"function": "two_sum", "input": ([3,2,4], 6), "expected": [1,2], "description": "Middle elements"},
437
+ {"function": "two_sum", "input": ([3,3], 6), "expected": [0,1], "description": "Duplicate values"},
438
+ ],
439
+ },
440
  {
441
+ "task_id": "humaneval/2",
442
+ "prompt": '''def is_palindrome(x):
443
+ """Check if an integer is a palindrome. An integer is a palindrome when it reads the same backward as forward."''',
444
+ "test_cases": [
445
+ {"function": "is_palindrome", "input": 121, "expected": True, "description": "Positive palindrome"},
446
+ {"function": "is_palindrome", "input": -121, "expected": False, "description": "Negative number"},
447
+ {"function": "is_palindrome", "input": 10, "expected": False, "description": "Ends with 0"},
448
+ ],
449
  },
450
  {
451
+ "task_id": "humaneval/3",
452
+ "prompt": '''def fizz_buzz(n):
453
+ """Given number n, return a list of strings from 1 to n. For multiples of 3 add 'Fizz', for multiples of 5 add 'Buzz', for both add 'FizzBuzz'."''',
454
+ "test_cases": [
455
+ {"function": "fizz_buzz", "input": 3, "expected": ["1", "2", "Fizz"], "description": "n=3"},
456
+ {"function": "fizz_buzz", "input": 5, "expected": ["1", "2", "Fizz", "4", "Buzz"], "description": "n=5"},
457
+ {"function": "fizz_buzz", "input": 15, "expected": ["1","2","Fizz","4","Buzz","Fizz","7","8","Fizz","Buzz","11","Fizz","13","14","FizzBuzz"], "description": "n=15"},
458
+ ],
459
  },
460
  {
461
+ "task_id": "humaneval/4",
462
+ "prompt": '''def fibonacci(n):
463
+ """Return the first n Fibonacci numbers starting from 0 and 1. So fibonacci(7) returns [0, 1, 1, 2, 3, 5, 8]."''',
464
+ "test_cases": [
465
+ {"function": "fibonacci", "input": 1, "expected": [0], "description": "n=1"},
466
+ {"function": "fibonacci", "input": 7, "expected": [0, 1, 1, 2, 3, 5, 8], "description": "n=7"},
467
+ {"function": "fibonacci", "input": 10, "expected": [0, 1, 1, 2, 3, 5, 8, 13, 21, 34], "description": "n=10"},
468
+ ],
469
  },
470
  {
471
+ "task_id": "humaneval/5",
472
+ "prompt": '''def valid_parentheses(s):
473
+ """Given a string s containing just the characters '(', ')', '{', '}', '[' and ']', determine if the input string is valid. An input string is valid if: Open brackets must be closed by the same type of brackets, and Open brackets must be closed in the correct order."''',
474
+ "test_cases": [
475
+ {"function": "valid_parentheses", "input": "()", "expected": True, "description": "Simple pair"},
476
+ {"function": "valid_parentheses", "input": "()[]{}", "expected": True, "description": "Multiple types"},
477
+ {"function": "valid_parentheses", "input": "(]", "expected": False, "description": "Mismatched"},
478
+ {"function": "valid_parentheses", "input": "([)]", "expected": False, "description": "Wrong order"},
479
+ ],
480
  },
481
  {
482
+ "task_id": "humaneval/6",
483
+ "prompt": '''def reverse_string(s):
484
+ """Return the reverse of string s."''',
485
+ "test_cases": [
486
+ {"function": "reverse_string", "input": "hello", "expected": "olleh", "description": "Basic"},
487
+ {"function": "reverse_string", "input": "Hannah", "expected": "hannaH", "description": "Palindrome name"},
488
+ ],
489
  },
490
  {
491
+ "task_id": "humaneval/7",
492
+ "prompt": '''def merge_sorted_lists(l1, l2):
493
+ """Merge two sorted lists into one sorted list."''',
494
+ "test_cases": [
495
+ {"function": "merge_sorted_lists", "input": ([1,3,5], [2,4,6]), "expected": [1,2,3,4,5,6], "description": "Interleaved"},
496
+ {"function": "merge_sorted_lists", "input": ([1,2,3], [4,5,6]), "expected": [1,2,3,4,5,6], "description": "Sequential"},
497
+ ],
498
  },
499
  {
500
+ "task_id": "humaneval/8",
501
+ "prompt": '''def maximum_subarray(nums):
502
+ """Find the contiguous subarray which has the largest sum and return its sum."''',
503
+ "test_cases": [
504
+ {"function": "maximum_subarray", "input": [-2,1,-3,4,-1,2,1,-5,4], "expected": 6, "description": "Mixed"},
505
+ {"function": "maximum_subarray", "input": [1], "expected": 1, "description": "Single element"},
506
+ {"function": "maximum_subarray", "input": [5,4,-1,7,8], "expected": 23, "description": "Mostly positive"},
507
+ ],
508
  },
509
  {
510
+ "task_id": "humaneval/9",
511
+ "prompt": '''def climbing_stairs(n):
512
+ """You are climbing a staircase. It takes n steps to reach the top. Each time you can either climb 1 or 2 steps. In how many distinct ways can you climb to the top?"''',
513
+ "test_cases": [
514
+ {"function": "climbing_stairs", "input": 2, "expected": 2, "description": "n=2"},
515
+ {"function": "climbing_stairs", "input": 3, "expected": 3, "description": "n=3"},
516
+ {"function": "climbing_stairs", "input": 5, "expected": 8, "description": "n=5"},
517
+ ],
518
  },
519
  {
520
+ "task_id": "humaneval/10",
521
+ "prompt": '''def contains_duplicate(nums):
522
+ """Given an integer array nums, return True if any value appears at least twice in the array, and False if every element is distinct."''',
523
+ "test_cases": [
524
+ {"function": "contains_duplicate", "input": [1,2,3,1], "expected": True, "description": "Has duplicate"},
525
+ {"function": "contains_duplicate", "input": [1,2,3,4], "expected": False, "description": "All unique"},
526
+ ],
527
  },
528
  {
529
+ "task_id": "humaneval/11",
530
+ "prompt": '''def roman_to_int(s):
531
+ """Convert a Roman numeral to an integer."''',
532
+ "test_cases": [
533
+ {"function": "roman_to_int", "input": "III", "expected": 3, "description": "Simple"},
534
+ {"function": "roman_to_int", "input": "IV", "expected": 4, "description": "Subtractive"},
535
+ {"function": "roman_to_int", "input": "MCMXCIV", "expected": 1994, "description": "Complex"},
536
+ ],
537
+ },
538
+ {
539
+ "task_id": "humaneval/12",
540
+ "prompt": '''def longest_common_prefix(strs):
541
+ """Write a function to find the longest common prefix string amongst an array of strings."''',
542
+ "test_cases": [
543
+ {"function": "longest_common_prefix", "input": ["flower","flow","flight"], "expected": "fl", "description": "Basic"},
544
+ {"function": "longest_common_prefix", "input": ["dog","racecar","car"], "expected": "", "description": "No prefix"},
545
+ ],
546
  },
547
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
 
549
 
550
+ def get_mbpp_problems() -> List[Dict]:
551
+ """Return MBPP (Mostly Basic Python Problems) benchmark problems."""
552
+ return [
 
 
 
 
 
553
  {
554
+ "task_id": "mbpp/1",
555
+ "prompt": '''def add_numbers(a, b):
556
+ # Return the sum of a and b
557
+ pass''',
558
+ "test_cases": [
559
+ {"function": "add_numbers", "input": (2, 3), "expected": 5, "description": "Basic add"},
560
+ {"function": "add_numbers", "input": (-1, 1), "expected": 0, "description": "Opposite signs"},
561
+ ],
562
  },
563
  {
564
+ "task_id": "mbpp/2",
565
+ "prompt": '''def multiply_list(nums):
566
+ # Return the product of all numbers in the list
567
+ pass''',
568
+ "test_cases": [
569
+ {"function": "multiply_list", "input": ([1, 2, 3, 4],), "expected": 24, "description": "Basic"},
570
+ {"function": "multiply_list", "input": ([5,],), "expected": 5, "description": "Single element"},
571
+ {"function": "multiply_list", "input": ([],), "expected": 1, "description": "Empty (identity)"},
572
+ ],
573
  },
574
  {
575
+ "task_id": "mbpp/3",
576
+ "prompt": '''def square(x):
577
+ # Return the square of x
578
+ pass''',
579
+ "test_cases": [
580
+ {"function": "square", "input": (5,), "expected": 25, "description": "Basic"},
581
+ {"function": "square", "input": (-3,), "expected": 9, "description": "Negative"},
582
+ {"function": "square", "input": (0,), "expected": 0, "description": "Zero"},
583
+ ],
584
  },
585
  {
586
+ "task_id": "mbpp/4",
587
+ "prompt": '''def is_even(n):
588
+ # Return True if n is even, False otherwise
589
+ pass''',
590
+ "test_cases": [
591
+ {"function": "is_even", "input": (4,), "expected": True, "description": "Even number"},
592
+ {"function": "is_even", "input": (7,), "expected": False, "description": "Odd number"},
593
+ {"function": "is_even", "input": (0,), "expected": True, "description": "Zero is even"},
594
+ ],
595
  },
596
  {
597
+ "task_id": "mbpp/5",
598
+ "prompt": '''def string_length(s):
599
+ # Return the length of string s
600
+ pass''',
601
+ "test_cases": [
602
+ {"function": "string_length", "input": ("hello",), "expected": 5, "description": "Basic"},
603
+ {"function": "string_length", "input": ("",), "expected": 0, "description": "Empty string"},
604
+ ],
605
  },
606
  {
607
+ "task_id": "mbpp/6",
608
+ "prompt": '''def get_max(nums):
609
+ # Return the maximum number from the list
610
+ pass''',
611
+ "test_cases": [
612
+ {"function": "get_max", "input": ([1, 5, 3],), "expected": 5, "description": "Basic"},
613
+ {"function": "get_max", "input": ([-1, -5, -3],), "expected": -1, "description": "Negative numbers"},
614
+ ],
615
  },
616
  {
617
+ "task_id": "mbpp/7",
618
+ "prompt": '''def get_min(nums):
619
+ # Return the minimum number from the list
620
+ pass''',
621
+ "test_cases": [
622
+ {"function": "get_min", "input": ([1, 5, 3],), "expected": 1, "description": "Basic"},
623
+ {"function": "get_min", "input": ([-1, -5, -3],), "expected": -5, "description": "Negative numbers"},
624
+ ],
625
  },
626
  {
627
+ "task_id": "mbpp/8",
628
+ "prompt": '''def count_zeros(nums):
629
+ # Return the count of zeros in the list
630
+ pass''',
631
+ "test_cases": [
632
+ {"function": "count_zeros", "input": ([0, 1, 0, 2, 0],), "expected": 3, "description": "Mixed"},
633
+ {"function": "count_zeros", "input": ([1, 2, 3],), "expected": 0, "description": "No zeros"},
634
+ ],
635
  },
636
  {
637
+ "task_id": "mbpp/9",
638
+ "prompt": '''def reverse_list(lst):
639
+ # Return a new list with elements in reverse order
640
+ pass''',
641
+ "test_cases": [
642
+ {"function": "reverse_list", "input": ([1, 2, 3],), "expected": [3, 2, 1], "description": "Basic"},
643
+ {"function": "reverse_list", "input": ([],), "expected": [], "description": "Empty"},
644
+ ],
645
  },
646
  {
647
+ "task_id": "mbpp/10",
648
+ "prompt": '''def unique_elements(lst):
649
+ # Return list of unique elements (preserving order)
650
+ pass''',
651
+ "test_cases": [
652
+ {"function": "unique_elements", "input": ([1, 2, 2, 3],), "expected": [1, 2, 3], "description": "With duplicates"},
653
+ {"function": "unique_elements", "input": ([1, 2, 3],), "expected": [1, 2, 3], "description": "All unique"},
654
+ ],
655
+ },
656
+ {
657
+ "task_id": "mbpp/11",
658
+ "prompt": '''def factorial(n):
659
+ # Return n! (factorial of n)
660
+ pass''',
661
+ "test_cases": [
662
+ {"function": "factorial", "input": (5,), "expected": 120, "description": "Basic"},
663
+ {"function": "factorial", "input": (0,), "expected": 1, "description": "Zero factorial"},
664
+ {"function": "factorial", "input": (1,), "expected": 1, "description": "One factorial"},
665
+ ],
666
+ },
667
+ {
668
+ "task_id": "mbpp/12",
669
+ "prompt": '''def is_prime(n):
670
+ # Return True if n is prime, False otherwise
671
+ pass''',
672
+ "test_cases": [
673
+ {"function": "is_prime", "input": (7,), "expected": True, "description": "Prime"},
674
+ {"function": "is_prime", "input": (4,), "expected": False, "description": "Not prime"},
675
+ {"function": "is_prime", "input": (1,), "expected": False, "description": "One is not prime"},
676
+ ],
677
  },
678
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
 
681
+ def save_results(results: Dict, output_path: str):
682
+ """Save evaluation results to JSON file."""
 
 
 
 
 
 
 
 
 
 
 
 
683
  with open(output_path, 'w') as f:
684
+ json.dump(results, f, indent=2)
 
685
  print(f"\n✅ Results saved to: {output_path}")
686
+
687
+
688
+ def print_summary(results: Dict, benchmark_name: str):
689
+ """Print a summary of evaluation results."""
690
+ print(f"\n{'='*60}")
691
+ print(f"{benchmark_name} Results")
692
+ print('='*60)
693
+
694
+ summary = results.get("summary", {})
695
+ total = summary.get("total_problems", 0)
696
+
697
+ for key, value in summary.items():
698
+ if key.startswith("pass@"):
699
+ print(f" {key}: {100*value:.2f}%")
700
+ elif key.endswith("_exact") or key.endswith("_total") or key == "total_problems" or key == "total_samples_per_problem":
701
+ print(f" {key}: {value}")
702
+
703
+ print(f"\n Total problems evaluated: {total}")
704
+ print('='*60)
705
 
706
 
707
  def main():
708
+ parser = argparse.ArgumentParser(
709
+ description="Evaluate Stack 2.9 model on HumanEval and MBPP benchmarks"
710
+ )
711
+ parser.add_argument(
712
+ "--model-path",
713
+ type=str,
714
+ required=True,
715
+ help="Path to the merged model directory"
716
+ )
717
+ parser.add_argument(
718
+ "--output",
719
+ type=str,
720
+ default="evaluation_results.json",
721
+ help="Output file for results (default: evaluation_results.json)"
722
+ )
723
+ parser.add_argument(
724
+ "--num-samples",
725
+ type=int,
726
+ default=10,
727
+ help="Number of samples per problem for pass@k (default: 10)"
728
+ )
729
+ parser.add_argument(
730
+ "--max-new-tokens",
731
+ type=int,
732
+ default=256,
733
+ help="Maximum new tokens to generate (default: 256)"
734
+ )
735
+ parser.add_argument(
736
+ "--k-values",
737
+ type=str,
738
+ default="1,10",
739
+ help="Comma-separated k values for pass@k (default: 1,10)"
740
+ )
741
+ parser.add_argument(
742
+ "--benchmark",
743
+ type=str,
744
+ choices=["humaneval", "mbpp", "both"],
745
+ default="both",
746
+ help="Which benchmark to run (default: both)"
747
+ )
748
+ parser.add_argument(
749
+ "--num-problems",
750
+ type=int,
751
+ default=None,
752
+ help="Limit number of problems per benchmark (default: all)"
753
+ )
754
+
755
  args = parser.parse_args()
756
 
757
+ # Parse k values
758
+ k_values = [int(k.strip()) for k in args.k_values.split(",")]
759
+
760
  print("="*60)
761
  print("Stack 2.9 Model Evaluation")
762
  print("="*60)
763
+ print(f"Model path: {args.model_path}")
764
+ print(f"Output: {args.output}")
765
+ print(f"Num samples per problem: {args.num_samples}")
766
+ print(f"Pass@k values: {k_values}")
767
+ print(f"Benchmark: {args.benchmark}")
768
 
769
+ # Load model
770
  model, tokenizer = load_model(args.model_path)
771
  model.eval()
772
 
773
+ all_results = {}
774
+ total_start = time.time()
 
775
 
776
+ # Run HumanEval
777
+ if args.benchmark in ["humaneval", "both"]:
778
+ print("\n" + "="*60)
779
+ print("Running HumanEval Benchmark")
780
+ print("="*60)
781
+
782
+ problems = get_humaneval_problems()
783
+ if args.num_problems:
784
+ problems = problems[:args.num_problems]
785
+
786
+ results = evaluate_problems(
787
+ model, tokenizer,
788
+ problems,
789
+ k_values=k_values,
790
+ num_samples_per_problem=args.num_samples,
791
+ max_new_tokens=args.max_new_tokens,
792
+ )
793
+ all_results["humaneval"] = results
794
+ print_summary(results, "HumanEval")
795
+
796
+ # Run MBPP
797
+ if args.benchmark in ["mbpp", "both"]:
798
+ print("\n" + "="*60)
799
+ print("Running MBPP Benchmark")
800
+ print("="*60)
801
+
802
+ problems = get_mbpp_problems()
803
+ if args.num_problems:
804
+ problems = problems[:args.num_problems]
805
+
806
+ results = evaluate_problems(
807
+ model, tokenizer,
808
+ problems,
809
+ k_values=k_values,
810
+ num_samples_per_problem=args.num_samples,
811
+ max_new_tokens=args.max_new_tokens,
812
+ )
813
+ all_results["mbpp"] = results
814
+ print_summary(results, "MBPP")
815
+
816
+ total_time = time.time() - total_start
817
 
818
+ # Final summary
819
  print("\n" + "="*60)
820
  print("FINAL SUMMARY")
821
  print("="*60)
822
+
823
+ for bench_name in ["humaneval", "mbpp"]:
824
+ if bench_name in all_results:
825
+ summary = all_results[bench_name]["summary"]
826
+ for k in k_values:
827
+ key = f"pass@{k}"
828
+ if key in summary:
829
+ print(f" {bench_name.upper()} {key}: {100*summary[key]:.2f}%")
830
+
831
+ print(f"\n Total evaluation time: {total_time:.1f}s")
832
  print("="*60)
833
+
834
+ # Add metadata to results
835
+ all_results["metadata"] = {
836
+ "model_path": args.model_path,
837
+ "num_samples": args.num_samples,
838
+ "k_values": k_values,
839
+ "total_time_seconds": total_time,
840
+ }
841
+
842
+ # Save results
843
+ save_results(all_results, args.output)
844
 
845
 
846
  if __name__ == "__main__":
kaggle_train_stack29.ipynb DELETED
@@ -1,288 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 🚀 Stack 2.9 - Kaggle Training\n",
8
- "\n",
9
- "**Free GPU training on Kaggle**\n",
10
- "\n",
11
- "This notebook trains a LoRA adapter for Stack 2.9 on **Qwen2.5-Coder-7B** using Kaggle's free GPU.\n",
12
- "\n",
13
- "⏱️ **Expected runtime:** 2-4 hours\n",
14
- "💾 **VRAM needed:** ~16GB (Kaggle P100 has 16GB)\n",
15
- "\n",
16
- "---\n",
17
- "\n",
18
- "**Instructions:**\n",
19
- "1. Enable GPU: Settings → Accelerator → GPU T4\n",
20
- "2. Run cells in order from the top\n",
21
- "3. Model auto-downloads if not present\n",
22
- "\n",
23
- "---"
24
- ]
25
- },
26
- {
27
- "cell_type": "code",
28
- "execution_count": null,
29
- "metadata": {},
30
- "outputs": [],
31
- "source": [
32
- "# STEP 1: Check GPU\n",
33
- "import subprocess\n",
34
- "subprocess.run([\"nvidia-smi\"], check=True)\n",
35
- "print(\"✅ GPU ready!\")"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "metadata": {},
42
- "outputs": [],
43
- "source": [
44
- "# STEP 2: Clone repo and setup paths\n",
45
- "import os\n",
46
- "import shutil\n",
47
- "import subprocess\n",
48
- "\n",
49
- "os.chdir(\"/kaggle/working\")\n",
50
- "\n",
51
- "REPO_DIR = \"/kaggle/working/stack-2.9\"\n",
52
- "MODEL_DIR = os.path.join(REPO_DIR, \"base_model_qwen7b\")\n",
53
- "OUTPUT_DIR = os.path.join(REPO_DIR, \"training_output\")\n",
54
- "\n",
55
- "# Remove old repo if exists\n",
56
- "if os.path.exists(REPO_DIR):\n",
57
- " shutil.rmtree(REPO_DIR)\n",
58
- "\n",
59
- "# Clone fresh\n",
60
- "subprocess.run([\"git\", \"clone\", \"https://github.com/my-ai-stack/stack-2.9.git\", REPO_DIR], check=True)\n",
61
- "os.chdir(REPO_DIR)\n",
62
- "\n",
63
- "print(f\"✅ Working in: {os.getcwd()}\")"
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": null,
69
- "metadata": {},
70
- "outputs": [],
71
- "source": [
72
- "# STEP 3: Install dependencies\n",
73
- "import subprocess\n",
74
- "\n",
75
- "subprocess.run([\"pip\", \"install\", \"-q\", \"torch\", \"torchvision\", \"torchaudio\", \"--index-url\", \"https://download.pytorch.org/whl/cu118\"], check=True)\n",
76
- "subprocess.run([\"pip\", \"install\", \"-q\", \"transformers==4.40.0\", \"peft==0.10.0\", \"accelerate==0.34.0\", \"datasets\", \"pyyaml\", \"tqdm\", \"scipy\", \"bitsandbytes==0.43.0\"], check=True)\n",
77
- "print(\"✅ Dependencies installed\")"
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": null,
83
- "metadata": {},
84
- "outputs": [],
85
- "source": [
86
- "# STEP 4: Prepare training data\n",
87
- "import os\n",
88
- "import json\n",
89
- "\n",
90
- "# Check for available training data\n",
91
- "REPO_TRAIN_DATA = os.path.join(REPO_DIR, \"training-data/final/train.jsonl\")\n",
92
- "MINI_DATA_DIR = os.path.join(REPO_DIR, \"data_mini\")\n",
93
- "MINI_DATA_FILE = os.path.join(MINI_DATA_DIR, \"train_mini.jsonl\")\n",
94
- "SYNTHETIC_DATA_FILE = os.path.join(REPO_DIR, \"data/synthetic.jsonl\")\n",
95
- "\n",
96
- "print(\"🔍 Checking for training data...\")\n",
97
- "\n",
98
- "if os.path.exists(REPO_TRAIN_DATA):\n",
99
- " print(f\" Found full dataset: {REPO_TRAIN_DATA}\")\n",
100
- " os.makedirs(MINI_DATA_DIR, exist_ok=True)\n",
101
- " if not os.path.exists(MINI_DATA_FILE):\n",
102
- " print(\" Creating mini dataset (1000 samples)...\")\n",
103
- " import subprocess\n",
104
- " subprocess.run([\"python\", os.path.join(REPO_DIR, \"scripts/create_mini_dataset.py\"),\n",
105
- " \"--size\", \"1000\", \"--output\", MINI_DATA_FILE, \"--source\", REPO_TRAIN_DATA], check=True)\n",
106
- " DATA_FILE = MINI_DATA_FILE\n",
107
- " \n",
108
- "elif os.path.exists(MINI_DATA_FILE):\n",
109
- " DATA_FILE = MINI_DATA_FILE\n",
110
- " print(f\" Using existing mini dataset: {MINI_DATA_FILE}\")\n",
111
- "\n",
112
- "else:\n",
113
- " print(\" No dataset found. Creating synthetic data...\")\n",
114
- " \n",
115
- " # Simple code completion examples\n",
116
- " examples = [\n",
117
- " {\"instruction\": \"Write a Python function to reverse a string\", \n",
118
- " \"output\": \"def reverse_string(s):\\n return s[::-1]\"},\n",
119
- " {\"instruction\": \"Write a function to check if a number is prime\", \n",
120
- " \"output\": \"def is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(n**0.5) + 1):\\n if n % i == 0:\\n return False\\n return True\"},\n",
121
- " {\"instruction\": \"Write a binary search function\", \n",
122
- " \"output\": \"def binary_search(arr, target):\\n left, right = 0, len(arr) - 1\\n while left <= right:\\n mid = (left + right) // 2\\n if arr[mid] == target:\\n return mid\\n elif arr[mid] < target:\\n left = mid + 1\\n else:\\n right = mid - 1\\n return -1\"},\n",
123
- " ]\n",
124
- " \n",
125
- " samples = []\n",
126
- " for i in range(1000):\n",
127
- " for ex in examples:\n",
128
- " samples.append(ex)\n",
129
- " \n",
130
- " os.makedirs(os.path.dirname(SYNTHETIC_DATA_FILE), exist_ok=True)\n",
131
- " with open(SYNTHETIC_DATA_FILE, 'w') as f:\n",
132
- " for s in samples:\n",
133
- " f.write(json.dumps(s) + '\\n')\n",
134
- " \n",
135
- " DATA_FILE = SYNTHETIC_DATA_FILE\n",
136
- " print(f\" Created synthetic dataset: {len(samples)} samples\")\n",
137
- "\n",
138
- "print(f\"\\n✅ Using training data: {DATA_FILE}\")\n",
139
- "print(f\" Size: {os.path.getsize(DATA_FILE) / 1024:.1f} KB\")"
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "metadata": {},
146
- "outputs": [],
147
- "source": [
148
- "# STEP 5: Prepare config for training\n",
149
- "import yaml\n",
150
- "import os\n",
151
- "\n",
152
- "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
153
- "\n",
154
- "config = {\n",
155
- " 'model': {\n",
156
- " 'name': 'Qwen/Qwen2.5-Coder-7B',\n",
157
- " 'trust_remote_code': True,\n",
158
- " 'torch_dtype': 'float16'\n",
159
- " },\n",
160
- " 'data': {\n",
161
- " 'input_path': DATA_FILE,\n",
162
- " 'max_length': 2048,\n",
163
- " 'train_split': 1.0\n",
164
- " },\n",
165
- " 'lora': {\n",
166
- " 'r': 16,\n",
167
- " 'alpha': 32,\n",
168
- " 'dropout': 0.05,\n",
169
- " 'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
170
- " 'bias': 'none',\n",
171
- " 'task_type': 'CAUSAL_LM'\n",
172
- " },\n",
173
- " 'training': {\n",
174
- " 'num_epochs': 1,\n",
175
- " 'batch_size': 2,\n",
176
- " 'gradient_accumulation': 4,\n",
177
- " 'learning_rate': 2e-4,\n",
178
- " 'warmup_steps': 50,\n",
179
- " 'weight_decay': 0.01,\n",
180
- " 'max_grad_norm': 1.0,\n",
181
- " 'logging_steps': 10,\n",
182
- " 'save_steps': 100,\n",
183
- " 'save_total_limit': 2,\n",
184
- " 'fp16': True,\n",
185
- " 'bf16': False,\n",
186
- " 'gradient_checkpointing': True\n",
187
- " },\n",
188
- " 'output': {\n",
189
- " 'lora_dir': os.path.join(OUTPUT_DIR, 'lora'),\n",
190
- " 'logging_dir': os.path.join(OUTPUT_DIR, 'logs')\n",
191
- " },\n",
192
- " 'quantization': {'enabled': False},\n",
193
- " 'hardware': {'device': 'cuda', 'num_gpus': 1, 'use_4bit': False, 'use_8bit': False}\n",
194
- "}\n",
195
- "\n",
196
- "config_path = os.path.join(OUTPUT_DIR, \"train_config.yaml\")\n",
197
- "with open(config_path, 'w') as f:\n",
198
- " yaml.dump(config, f, default_flow_style=False)\n",
199
- "\n",
200
- "print(f\"✅ Config saved: {config_path}\")\n",
201
- "print(f\" Model: {config['model']['name']}\")\n",
202
- "print(f\" Data: {config['data']['input_path']}\")"
203
- ]
204
- },
205
- {
206
- "cell_type": "code",
207
- "execution_count": null,
208
- "metadata": {},
209
- "outputs": [],
210
- "source": [
211
- "# STEP 6: Train LoRA\n",
212
- "import sys\n",
213
- "sys.path.insert(0, os.path.join(REPO_DIR, \"stack_2_9_training\"))\n",
214
- "\n",
215
- "print(\"=\"*60)\n",
216
- "print(\"STARTING TRAINING\")\n",
217
- "print(\"=\"*60)\n",
218
- "\n",
219
- "from stack_2_9_training.train_lora import train_lora\n",
220
- "\n",
221
- "try:\n",
222
- " trainer = train_lora(config_path)\n",
223
- " print(\"\\n\" + \"=\"*60)\n",
224
- " print(\"TRAINING COMPLETED\")\n",
225
- " print(\"=\"*60)\n",
226
- "except Exception as e:\n",
227
- " print(f\"\\n❌ Training failed: {e}\")\n",
228
- " import traceback\n",
229
- " traceback.print_exc()\n",
230
- " raise"
231
- ]
232
- },
233
- {
234
- "cell_type": "code",
235
- "execution_count": null,
236
- "metadata": {},
237
- "outputs": [],
238
- "source": [
239
- "# STEP 7: Merge LoRA adapter\n",
240
- "import sys\n",
241
- "sys.path.insert(0, os.path.join(REPO_DIR, \"stack_2_9_training\"))\n",
242
- "from stack_2_9_training.merge_adapter import merge_adapter\n",
243
- "\n",
244
- "lora_dir = config['output']['lora_dir']\n",
245
- "merged_dir = os.path.join(OUTPUT_DIR, 'merged')\n",
246
- "os.makedirs(merged_dir, exist_ok=True)\n",
247
- "\n",
248
- "print(\"=\"*60)\n",
249
- "print(\"MERGING\")\n",
250
- "print(\"=\"*60)\n",
251
- "\n",
252
- "try:\n",
253
- " merge_adapter(\n",
254
- " base_model_name_or_path=config['model']['name'],\n",
255
- " adapter_path=lora_dir,\n",
256
- " output_path=merged_dir,\n",
257
- " use_safetensors=True\n",
258
- " )\n",
259
- " print(\"\\n✅ Merge completed!\")\n",
260
- " print(f\"Files: {os.listdir(merged_dir)}\")\n",
261
- "except Exception as e:\n",
262
- " print(f\"\\n❌ Merge failed: {e}\")\n",
263
- " import traceback\n",
264
- " traceback.print_exc()\n",
265
- " raise\n",
266
- "\n",
267
- "print(\"\\n\" + \"=\"*60)\n",
268
- "print(\"🎉 ALL DONE!\")\n",
269
- "print(\"=\"*60)\n",
270
- "print(f\"\\n📦 Model ready: {merged_dir}\")\n",
271
- "print(\"\\n⏳ Download 'merged' folder from Kaggle Output panel before session ends!\")"
272
- ]
273
- }
274
- ],
275
- "metadata": {
276
- "kaggle": {
277
- "accelerator": "gpu",
278
- "dataSources": [],
279
- "kernelSpec": {
280
- "displayName": "Python 3",
281
- "language": "python",
282
- "name": "python3"
283
- }
284
- }
285
- },
286
- "nbformat": 4,
287
- "nbformat_minor": 0
288
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kaggle_train_stack29_final.ipynb DELETED
@@ -1,198 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 🚀 Stack 2.9 - Kaggle Training\n",
8
- "\n",
9
- "Free GPU training on Kaggle using Qwen2.5-Coder-7B.\n",
10
- "\n",
11
- "⏱️ **Runtime:** 2-4 hours | 💾 **VRAM:** ~16GB\n",
12
- "\n",
13
- "**Setup:**\n",
14
- "1. Settings → Accelerator → GPU **T4**\n",
15
- "2. Run all cells in order\n",
16
- "3. Download merged model from Output tab when done"
17
- ]
18
- },
19
- {
20
- "cell_type": "code",
21
- "execution_count": null,
22
- "metadata": {},
23
- "outputs": [],
24
- "source": [
25
- "# Check GPU\n",
26
- "!nvidia-smi"
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": null,
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "# Clone repository\n",
36
- "import os, shutil, subprocess\n",
37
- "\n",
38
- "os.chdir('/kaggle/working')\n",
39
- "REPO_DIR = '/kaggle/working/stack-2.9'\n",
40
- "OUTPUT_DIR = os.path.join(REPO_DIR, 'training_output')\n",
41
- "\n",
42
- "if os.path.exists(REPO_DIR):\n",
43
- " shutil.rmtree(REPO_DIR)\n",
44
- "subprocess.run(['git', 'clone', 'https://github.com/my-ai-stack/stack-2.9.git', REPO_DIR], check=True)\n",
45
- "os.chdir(REPO_DIR)\n",
46
- "print('✅ Repo ready:', REPO_DIR)"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "# Install dependencies (single command)\n",
56
- "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
57
- "!pip install -q transformers==4.40.0 peft==0.10.0 accelerate==0.34.0 datasets==3.0.0 pyyaml tqdm scipy bitsandbytes==0.43.0\n",
58
- "print('✅ Dependencies ready')"
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": null,
64
- "metadata": {},
65
- "outputs": [],
66
- "source": [
67
- "# Prepare training data (auto-detect or synthetic fallback)\n",
68
- "import os, json\n",
69
- "\n",
70
- "REPO_TRAIN_DATA = os.path.join(REPO_DIR, 'training-data/final/train.jsonl')\n",
71
- "MINI_DATA_DIR = os.path.join(REPO_DIR, 'data_mini')\n",
72
- "MINI_DATA_FILE = os.path.join(MINI_DATA_DIR, 'train_mini.jsonl')\n",
73
- "SYNTHETIC_FILE = os.path.join(REPO_DIR, 'data/synthetic.jsonl')\n",
74
- "\n",
75
- "print('🔍 Data check')\n",
76
- "\n",
77
- "if os.path.exists(REPO_TRAIN_DATA):\n",
78
- " os.makedirs(MINI_DATA_DIR, exist_ok=True)\n",
79
- " if not os.path.exists(MINI_DATA_FILE):\n",
80
- " print(' Building mini dataset (1K samples) from full data...')\n",
81
- " !python scripts/create_mini_dataset.py --size 1000 --output {MINI_DATA_FILE} --source {REPO_TRAIN_DATA}\n",
82
- " DATA_FILE = MINI_DATA_FILE\n",
83
- " print(' Using mini dataset')\n",
84
- "elif os.path.exists(MINI_DATA_FILE):\n",
85
- " DATA_FILE = MINI_DATA_FILE\n",
86
- " print(' Using existing mini dataset')\n",
87
- "else:\n",
88
- " print(' Creating synthetic data (last resort)')\n",
89
- " examples = [\n",
90
- " {'instruction': 'Write a Python function to reverse a string', 'output': 'def reverse_string(s):\\n return s[::-1]'},\n",
91
- " {'instruction': 'Write a function to check if a number is prime', 'output': 'def is_prime(n):\\n if n <= 1:\\n return False\\n for i in range(2, int(n**0.5) + 1):\\n if n % i == 0:\\n return False\\n return True'},\n",
92
- " {'instruction': 'Write a binary search function', 'output': 'def binary_search(arr, target):\\n left, right = 0, len(arr) - 1\\n while left <= right:\\n mid = (left + right) // 2\\n if arr[mid] == target:\\n return mid\\n elif arr[mid] < target:\\n left = mid + 1\\n else:\\n right = mid - 1\\n return -1'},\n",
93
- " ]\n",
94
- " samples = examples * 333\n",
95
- " os.makedirs(os.path.dirname(SYNTHETIC_FILE), exist_ok=True)\n",
96
- " with open(SYNTHETIC_FILE, 'w') as f:\n",
97
- " for s in samples:\n",
98
- " f.write(json.dumps(s) + '\\n')\n",
99
- " DATA_FILE = SYNTHETIC_FILE\n",
100
- " print(f' Synthetic dataset: {len(samples)} examples')\n",
101
- "\n",
102
- "print(f'\\n✅ Data: {DATA_FILE}')\n",
103
- "!ls -lh {DATA_FILE}"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": null,
109
- "metadata": {},
110
- "outputs": [],
111
- "source": [
112
- "# Generate training configuration\n",
113
- "import yaml\n",
114
- "\n",
115
- "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
116
- "\n",
117
- "config = {\n",
118
- " 'model': {'name': 'Qwen/Qwen2.5-Coder-7B', 'trust_remote_code': True, 'torch_dtype': 'float16'},\n",
119
- " 'data': {'input_path': DATA_FILE, 'max_length': 2048, 'train_split': 1.0},\n",
120
- " 'lora': {'r': 16, 'alpha': 32, 'dropout': 0.05, 'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 'bias': 'none', 'task_type': 'CAUSAL_LM'},\n",
121
- " 'training': {'num_epochs': 1, 'batch_size': 2, 'gradient_accumulation': 4, 'learning_rate': 2e-4, 'warmup_steps': 50, 'weight_decay': 0.01, 'max_grad_norm': 1.0, 'logging_steps': 10, 'save_steps': 100, 'save_total_limit': 2, 'fp16': True, 'bf16': False, 'gradient_checkpointing': True},\n",
122
- " 'output': {'lora_dir': os.path.join(OUTPUT_DIR, 'lora'), 'logging_dir': os.path.join(OUTPUT_DIR, 'logs')},\n",
123
- " 'quantization': {'enabled': False},\n",
124
- " 'hardware': {'device': 'cuda', 'num_gpus': 1, 'use_4bit': False, 'use_8bit': False}\n",
125
- "}\n",
126
- "\n",
127
- "config_path = os.path.join(OUTPUT_DIR, 'train_config.yaml')\n",
128
- "with open(config_path, 'w') as f:\n",
129
- " yaml.dump(config, f, default_flow_style=False)\n",
130
- "\n",
131
- "print(f'✅ Config: {config_path}')\n",
132
- "print(f\" Model: {config['model']['name']}\")\n",
133
- "print(f\" Data: {config['data']['input_path']}\")"
134
- ]
135
- },
136
- {
137
- "cell_type": "code",
138
- "execution_count": null,
139
- "metadata": {},
140
- "outputs": [],
141
- "source": [
142
- "# Train (using standalone train_simple.py - no package install needed)\n",
143
- "print('='*60)\n",
144
- "print('STARTING TRAINING')\n",
145
- "print('='*60)\n",
146
- "\n",
147
- "!cd {REPO_DIR} && python train_simple.py --config {config_path}\n",
148
- "\n",
149
- "print('\\n✅ Training step finished')"
150
- ]
151
- },
152
- {
153
- "cell_type": "code",
154
- "execution_count": null,
155
- "metadata": {},
156
- "outputs": [],
157
- "source": [
158
- "# Merge LoRA adapter into final model\n",
159
- "lora_dir = os.path.join(OUTPUT_DIR, 'lora')\n",
160
- "merged_dir = os.path.join(OUTPUT_DIR, 'merged')\n",
161
- "\n",
162
- "print('='*60)\n",
163
- "print('MERGING LORA ADAPTER')\n",
164
- "print('='*60)\n",
165
- "\n",
166
- "!cd {REPO_DIR} && python merge_simple.py \\\n",
167
- " --base-model {config['model']['name']} \\\n",
168
- " --adapter-path {lora_dir} \\\n",
169
- " --output-path {merged_dir} \\\n",
170
- " --use-safetensors\n",
171
- "\n",
172
- "print('\\n✅ Merge complete!')\n",
173
- "print(f'Merged model: {merged_dir}')\n",
174
- "!ls -lh {merged_dir}"
175
- ]
176
- },
177
- {
178
- "cell_type": "markdown",
179
- "metadata": {},
180
- "source": [
181
- "## 📥 Download Model\n",
182
- "\n",
183
- "1. Open **Output** tab on the right\n",
184
- "2. Find `training_output/merged/`\n",
185
- "3. Select all files and **Download**\n",
186
- "\n",
187
- "⚠️ **Do this before Kaggle session ends!**"
188
- ]
189
- }
190
- ],
191
- "metadata": {
192
- "kaggle": {
193
- "accelerator": "gpu"
194
- }
195
- },
196
- "nbformat": 4,
197
- "nbformat_minor": 0
198
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
merge_simple.py CHANGED
@@ -34,7 +34,7 @@ def main():
34
  print("Loading base model...")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  args.base_model,
37
- torch_dtype=torch.float16,
38
  device_map="auto",
39
  trust_remote_code=True
40
  )
 
34
  print("Loading base model...")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  args.base_model,
37
+ torch_dtype=torch.bfloat16,
38
  device_map="auto",
39
  trust_remote_code=True
40
  )
scripts/generate_tool_data.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Synthetic Tool-Calling Training Data Generator for Stack 2.9
4
+ Generates training examples in Qwen2.5-Coder format with tool_calls.
5
+ """
6
+
7
+ import json
8
+ import random
9
+ import argparse
10
+ from pathlib import Path
11
+ from typing import Dict, List, Any
12
+ from datetime import datetime
13
+
14
+ # ============================================================================
15
+ # Tool Definitions (Qwen2.5-Coder format)
16
+ # ============================================================================
17
+
18
+ TOOL_DEFINITIONS = [
19
+ {
20
+ "type": "function",
21
+ "function": {
22
+ "name": "Bash",
23
+ "description": "Execute bash commands in the terminal. Use for running shell commands, scripts, git operations, package managers, and system commands.",
24
+ "parameters": {
25
+ "type": "object",
26
+ "properties": {
27
+ "command": {
28
+ "type": "string",
29
+ "description": "The bash command to execute"
30
+ },
31
+ "timeout": {
32
+ "type": "integer",
33
+ "description": "Timeout in seconds (default: 30)"
34
+ }
35
+ },
36
+ "required": ["command"]
37
+ }
38
+ }
39
+ },
40
+ {
41
+ "type": "function",
42
+ "function": {
43
+ "name": "FileRead",
44
+ "description": "Read the contents of a file from the filesystem. Use for viewing source code, configuration files, documentation, or any text-based files.",
45
+ "parameters": {
46
+ "type": "object",
47
+ "properties": {
48
+ "path": {
49
+ "type": "string",
50
+ "description": "Path to the file to read"
51
+ },
52
+ "offset": {
53
+ "type": "integer",
54
+ "description": "Line number to start reading from (1-indexed)"
55
+ },
56
+ "limit": {
57
+ "type": "integer",
58
+ "description": "Maximum number of lines to read"
59
+ }
60
+ },
61
+ "required": ["path"]
62
+ }
63
+ }
64
+ },
65
+ {
66
+ "type": "function",
67
+ "function": {
68
+ "name": "FileWrite",
69
+ "description": "Create or overwrite a file with content. Use for creating new files, updating existing files, or writing code, configuration, or documentation.",
70
+ "parameters": {
71
+ "type": "object",
72
+ "properties": {
73
+ "path": {
74
+ "type": "string",
75
+ "description": "Path where the file should be created or written"
76
+ },
77
+ "content": {
78
+ "type": "string",
79
+ "description": "The content to write to the file"
80
+ },
81
+ "append": {
82
+ "type": "boolean",
83
+ "description": "Append to existing file instead of overwriting (default: false)"
84
+ }
85
+ },
86
+ "required": ["path", "content"]
87
+ }
88
+ }
89
+ },
90
+ {
91
+ "type": "function",
92
+ "function": {
93
+ "name": "WebSearch",
94
+ "description": "Search the web for information. Use for finding documentation, looking up error messages, researching libraries, or getting up-to-date information.",
95
+ "parameters": {
96
+ "type": "object",
97
+ "properties": {
98
+ "query": {
99
+ "type": "string",
100
+ "description": "The search query to look up on the web"
101
+ },
102
+ "count": {
103
+ "type": "integer",
104
+ "description": "Number of results to return (default: 5)"
105
+ }
106
+ },
107
+ "required": ["query"]
108
+ }
109
+ }
110
+ },
111
+ {
112
+ "type": "function",
113
+ "function": {
114
+ "name": "Grep",
115
+ "description": "Search for patterns in files. Use for finding specific code, function definitions, imports, TODO comments, error patterns, or any text across the codebase.",
116
+ "parameters": {
117
+ "type": "object",
118
+ "properties": {
119
+ "pattern": {
120
+ "type": "string",
121
+ "description": "The search pattern or regex to match"
122
+ },
123
+ "path": {
124
+ "type": "string",
125
+ "description": "Directory or file path to search in (default: current directory)"
126
+ },
127
+ "recursive": {
128
+ "type": "boolean",
129
+ "description": "Search recursively in subdirectories (default: true)"
130
+ },
131
+ "file_pattern": {
132
+ "type": "string",
133
+ "description": "File pattern to filter results (e.g., '*.py', '*.js')"
134
+ }
135
+ },
136
+ "required": ["pattern"]
137
+ }
138
+ }
139
+ }
140
+ ]
141
+
142
+ # ============================================================================
143
+ # Template Data for Generation
144
+ # ============================================================================
145
+
146
+ FILE_PATHS = [
147
+ "src/main.py", "src/utils.py", "src/config.py", "src/models.py",
148
+ "src/api.py", "src/handlers.py", "src/middleware.py",
149
+ "tests/test_main.py", "tests/test_utils.py", "tests/conftest.py",
150
+ "README.md", "LICENSE", "package.json", "requirements.txt",
151
+ "config.yaml", "config.json", ".env.example",
152
+ "src/components/Button.tsx", "src/components/Header.jsx",
153
+ "src/styles.css", "src/index.js", "src/app.js",
154
+ "docs/API.md", "docs/ARCHITECTURE.md", "docs/CONTRIBUTING.md",
155
+ "scripts/setup.sh", "scripts/deploy.py", "Makefile"
156
+ ]
157
+
158
+ CODE_SNIPPETS = {
159
+ "python": [
160
+ "def hello():\n print('Hello, World!')",
161
+ "class MyClass:\n def __init__(self):\n self.value = 42",
162
+ "import os\nos.path.join('a', 'b')",
163
+ "async def fetch_data():\n async with aiohttp.ClientSession() as session:\n return await session.get(url)",
164
+ ],
165
+ "javascript": [
166
+ "const fetch = require('node-fetch');\nconst data = await fetch(url);",
167
+ "function handleClick() {\n setCount(count + 1);\n}",
168
+ "export default function App() {\n return <div>Hello</div>;\n}",
169
+ "const [state, setState] = useState(null);",
170
+ ],
171
+ "bash": [
172
+ "npm install",
173
+ "git status",
174
+ "pytest -v",
175
+ "python -m pytest tests/",
176
+ "make build",
177
+ "docker build -t myapp .",
178
+ "ls -la",
179
+ "curl -X GET https://api.example.com",
180
+ ]
181
+ }
182
+
183
+ WEB_SEARCH_QUERIES = [
184
+ "python async await best practices",
185
+ "javascript array methods map filter reduce",
186
+ "TypeScript generics tutorial",
187
+ "React hooks useEffect dependency array",
188
+ "Node.js error handling best practices",
189
+ "Docker vs Kubernetes differences",
190
+ "Git rebase vs merge",
191
+ "SQL join types explained",
192
+ "REST API design principles",
193
+ "Python list comprehension examples",
194
+ "JavaScript promise async await",
195
+ "CSS flexbox vs grid",
196
+ "MongoDB vs PostgreSQL",
197
+ "AWS Lambda cold start optimization",
198
+ "Docker compose tutorial",
199
+ ]
200
+
201
+ GREP_PATTERNS = [
202
+ ("TODO", "*.py", "Find TODO comments in Python files"),
203
+ ("function ", "*.js", "Find function definitions in JS"),
204
+ ("import ", "*.py", "Find imports in Python"),
205
+ ("class ", "*.ts", "Find class definitions in TypeScript"),
206
+ ("def ", "*.py", "Find function definitions in Python"),
207
+ ("const ", "*.js", "Find const declarations in JS"),
208
+ ("interface ", "*.ts", "Find interface definitions"),
209
+ ("#pragma", "*.c", "Find pragma directives in C"),
210
+ ("SELECT ", "*.sql", "Find SQL queries"),
211
+ ("TODO", "*.md", "Find TODOs in documentation"),
212
+ ]
213
+
214
+ # ============================================================================
215
+ # Scenario Templates
216
+ # ============================================================================
217
+
218
+ SCENARIOS = {
219
+ "FileRead": [
220
+ {
221
+ "user": "Read the file at src/main.py and show me what's in it",
222
+ "tool_call": {"name": "FileRead", "arguments": {"path": "src/main.py"}},
223
+ "result": "Successfully read file: src/main.py\n```python\ndef main():\n print('Hello from Stack 2.9!')\n\nif __name__ == '__main__':\n main()\n```",
224
+ "response": "Here's the contents of `src/main.py`:\n\n```python\ndef main():\n print('Hello from Stack 2.9!')\n\nif __name__ == '__main__':\n main()\n```\n\nThis is a simple Python script that prints a greeting when executed directly."
225
+ },
226
+ {
227
+ "user": "Can you show me the README.md file?",
228
+ "tool_call": {"name": "FileRead", "arguments": {"path": "README.md"}},
229
+ "result": "Successfully read file: README.md\n```markdown\n# My Project\n\nA sample project for Stack 2.9.\n\n## Installation\n\npip install -r requirements.txt\n```",
230
+ "response": "Here's the README.md:\n\n```markdown\n# My Project\n\nA sample project for Stack 2.9.\n\n## Installation\n\npip install -r requirements.txt\n```"
231
+ },
232
+ {
233
+ "user": "What's in the config.yaml file?",
234
+ "tool_call": {"name": "FileRead", "arguments": {"path": "config.yaml"}},
235
+ "result": "Successfully read file: config.yaml\n```yaml\ndatabase:\n host: localhost\n port: 5432\n name: myapp\n\napp:\n port: 3000\n debug: true\n```",
236
+ "response": "The config.yaml contains:\n\n```yaml\ndatabase:\n host: localhost\n port: 5432\n name: myapp\n\napp:\n port: 3000\n debug: true\n```\n\nIt has database connection settings and app configuration."
237
+ },
238
+ {
239
+ "user": "Check the package.json to see the dependencies",
240
+ "tool_call": {"name": "FileRead", "arguments": {"path": "package.json"}},
241
+ "result": "Successfully read file: package.json\n```json\n{\n \"name\": \"my-app\",\n \"version\": \"1.0.0\",\n \"dependencies\": {\n \"express\": \"^4.18.0\",\n \"react\": \"^18.2.0\"\n }\n}\n```",
242
+ "response": "The package.json shows these dependencies:\n\n```json\n{\n \"name\": \"my-app\",\n \"version\": \"1.0.0\",\n \"dependencies\": {\n \"express\": \"^4.18.0\",\n \"react\": \"^18.2.0\"\n }\n}\n```"
243
+ },
244
+ ],
245
+ "FileWrite": [
246
+ {
247
+ "user": "Create a new file called hello.py with a hello world function",
248
+ "tool_call": {"name": "FileWrite", "arguments": {"path": "hello.py", "content": "def hello():\n print('Hello, World!')\n\nif __name__ == '__main__':\n hello()\n"}},
249
+ "result": "Successfully wrote file: hello.py",
250
+ "response": "Created `hello.py` with a hello world function:\n\n```python\ndef hello():\n print('Hello, World!')\n\nif __name__ == '__main__':\n hello()\n```\n\nYou can run it with `python hello.py`."
251
+ },
252
+ {
253
+ "user": "Write a simple React component to src/components/Button.jsx",
254
+ "tool_call": {"name": "FileWrite", "arguments": {"path": "src/components/Button.jsx", "content": "import React from 'react';\n\nexport default function Button({ onClick, children }) {\n return (\n <button onClick={onClick} className=\"btn\">\n {children}\n </button>\n );\n}\n"}},
255
+ "result": "Successfully wrote file: src/components/Button.jsx",
256
+ "response": "Created `src/components/Button.jsx`:\n\n```jsx\nimport React from 'react';\n\nexport default function Button({ onClick, children }) {\n return (\n <button onClick={onClick} className=\"btn\">\n {children}\n </button>\n );\n}\n```"
257
+ },
258
+ {
259
+ "user": "Create a config.json file with some settings",
260
+ "tool_call": {"name": "FileWrite", "arguments": {"path": "config.json", "content": "{\n \"app\": {\n \"name\": \"MyApp\",\n \"version\": \"1.0.0\"\n },\n \"server\": {\n \"port\": 8080,\n \"host\": \"localhost\"\n }\n}\n"}},
261
+ "result": "Successfully wrote file: config.json",
262
+ "response": "Created `config.json`:\n\n```json\n{\n \"app\": {\n \"name\": \"MyApp\",\n \"version\": \"1.0.0\"\n },\n \"server\": {\n \"port\": 8080,\n \"host\": \"localhost\"\n }\n}\n```"
263
+ },
264
+ ],
265
+ "Bash": [
266
+ {
267
+ "user": "Run npm install to install dependencies",
268
+ "tool_call": {"name": "Bash", "arguments": {"command": "npm install"}},
269
+ "result": "added 245 packages in 3.5s",
270
+ "response": "Dependencies installed successfully! Added 245 packages."
271
+ },
272
+ {
273
+ "user": "Check git status",
274
+ "tool_call": {"name": "Bash", "arguments": {"command": "git status"}},
275
+ "result": "On branch main\nYour branch is up to date with 'origin/main'.\n\nChanges not staged for commit:\n modified: README.md\n modified: src/main.py\n\nno changes added to commit (but untracked files present)",
276
+ "response": "Git status shows:\n- Modified: README.md, src/main.py (not staged)\n- Untracked files present"
277
+ },
278
+ {
279
+ "user": "Run the tests with pytest",
280
+ "tool_call": {"name": "Bash", "arguments": {"command": "pytest -v"}},
281
+ "result": "tests/test_main.py::test_hello PASSED\ntests/test_utils.py::test_add PASSED\n\n2 passed in 0.5s",
282
+ "response": "All tests passed!\n- `tests/test_main.py::test_hello` ✓\n- `tests/test_utils.py::test_add` ✓"
283
+ },
284
+ {
285
+ "user": "List all files in the current directory",
286
+ "tool_call": {"name": "Bash", "arguments": {"command": "ls -la"}},
287
+ "result": "total 64\ndrwxr-xr-x 12 user staff 384 Apr 6 00:58 .\ndrwxr-xr-x 9 user staff 288 Apr 6 00:50 ..\n-rw-r--r-- 1 user staff 220 Apr 6 00:50 README.md\ndrwxr-xr-x 3 user staff 96 Apr 6 00:51 src\ndrwxr-xr-x 3 user staff 96 Apr 6 00:51 tests\n...",
288
+ "response": "Files in current directory:\n- README.md\n- src/\n- tests/\n- And more (permissions, dates shown)"
289
+ },
290
+ {
291
+ "user": "Build the project with make",
292
+ "tool_call": {"name": "Bash", "arguments": {"command": "make build"}},
293
+ "result": "Building...\ngcc -o main main.c\nobuild successful!",
294
+ "response": "Build completed successfully!"
295
+ },
296
+ ],
297
+ "WebSearch": [
298
+ {
299
+ "user": "How do I use async/await in Python?",
300
+ "tool_call": {"name": "WebSearch", "arguments": {"query": "python async await best practices", "count": 5}},
301
+ "result": "Search results for 'python async await best practices':\n1. Real Python - Async IO in Python: A Complete Walkthrough\n2. Python Docs - asyncio - Asynchronous I/O\n3. Towards Data Science - Mastering Python Async/Await\n4. Stack Overflow - Common async/await mistakes\n5. Medium - Async/Await Patterns in Python",
302
+ "response": "Here are some resources on Python async/await:\n\n1. **Real Python** - Async IO in Python: A Complete Walkthrough\n2. **Python Docs** - asyncio - Asynchronous I/O\n3. **Towards Data Science** - Mastering Python Async/Await\n\nKey points:\n- Use `async def` to define async functions\n- Use `await` to call other async functions\n- `asyncio.run()` to execute the main coroutine"
303
+ },
304
+ {
305
+ "user": "What's the difference between Docker and Kubernetes?",
306
+ "tool_call": {"name": "WebSearch", "arguments": {"query": "Docker vs Kubernetes differences", "count": 5}},
307
+ "result": "Search results for 'Docker vs Kubernetes differences':\n1. Docker Blog - Containers 101\n2. Kubernetes.io - What is Kubernetes\n3. Red Hat - Docker vs Kubernetes explained\n4. DevOps.com - Container orchestration comparison\n5. Container Journal - When to use each",
308
+ "response": "**Docker** is a container platform - it packages applications into containers.\n\n**Kubernetes** is a container orchestrator - it manages multiple containers across machines.\n\nThink of Docker as individual shipping containers, and Kubernetes as the port management system that coordinates all of them."
309
+ },
310
+ {
311
+ "user": "How do I use React hooks properly?",
312
+ "tool_call": {"name": "WebSearch", "arguments": {"query": "React hooks useEffect dependency array", "count": 5}},
313
+ "result": "Search results for 'React hooks useEffect dependency array':\n1. React Docs - Synchronizing with Effects\n2. Kent C. Dodds - A Complete Guide to useEffect\n3. Robin Wieruch - React useEffect Hook Tutorial\n4. Stack Overflow - Common useEffect mistakes\n5. LogRocket - useEffect dependency array deep dive",
314
+ "response": "**React Hooks Guide:**\n\n1. **useState** - For state management\n2. **useEffect** - For side effects, takes dependency array\n3. **useContext** - For global state\n\nThe dependency array controls when the effect runs. Empty `[]` = once on mount. Include variables to re-run when they change."
315
+ },
316
+ ],
317
+ "Grep": [
318
+ {
319
+ "user": "Find all TODO comments in the Python files",
320
+ "tool_call": {"name": "Grep", "arguments": {"pattern": "TODO", "path": ".", "recursive": True, "file_pattern": "*.py"}},
321
+ "result": "Found 3 matches:\nsrc/main.py:14: # TODO: Add error handling\nsrc/utils.py:22: # TODO: Implement caching\ntests/test_main.py:8: # TODO: Add more tests",
322
+ "response": "Found 3 TODO comments in Python files:\n- `src/main.py:14` - Add error handling\n- `src/utils.py:22` - Implement caching\n- `tests/test_main.py:8` - Add more tests"
323
+ },
324
+ {
325
+ "user": "Search for function definitions in the JavaScript files",
326
+ "tool_call": {"name": "Grep", "arguments": {"pattern": "function ", "path": "src", "recursive": True, "file_pattern": "*.js"}},
327
+ "result": "Found 5 matches:\nsrc/index.js:3: function setup() {\nsrc/handlers.js:12: function handleClick(event) {\nsrc/handlers.js:25: function handleSubmit(form) {\nsrc/utils.js:8: function formatDate(date) {\nsrc/utils.js:15: function parseJSON(str) {",
328
+ "response": "Found 5 function definitions in src/:\n- `src/index.js:3` - setup()\n- `src/handlers.js:12` - handleClick(event)\n- `src/handlers.js:15` - handleSubmit(form)\n- `src/utils.js:8` - formatDate(date)\n- `src/utils.js:15` - parseJSON(str)"
329
+ },
330
+ {
331
+ "user": "Find all imports in the Python codebase",
332
+ "tool_call": {"name": "Grep", "arguments": {"pattern": "import ", "path": ".", "recursive": True, "file_pattern": "*.py"}},
333
+ "result": "Found 12 matches:\nsrc/main.py:1: import os\nsrc/main.py:2: import sys\nsrc/main.py:3: import json\nsrc/utils.py:1: from typing import Dict\nsrc/api.py:5: import requests\n...",
334
+ "response": "Found 12 import statements across Python files. Main imports:\n- `os`, `sys`, `json` in src/main.py\n- `typing.Dict` in src/utils.py\n- `requests` in src/api.py"
335
+ },
336
+ {
337
+ "user": "Look for class definitions in TypeScript files",
338
+ "tool_call": {"name": "Grep", "arguments": {"pattern": "class ", "path": "src", "recursive": True, "file_pattern": "*.ts"}},
339
+ "result": "Found 2 matches:\nsrc/models/User.ts:3: class User {\nsrc/models/Config.ts:8: class Config {",
340
+ "response": "Found 2 class definitions:\n- `src/models/User.ts:3` - User class\n- `src/models/Config.ts:8` - Config class"
341
+ },
342
+ ],
343
+ }
344
+
345
+
346
+ # ============================================================================
347
+ # Data Augmentation Helpers
348
+ # ============================================================================
349
+
350
+ def vary_file_paths(template: Dict) -> Dict:
351
+ """Vary file paths in a scenario to create diversity."""
352
+ import copy
353
+ result = copy.deepcopy(template)
354
+
355
+ # Vary the file paths
356
+ if "path" in result["tool_call"]["arguments"]:
357
+ original_path = result["tool_call"]["arguments"]["path"]
358
+ for old_path in FILE_PATHS:
359
+ if old_path in result["user"]:
360
+ result["user"] = result["user"].replace(old_path, random.choice(FILE_PATHS))
361
+ result["tool_call"]["arguments"]["path"] = random.choice(FILE_PATHS)
362
+ break
363
+
364
+ return result
365
+
366
+
367
+ def vary_content(template: Dict) -> Dict:
368
+ """Vary content in a scenario."""
369
+ import copy
370
+ result = copy.deepcopy(template)
371
+
372
+ if "content" in result["tool_call"]["arguments"]:
373
+ # Vary code snippets
374
+ lang = random.choice(["python", "javascript"])
375
+ result["tool_call"]["arguments"]["content"] = random.choice(CODE_SNIPPETS[lang])
376
+
377
+ return result
378
+
379
+
380
+ def vary_bash_command(template: Dict) -> Dict:
381
+ """Vary bash commands."""
382
+ import copy
383
+ result = copy.deepcopy(template)
384
+
385
+ if "command" in result["tool_call"]["arguments"]:
386
+ original = result["tool_call"]["arguments"]["command"].split()[0] if result["tool_call"]["arguments"]["command"] else ""
387
+
388
+ if "npm" in original:
389
+ commands = ["npm install", "npm run build", "npm test", "npm start"]
390
+ elif "git" in original:
391
+ commands = ["git status", "git log --oneline -5", "git diff", "git branch -a"]
392
+ elif "pytest" in original:
393
+ commands = ["pytest -v", "pytest tests/", "pytest -xvs", "pytest --cov"]
394
+ elif "ls" in original:
395
+ commands = ["ls -la", "ls -1", "ls -lah"]
396
+ elif "make" in original:
397
+ commands = ["make build", "make clean", "make test", "make install"]
398
+ else:
399
+ commands = ["echo 'hello'", "pwd", "whoami", "date"]
400
+
401
+ result["tool_call"]["arguments"]["command"] = random.choice(commands)
402
+
403
+ return result
404
+
405
+
406
+ def vary_search_query(template: Dict) -> Dict:
407
+ """Vary web search queries."""
408
+ import copy
409
+ result = copy.deepcopy(template)
410
+
411
+ if "query" in result["tool_call"]["arguments"]:
412
+ result["tool_call"]["arguments"]["query"] = random.choice(WEB_SEARCH_QUERIES)
413
+
414
+ return result
415
+
416
+
417
+ def vary_grep_pattern(template: Dict) -> Dict:
418
+ """Vary grep patterns."""
419
+ import copy
420
+ result = copy.deepcopy(template)
421
+
422
+ pattern, file_pattern, _ = random.choice(GREP_PATTERNS)
423
+ result["tool_call"]["arguments"]["pattern"] = pattern
424
+ result["tool_call"]["arguments"]["file_pattern"] = file_pattern
425
+
426
+ return result
427
+
428
+
429
+ # ============================================================================
430
+ # Main Generation Functions
431
+ # ============================================================================
432
+
433
+ def create_tool_call_message(tool_call: Dict, tool_call_id: str) -> Dict:
434
+ """Create a tool_calls message in Qwen format."""
435
+ return {
436
+ "role": "assistant",
437
+ "content": None,
438
+ "tool_calls": [
439
+ {
440
+ "id": tool_call_id,
441
+ "type": "function",
442
+ "function": {
443
+ "name": tool_call["name"],
444
+ "arguments": json.dumps(tool_call["arguments"])
445
+ }
446
+ }
447
+ ]
448
+ }
449
+
450
+
451
+ def create_tool_message(role: str, tool_call_id: str, tool_name: str, result: str) -> Dict:
452
+ """Create a tool message (result of tool execution)."""
453
+ return {
454
+ "role": role, # typically "tool"
455
+ "content": result,
456
+ "tool_call_id": tool_call_id,
457
+ "name": tool_name
458
+ }
459
+
460
+
461
+ def generate_example(scenario: Dict, system_prompt: str = None) -> Dict:
462
+ """Generate a single training example in Qwen2.5-Coder format."""
463
+ if system_prompt is None:
464
+ system_prompt = "You are a helpful AI assistant that can use tools to help users solve problems. When you need to perform actions like reading files, running commands, searching the web, or searching code, use the appropriate tool."
465
+
466
+ tool_call_id = f"call_${random.randint(1000, 9999)}"
467
+
468
+ messages = [
469
+ {"role": "system", "content": system_prompt},
470
+ {"role": "user", "content": scenario["user"]},
471
+ create_tool_call_message(scenario["tool_call"], tool_call_id),
472
+ create_tool_message("tool", tool_call_id, scenario["tool_call"]["name"], scenario["result"]),
473
+ {"role": "assistant", "content": scenario["response"]}
474
+ ]
475
+
476
+ return {
477
+ "messages": messages,
478
+ "tools": TOOL_DEFINITIONS
479
+ }
480
+
481
+
482
+ def augment_scenario(scenario: Dict, tool_name: str) -> Dict:
483
+ """Apply random augmentations to a scenario."""
484
+ import random
485
+
486
+ augmented = scenario.copy()
487
+
488
+ if tool_name == "FileRead":
489
+ augmented = vary_file_paths(augmented)
490
+ elif tool_name == "FileWrite":
491
+ augmented = vary_file_paths(augmented)
492
+ augmented = vary_content(augmented)
493
+ elif tool_name == "Bash":
494
+ augmented = vary_bash_command(augmented)
495
+ elif tool_name == "WebSearch":
496
+ augmented = vary_search_query(augmented)
497
+ elif tool_name == "Grep":
498
+ augmented = vary_grep_pattern(augmented)
499
+
500
+ return augmented
501
+
502
+
503
+ def generate_dataset(num_examples: int = 1000, output_path: str = None) -> List[Dict]:
504
+ """Generate the complete dataset."""
505
+ examples = []
506
+ tools = list(SCENARIOS.keys())
507
+
508
+ # Track counts for balance
509
+ examples_per_tool = num_examples // len(tools)
510
+ remainder = num_examples % len(tools)
511
+
512
+ for i, tool_name in enumerate(tools):
513
+ # Determine how many examples for this tool
514
+ count = examples_per_tool + (1 if i < remainder else 0)
515
+
516
+ base_scenarios = SCENARIOS[tool_name]
517
+
518
+ for j in range(count):
519
+ # Use base scenario and vary
520
+ base = base_scenarios[j % len(base_scenarios)]
521
+
522
+ # Apply augmentations for variety
523
+ if j >= len(base_scenarios):
524
+ scenario = augment_scenario(base, tool_name)
525
+ else:
526
+ scenario = base
527
+
528
+ example = generate_example(scenario)
529
+ examples.append(example)
530
+
531
+ # Shuffle for better training
532
+ random.shuffle(examples)
533
+
534
+ return examples
535
+
536
+
537
+ def save_jsonl(examples: List[Dict], output_path: str):
538
+ """Save examples to JSONL format."""
539
+ output_file = Path(output_path)
540
+ output_file.parent.mkdir(parents=True, exist_ok=True)
541
+
542
+ with open(output_file, 'w', encoding='utf-8') as f:
543
+ for example in examples:
544
+ f.write(json.dumps(example, ensure_ascii=False) + '\n')
545
+
546
+
547
+ def save_json(examples: List[Dict], output_path: str):
548
+ """Save examples to JSON format."""
549
+ output_file = Path(output_path)
550
+ output_file.parent.mkdir(parents=True, exist_ok=True)
551
+
552
+ with open(output_file, 'w', encoding='utf-8') as f:
553
+ json.dump(examples, f, ensure_ascii=False, indent=2)
554
+
555
+
556
+ def main():
557
+ parser = argparse.ArgumentParser(description="Generate synthetic tool-calling training data")
558
+ parser.add_argument("--num-examples", type=int, default=1000, help="Number of examples to generate")
559
+ parser.add_argument("--output-dir", type=str, default="training-data", help="Output directory")
560
+ parser.add_argument("--output-format", choices=["jsonl", "json", "both"], default="jsonl", help="Output format")
561
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
562
+ args = parser.parse_args()
563
+
564
+ # Set seed for reproducibility
565
+ random.seed(args.seed)
566
+
567
+ print(f"🎯 Generating {args.num_examples} tool-calling training examples...")
568
+ print(f" Output directory: {args.output_dir}")
569
+ print(f" Format: {args.output_format}")
570
+ print()
571
+
572
+ # Generate dataset
573
+ examples = generate_dataset(args.num_examples)
574
+
575
+ output_dir = Path(args.output_dir)
576
+
577
+ # Save based on format
578
+ if args.output_format in ["jsonl", "both"]:
579
+ jsonl_path = output_dir / "tool_examples.jsonl"
580
+ save_jsonl(examples, str(jsonl_path))
581
+ print(f"✅ Saved JSONL: {jsonl_path}")
582
+
583
+ if args.output_format in ["json", "both"]:
584
+ json_path = output_dir / "tool_examples.json"
585
+ save_json(examples, str(json_path))
586
+ print(f"✅ Saved JSON: {json_path}")
587
+
588
+ # Statistics
589
+ print(f"\n📊 Statistics:")
590
+ print(f" Total examples: {len(examples)}")
591
+
592
+ # Count by tool
593
+ tool_counts = {}
594
+ for ex in examples:
595
+ for msg in ex["messages"]:
596
+ if msg.get("tool_calls"):
597
+ tool_name = msg["tool_calls"][0]["function"]["name"]
598
+ tool_counts[tool_name] = tool_counts.get(tool_name, 0) + 1
599
+
600
+ print(f" Examples by tool:")
601
+ for tool, count in sorted(tool_counts.items(), key=lambda x: x[1], reverse=True):
602
+ print(f" - {tool}: {count}")
603
+
604
+ # Show sample
605
+ print(f"\n📝 Sample example (first in dataset):")
606
+ sample = examples[0]
607
+ print(f" Tools defined: {len(sample['tools'])}")
608
+ print(f" Messages: {len(sample['messages'])}")
609
+ print(f" First user message: {sample['messages'][1]['content'][:60]}...")
610
+
611
+ print(f"\n✨ Generation complete!")
612
+
613
+
614
+ if __name__ == "__main__":
615
+ main()
test_model.py CHANGED
@@ -1,10 +1,17 @@
1
  #!/usr/bin/env python3
2
  """
3
- Test script for fine-tuned Stack 2.9 model.
4
- Tests basic code generation capabilities.
 
 
 
 
5
  """
6
 
7
  import argparse
 
 
 
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
@@ -12,141 +19,583 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  def load_model(model_path: str):
13
  """Load the fine-tuned model and tokenizer."""
14
  print(f"Loading model from: {model_path}")
15
- tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_path,
18
  torch_dtype=torch.float16,
19
  device_map="auto",
20
  low_cpu_mem_usage=True,
 
21
  )
 
22
  return model, tokenizer
23
 
24
 
25
- def test_code_completion(model, tokenizer, prompt: str, max_new_tokens: int = 100):
26
- """Test code completion for a given prompt."""
27
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
28
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
29
 
30
- with torch.no_grad():
31
- outputs = model.generate(
32
- **inputs,
33
- max_new_tokens=max_new_tokens,
34
- temperature=0.2,
35
- top_p=0.95,
36
- do_sample=True,
37
- repetition_penalty=1.1,
38
- )
39
-
40
- completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
- # Remove the prompt from the completion
42
- if completion.startswith(prompt):
43
- completion = completion[len(prompt):].strip()
44
- return completion
45
-
46
-
47
- def run_tests(model_path: str):
48
- """Run all code generation tests."""
49
- model, tokenizer = load_model(model_path)
50
- model.eval()
51
 
52
- test_cases = [
53
- {
54
- "name": "Reverse String",
55
- "prompt": "def reverse_string(s):",
56
- "max_tokens": 50,
57
- "expected_keywords": ["return", "s[::-1]", "reversed"],
58
- },
59
- {
60
- "name": "Binary Search",
61
- "prompt": "def binary_search(arr, target):",
62
- "max_tokens": 100,
63
- "expected_keywords": ["while", "left", "right", "mid"],
64
- },
65
- {
66
- "name": "Fibonacci",
67
- "prompt": "def fibonacci(n):",
68
- "max_tokens": 80,
69
- "expected_keywords": ["return", "if", "else", "fib"],
70
- },
71
- {
72
- "name": "Factorial",
73
- "prompt": "def factorial(n):",
74
- "max_tokens": 60,
75
- "expected_keywords": ["return", "if", "*"],
76
- },
77
- {
78
- "name": "Is Prime",
79
- "prompt": "def is_prime(n):",
80
- "max_tokens": 80,
81
- "expected_keywords": ["if", "return", "for", "%"],
82
- },
83
- {
84
- "name": "List Sum",
85
- "prompt": "def list_sum(nums):",
86
- "max_tokens": 50,
87
- "expected_keywords": ["return", "sum", "+"],
88
- },
89
- {
90
- "name": "Merge Sort",
91
- "prompt": "def merge_sort(arr):",
92
- "max_tokens": 150,
93
- "expected_keywords": ["if", "len", "return", "merge"],
94
- },
95
- {
96
- "name": "Quick Sort",
97
- "prompt": "def quick_sort(arr):",
98
- "max_tokens": 150,
99
- "expected_keywords": ["if", "len", "return", "pivot"],
100
- },
101
- ]
102
 
103
- print("\n" + "="*60)
104
- print("Running Code Generation Tests")
105
- print("="*60 + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- passed = 0
108
- failed = 0
 
 
109
 
110
- for i, test in enumerate(test_cases, 1):
111
- print(f"Test {i}: {test['name']}")
112
- print(f"Prompt: {test['prompt']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  try:
115
- completion = test_code_completion(
116
- model, tokenizer,
117
- test['prompt'],
118
- test['max_tokens']
119
- )
120
- print(f"Completion:\n{completion[:300]}")
121
-
122
- # Check for expected keywords
123
- keywords_found = sum(1 for kw in test['expected_keywords'] if kw.lower() in completion.lower())
124
- if keywords_found >= len(test['expected_keywords']) // 2:
125
- print(f"✅ PASS (found {keywords_found}/{len(test['expected_keywords'])} keywords)")
126
- passed += 1
127
  else:
128
- print(f"⚠️ PARTIAL (found {keywords_found}/{len(test['expected_keywords'])} keywords)")
129
- passed += 1 # Still count as pass if some keywords found
130
- print()
131
-
132
  except Exception as e:
133
- print(f" FAIL: {e}")
134
- failed += 1
135
- print()
 
136
 
137
- print("="*60)
138
- print(f"Results: {passed} passed, {failed} failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  print("="*60)
140
 
141
- return passed, failed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
  def main():
145
- parser = argparse.ArgumentParser(description="Test fine-tuned Stack 2.9 model")
146
- parser.add_argument("--model-path", type=str, required=True, help="Path to fine-tuned model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  args = parser.parse_args()
148
 
149
- run_tests(args.model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ Basic Code Generation Tests for Stack 2.9 Model
4
+ Tests common algorithms and data structures.
5
+
6
+ Usage:
7
+ python test_model.py --model-path /path/to/merged/model
8
+ python test_model.py --model-path /path/to/merged/model --output test_results.json
9
  """
10
 
11
  import argparse
12
+ import json
13
+ import time
14
+ from typing import List, Dict, Tuple, Optional
15
  import torch
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
 
 
19
  def load_model(model_path: str):
20
  """Load the fine-tuned model and tokenizer."""
21
  print(f"Loading model from: {model_path}")
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_path,
29
  torch_dtype=torch.float16,
30
  device_map="auto",
31
  low_cpu_mem_usage=True,
32
+ trust_remote_code=True,
33
  )
34
+
35
  return model, tokenizer
36
 
37
 
38
+ def generate_completion(
39
+ model,
40
+ tokenizer,
41
+ prompt: str,
42
+ max_new_tokens: int = 128,
43
+ temperature: float = 0.2,
44
+ num_samples: int = 1
45
+ ) -> List[str]:
46
+ """Generate code completion(s) for a prompt."""
47
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
48
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
49
 
50
+ outputs = model.generate(
51
+ **inputs,
52
+ max_new_tokens=max_new_tokens,
53
+ temperature=temperature,
54
+ top_p=0.95,
55
+ do_sample=True,
56
+ repetition_penalty=1.1,
57
+ num_return_sequences=num_samples,
58
+ pad_token_id=tokenizer.pad_token_id,
59
+ eos_token_id=tokenizer.eos_token_id,
60
+ )
 
 
 
 
 
 
 
 
 
 
61
 
62
+ completions = []
63
+ for output in outputs:
64
+ text = tokenizer.decode(output, skip_special_tokens=True)
65
+ if text.startswith(prompt):
66
+ text = text[len(prompt):]
67
+ completions.append(text.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ return completions
70
+
71
+
72
+ def extract_code(completion: str) -> str:
73
+ """Extract code from completion, handling markdown code blocks."""
74
+ # Try ```python blocks first
75
+ if "```python" in completion:
76
+ start = completion.find("```python") + len("```python")
77
+ end = completion.find("```", start)
78
+ if end != -1:
79
+ return completion[start:end].strip()
80
+
81
+ # Try generic ``` blocks
82
+ if "```" in completion:
83
+ start = completion.find("```") + len("```")
84
+ # Skip language identifier if present
85
+ if completion[start:start+10].strip():
86
+ start = completion.find("\n", start) + 1
87
+ end = completion.find("```", start)
88
+ if end != -1:
89
+ return completion[start:end].strip()
90
+
91
+ return completion.strip()
92
+
93
+
94
+ def execute_code(code: str, timeout: int = 5) -> Tuple[bool, str, Optional[any]]:
95
+ """Safely execute code and return (success, error_msg, result)."""
96
+ import signal
97
+
98
+ class TimeoutError(Exception):
99
+ pass
100
+
101
+ def timeout_handler(signum, frame):
102
+ raise TimeoutError("Execution timed out")
103
+
104
+ safe_builtins = {
105
+ 'print': print,
106
+ 'len': len,
107
+ 'range': range,
108
+ 'str': str,
109
+ 'int': int,
110
+ 'float': float,
111
+ 'bool': bool,
112
+ 'list': list,
113
+ 'dict': dict,
114
+ 'set': set,
115
+ 'tuple': tuple,
116
+ 'sum': sum,
117
+ 'min': min,
118
+ 'max': max,
119
+ 'abs': abs,
120
+ 'sorted': sorted,
121
+ 'reversed': reversed,
122
+ 'enumerate': enumerate,
123
+ 'zip': zip,
124
+ 'map': map,
125
+ 'filter': filter,
126
+ 'any': any,
127
+ 'all': all,
128
+ 'isinstance': isinstance,
129
+ 'type': type,
130
+ 'round': round,
131
+ 'pow': pow,
132
+ 'divmod': divmod,
133
+ 'ord': ord,
134
+ 'chr': chr,
135
+ 'hex': hex,
136
+ 'bin': bin,
137
+ 'id': id,
138
+ }
139
+
140
+ namespace = {'__builtins__': safe_builtins}
141
+
142
+ try:
143
+ signal.signal(signal.SIGALRM, timeout_handler)
144
+ signal.alarm(timeout)
145
+ exec(code, namespace)
146
+ signal.alarm(0)
147
+ return True, "", namespace.get('result')
148
+ except TimeoutError as e:
149
+ signal.alarm(0)
150
+ return False, f"Timeout after {timeout}s", None
151
+ except SyntaxError as e:
152
+ signal.alarm(0)
153
+ return False, f"Syntax error: {e}", None
154
+ except Exception as e:
155
+ signal.alarm(0)
156
+ return False, f"{type(e).__name__}: {e}", None
157
+
158
+
159
+ def check_function_output(code: str, func_name: str, test_cases: List[Dict]) -> Tuple[bool, str]:
160
+ """Test a function with given test cases.
161
 
162
+ Args:
163
+ code: The generated code
164
+ func_name: Name of function to test
165
+ test_cases: List of dicts with 'input' (tuple), 'expected', 'description'
166
 
167
+ Returns:
168
+ Tuple of (all_passed, failure_message)
169
+ """
170
+ namespace = {'__builtins__': {
171
+ 'print': print,
172
+ 'len': len,
173
+ 'range': range,
174
+ 'str': str,
175
+ 'int': int,
176
+ 'float': float,
177
+ 'bool': bool,
178
+ 'list': list,
179
+ 'dict': dict,
180
+ 'set': set,
181
+ 'tuple': tuple,
182
+ 'sum': sum,
183
+ 'min': min,
184
+ 'max': max,
185
+ 'abs': abs,
186
+ 'sorted': sorted,
187
+ }}
188
+
189
+ try:
190
+ exec(code, namespace)
191
+ except Exception as e:
192
+ return False, f"Code execution failed: {type(e).__name__}: {e}"
193
+
194
+ if func_name not in namespace:
195
+ return False, f"Function '{func_name}' not found"
196
+
197
+ func = namespace[func_name]
198
+
199
+ for tc in test_cases:
200
+ inputs = tc.get('input', ())
201
+ expected = tc.get('expected')
202
+ desc = tc.get('description', str(inputs))
203
 
204
  try:
205
+ if isinstance(inputs, tuple):
206
+ result = func(*inputs)
 
 
 
 
 
 
 
 
 
 
207
  else:
208
+ result = func(inputs)
 
 
 
209
  except Exception as e:
210
+ return False, f"Failed on {desc}: {type(e).__name__}: {e}"
211
+
212
+ if result != expected:
213
+ return False, f"Failed on {desc}: expected {expected}, got {result}"
214
 
215
+ return True, ""
216
+
217
+
218
+ # Common algorithm test cases
219
+ ALGORITHM_TESTS = [
220
+ {
221
+ "name": "Two Sum",
222
+ "prompt": "def two_sum(nums, target):\n \"\"\"Return indices of two numbers that add up to target.\"\"\"\n",
223
+ "function": "two_sum",
224
+ "max_tokens": 100,
225
+ "test_cases": [
226
+ {"input": ([2,7,11,15], 9), "expected": [0,1], "description": "Basic"},
227
+ {"input": ([3,2,4], 6), "expected": [1,2], "description": "Middle"},
228
+ ],
229
+ "keywords": ["for", "in", "range", "enumerate"],
230
+ },
231
+ {
232
+ "name": "Reverse String",
233
+ "prompt": "def reverse_string(s):\n \"\"\"Return the reverse of string s.\"\"\"\n",
234
+ "function": "reverse_string",
235
+ "max_tokens": 50,
236
+ "test_cases": [
237
+ {"input": ("hello",), "expected": "olleh", "description": "Basic"},
238
+ {"input": ("Python",), "expected": "nohtyP", "description": "Mixed case"},
239
+ ],
240
+ "keywords": ["[::-1]", "reversed"],
241
+ },
242
+ {
243
+ "name": "Fibonacci",
244
+ "prompt": "def fibonacci(n):\n \"\"\"Return first n Fibonacci numbers.\"\"\"\n",
245
+ "function": "fibonacci",
246
+ "max_tokens": 100,
247
+ "test_cases": [
248
+ {"input": (7,), "expected": [0,1,1,2,3,5,8], "description": "n=7"},
249
+ {"input": (1,), "expected": [0], "description": "n=1"},
250
+ ],
251
+ "keywords": ["for", "while", "append", "range"],
252
+ },
253
+ {
254
+ "name": "Factorial",
255
+ "prompt": "def factorial(n):\n \"\"\"Return n! (factorial).\"\"\"\n",
256
+ "function": "factorial",
257
+ "max_tokens": 60,
258
+ "test_cases": [
259
+ {"input": (5,), "expected": 120, "description": "5!"},
260
+ {"input": (0,), "expected": 1, "description": "0!"},
261
+ ],
262
+ "keywords": ["for", "while", "range", "*"],
263
+ },
264
+ {
265
+ "name": "Is Palindrome",
266
+ "prompt": "def is_palindrome(x):\n \"\"\"Check if integer x is a palindrome.\"\"\"\n",
267
+ "function": "is_palindrome",
268
+ "max_tokens": 60,
269
+ "test_cases": [
270
+ {"input": (121,), "expected": True, "description": "121"},
271
+ {"input": (-121,), "expected": False, "description": "-121"},
272
+ ],
273
+ "keywords": ["str", "[::-1]"],
274
+ },
275
+ {
276
+ "name": "Binary Search",
277
+ "prompt": "def binary_search(arr, target):\n \"\"\"Return index of target in sorted array, or -1 if not found.\"\"\"\n",
278
+ "function": "binary_search",
279
+ "max_tokens": 120,
280
+ "test_cases": [
281
+ {"input": ([1,2,3,4,5], 3), "expected": 2, "description": "Found"},
282
+ {"input": ([1,2,3,4,5], 6), "expected": -1, "description": "Not found"},
283
+ ],
284
+ "keywords": ["while", "left", "right", "<=", ">"],
285
+ },
286
+ {
287
+ "name": "Merge Sort",
288
+ "prompt": "def merge_sort(arr):\n \"\"\"Return sorted copy of array using merge sort.\"\"\"\n",
289
+ "function": "merge_sort",
290
+ "max_tokens": 200,
291
+ "test_cases": [
292
+ {"input": ([3,1,4,1,5,9,2,6],), "expected": [1,1,2,3,4,5,6,9], "description": "Mixed"},
293
+ {"input": ([1,2,3],), "expected": [1,2,3], "description": "Already sorted"},
294
+ ],
295
+ "keywords": ["def merge_sort", "if", "len", "return", "merge"],
296
+ },
297
+ {
298
+ "name": "Quick Sort",
299
+ "prompt": "def quick_sort(arr):\n \"\"\"Return sorted copy of array using quick sort.\"\"\"\n",
300
+ "function": "quick_sort",
301
+ "max_tokens": 200,
302
+ "test_cases": [
303
+ {"input": ([3,1,4,1,5,9,2,6],), "expected": [1,1,2,3,4,5,6,9], "description": "Mixed"},
304
+ ],
305
+ "keywords": ["def quick_sort", "if", "len", "return"],
306
+ },
307
+ {
308
+ "name": "Maximum Subarray (Kadane's)",
309
+ "prompt": "def max_subarray(nums):\n \"\"\"Return maximum sum of contiguous subarray.\"\"\"\n",
310
+ "function": "max_subarray",
311
+ "max_tokens": 100,
312
+ "test_cases": [
313
+ {"input": ([-2,1,-3,4,-1,2,1,-5,4],), "expected": 6, "description": "Mixed"},
314
+ {"input": ([1],), "expected": 1, "description": "Single"},
315
+ ],
316
+ "keywords": ["for", "max", "+"],
317
+ },
318
+ {
319
+ "name": "Valid Parentheses",
320
+ "prompt": "def valid_parentheses(s):\n \"\"\"Check if string has valid bracket matching.\"\"\"\n",
321
+ "function": "valid_parentheses",
322
+ "max_tokens": 100,
323
+ "test_cases": [
324
+ {"input": ("()",), "expected": True, "description": "Simple"},
325
+ {"input": ("([)]",), "expected": False, "description": "Wrong order"},
326
+ ],
327
+ "keywords": ["stack", "if", "for", "in", "pop", "append"],
328
+ },
329
+ {
330
+ "name": "Climbing Stairs",
331
+ "prompt": "def climb_stairs(n):\n \"\"\"Count ways to climb n stairs (1 or 2 steps at a time).\"\"\"\n",
332
+ "function": "climb_stairs",
333
+ "max_tokens": 80,
334
+ "test_cases": [
335
+ {"input": (5,), "expected": 8, "description": "n=5"},
336
+ {"input": (2,), "expected": 2, "description": "n=2"},
337
+ ],
338
+ "keywords": ["for", "while", "+", "="],
339
+ },
340
+ {
341
+ "name": "List Sum",
342
+ "prompt": "def list_sum(nums):\n \"\"\"Return sum of all numbers in list.\"\"\"\n",
343
+ "function": "list_sum",
344
+ "max_tokens": 50,
345
+ "test_cases": [
346
+ {"input": ([1,2,3,4],), "expected": 10, "description": "Basic"},
347
+ {"input": ([],), "expected": 0, "description": "Empty"},
348
+ ],
349
+ "keywords": ["for", "in", "+", "sum", "0"],
350
+ },
351
+ {
352
+ "name": "List Average",
353
+ "prompt": "def list_avg(nums):\n \"\"\"Return average of numbers in list.\"\"\"\n",
354
+ "function": "list_avg",
355
+ "max_tokens": 60,
356
+ "test_cases": [
357
+ {"input": ([1,2,3,4,5],), "expected": 3.0, "description": "Basic"},
358
+ {"input": ([5],), "expected": 5.0, "description": "Single"},
359
+ ],
360
+ "keywords": ["sum", "len", "/", "float"],
361
+ },
362
+ {
363
+ "name": "Find Maximum",
364
+ "prompt": "def find_max(nums):\n \"\"\"Return maximum value in list.\"\"\"\n",
365
+ "function": "find_max",
366
+ "max_tokens": 60,
367
+ "test_cases": [
368
+ {"input": ([3,1,4,1,5,9],), "expected": 9, "description": "Basic"},
369
+ {"input": ([-1,-5,-3],), "expected": -1, "description": "Negatives"},
370
+ ],
371
+ "keywords": ["for", "in", "max", ">", "<"],
372
+ },
373
+ {
374
+ "name": "Count Zeros",
375
+ "prompt": "def count_zeros(nums):\n \"\"\"Count zeros in list.\"\"\"\n",
376
+ "function": "count_zeros",
377
+ "max_tokens": 50,
378
+ "test_cases": [
379
+ {"input": ([0,1,0,2,0],), "expected": 3, "description": "Mixed"},
380
+ {"input": ([1,2,3],), "expected": 0, "description": "No zeros"},
381
+ ],
382
+ "keywords": ["for", "in", "count", "==", "+"],
383
+ },
384
+ {
385
+ "name": "Unique Elements",
386
+ "prompt": "def unique_elements(lst):\n \"\"\"Return list of unique elements preserving order.\"\"\"\n",
387
+ "function": "unique_elements",
388
+ "max_tokens": 80,
389
+ "test_cases": [
390
+ {"input": ([1,2,2,3,1],), "expected": [1,2,3], "description": "With dups"},
391
+ {"input": ([1,2,3],), "expected": [1,2,3], "description": "All unique"},
392
+ ],
393
+ "keywords": ["for", "in", "if", "append", "set"],
394
+ },
395
+ ]
396
+
397
+
398
+ def run_test(model, tokenizer, test_config: Dict) -> Dict:
399
+ """Run a single test and return results."""
400
+ name = test_config["name"]
401
+ prompt = test_config["prompt"]
402
+ func_name = test_config["function"]
403
+ max_tokens = test_config["max_tokens"]
404
+ test_cases = test_config["test_cases"]
405
+ keywords = test_config.get("keywords", [])
406
+
407
+ print(f"\n Test: {name}")
408
+ print(f" Prompt: {prompt.strip()[:60]}...")
409
+
410
+ start_time = time.time()
411
+
412
+ # Generate completion
413
+ completions = generate_completion(model, tokenizer, prompt, max_tokens=max_tokens)
414
+ elapsed = time.time() - start_time
415
+
416
+ # Extract and check code
417
+ code = extract_code(completions[0])
418
+
419
+ print(f" Generated in {elapsed:.2f}s")
420
+ print(f" Code preview: {code[:100]}...")
421
+
422
+ # Check syntax and execution
423
+ success, error_msg = check_function_output(code, func_name, test_cases)
424
+
425
+ # Check keywords
426
+ keywords_found = sum(1 for kw in keywords if kw.lower() in code.lower())
427
+ keyword_ratio = keywords_found / len(keywords) if keywords else 0
428
+
429
+ result = {
430
+ "name": name,
431
+ "passed": success,
432
+ "keywords_found": keywords_found,
433
+ "keywords_total": len(keywords),
434
+ "keyword_ratio": keyword_ratio,
435
+ "execution_time": elapsed,
436
+ "error": error_msg if not success else None,
437
+ "generated_code": code[:500], # Truncate for storage
438
+ }
439
+
440
+ if success:
441
+ print(f" Result: ✅ PASS")
442
+ else:
443
+ print(f" Result: ❌ FAIL - {error_msg[:60]}")
444
+
445
+ return result
446
+
447
+
448
+ def calculate_pass_at_k(results: List[Dict], k: int) -> float:
449
+ """Calculate pass@k across all tests.
450
+
451
+ For simplicity, a test passes if it passes the functional test.
452
+ """
453
+ if not results or k <= 0:
454
+ return 0.0
455
+
456
+ passed = sum(1 for r in results if r["passed"])
457
+ total = len(results)
458
+
459
+ # Simple pass@k: probability that at least 1 of k samples would pass
460
+ # Assuming independence, this is 1 - (1 - p)^k where p = passed/total
461
+ if k >= total:
462
+ return passed / total if total > 0 else 0.0
463
+
464
+ # For pass@1, it's just the pass rate
465
+ if k == 1:
466
+ return passed / total if total > 0 else 0.0
467
+
468
+ # For pass@k where k > 1, estimate based on single sample
469
+ p = passed / total if total > 0 else 0.0
470
+ p_at_least_1 = 1 - (1 - p) ** k
471
+ return p_at_least_1
472
+
473
+
474
+ def print_results(results: List[Dict], k_values: List[int] = [1, 10]):
475
+ """Print test results summary."""
476
+ print("\n" + "="*60)
477
+ print("TEST RESULTS SUMMARY")
478
  print("="*60)
479
 
480
+ passed = sum(1 for r in results if r["passed"])
481
+ total = len(results)
482
+
483
+ print(f"\n Total tests: {total}")
484
+ print(f" Passed: {passed}")
485
+ print(f" Failed: {total - passed}")
486
+ print(f" Pass rate: {100*passed/total:.1f}%")
487
+
488
+ for k in k_values:
489
+ pass_at_k = calculate_pass_at_k(results, k)
490
+ print(f"\n Pass@{k}: {100*pass_at_k:.1f}%")
491
+
492
+ print("\n Individual Results:")
493
+ for r in results:
494
+ status = "✅" if r["passed"] else "❌"
495
+ print(f" {status} {r['name']} (keywords: {r['keywords_found']}/{r['keywords_total']})")
496
+
497
+
498
+ def save_results(results: List[Dict], output_path: str):
499
+ """Save test results to JSON."""
500
+ with open(output_path, 'w') as f:
501
+ json.dump(results, f, indent=2)
502
+ print(f"\n✅ Results saved to: {output_path}")
503
 
504
 
505
  def main():
506
+ parser = argparse.ArgumentParser(
507
+ description="Test Stack 2.9 model on common algorithm tasks"
508
+ )
509
+ parser.add_argument(
510
+ "--model-path",
511
+ type=str,
512
+ required=True,
513
+ help="Path to the merged model directory"
514
+ )
515
+ parser.add_argument(
516
+ "--output",
517
+ type=str,
518
+ default="test_results.json",
519
+ help="Output file for results (default: test_results.json)"
520
+ )
521
+ parser.add_argument(
522
+ "--max-tokens",
523
+ type=int,
524
+ default=200,
525
+ help="Max new tokens per generation (default: 200)"
526
+ )
527
+ parser.add_argument(
528
+ "--temperature",
529
+ type=float,
530
+ default=0.2,
531
+ help="Sampling temperature (default: 0.2)"
532
+ )
533
+ parser.add_argument(
534
+ "--test-names",
535
+ type=str,
536
+ default=None,
537
+ help="Comma-separated test names to run (default: all)"
538
+ )
539
+ parser.add_argument(
540
+ "--k-values",
541
+ type=str,
542
+ default="1,10",
543
+ help="Comma-separated k values for pass@k (default: 1,10)"
544
+ )
545
+
546
  args = parser.parse_args()
547
 
548
+ k_values = [int(k.strip()) for k in args.k_values.split(",")]
549
+
550
+ print("="*60)
551
+ print("Stack 2.9 Model - Algorithm Tests")
552
+ print("="*60)
553
+ print(f"Model path: {args.model_path}")
554
+ print(f"Output: {args.output}")
555
+ print(f"Max tokens: {args.max_tokens}")
556
+ print(f"Temperature: {args.temperature}")
557
+
558
+ # Load model
559
+ model, tokenizer = load_model(args.model_path)
560
+ model.eval()
561
+
562
+ # Select tests to run
563
+ tests_to_run = ALGORITHM_TESTS
564
+ if args.test_names:
565
+ names = [n.strip() for n in args.test_names.split(",")]
566
+ tests_to_run = [t for t in tests_to_run if t["name"] in names]
567
+ print(f"Running tests: {[t['name'] for t in tests_to_run]}")
568
+
569
+ if not tests_to_run:
570
+ print("No tests to run!")
571
+ return
572
+
573
+ # Override max_tokens for each test
574
+ for test in tests_to_run:
575
+ if args.max_tokens:
576
+ test["max_tokens"] = args.max_tokens
577
+
578
+ # Run tests
579
+ print("\n" + "="*60)
580
+ print(f"Running {len(tests_to_run)} Tests")
581
+ print("="*60)
582
+
583
+ results = []
584
+ total_start = time.time()
585
+
586
+ for i, test in enumerate(tests_to_run, 1):
587
+ print(f"\n[{i}/{len(tests_to_run)}]")
588
+ result = run_test(model, tokenizer, test)
589
+ results.append(result)
590
+
591
+ total_time = time.time() - total_start
592
+
593
+ # Print summary
594
+ print_results(results, k_values)
595
+ print(f"\n Total time: {total_time:.1f}s")
596
+
597
+ # Save results
598
+ save_results(results, args.output)
599
 
600
 
601
  if __name__ == "__main__":
train_simple.py DELETED
@@ -1,197 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Simple standalone training script for Stack 2.9.
4
- No package installation required — just run: python train_simple.py --config train_config.yaml
5
- """
6
-
7
- import argparse
8
- import os
9
- import sys
10
- from pathlib import Path
11
-
12
- import yaml
13
- from datasets import load_dataset
14
- from transformers import (
15
- AutoModelForCausalLM,
16
- AutoTokenizer,
17
- BitsAndBytesConfig,
18
- TrainingArguments,
19
- Trainer,
20
- DataCollatorForLanguageModeling
21
- )
22
- from peft import LoraConfig, get_peft_model, TaskType
23
- import torch
24
-
25
-
26
- def load_config(config_path: str) -> dict:
27
- with open(config_path, 'r') as f:
28
- return yaml.safe_load(f)
29
-
30
-
31
- def load_model_and_tokenizer(model_name: str, trust_remote_code: bool = True, use_4bit: bool = False):
32
- """Load base model and tokenizer."""
33
- if use_4bit:
34
- quantization_config = BitsAndBytesConfig(
35
- load_in_4bit=True,
36
- bnb_4bit_compute_dtype=torch.float16,
37
- bnb_4bit_quant_type="nf4",
38
- bnb_4bit_use_double_quant=True
39
- )
40
- else:
41
- quantization_config = None
42
-
43
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
44
- model = AutoModelForCausalLM.from_pretrained(
45
- model_name,
46
- quantization_config=quantization_config,
47
- torch_dtype=torch.float16,
48
- trust_remote_code=trust_remote_code,
49
- device_map="auto"
50
- )
51
- return model, tokenizer
52
-
53
-
54
- def load_data(data_path: str, tokenizer, max_length: int = 2048, train_split: float = 0.9):
55
- """Load and tokenize dataset."""
56
- raw_dataset = load_dataset("json", data_files=data_path, split="train")
57
-
58
- def tokenize_function(examples):
59
- # Combine instruction and output
60
- texts = []
61
- for instr, out in zip(examples.get("instruction", [""]), examples.get("output", [""])):
62
- if instr and out:
63
- texts.append(f"### Instruction:\n{instr}\n\n### Response:\n{out}")
64
- elif out:
65
- texts.append(out)
66
- elif instr:
67
- texts.append(instr)
68
- else:
69
- texts.append("")
70
-
71
- tokenized = tokenizer(texts, truncation=True, max_length=max_length, padding="max_length")
72
- tokenized["labels"] = tokenized["input_ids"].copy()
73
- return tokenized
74
-
75
- tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=raw_dataset.column_names)
76
- split = tokenized_dataset.train_test_split(train_size=train_split)
77
- return split["train"], split["test"]
78
-
79
-
80
- def train(config: dict):
81
- """Main training function."""
82
- model_config = config["model"]
83
- data_config = config["data"]
84
- lora_config = config["lora"]
85
- training_config = config["training"]
86
- output_config = config["output"]
87
- hardware_config = config["hardware"]
88
-
89
- # Load model and tokenizer
90
- print(f"Loading model: {model_config['name']}")
91
- model, tokenizer = load_model_and_tokenizer(
92
- model_name=model_config["name"],
93
- trust_remote_code=model_config.get("trust_remote_code", True),
94
- use_4bit=hardware_config.get("use_4bit", False)
95
- )
96
-
97
- # Load data
98
- print(f"Loading dataset: {data_config['input_path']}")
99
- train_dataset, eval_dataset = load_data(
100
- data_path=data_config["input_path"],
101
- tokenizer=tokenizer,
102
- max_length=data_config.get("max_length", 2048),
103
- train_split=data_config.get("train_split", 0.9)
104
- )
105
- print(f" Train samples: {len(train_dataset)}")
106
- print(f" Eval samples: {len(eval_dataset)}")
107
-
108
- # Apply LoRA
109
- peft_config = LoraConfig(
110
- r=lora_config["r"],
111
- alpha=lora_config["alpha"],
112
- dropout=lora_config["dropout"],
113
- target_modules=lora_config["target_modules"],
114
- bias=lora_config["bias"],
115
- task_type=TaskType.CAUSAL_LM
116
- )
117
- model = get_peft_model(model, peft_config)
118
- model.print_trainable_parameters()
119
-
120
- # Training arguments
121
- output_dir = output_config["lora_dir"]
122
- os.makedirs(output_dir, exist_ok=True)
123
-
124
- training_args = TrainingArguments(
125
- output_dir=output_dir,
126
- num_train_epochs=training_config["num_epochs"],
127
- per_device_train_batch_size=training_config["batch_size"],
128
- gradient_accumulation_steps=training_config["gradient_accumulation"],
129
- learning_rate=training_config["learning_rate"],
130
- warmup_steps=training_config.get("warmup_steps", 100),
131
- weight_decay=training_config.get("weight_decay", 0.01),
132
- max_grad_norm=training_config.get("max_grad_norm", 1.0),
133
- logging_steps=training_config.get("logging_steps", 10),
134
- save_steps=training_config.get("save_steps", 100),
135
- save_total_limit=training_config.get("save_total_limit", 2),
136
- fp16=training_config.get("fp16", True),
137
- bf16=training_config.get("bf16", False),
138
- gradient_checkpointing=training_config.get("gradient_checkpointing", True),
139
- evaluation_strategy="steps" if eval_dataset else "no",
140
- eval_steps=training_config.get("eval_steps", 100) if eval_dataset else None,
141
- report_to="none", # No WandB
142
- )
143
-
144
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
145
-
146
- trainer = Trainer(
147
- model=model,
148
- args=training_args,
149
- train_dataset=train_dataset,
150
- eval_dataset=eval_dataset,
151
- data_collator=data_collator,
152
- tokenizer=tokenizer,
153
- )
154
-
155
- print("="*60)
156
- print("Starting training...")
157
- print("="*60)
158
- trainer.train()
159
- print("Training completed!")
160
-
161
- # Save final adapter
162
- trainer.save_model(output_dir)
163
- print(f"✅ Adapter saved to {output_dir}")
164
-
165
- return trainer
166
-
167
-
168
- def main():
169
- parser = argparse.ArgumentParser()
170
- parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
171
- args = parser.parse_args()
172
-
173
- print("="*60)
174
- print("Stack 2.9 Simple Training")
175
- print("="*60)
176
-
177
- config = load_config(args.config)
178
- print(f"Config loaded: {args.config}")
179
- print(f"Model: {config['model']['name']}")
180
- print(f"Data: {config['data']['input_path']}")
181
-
182
- try:
183
- train(config)
184
- print("\n" + "="*60)
185
- print("✅ TRAINING SUCCESS")
186
- print("="*60)
187
- except Exception as e:
188
- print("\n" + "="*60)
189
- print(f"❌ TRAINING FAILED: {e}")
190
- print("="*60)
191
- import traceback
192
- traceback.print_exc()
193
- sys.exit(1)
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training-configs/7b-lora-config.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stack 2.9 Training Configuration for Qwen2.5-Coder-7B
2
+ # Optimized for A100 80GB or multi-GPU setup (2x A100 40GB minimum)
3
+ # For single GPU training with 7B model
4
+
5
+ model:
6
+ name: "Qwen/Qwen2.5-Coder-7B"
7
+ trust_remote_code: true
8
+ torch_dtype: "bfloat16" # bf16 for A100, better precision than fp16
9
+ use_flash_attention: true # A100 supports FA2
10
+
11
+ tokenizer:
12
+ model_max_length: 8192 # Standard 8K context for coding tasks
13
+ padding_side: "right"
14
+ truncation_side: "right"
15
+
16
+ # LoRA Configuration - Optimized for 7B model
17
+ lora:
18
+ r: 16 # LoRA rank: 16 is sufficient for 7B, saves memory
19
+ alpha: 32 # alpha = 2*r is standard
20
+ dropout: 0.05
21
+ target_modules:
22
+ - "q_proj"
23
+ - "k_proj"
24
+ - "v_proj"
25
+ - "o_proj"
26
+ - "gate_proj"
27
+ - "up_proj"
28
+ - "down_proj"
29
+ bias: "none"
30
+ task_type: "CAUSAL_LM"
31
+
32
+ # Training Configuration
33
+ training:
34
+ num_epochs: 3 # 3 epochs is typically sufficient
35
+ batch_size: 2 # 2 per GPU for A100 80GB
36
+ gradient_accumulation: 8 # Effective batch = 16 (2 * 8)
37
+ learning_rate: 1.0e-4 # Standard for LoRA
38
+ warmup_steps: 100
39
+ weight_decay: 0.01
40
+ max_grad_norm: 1.0
41
+ logging_steps: 10
42
+ eval_steps: 500
43
+ save_steps: 1000
44
+ save_total_limit: 3
45
+ fp16: false
46
+ bf16: true # Use bf16 for A100
47
+ gradient_checkpointing: true # Enable for memory savings
48
+ optim: "adamw_torch" # Standard optimizer
49
+
50
+ # Data Configuration
51
+ data:
52
+ train_files:
53
+ - "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/train.jsonl"
54
+ val_file: "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/val.jsonl"
55
+ test_file: "/Users/walidsobhi/.openclaw/workspace/stack-2.9/training-data/final/test.jsonl"
56
+ max_length: 8192
57
+ train_split: 0.9
58
+ test_split: 0.1
59
+
60
+ # Output Configuration
61
+ output:
62
+ lora_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-lora"
63
+ merged_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-merged"
64
+ awq_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-awq"
65
+
66
+ # Quantization - Optional for 7B
67
+ # Enable 4-bit only if running on limited VRAM (< 40GB per GPU)
68
+ quantization:
69
+ enabled: false # Disabled for A100 80GB - not needed
70
+ bits: 4
71
+ group_size: 128
72
+
73
+ # Logging Configuration
74
+ logging:
75
+ report_to: "none" # Set to "wandb" for Weights & Biases
76
+ wandb_project: "stack-2.9-training"
77
+ run_name: null
78
+
79
+ # Hardware Configuration
80
+ hardware:
81
+ device: "cuda"
82
+ num_gpus: 1 # Set to 2, 4, or 8 for multi-GPU
83
+ use_4bit: false # Not needed for A100 80GB
84
+ use_8bit: false # Optional: enables for memory savings
85
+ zero_stage: 0 # ZeRO stage (0=disabled, 1,2,3 for multi-GPU)
86
+
87
+ # Multi-GPU Configuration (use when num_gpus > 1)
88
+ multigpu:
89
+ zero_stage: 2 # ZeRO-2 recommended for 2-4 GPUs, ZeRO-3 for 8+
90
+ gradient_bucket_size: 5e8
91
+ prefetch_bucket_size: 5e8
92
+ steps_per_print: 10
93
+
94
+ # Memory Estimates (7B):
95
+ # - Full bf16, batch_size=2: ~18GB VRAM
96
+ # - Full bf16, batch_size=4: ~22GB VRAM
97
+ # - 4-bit QLoRA, batch_size=1: ~6GB VRAM
98
+ # - 4-bit QLoRA, batch_size=2: ~8GB VRAM
99
+ # A100 80GB = 80GB available (has headroom)
100
+
101
+ # Merge Configuration (run after training)
102
+ merge:
103
+ enabled: true
104
+ output_dir: "/Users/walidsobhi/.openclaw/workspace/stack-2.9-training/output/stack-2.9-7b-merged"
training-configs/kaggle-7b-qlora.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Kaggle QLoRA Training Script for Stack 2.9 on Qwen2.5-Coder-7B
3
+ #
4
+ # Hardware: Kaggle P100 GPU (16GB VRAM)
5
+ # Expected runtime: 2-4 hours depending on dataset size
6
+ # Memory: ~6GB VRAM with QLoRA + gradient checkpointing
7
+ #
8
+ # Usage:
9
+ # chmod +x kaggle-7b-qlora.sh
10
+ # ./kaggle-7b-qlora.sh
11
+
12
+ set -e
13
+
14
+ # === Configuration ===
15
+ MODEL_NAME="Qwen/Qwen2.5-Coder-7B"
16
+ OUTPUT_DIR="/kaggle/working/stack-2.9-output/lora-7b-qlora"
17
+ DATA_DIR="/kaggle/working/stack-2.9/data"
18
+
19
+ # === Hyperparameters ===
20
+ # LoRA config (r=16, alpha=32)
21
+ LORA_R=16
22
+ LORA_ALPHA=32
23
+ LORA_DROPOUT=0.05
24
+
25
+ # Training config
26
+ EPOCHS=2
27
+ BATCH_SIZE=1
28
+ GRAD_ACCUM=8 # Effective = 8
29
+ LR=1e-4
30
+ MAX_LENGTH=4096 # Reduced for Kaggle memory
31
+
32
+ # === Environment Setup ===
33
+ echo "=== Setting up environment ==="
34
+ cd /kaggle/working/stack-2.9
35
+
36
+ # Install dependencies if needed
37
+ pip install -q transformers datasets peft bitsandbytes accelerate torch
38
+
39
+ # === Training Command ===
40
+ echo "=== Starting QLoRA training ==="
41
+ echo "Model: $MODEL_NAME"
42
+ echo "Output: $OUTPUT_DIR"
43
+ echo "Epochs: $EPOCHS"
44
+
45
+ python -m trl.sft \
46
+ --model_name "$MODEL_NAME" \
47
+ --train_files "$DATA_DIR/train.jsonl" \
48
+ --validation_files "$DATA_DIR/val.jsonl" \
49
+ --output_dir "$OUTPUT_DIR" \
50
+ --per_device_train_batch_size $BATCH_SIZE \
51
+ --gradient_accumulation_steps $GRAD_ACCUM \
52
+ --num_train_epochs $EPOCHS \
53
+ --learning_rate $LR \
54
+ --max_seq_length $MAX_LENGTH \
55
+ --logging_steps 10 \
56
+ --save_steps 500 \
57
+ --save_total_limit 2 \
58
+ --bf16 true \
59
+ --gradient_checkpointing true \
60
+ --lora_r $LORA_R \
61
+ --lora_alpha $LORA_ALPHA \
62
+ --lora_dropout $LORA_DROPOUT \
63
+ --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \
64
+ --load_in_4bit true \
65
+ --bnb_4bit_compute_dtype "bfloat16" \
66
+ --bnb_4bit_quant_type "nf4" \
67
+ --bnb_4bit_use_double_quant true
68
+
69
+ echo "=== Training complete ==="
70
+ echo "LoRA adapters saved to: $OUTPUT_DIR"
71
+
72
+ # Optional: Merge and save merged model
73
+ echo "=== Merging adapter ==="
74
+ python merge_simple.py \
75
+ --base_model "$MODEL_NAME" \
76
+ --lora_adapter "$OUTPUT_DIR" \
77
+ --output_dir "$OUTPUT_DIR/merged"
78
+
79
+ echo "=== Done ==="
training-data/tool_examples.json ADDED
The diff for this file is too large to render. See raw diff
 
training-data/tool_examples.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1043720a918f5fe0f70cc013c108710570c37ae6c9cee6f504e49dc359af5a2a
3
+ size 3779800