resonance / train.py
ataeff's picture
Add train.py
413b3cd verified
"""
Training script for Resonance 200M.
ClimbMix data, own BPE tokenizer (Rust backend), AdamW optimizer.
Shows BOTH train loss AND val loss. Always.
"""
import os
import sys
import time
import math
import struct
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from model import Resonance, RESONANCE_200M
from bpe_tokenizer import BPETokenizer
# ─────────────────────────────────────────────────────────────────────────────
# Data
# ─────────────────────────────────────────────────────────────────────────────
def download_climbmix_shards(data_dir, n_shards=100):
"""Download ClimbMix parquet shards from HuggingFace."""
os.makedirs(data_dir, exist_ok=True)
try:
import pyarrow.parquet as pq
except ImportError:
print("pip install pyarrow pandas")
sys.exit(1)
base_url = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
texts_path = os.path.join(data_dir, "texts.txt")
if os.path.exists(texts_path):
size = os.path.getsize(texts_path)
print(f" [Data] texts.txt exists ({size/1e9:.2f} GB), skipping download")
return texts_path
import urllib.request
import ssl
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
total_bytes = 0
with open(texts_path, 'w', encoding='utf-8') as out:
for i in range(n_shards):
shard_name = f"shard_{i:05d}.parquet"
shard_path = os.path.join(data_dir, shard_name)
url = f"{base_url}/{shard_name}"
if not os.path.exists(shard_path):
print(f" [Data] Downloading shard {i+1}/{n_shards}...", end=" ", flush=True)
try:
urllib.request.urlretrieve(url, shard_path)
print("OK")
except Exception as e:
print(f"FAIL: {e}")
continue
# Extract text
try:
table = pq.read_table(shard_path, columns=['text'])
texts = table.column('text').to_pylist()
for text in texts:
if text and len(text) > 100:
out.write(text + '\n')
total_bytes += len(text)
# Remove parquet to save disk
os.remove(shard_path)
except Exception as e:
print(f" [Data] Error reading shard {i}: {e}")
continue
if (i + 1) % 10 == 0:
print(f" [Data] {i+1}/{n_shards} shards, {total_bytes/1e9:.2f} GB text")
print(f" [Data] Total: {total_bytes/1e9:.2f} GB text from {n_shards} shards")
return texts_path
def tokenize_data(texts_path, tokenizer, data_dir, context_len):
"""Tokenize text into binary shards (uint16 for vocab < 65536).
Streams to disk β€” no OOM on 16GB+ corpora."""
train_path = os.path.join(data_dir, "train.bin")
val_path = os.path.join(data_dir, "val.bin")
if os.path.exists(train_path) and os.path.exists(val_path):
train_tokens = os.path.getsize(train_path) // 2
val_tokens = os.path.getsize(val_path) // 2
print(f" [Data] Tokenized data exists: train={train_tokens:,} val={val_tokens:,}")
return train_tokens, val_tokens
print(f" [Data] Tokenizing...")
tmp_path = os.path.join(data_dir, "tokens_all.bin")
total_tokens = 0
t0 = time.time()
with open(texts_path, 'r', encoding='utf-8', errors='replace') as f_in, \
open(tmp_path, 'wb') as f_out:
chunk_size = 10_000_000 # 10MB chunks
total_chars = 0
while True:
text = f_in.read(chunk_size)
if not text:
break
ids = tokenizer.encode(text)
arr = np.array(ids, dtype=np.uint16)
f_out.write(arr.tobytes())
total_tokens += len(ids)
total_chars += len(text)
if total_chars % 100_000_000 < chunk_size:
elapsed = time.time() - t0
rate = total_chars / elapsed / 1e6
print(f" [Data] {total_chars/1e9:.2f} GB text β†’ {total_tokens:,} tokens "
f"({rate:.1f} MB/s, {elapsed:.0f}s)")
elapsed = time.time() - t0
print(f" [Data] Tokenized {total_chars/1e9:.2f} GB β†’ {total_tokens:,} tokens in {elapsed:.0f}s")
# Split 95/5 train/val β€” stream from memmap to avoid loading all into RAM
split = int(total_tokens * 0.95)
print(f" [Data] Splitting: train={split:,} val={total_tokens - split:,}")
all_data = np.memmap(tmp_path, dtype=np.uint16, mode='r')
# Write train split in chunks
chunk = 50_000_000 # 50M tokens per chunk
with open(train_path, 'wb') as f:
for start in range(0, split, chunk):
end = min(start + chunk, split)
f.write(all_data[start:end].tobytes())
# Write val split
with open(val_path, 'wb') as f:
for start in range(split, total_tokens, chunk):
end = min(start + chunk, total_tokens)
f.write(all_data[start:end].tobytes())
del all_data
os.remove(tmp_path)
train_tokens = split
val_tokens = total_tokens - split
print(f" [Data] train: {train_tokens:,} tokens ({train_tokens*2/1e9:.2f} GB)")
print(f" [Data] val: {val_tokens:,} tokens ({val_tokens*2/1e9:.2f} GB)")
return train_tokens, val_tokens
class DataLoader:
"""Simple random-chunk dataloader from mmap'd binary file."""
def __init__(self, path, context_len, batch_size, device):
self.data = np.memmap(path, dtype=np.uint16, mode='r')
self.context_len = context_len
self.batch_size = batch_size
self.device = device
self.n_tokens = len(self.data)
def get_batch(self):
T = self.context_len
B = self.batch_size
ix = torch.randint(0, self.n_tokens - T - 1, (B,))
x = torch.stack([torch.from_numpy(self.data[i:i+T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(self.data[i+1:i+T+1].astype(np.int64)) for i in ix])
return x.to(self.device), y.to(self.device)
# ─────────────────────────────────────────────────────────────────────────────
# Training
# ─────────────────────────────────────────────────────────────────────────────
def get_lr(step, warmup_steps, total_steps, max_lr, min_lr=0.0):
"""WSD schedule: warmup β†’ stable β†’ linear decay."""
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
decay_start = total_steps // 2
if step < decay_start:
return max_lr
# Linear decay
progress = (step - decay_start) / (total_steps - decay_start)
return max_lr * (1.0 - progress) + min_lr * progress
@torch.no_grad()
def evaluate(model, val_loader, n_batches=50):
"""Evaluate val loss. Returns average loss."""
model.eval()
losses = []
for _ in range(n_batches):
x, y = val_loader.get_batch()
with autocast('cuda', dtype=torch.bfloat16):
_, loss = model(x, y)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses)
def save_checkpoint(model, optimizer, step, train_loss, val_loss, config, path):
"""Save PyTorch checkpoint."""
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step,
'train_loss': train_loss,
'val_loss': val_loss,
'config': config,
}, path)
def save_c_weights(model, tokenizer, config, path):
"""Save weights in C-compatible binary format for resonance-bpe.c."""
with open(path, 'wb') as f:
# Header: magic + config
f.write(struct.pack('<I', 0x52533032)) # "RS02"
f.write(struct.pack('<9i',
config['n_embd'], config['n_layer'], config['context_len'],
config['n_head'], config['head_dim'], config['rrpram_rank'],
config['ffn_dim'], config['vocab_size'], config['n_head'])) # kv_heads = n_head (MHA)
# BPE merges
f.write(struct.pack('<I', len(tokenizer.merges)))
for a, b, new_id in tokenizer.merges:
f.write(struct.pack('<III', a, b, new_id))
# All parameters in order
for name, param in model.named_parameters():
data = param.detach().float().cpu().numpy()
f.write(data.tobytes())
size_mb = os.path.getsize(path) / 1e6
print(f" [Save] C weights: {path} ({size_mb:.1f} MB)")
def train(args):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
config = RESONANCE_200M.copy()
if args.vocab_size:
config['vocab_size'] = args.vocab_size
data_dir = args.data_dir
os.makedirs(data_dir, exist_ok=True)
os.makedirs(args.save_dir, exist_ok=True)
# Step 1: Download ClimbMix
print("\n[1] Data...")
texts_path = download_climbmix_shards(data_dir, n_shards=args.n_shards)
# Step 2: Train BPE tokenizer
print("\n[2] BPE tokenizer...")
tok_path = os.path.join(args.save_dir, "tokenizer.bin")
tokenizer = BPETokenizer(max_merges=config['vocab_size'] - 256)
if os.path.exists(tok_path):
tokenizer.load(tok_path)
else:
# Train on first 200MB of text
with open(texts_path, 'rb') as f:
sample = f.read(200_000_000)
tokenizer.train(sample, num_merges=config['vocab_size'] - 256, report_every=2000)
tokenizer.save_copies(tok_path, n=3)
config['vocab_size'] = tokenizer.vocab_size
# Step 3: Tokenize data
print("\n[3] Tokenizing data...")
n_train, n_val = tokenize_data(texts_path, tokenizer, data_dir, config['context_len'])
# Step 4: Build model
print("\n[4] Model...")
model = Resonance(config).to(device)
model.set_gradient_checkpointing(True)
model = torch.compile(model)
print(f" Gradient checkpointing: ON, torch.compile: ON")
# Step 5: Optimizer
print("\n[5] Optimizer...")
# Separate param groups: decay vs no-decay
decay_params = []
no_decay_params = []
for name, p in model.named_parameters():
if p.dim() >= 2:
decay_params.append(p)
else:
no_decay_params.append(p)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': args.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0},
], lr=args.lr, betas=(0.9, 0.95), eps=1e-8)
scaler = GradScaler('cuda')
# Step 6: Data loaders (micro-batch for gradient accumulation)
T = config['context_len']
micro_B = args.micro_batch // T # sequences per micro-batch
grad_accum = args.batch_size // args.micro_batch
print(f"\n[6] DataLoader: effective_batch={args.batch_size} tokens "
f"({grad_accum} x {args.micro_batch} micro), {micro_B} seq x {T} ctx")
train_loader = DataLoader(os.path.join(data_dir, "train.bin"), T, micro_B, device)
val_loader = DataLoader(os.path.join(data_dir, "val.bin"), T, micro_B, device)
total_steps = n_train // args.batch_size
print(f" Total steps: {total_steps:,}")
# Step 7: Train loop
print(f"\n[7] Training resonance-200m...")
print(f" {'step':>8} | {'train_loss':>10} | {'val_loss':>10} | {'lr':>10} | {'tok/s':>10} | {'time':>8}")
print(" " + "-" * 75)
best_val_loss = float('inf')
running_loss = 0.0
t0 = time.time()
tokens_seen = 0
model.train()
for step in range(total_steps):
# LR schedule
lr = get_lr(step, args.warmup_steps, total_steps, args.lr)
for pg in optimizer.param_groups:
pg['lr'] = lr
# Gradient accumulation: grad_accum micro-batches per optimizer step
optimizer.zero_grad(set_to_none=True)
step_loss = 0.0
for micro_step in range(grad_accum):
x, y = train_loader.get_batch()
with autocast('cuda', dtype=torch.bfloat16):
_, loss = model(x, y)
loss = loss / grad_accum
scaler.scale(loss).backward()
step_loss += loss.item() * grad_accum
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
train_loss = step_loss / grad_accum
running_loss += train_loss
tokens_seen += args.batch_size
# Log every N steps
if (step + 1) % args.log_every == 0:
avg_train = running_loss / args.log_every
running_loss = 0.0
elapsed = time.time() - t0
tok_per_sec = tokens_seen / elapsed
# Val loss
val_loss = evaluate(model, val_loader, n_batches=args.val_batches)
print(f" {step+1:>8} | {avg_train:>10.4f} | {val_loss:>10.4f} | "
f"{lr:>10.2e} | {tok_per_sec/1000:>8.1f}k | {elapsed:>7.0f}s")
# Save best
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(model, optimizer, step, avg_train, val_loss, config,
os.path.join(args.save_dir, "best.pt"))
# Checkpoint every N steps
if (step + 1) % args.save_every == 0:
save_checkpoint(model, optimizer, step, train_loss, val_loss if 'val_loss' in dir() else 0,
config, os.path.join(args.save_dir, f"step_{step+1}.pt"))
save_c_weights(model, tokenizer, config,
os.path.join(args.save_dir, f"resonance_200m_step{step+1}.bin"))
# Gate monitoring every N steps
if (step + 1) % (args.log_every * 5) == 0:
gates = []
for block in model._orig_mod.blocks if hasattr(model, '_orig_mod') else model.blocks:
g = torch.sigmoid(block.gate).detach().cpu().numpy()
gates.append(g.mean())
gate_str = " ".join(f"{g:.2f}" for g in gates)
print(f" [gates] {gate_str}")
# Final save
elapsed = time.time() - t0
print(f"\n Training complete. {elapsed/3600:.1f} hours, {tokens_seen:,} tokens")
save_checkpoint(model, optimizer, total_steps, train_loss, best_val_loss, config,
os.path.join(args.save_dir, "final.pt"))
save_c_weights(model, tokenizer, config,
os.path.join(args.save_dir, "resonance_200m_final.bin"))
# Re-save tokenizer (paranoia)
tokenizer.save_copies(os.path.join(args.save_dir, "tokenizer.bin"), n=3)
print(f"\n Best val loss: {best_val_loss:.4f}")
print(f" resonance is unbreakable.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, default='data/')
parser.add_argument('--save-dir', type=str, default='checkpoints/')
parser.add_argument('--n-shards', type=int, default=65,
help='Number of ClimbMix shards to download (~65 for ~4B tokens)')
parser.add_argument('--vocab-size', type=int, default=None,
help='Override vocab size (default: 16384)')
parser.add_argument('--batch-size', type=int, default=131072,
help='Effective batch size in tokens (default: 131072)')
parser.add_argument('--micro-batch', type=int, default=65536,
help='Micro-batch size in tokens for grad accum (default: 65536)')
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--warmup-steps', type=int, default=800)
parser.add_argument('--weight-decay', type=float, default=0.1)
parser.add_argument('--grad-clip', type=float, default=1.0)
parser.add_argument('--log-every', type=int, default=100)
parser.add_argument('--save-every', type=int, default=2000)
parser.add_argument('--val-batches', type=int, default=50)
args = parser.parse_args()
train(args)