SmartCoderMoE / modeling_smartcoder_moe.py
Johnblick187's picture
Update modeling_smartcoder_moe.py
f10d44f verified
# modeling_smartcoder_moe.py
#Architecture (from tensor inspection):
#- vocab_size: 65536, hidden: 2048, layers: 40
#- Attention: q[2048,2048], k/v[512,2048] - 16 heads, 4 KV heads, head_dim=128
#- MLP (hybrid dense + MoE):
# dense_fc: [8192, 2048] up
# dense_proj: [2048, 8192] down
# experts_fc: [32, 512, 2048] expert up (batched)
# experts_proj: [32, 2048, 512] expert down (batched)
# router: [32, 2048] router logits
#- LayerNorm: weight+bias (input_layernorm, post_attention_layernorm)
#- Final norm: model.norm.weight/bias
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
# ── Config ────────────────────────────────────────────────────────────────────
class SmartCoderMoEConfig(PretrainedConfig):
model_type = "smartcoder_moe"
def __init__(
self,
vocab_size=65536,
hidden_size=2048,
num_hidden_layers=40,
num_attention_heads=16,
num_key_value_heads=4,
dense_intermediate_size=8192,
num_experts=32,
expert_intermediate_size=512,
num_experts_per_tok=2,
max_position_embeddings=16384,
rope_theta=10000.0,
rms_norm_eps=1e-5,
pad_token_id=0,
bos_token_id=1,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
self.dense_intermediate_size = dense_intermediate_size
self.num_experts = num_experts
self.expert_intermediate_size = expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rms_norm_eps = rms_norm_eps
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# ── RoPE ──────────────────────────────────────────────────────────────────────
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), \
(k * cos) + (rotate_half(k) * sin)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_pos=16384, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._cached_len = 0
def _build_cache(self, seq_len, device):
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, self.inv_freq.to(device))
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
self._cached_len = seq_len
def forward(self, seq_len, device):
if seq_len > self._cached_len:
self._build_cache(seq_len, device)
return self.cos_cached[:, :, :seq_len, :], \
self.sin_cached[:, :, :seq_len, :]
# ── LayerNorm with bias ───────────────────────────────────────────────────────
class LayerNormWithBias(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.eps = eps
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.weight, self.bias, self.eps)
# ── Attention ─────────────────────────────────────────────────────────────────
class SmartCoderAttention(nn.Module):
def __init__(self, config: SmartCoderMoEConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=True)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=True)
self.rotary_emb = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, hidden_states, attention_mask=None, **kwargs):
B, T, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(T, hidden_states.device)
cos = cos[:, :, :T, :self.head_dim]
sin = sin[:, :, :T, :self.head_dim]
q, k = apply_rotary_emb(q, k, cos, sin)
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
causal = torch.triu(torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype), diagonal=1)
attn = attn + causal.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
attn = attn + attention_mask
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(out)
# ── MoE MLP ───────────────────────────────────────────────────────────────────
class SmartCoderMoEMLP(nn.Module):
def __init__(self, config: SmartCoderMoEConfig):
super().__init__()
H = config.hidden_size
DI = config.dense_intermediate_size
NE = config.num_experts
EI = config.expert_intermediate_size
self.num_experts = NE
self.top_k = config.num_experts_per_tok
self.dense_fc = nn.Linear(H, DI, bias=True)
self.dense_proj = nn.Linear(DI, H, bias=True)
self.experts_fc = nn.Parameter(torch.empty(NE, EI, H))
self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
self.router = nn.Linear(H, NE, bias=False)
def forward(self, x):
B, T, H = x.shape
dense_out = self.dense_proj(F.gelu(self.dense_fc(x)))
router_logits = self.router(x)
router_weights = F.softmax(router_logits, dim=-1)
top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
expert_out = torch.zeros_like(x)
x_flat = x.view(B * T, H)
for k in range(self.top_k):
expert_ids = top_indices[:, :, k].reshape(B * T)
weights = top_weights[:, :, k].reshape(B * T, 1)
fc_w = self.experts_fc[expert_ids]
proj_w = self.experts_proj[expert_ids]
hidden = F.gelu(torch.bmm(fc_w, x_flat.unsqueeze(-1)).squeeze(-1))
out = torch.bmm(proj_w, hidden.unsqueeze(-1)).squeeze(-1)
expert_out = expert_out + (out * weights).view(B, T, H)
return dense_out + expert_out
# ── Decoder Layer ─────────────────────────────────────────────────────────────
class SmartCoderDecoderLayer(nn.Module):
def __init__(self, config: SmartCoderMoEConfig):
super().__init__()
self.input_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
self.self_attn = SmartCoderAttention(config)
self.post_attention_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
self.mlp = SmartCoderMoEMLP(config)
def forward(self, hidden_states, attention_mask=None, **kwargs):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# ── Model ─────────────────────────────────────────────────────────────────────
class SmartCoderMoEModel(nn.Module):
def __init__(self, config: SmartCoderMoEConfig):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
def forward(self, input_ids, attention_mask=None, **kwargs):
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask=attention_mask)
return self.norm(hidden_states)
# ── CausalLM ──────────────────────────────────────────────────────────────────
class SmartCoderMoEForCausalLM(PreTrainedModel, GenerationMixin):
config_class = SmartCoderMoEConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
def __init__(self, config: SmartCoderMoEConfig):
super().__init__(config)
self.model = SmartCoderMoEModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
remapped = {}
for k, v in state_dict.items():
k = k.replace('experts_fc.weight', 'experts_fc')
k = k.replace('experts_proj.weight', 'experts_proj')
remapped[k] = v
super()._load_from_state_dict(remapped, prefix, *args, **kwargs)
def get_input_embeddings(self): return self.model.embed_tokens
def get_output_embeddings(self): return self.lm_head
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
**kwargs,
):
hidden_states = self.model(input_ids, attention_mask=attention_mask)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
# ── Loader ────────────────────────────────────────────────────────────────────
def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloat16):
import os
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from pathlib import Path
os.environ["HF_HUB_DISABLE_XET"] = "1"
print(f"Downloading {model_id}...")
model_dir = snapshot_download(model_id)
config = SmartCoderMoEConfig()
print("Initializing model...")
model = SmartCoderMoEForCausalLM(config)
print("Loading weights...")
sf_files = sorted(Path(model_dir).glob("*.safetensors"))
state_dict = {}
for f in sf_files:
state_dict.update(load_file(str(f)))
# Remap expert keys — safetensors has .weight suffix, our params don't
remapped = {}
for k, v in state_dict.items():
if 'experts_fc.weight' in k:
remapped[k.replace('experts_fc.weight', 'experts_fc')] = v
elif 'experts_proj.weight' in k:
remapped[k.replace('experts_proj.weight', 'experts_proj')] = v
else:
remapped[k] = v
state_dict = remapped
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Missing: {missing[:3]}{'...' if len(missing)>3 else ''}")
if unexpected:
print(f"Unexpected: {unexpected[:3]}{'...' if len(unexpected)>3 else ''}")
model = model.to(dtype)
print(f"Loaded! Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
return model, config
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("smartcoder_moe", SmartCoderMoEConfig)
AutoModelForCausalLM.register(SmartCoderMoEConfig, SmartCoderMoEForCausalLM)