| """ |
| Training script for Resonance 200M. |
| ClimbMix data, own BPE tokenizer (Rust backend), AdamW optimizer. |
| Shows BOTH train loss AND val loss. Always. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import math |
| import struct |
| import argparse |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.amp import autocast, GradScaler |
|
|
| from model import Resonance, RESONANCE_200M |
| from bpe_tokenizer import BPETokenizer |
|
|
|
|
| |
| |
| |
|
|
| def download_climbmix_shards(data_dir, n_shards=100): |
| """Download ClimbMix parquet shards from HuggingFace.""" |
| os.makedirs(data_dir, exist_ok=True) |
|
|
| try: |
| import pyarrow.parquet as pq |
| except ImportError: |
| print("pip install pyarrow pandas") |
| sys.exit(1) |
|
|
| base_url = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" |
| texts_path = os.path.join(data_dir, "texts.txt") |
|
|
| if os.path.exists(texts_path): |
| size = os.path.getsize(texts_path) |
| print(f" [Data] texts.txt exists ({size/1e9:.2f} GB), skipping download") |
| return texts_path |
|
|
| import urllib.request |
| import ssl |
| ctx = ssl.create_default_context() |
| ctx.check_hostname = False |
| ctx.verify_mode = ssl.CERT_NONE |
|
|
| total_bytes = 0 |
| with open(texts_path, 'w', encoding='utf-8') as out: |
| for i in range(n_shards): |
| shard_name = f"shard_{i:05d}.parquet" |
| shard_path = os.path.join(data_dir, shard_name) |
| url = f"{base_url}/{shard_name}" |
|
|
| if not os.path.exists(shard_path): |
| print(f" [Data] Downloading shard {i+1}/{n_shards}...", end=" ", flush=True) |
| try: |
| urllib.request.urlretrieve(url, shard_path) |
| print("OK") |
| except Exception as e: |
| print(f"FAIL: {e}") |
| continue |
|
|
| |
| try: |
| table = pq.read_table(shard_path, columns=['text']) |
| texts = table.column('text').to_pylist() |
| for text in texts: |
| if text and len(text) > 100: |
| out.write(text + '\n') |
| total_bytes += len(text) |
| |
| os.remove(shard_path) |
| except Exception as e: |
| print(f" [Data] Error reading shard {i}: {e}") |
| continue |
|
|
| if (i + 1) % 10 == 0: |
| print(f" [Data] {i+1}/{n_shards} shards, {total_bytes/1e9:.2f} GB text") |
|
|
| print(f" [Data] Total: {total_bytes/1e9:.2f} GB text from {n_shards} shards") |
| return texts_path |
|
|
|
|
| def tokenize_data(texts_path, tokenizer, data_dir, context_len): |
| """Tokenize text into binary shards (uint16 for vocab < 65536). |
| Streams to disk β no OOM on 16GB+ corpora.""" |
| train_path = os.path.join(data_dir, "train.bin") |
| val_path = os.path.join(data_dir, "val.bin") |
|
|
| if os.path.exists(train_path) and os.path.exists(val_path): |
| train_tokens = os.path.getsize(train_path) // 2 |
| val_tokens = os.path.getsize(val_path) // 2 |
| print(f" [Data] Tokenized data exists: train={train_tokens:,} val={val_tokens:,}") |
| return train_tokens, val_tokens |
|
|
| print(f" [Data] Tokenizing...") |
| tmp_path = os.path.join(data_dir, "tokens_all.bin") |
| total_tokens = 0 |
| t0 = time.time() |
|
|
| with open(texts_path, 'r', encoding='utf-8', errors='replace') as f_in, \ |
| open(tmp_path, 'wb') as f_out: |
| chunk_size = 10_000_000 |
| total_chars = 0 |
| while True: |
| text = f_in.read(chunk_size) |
| if not text: |
| break |
| ids = tokenizer.encode(text) |
| arr = np.array(ids, dtype=np.uint16) |
| f_out.write(arr.tobytes()) |
| total_tokens += len(ids) |
| total_chars += len(text) |
| if total_chars % 100_000_000 < chunk_size: |
| elapsed = time.time() - t0 |
| rate = total_chars / elapsed / 1e6 |
| print(f" [Data] {total_chars/1e9:.2f} GB text β {total_tokens:,} tokens " |
| f"({rate:.1f} MB/s, {elapsed:.0f}s)") |
|
|
| elapsed = time.time() - t0 |
| print(f" [Data] Tokenized {total_chars/1e9:.2f} GB β {total_tokens:,} tokens in {elapsed:.0f}s") |
|
|
| |
| split = int(total_tokens * 0.95) |
| print(f" [Data] Splitting: train={split:,} val={total_tokens - split:,}") |
|
|
| all_data = np.memmap(tmp_path, dtype=np.uint16, mode='r') |
|
|
| |
| chunk = 50_000_000 |
| with open(train_path, 'wb') as f: |
| for start in range(0, split, chunk): |
| end = min(start + chunk, split) |
| f.write(all_data[start:end].tobytes()) |
|
|
| |
| with open(val_path, 'wb') as f: |
| for start in range(split, total_tokens, chunk): |
| end = min(start + chunk, total_tokens) |
| f.write(all_data[start:end].tobytes()) |
|
|
| del all_data |
| os.remove(tmp_path) |
|
|
| train_tokens = split |
| val_tokens = total_tokens - split |
| print(f" [Data] train: {train_tokens:,} tokens ({train_tokens*2/1e9:.2f} GB)") |
| print(f" [Data] val: {val_tokens:,} tokens ({val_tokens*2/1e9:.2f} GB)") |
| return train_tokens, val_tokens |
|
|
|
|
| class DataLoader: |
| """Simple random-chunk dataloader from mmap'd binary file.""" |
|
|
| def __init__(self, path, context_len, batch_size, device): |
| self.data = np.memmap(path, dtype=np.uint16, mode='r') |
| self.context_len = context_len |
| self.batch_size = batch_size |
| self.device = device |
| self.n_tokens = len(self.data) |
|
|
| def get_batch(self): |
| T = self.context_len |
| B = self.batch_size |
| ix = torch.randint(0, self.n_tokens - T - 1, (B,)) |
| x = torch.stack([torch.from_numpy(self.data[i:i+T].astype(np.int64)) for i in ix]) |
| y = torch.stack([torch.from_numpy(self.data[i+1:i+T+1].astype(np.int64)) for i in ix]) |
| return x.to(self.device), y.to(self.device) |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step, warmup_steps, total_steps, max_lr, min_lr=0.0): |
| """WSD schedule: warmup β stable β linear decay.""" |
| if step < warmup_steps: |
| return max_lr * (step + 1) / warmup_steps |
| decay_start = total_steps // 2 |
| if step < decay_start: |
| return max_lr |
| |
| progress = (step - decay_start) / (total_steps - decay_start) |
| return max_lr * (1.0 - progress) + min_lr * progress |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, val_loader, n_batches=50): |
| """Evaluate val loss. Returns average loss.""" |
| model.eval() |
| losses = [] |
| for _ in range(n_batches): |
| x, y = val_loader.get_batch() |
| with autocast('cuda', dtype=torch.bfloat16): |
| _, loss = model(x, y) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) |
|
|
|
|
| def save_checkpoint(model, optimizer, step, train_loss, val_loss, config, path): |
| """Save PyTorch checkpoint.""" |
| torch.save({ |
| 'model': model.state_dict(), |
| 'optimizer': optimizer.state_dict(), |
| 'step': step, |
| 'train_loss': train_loss, |
| 'val_loss': val_loss, |
| 'config': config, |
| }, path) |
|
|
|
|
| def save_c_weights(model, tokenizer, config, path): |
| """Save weights in C-compatible binary format for resonance-bpe.c.""" |
| with open(path, 'wb') as f: |
| |
| f.write(struct.pack('<I', 0x52533032)) |
| f.write(struct.pack('<9i', |
| config['n_embd'], config['n_layer'], config['context_len'], |
| config['n_head'], config['head_dim'], config['rrpram_rank'], |
| config['ffn_dim'], config['vocab_size'], config['n_head'])) |
|
|
| |
| f.write(struct.pack('<I', len(tokenizer.merges))) |
| for a, b, new_id in tokenizer.merges: |
| f.write(struct.pack('<III', a, b, new_id)) |
|
|
| |
| for name, param in model.named_parameters(): |
| data = param.detach().float().cpu().numpy() |
| f.write(data.tobytes()) |
|
|
| size_mb = os.path.getsize(path) / 1e6 |
| print(f" [Save] C weights: {path} ({size_mb:.1f} MB)") |
|
|
|
|
| def train(args): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Device: {device}") |
|
|
| config = RESONANCE_200M.copy() |
| if args.vocab_size: |
| config['vocab_size'] = args.vocab_size |
|
|
| data_dir = args.data_dir |
| os.makedirs(data_dir, exist_ok=True) |
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| |
| print("\n[1] Data...") |
| texts_path = download_climbmix_shards(data_dir, n_shards=args.n_shards) |
|
|
| |
| print("\n[2] BPE tokenizer...") |
| tok_path = os.path.join(args.save_dir, "tokenizer.bin") |
| tokenizer = BPETokenizer(max_merges=config['vocab_size'] - 256) |
|
|
| if os.path.exists(tok_path): |
| tokenizer.load(tok_path) |
| else: |
| |
| with open(texts_path, 'rb') as f: |
| sample = f.read(200_000_000) |
| tokenizer.train(sample, num_merges=config['vocab_size'] - 256, report_every=2000) |
| tokenizer.save_copies(tok_path, n=3) |
|
|
| config['vocab_size'] = tokenizer.vocab_size |
|
|
| |
| print("\n[3] Tokenizing data...") |
| n_train, n_val = tokenize_data(texts_path, tokenizer, data_dir, config['context_len']) |
|
|
| |
| print("\n[4] Model...") |
| model = Resonance(config).to(device) |
| model.set_gradient_checkpointing(True) |
| model = torch.compile(model) |
| print(f" Gradient checkpointing: ON, torch.compile: ON") |
|
|
| |
| print("\n[5] Optimizer...") |
| |
| decay_params = [] |
| no_decay_params = [] |
| for name, p in model.named_parameters(): |
| if p.dim() >= 2: |
| decay_params.append(p) |
| else: |
| no_decay_params.append(p) |
|
|
| optimizer = torch.optim.AdamW([ |
| {'params': decay_params, 'weight_decay': args.weight_decay}, |
| {'params': no_decay_params, 'weight_decay': 0.0}, |
| ], lr=args.lr, betas=(0.9, 0.95), eps=1e-8) |
|
|
| scaler = GradScaler('cuda') |
|
|
| |
| T = config['context_len'] |
| micro_B = args.micro_batch // T |
| grad_accum = args.batch_size // args.micro_batch |
| print(f"\n[6] DataLoader: effective_batch={args.batch_size} tokens " |
| f"({grad_accum} x {args.micro_batch} micro), {micro_B} seq x {T} ctx") |
|
|
| train_loader = DataLoader(os.path.join(data_dir, "train.bin"), T, micro_B, device) |
| val_loader = DataLoader(os.path.join(data_dir, "val.bin"), T, micro_B, device) |
|
|
| total_steps = n_train // args.batch_size |
| print(f" Total steps: {total_steps:,}") |
|
|
| |
| print(f"\n[7] Training resonance-200m...") |
| print(f" {'step':>8} | {'train_loss':>10} | {'val_loss':>10} | {'lr':>10} | {'tok/s':>10} | {'time':>8}") |
| print(" " + "-" * 75) |
|
|
| best_val_loss = float('inf') |
| running_loss = 0.0 |
| t0 = time.time() |
| tokens_seen = 0 |
|
|
| model.train() |
| for step in range(total_steps): |
| |
| lr = get_lr(step, args.warmup_steps, total_steps, args.lr) |
| for pg in optimizer.param_groups: |
| pg['lr'] = lr |
|
|
| |
| optimizer.zero_grad(set_to_none=True) |
| step_loss = 0.0 |
| for micro_step in range(grad_accum): |
| x, y = train_loader.get_batch() |
| with autocast('cuda', dtype=torch.bfloat16): |
| _, loss = model(x, y) |
| loss = loss / grad_accum |
| scaler.scale(loss).backward() |
| step_loss += loss.item() * grad_accum |
|
|
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| train_loss = step_loss / grad_accum |
| running_loss += train_loss |
| tokens_seen += args.batch_size |
|
|
| |
| if (step + 1) % args.log_every == 0: |
| avg_train = running_loss / args.log_every |
| running_loss = 0.0 |
| elapsed = time.time() - t0 |
| tok_per_sec = tokens_seen / elapsed |
|
|
| |
| val_loss = evaluate(model, val_loader, n_batches=args.val_batches) |
|
|
| print(f" {step+1:>8} | {avg_train:>10.4f} | {val_loss:>10.4f} | " |
| f"{lr:>10.2e} | {tok_per_sec/1000:>8.1f}k | {elapsed:>7.0f}s") |
|
|
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint(model, optimizer, step, avg_train, val_loss, config, |
| os.path.join(args.save_dir, "best.pt")) |
|
|
| |
| if (step + 1) % args.save_every == 0: |
| save_checkpoint(model, optimizer, step, train_loss, val_loss if 'val_loss' in dir() else 0, |
| config, os.path.join(args.save_dir, f"step_{step+1}.pt")) |
| save_c_weights(model, tokenizer, config, |
| os.path.join(args.save_dir, f"resonance_200m_step{step+1}.bin")) |
|
|
| |
| if (step + 1) % (args.log_every * 5) == 0: |
| gates = [] |
| for block in model._orig_mod.blocks if hasattr(model, '_orig_mod') else model.blocks: |
| g = torch.sigmoid(block.gate).detach().cpu().numpy() |
| gates.append(g.mean()) |
| gate_str = " ".join(f"{g:.2f}" for g in gates) |
| print(f" [gates] {gate_str}") |
|
|
| |
| elapsed = time.time() - t0 |
| print(f"\n Training complete. {elapsed/3600:.1f} hours, {tokens_seen:,} tokens") |
|
|
| save_checkpoint(model, optimizer, total_steps, train_loss, best_val_loss, config, |
| os.path.join(args.save_dir, "final.pt")) |
| save_c_weights(model, tokenizer, config, |
| os.path.join(args.save_dir, "resonance_200m_final.bin")) |
|
|
| |
| tokenizer.save_copies(os.path.join(args.save_dir, "tokenizer.bin"), n=3) |
|
|
| print(f"\n Best val loss: {best_val_loss:.4f}") |
| print(f" resonance is unbreakable.") |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--data-dir', type=str, default='data/') |
| parser.add_argument('--save-dir', type=str, default='checkpoints/') |
| parser.add_argument('--n-shards', type=int, default=65, |
| help='Number of ClimbMix shards to download (~65 for ~4B tokens)') |
| parser.add_argument('--vocab-size', type=int, default=None, |
| help='Override vocab size (default: 16384)') |
| parser.add_argument('--batch-size', type=int, default=131072, |
| help='Effective batch size in tokens (default: 131072)') |
| parser.add_argument('--micro-batch', type=int, default=65536, |
| help='Micro-batch size in tokens for grad accum (default: 65536)') |
| parser.add_argument('--lr', type=float, default=3e-4) |
| parser.add_argument('--warmup-steps', type=int, default=800) |
| parser.add_argument('--weight-decay', type=float, default=0.1) |
| parser.add_argument('--grad-clip', type=float, default=1.0) |
| parser.add_argument('--log-every', type=int, default=100) |
| parser.add_argument('--save-every', type=int, default=2000) |
| parser.add_argument('--val-batches', type=int, default=50) |
| args = parser.parse_args() |
| train(args) |
|
|