Kernels
optimizer / docs /optimizations.md
wyldecat's picture
Add optimization docs and update implementation guide [skip-build]
14040eb
# Performance Optimizations (vs. main)
Summary of optimizations on branch `perf/pipelined-distributed-muon-clean` relative to `main`.
---
## 1. Batched Momentum (`core.py`)
**Before:** Per-param `update_g()` β€” one `torch.add` + optional `torch.add_` per parameter.
**After:** `_batch_pre_ortho()` β€” `_foreach_mul_`, `_foreach_add_` on lists of local tensors (unwrapped from DTensor). Single fused kernel per batch instead of N individual kernels.
**Impact:** Eliminates N per-param Python-loop overhead + N small kernel launches. Scales with parameter count.
---
## 2. Pipeline Buffer Packing (`pipeline.py`)
### Gather send buffer
**Before:** Per-param `.to(COMM_DTYPE).contiguous()` followed by per-destination `append` to list, then `torch.cat` on the per-dst lists.
**After:** Collect all grad slices in destination order in a single pass, then one `torch.cat` call. Avoids intermediate per-destination lists and redundant dtype conversions.
### Scatter send buffer
**Before:** Per-param, per-destination-rank: index `u_full[indices].flatten()`, append to per-dst list, then flatten+cat.
**After:** Cache `u_full` conversions (avoid redundant `.to()` per dst_rank). Collect all slices in dst order in one pass, single `torch.cat`.
**Impact:** Fewer kernel launches, less Python overhead, reduced intermediate allocations.
---
## 3. Zero-Copy Scatter (`pipeline.py`)
**Before:** `_launch_scatter` pre-allocates `torch.empty_like(p.to_local())` for every param. `_complete_scatter` copies from recv_buf into these pre-allocated tensors via `copy_()`.
**After:** `_complete_scatter` assigns **views** into `recv_buf` directly (via `recv_buf.narrow(...).view_as(...)`). No pre-allocation, no copy. The recv_buf storage stays alive through the views until `_update_params` consumes them.
**Impact:** Eliminates N `empty_like` allocations + N `copy_` kernel launches per scatter stage.
---
## 4. Batched Parameter Update (`pipeline.py`)
**Before:** Per-param loop calling `update_p()` (which unwraps DTensor, applies weight decay, applies update individually).
**After:** Batched using `_foreach_mul_` (weight decay) and `_foreach_add_` (Muon update), grouped by `adjusted_lr` to preserve float32 alpha precision. Single kernel per group instead of per param.
**Impact:** Reduces N per-param kernel launches to 1-2 batched kernel launches.
---
## 5. Parallel Metadata Caching (`muon.py`)
**Before:** `init_state_and_assign_params()` called every step β€” sorts params by FLOP cost, assigns ownership via round-robin, precomputes per-rank indices/numels for all-to-all.
**After:** `_parallel_cache` keyed by `tuple(names)`. First call computes and caches `ordered_names`, `name_to_state`, `rank`, `chunk_size`. Subsequent calls reuse cached metadata, only rebuilding `param_to_state` with current `id(p)` keys (since param objects are stable but ids may change for QK clip updates).
**Impact:** Eliminates repeated sorting, mesh construction, and index precomputation on every step.
---
## 6. Expert Param Expansion Caching (`muon.py`)
**Before:** `_expand_expert_params()` called every step β€” for each expert param `(E, out, in)`, creates E `nn.Parameter` wrappers (triggers `aten::detach`), indexes data and grad (`aten::select`), and wraps in DTensor for TP.
**After:** `_expert_expand_cache` keyed by `tuple(id(p) for p in params)`. Cold path runs `_expand_expert_params` once and caches:
- `expanded_names` / `expanded_params` β€” the nn.Parameter wrappers with stable data views
- `grad_info` β€” per-expert-group metadata (orig param index, num experts, expanded start index, DTensor flag, TP mesh/placements)
Hot path reuses cached nn.Parameter objects (data views are stable since optimizer updates happen in-place on the same storage). Only updates `.grad` on each cached expert param by slicing the current step's gradient.
**Eliminated on hot path:**
- `nn.Parameter()` construction β€” removes `aten::detach`
- `local_data[i]` data slicing β€” removes half of `aten::select` + `aten::as_strided`
- `DTensor.from_local()` for data β€” only needed for grad now
- `is_expert_param()` name matching per step
**Still required per step:**
- `local_grad[i]` β€” grad tensor changes each step (nesterov)
- `DTensor.from_local(slice_grad, ...)` β€” for TP expert grads
- `p.grad = None` β€” freeing original 3D grad storage
**Impact:** ~8ms CPU overhead reduction per step at production scale (64 GPUs, 48 local experts).
---
## 7. Newton-Schulz Compile + CUDA Graph (`newton_schulz.py`)
**Before:** `_zeropower_via_newtonschulz5()` called directly every time.
**After:** `zeropower_via_newtonschulz5()` wrapper with per-shape `torch.compile` caching + CUDA graph (`triton.cudagraphs=True`). Each unique shape gets its own compiled function stored in `_ns_per_shape`. Toggled via `set_ns_compile(enabled)`.
**Impact:** After warmup, NS iterations run as CUDA graphs β€” eliminates per-step compilation overhead and CPU-GPU synchronization.
---
## 8. Removed `small_param_numel_threshold` (`muon.py`)
**Before:** Small sharded DTensors (below threshold, default 65536) fell back to `distributed_muon()` which used per-param `full_tensor()` + redistribute.
**After:** All sharded DTensors go to `parallel()`. `distributed_muon()` is retained as a test-only reference implementation. Uneven shard splits (e.g., MoE gate weights with fewer rows than shard ranks) are handled inline via `full_tensor()` fallback within the batched distributed_muon path.
**Impact:** Simpler routing, no silent fallback to slower path.
---
## Summary Table
| Optimization | Location | Category | Kernel Launches Saved |
|---|---|---|---|
| Batched momentum | `core.py` | CPU + GPU | N per-param β†’ 2-3 batched |
| Buffer packing (gather) | `pipeline.py` | CPU + GPU | N cat+cast β†’ 1 cat+cast |
| Buffer packing (scatter) | `pipeline.py` | CPU + GPU | N cat β†’ 1 cat |
| Zero-copy scatter | `pipeline.py` | GPU memory | N alloc+copy β†’ 0 |
| Batched param update | `pipeline.py` | CPU + GPU | N update β†’ 1-2 batched |
| Parallel metadata cache | `muon.py` | CPU | Sort+index per step β†’ once |
| Expert expand cache | `muon.py` | CPU | N detach+select β†’ grad-only |
| NS compile + CUDA graph | `newton_schulz.py` | GPU | JIT warmup β†’ graph replay |
| Remove small_param_threshold | `muon.py` | Routing | Simpler, unified path |