MLX LLaDA2.0-Uni
An MLX port of inclusionAI/LLaDA2.0-Uni β a 16B MoE unified multimodal diffusion LLM, running natively on Apple Silicon.
LLaDA2 doesn't generate left-to-right. It's a diffusion language model: fill a template with <mask> tokens, then iteratively un-mask them over multiple denoising steps using bidirectional attention β like image diffusion models but over discrete tokens. Images are represented as in-vocabulary VQ tokens (offset at index 157184), so the backbone is just a sequence-in / logits-out transformer that does text chat, image understanding (VQA), and text-to-image in one model.
![]() A bowl of fresh strawberries on a wooden table, morning light, photorealistic (512Γ512, cfg 4.0, 8 MLX Γ 9 blocks β 25-step decoder) |
![]() A medieval castle on a hill at sunset, oil painting style (same settings) |
(Generated at 512Γ512 for speed. LLaDA2 was designed for 1024Γ1024 β running at the native 32Γ32 token grid produces significantly sharper results but each generation takes several hours on M4 Pro. The 16Γ16 grid at 512Γ512 works well for still-life subjects; complex subjects like cat faces can show mild distortion at this reduced resolution.)
Why this is interesting
llama.cpp doesn't support diffusion LLMs β every architecture there is a hand-coded graph assuming causal attention + next-token sampling. LLaDA2 needs bidirectional attention and a multi-step denoising loop, so GGUF was not an option. MLX is a general-purpose tensor framework: you write the forward pass and the denoising loop in Python, and it runs.
Quick Start
git clone https://huggingface.co/treadon/mlx-llada2-uni
cd mlx-llada2-uni
git clone https://github.com/inclusionAI/LLaDA2.0-Uni llada2-uni-repo
python -m venv .venv && source .venv/bin/activate
pip install mlx "transformers==4.51" diffusers torch torchvision huggingface_hub safetensors Pillow accelerate torchdiffeq einops
# Text-only chat
python generate.py --prompt "Name three primary colors."
# Image understanding (VQA)
python image_understand.py --image some_photo.jpg --question "What's in this image?"
# Text-to-image (β18 min on M4 Pro for a 512Γ512)
python t2i.py --prompt "A photorealistic cat sitting on a wooden table"
First run pulls ~45 GB from inclusionAI/LLaDA2.0-Uni: the 32 GB MoE backbone, 12 GB diffusion decoder, 2.4 GB image tokenizer (ViT + VQVAE), and the 170 MB FLUX VAE. Requirements:
- Apple Silicon (tested on M4 Pro 64 GB).
- Python 3.10+.
transformers==4.51specifically β newer versions have hard flash_attn imports that fail on Apple Silicon.
What's in this repo
Just the MLX code + sample outputs. Weights are loaded directly from inclusionAI/LLaDA2.0-Uni β we don't redistribute them.
mlx-llada2-uni/
βββ llada2/
β βββ model.py # 16B MoE backbone (GQA, partial-RoPE, DeepSeek-V2 MoE, bidirectional attention)
β βββ weights.py # HF safetensors β MLX loader (packs per-expert ModuleList into stacked arrays)
β βββ generate.py # Block-diffusion text generation
β βββ generate_image.py # Block-diffusion VQ-token generation with CFG (for t2i)
βββ generate.py # Text CLI
βββ image_understand.py # VQA CLI (hybrid: PyTorch image tokenizer β MLX backbone)
βββ t2i.py # Text-to-image CLI (hybrid: MLX VQ gen β PyTorch decoder)
βββ samples/ # Example outputs
Architecture
Backbone (runs in MLX)
| Field | Value |
|---|---|
| Parameters | 16B total, ~1B active per token |
| hidden_size | 2048 |
| layers | 20 (1 dense + 19 MoE) |
| attention | 16Q / 4KV (GQA), head_dim 128, QK RMSNorm, bidirectional |
| RoPE | partial (rotates first 64 of 128 dims), ΞΈ=600000 |
| MoE | 256 routed + 1 shared, top-8 per token, DeepSeek-V2 group-limited (n_group=8, topk_group=4) |
| moe_intermediate_size | 512 |
| vocab | 173568 (text + 16384-entry VQ codebook at offset 157184) |
Pipeline (hybrid)
| Component | Size | Where |
|---|---|---|
| LLaDA2 MoE backbone | 32 GB bf16 | MLX (GPU) |
| Image tokenizer (ViT + VQVAE) | 2.4 GB | PyTorch CPU |
| SigVQ prior | 400 MB fp32 | PyTorch MPS |
| ZImage decoder (6.2B, 30 layers) | 12 GB bf16 | PyTorch MPS |
| FLUX VAE | 170 MB bf16 | PyTorch MPS |
Porting the 6.2B ZImageTransformer2DModel decoder + ~4B image tokenizer to MLX would take weeks. They run in PyTorch on MPS instead. The heavy lifting (the 16B MoE LLM that does all three tasks) is the MLX-native part.
Conversion notes
| Area | Issue | Fix |
|---|---|---|
| Packed QKV | query_key_value linear outputs (nh + 2Β·nkv)Β·d |
Split along channel, per-head RMSNorm |
| Partial RoPE | HF rotates first head_dim * partial_rotary_factor dims |
Skip concat if non-rotated half is empty |
| MoE gate | Sigmoid + expert_bias + group-limited top-k |
argpartition(-scores, k-1)[:, :k] for k-largest indices |
| MoE dispatch | Need gather [N_slots, H_moe, H] weights |
Chunked at 512 slots to cap transient memory |
| Packed experts | HF stores per-expert ModuleList | Stack into [E, out, in] at load time |
| Bidirectional mask | HF uses bool tril-block-diag mask | Additive fp32 mask with βinf off-block |
| CFG uncond padding | Bug β fixed | Separate forward passes; uncond gets its own attn_mask (pads masked) and position_ids (zeros on pad, start-at-0 for real tokens) |
| Decoder fp64 | MPS doesn't support fp64 | Patched decoder/transport/transport.py to use fp32 |
flash_attn import |
Not available on Apple Silicon | Stub sys.modules['flash_attn'] + downgrade transformers to 4.51 |
The interesting bug
The cond/uncond CFG paths had different prompt lengths (cond=36, uncond=30 for a typical t2i prompt). My first version shared the same attn_mask and position_ids between them, meaning the uncond path:
- Attended to the left-pad
<mask>tokens as if they were content. - Saw its real uncond tokens at RoPE positions shifted by
pad_len(so the model sees<uncondition>at position 6 rather than position 0, where it was trained).
CFG computes logits_uncond + 4Β·(logits_cond β logits_uncond), so small errors in the uncond direction got amplified 4Γ into visible vertical chromatic stripes in the decoded image.
Fix: separate forward passes for cond and uncond so each path gets its own attention mask (pads masked to βinf) and position_ids (real tokens start at RoPE position 0, matching training).
Performance (M4 Pro, 64 GB)
Text
- Short Q&A (e.g. "what is 2+2?") β 15s for 32 tokens.
- Load time: 5.9s warm cache / 8s cold.
Image understanding
- 512Γ512 input β 1024 VQ tokens + 10 question tokens β MLX backbone β 48 gen tokens β 90s end-to-end.
Text-to-image (512Γ512)
| Decoder | Decode steps | Decode time | Quality |
|---|---|---|---|
decoder-turbo (distilled) |
8 | ~45s | Stripes on our tokens (brittle to small VQ noise) |
decoder/ (full) |
50 | ~8 min | β Clean, production-quality |
decoder/ (full) |
25 | ~4 min | β Clean (tested) |
MLX VQ generation itself: ~10 min (8 steps Γ 9 blocks, CFG scale 4.0).
Default CLI: 50 steps for quality. Use --decoder-steps 25 to halve the decode time with no visible loss.
Links
- Original model: inclusionAI/LLaDA2.0-Uni
- Paper: arXiv:2604.20796
- Apple MLX: github.com/ml-explore/mlx
- Built by @treadon

