# modeling_smartcoder_moe.py #Architecture (from tensor inspection): #- vocab_size: 65536, hidden: 2048, layers: 40 #- Attention: q[2048,2048], k/v[512,2048] - 16 heads, 4 KV heads, head_dim=128 #- MLP (hybrid dense + MoE): # dense_fc: [8192, 2048] up # dense_proj: [2048, 8192] down # experts_fc: [32, 512, 2048] expert up (batched) # experts_proj: [32, 2048, 512] expert down (batched) # router: [32, 2048] router logits #- LayerNorm: weight+bias (input_layernorm, post_attention_layernorm) #- Final norm: model.norm.weight/bias import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast # ── Config ──────────────────────────────────────────────────────────────────── class SmartCoderMoEConfig(PretrainedConfig): model_type = "smartcoder_moe" def __init__( self, vocab_size=65536, hidden_size=2048, num_hidden_layers=40, num_attention_heads=16, num_key_value_heads=4, dense_intermediate_size=8192, num_experts=32, expert_intermediate_size=512, num_experts_per_tok=2, max_position_embeddings=16384, rope_theta=10000.0, rms_norm_eps=1e-5, pad_token_id=0, bos_token_id=1, eos_token_id=0, tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = hidden_size // num_attention_heads self.dense_intermediate_size = dense_intermediate_size self.num_experts = num_experts self.expert_intermediate_size = expert_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rms_norm_eps = rms_norm_eps super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) # ── RoPE ────────────────────────────────────────────────────────────────────── def rotate_half(x): x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] return torch.cat([-x2, x1], dim=-1) def apply_rotary_emb(q, k, cos, sin): return (q * cos) + (rotate_half(q) * sin), \ (k * cos) + (rotate_half(k) * sin) class RotaryEmbedding(nn.Module): def __init__(self, dim, max_pos=16384, base=10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._cached_len = 0 def _build_cache(self, seq_len, device): t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, self.inv_freq.to(device)) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self._cached_len = seq_len def forward(self, seq_len, device): if seq_len > self._cached_len: self._build_cache(seq_len, device) return self.cos_cached[:, :, :seq_len, :], \ self.sin_cached[:, :, :seq_len, :] # ── LayerNorm with bias ─────────────────────────────────────────────────────── class LayerNormWithBias(nn.Module): def __init__(self, hidden_size, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.eps = eps def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.weight, self.bias, self.eps) # ── Attention ───────────────────────────────────────────────────────────────── class SmartCoderAttention(nn.Module): def __init__(self, config: SmartCoderMoEConfig): super().__init__() self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.num_kv_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=True) self.rotary_emb = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta) def forward(self, hidden_states, attention_mask=None, **kwargs): B, T, _ = hidden_states.shape q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(T, hidden_states.device) cos = cos[:, :, :T, :self.head_dim] sin = sin[:, :, :T, :self.head_dim] q, k = apply_rotary_emb(q, k, cos, sin) k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) causal = torch.triu(torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype), diagonal=1) attn = attn + causal.unsqueeze(0).unsqueeze(0) if attention_mask is not None: attn = attn + attention_mask attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(out) # ── MoE MLP ─────────────────────────────────────────────────────────────────── class SmartCoderMoEMLP(nn.Module): def __init__(self, config: SmartCoderMoEConfig): super().__init__() H = config.hidden_size DI = config.dense_intermediate_size NE = config.num_experts EI = config.expert_intermediate_size self.num_experts = NE self.top_k = config.num_experts_per_tok self.dense_fc = nn.Linear(H, DI, bias=True) self.dense_proj = nn.Linear(DI, H, bias=True) self.experts_fc = nn.Parameter(torch.empty(NE, EI, H)) self.experts_proj = nn.Parameter(torch.empty(NE, H, EI)) self.router = nn.Linear(H, NE, bias=False) def forward(self, x): B, T, H = x.shape dense_out = self.dense_proj(F.gelu(self.dense_fc(x))) router_logits = self.router(x) router_weights = F.softmax(router_logits, dim=-1) top_weights, top_indices = router_weights.topk(self.top_k, dim=-1) top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) expert_out = torch.zeros_like(x) x_flat = x.view(B * T, H) for k in range(self.top_k): expert_ids = top_indices[:, :, k].reshape(B * T) weights = top_weights[:, :, k].reshape(B * T, 1) fc_w = self.experts_fc[expert_ids] proj_w = self.experts_proj[expert_ids] hidden = F.gelu(torch.bmm(fc_w, x_flat.unsqueeze(-1)).squeeze(-1)) out = torch.bmm(proj_w, hidden.unsqueeze(-1)).squeeze(-1) expert_out = expert_out + (out * weights).view(B, T, H) return dense_out + expert_out # ── Decoder Layer ───────────────────────────────────────────────────────────── class SmartCoderDecoderLayer(nn.Module): def __init__(self, config: SmartCoderMoEConfig): super().__init__() self.input_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps) self.self_attn = SmartCoderAttention(config) self.post_attention_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps) self.mlp = SmartCoderMoEMLP(config) def forward(self, hidden_states, attention_mask=None, **kwargs): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states # ── Model ───────────────────────────────────────────────────────────────────── class SmartCoderMoEModel(nn.Module): def __init__(self, config: SmartCoderMoEConfig): super().__init__() self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps) def forward(self, input_ids, attention_mask=None, **kwargs): hidden_states = self.embed_tokens(input_ids) for layer in self.layers: hidden_states = layer(hidden_states, attention_mask=attention_mask) return self.norm(hidden_states) # ── CausalLM ────────────────────────────────────────────────────────────────── class SmartCoderMoEForCausalLM(PreTrainedModel, GenerationMixin): config_class = SmartCoderMoEConfig base_model_prefix = "model" supports_gradient_checkpointing = False def __init__(self, config: SmartCoderMoEConfig): super().__init__(config) self.model = SmartCoderMoEModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): remapped = {} for k, v in state_dict.items(): k = k.replace('experts_fc.weight', 'experts_fc') k = k.replace('experts_proj.weight', 'experts_proj') remapped[k] = v super()._load_from_state_dict(remapped, prefix, *args, **kwargs) def get_input_embeddings(self): return self.model.embed_tokens def get_output_embeddings(self): return self.lm_head def forward( self, input_ids=None, attention_mask=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, **kwargs, ): hidden_states = self.model(input_ids, attention_mask=attention_mask) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} # ── Loader ──────────────────────────────────────────────────────────────────── def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloat16): import os from huggingface_hub import snapshot_download from safetensors.torch import load_file from pathlib import Path os.environ["HF_HUB_DISABLE_XET"] = "1" print(f"Downloading {model_id}...") model_dir = snapshot_download(model_id) config = SmartCoderMoEConfig() print("Initializing model...") model = SmartCoderMoEForCausalLM(config) print("Loading weights...") sf_files = sorted(Path(model_dir).glob("*.safetensors")) state_dict = {} for f in sf_files: state_dict.update(load_file(str(f))) # Remap expert keys — safetensors has .weight suffix, our params don't remapped = {} for k, v in state_dict.items(): if 'experts_fc.weight' in k: remapped[k.replace('experts_fc.weight', 'experts_fc')] = v elif 'experts_proj.weight' in k: remapped[k.replace('experts_proj.weight', 'experts_proj')] = v else: remapped[k] = v state_dict = remapped missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f"Missing: {missing[:3]}{'...' if len(missing)>3 else ''}") if unexpected: print(f"Unexpected: {unexpected[:3]}{'...' if len(unexpected)>3 else ''}") model = model.to(dtype) print(f"Loaded! Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B") return model, config from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("smartcoder_moe", SmartCoderMoEConfig) AutoModelForCausalLM.register(SmartCoderMoEConfig, SmartCoderMoEForCausalLM)