MLX LLaDA2.0-Uni

An MLX port of inclusionAI/LLaDA2.0-Uni β€” a 16B MoE unified multimodal diffusion LLM, running natively on Apple Silicon.

LLaDA2 doesn't generate left-to-right. It's a diffusion language model: fill a template with <mask> tokens, then iteratively un-mask them over multiple denoising steps using bidirectional attention β€” like image diffusion models but over discrete tokens. Images are represented as in-vocabulary VQ tokens (offset at index 157184), so the backbone is just a sequence-in / logits-out transformer that does text chat, image understanding (VQA), and text-to-image in one model.


A bowl of fresh strawberries on a wooden table, morning light, photorealistic (512Γ—512, cfg 4.0, 8 MLX Γ— 9 blocks β†’ 25-step decoder)

A medieval castle on a hill at sunset, oil painting style (same settings)

(Generated at 512Γ—512 for speed. LLaDA2 was designed for 1024Γ—1024 β€” running at the native 32Γ—32 token grid produces significantly sharper results but each generation takes several hours on M4 Pro. The 16Γ—16 grid at 512Γ—512 works well for still-life subjects; complex subjects like cat faces can show mild distortion at this reduced resolution.)

Why this is interesting

llama.cpp doesn't support diffusion LLMs β€” every architecture there is a hand-coded graph assuming causal attention + next-token sampling. LLaDA2 needs bidirectional attention and a multi-step denoising loop, so GGUF was not an option. MLX is a general-purpose tensor framework: you write the forward pass and the denoising loop in Python, and it runs.

Quick Start

git clone https://huggingface.co/treadon/mlx-llada2-uni
cd mlx-llada2-uni
git clone https://github.com/inclusionAI/LLaDA2.0-Uni llada2-uni-repo
python -m venv .venv && source .venv/bin/activate
pip install mlx "transformers==4.51" diffusers torch torchvision huggingface_hub safetensors Pillow accelerate torchdiffeq einops

# Text-only chat
python generate.py --prompt "Name three primary colors."

# Image understanding (VQA)
python image_understand.py --image some_photo.jpg --question "What's in this image?"

# Text-to-image (β‰ˆ18 min on M4 Pro for a 512Γ—512)
python t2i.py --prompt "A photorealistic cat sitting on a wooden table"

First run pulls ~45 GB from inclusionAI/LLaDA2.0-Uni: the 32 GB MoE backbone, 12 GB diffusion decoder, 2.4 GB image tokenizer (ViT + VQVAE), and the 170 MB FLUX VAE. Requirements:

  • Apple Silicon (tested on M4 Pro 64 GB).
  • Python 3.10+.
  • transformers==4.51 specifically β€” newer versions have hard flash_attn imports that fail on Apple Silicon.

What's in this repo

Just the MLX code + sample outputs. Weights are loaded directly from inclusionAI/LLaDA2.0-Uni β€” we don't redistribute them.

mlx-llada2-uni/
β”œβ”€β”€ llada2/
β”‚   β”œβ”€β”€ model.py           # 16B MoE backbone (GQA, partial-RoPE, DeepSeek-V2 MoE, bidirectional attention)
β”‚   β”œβ”€β”€ weights.py         # HF safetensors β†’ MLX loader (packs per-expert ModuleList into stacked arrays)
β”‚   β”œβ”€β”€ generate.py        # Block-diffusion text generation
β”‚   └── generate_image.py  # Block-diffusion VQ-token generation with CFG (for t2i)
β”œβ”€β”€ generate.py            # Text CLI
β”œβ”€β”€ image_understand.py    # VQA CLI (hybrid: PyTorch image tokenizer β†’ MLX backbone)
β”œβ”€β”€ t2i.py                 # Text-to-image CLI (hybrid: MLX VQ gen β†’ PyTorch decoder)
└── samples/               # Example outputs

Architecture

Backbone (runs in MLX)

Field Value
Parameters 16B total, ~1B active per token
hidden_size 2048
layers 20 (1 dense + 19 MoE)
attention 16Q / 4KV (GQA), head_dim 128, QK RMSNorm, bidirectional
RoPE partial (rotates first 64 of 128 dims), ΞΈ=600000
MoE 256 routed + 1 shared, top-8 per token, DeepSeek-V2 group-limited (n_group=8, topk_group=4)
moe_intermediate_size 512
vocab 173568 (text + 16384-entry VQ codebook at offset 157184)

Pipeline (hybrid)

Component Size Where
LLaDA2 MoE backbone 32 GB bf16 MLX (GPU)
Image tokenizer (ViT + VQVAE) 2.4 GB PyTorch CPU
SigVQ prior 400 MB fp32 PyTorch MPS
ZImage decoder (6.2B, 30 layers) 12 GB bf16 PyTorch MPS
FLUX VAE 170 MB bf16 PyTorch MPS

Porting the 6.2B ZImageTransformer2DModel decoder + ~4B image tokenizer to MLX would take weeks. They run in PyTorch on MPS instead. The heavy lifting (the 16B MoE LLM that does all three tasks) is the MLX-native part.

Conversion notes

Area Issue Fix
Packed QKV query_key_value linear outputs (nh + 2Β·nkv)Β·d Split along channel, per-head RMSNorm
Partial RoPE HF rotates first head_dim * partial_rotary_factor dims Skip concat if non-rotated half is empty
MoE gate Sigmoid + expert_bias + group-limited top-k argpartition(-scores, k-1)[:, :k] for k-largest indices
MoE dispatch Need gather [N_slots, H_moe, H] weights Chunked at 512 slots to cap transient memory
Packed experts HF stores per-expert ModuleList Stack into [E, out, in] at load time
Bidirectional mask HF uses bool tril-block-diag mask Additive fp32 mask with βˆ’inf off-block
CFG uncond padding Bug β€” fixed Separate forward passes; uncond gets its own attn_mask (pads masked) and position_ids (zeros on pad, start-at-0 for real tokens)
Decoder fp64 MPS doesn't support fp64 Patched decoder/transport/transport.py to use fp32
flash_attn import Not available on Apple Silicon Stub sys.modules['flash_attn'] + downgrade transformers to 4.51

The interesting bug

The cond/uncond CFG paths had different prompt lengths (cond=36, uncond=30 for a typical t2i prompt). My first version shared the same attn_mask and position_ids between them, meaning the uncond path:

  1. Attended to the left-pad <mask> tokens as if they were content.
  2. Saw its real uncond tokens at RoPE positions shifted by pad_len (so the model sees <uncondition> at position 6 rather than position 0, where it was trained).

CFG computes logits_uncond + 4Β·(logits_cond βˆ’ logits_uncond), so small errors in the uncond direction got amplified 4Γ— into visible vertical chromatic stripes in the decoded image.

Fix: separate forward passes for cond and uncond so each path gets its own attention mask (pads masked to βˆ’inf) and position_ids (real tokens start at RoPE position 0, matching training).

Performance (M4 Pro, 64 GB)

Text

  • Short Q&A (e.g. "what is 2+2?") β€” 15s for 32 tokens.
  • Load time: 5.9s warm cache / 8s cold.

Image understanding

  • 512Γ—512 input β†’ 1024 VQ tokens + 10 question tokens β†’ MLX backbone β†’ 48 gen tokens β‰ˆ 90s end-to-end.

Text-to-image (512Γ—512)

Decoder Decode steps Decode time Quality
decoder-turbo (distilled) 8 ~45s Stripes on our tokens (brittle to small VQ noise)
decoder/ (full) 50 ~8 min βœ… Clean, production-quality
decoder/ (full) 25 ~4 min βœ… Clean (tested)

MLX VQ generation itself: ~10 min (8 steps Γ— 9 blocks, CFG scale 4.0).

Default CLI: 50 steps for quality. Use --decoder-steps 25 to halve the decode time with no visible loss.

Links

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for treadon/mlx-llada2-uni

Finetuned
(1)
this model

Paper for treadon/mlx-llada2-uni