Qwen3-4B-ARC-MLX-4bit

An MLX 4-bit quantized version of sorokin/qwen3_4b_grids15_sft139, a Qwen3-4B model fine-tuned for ARC-AGI-2 puzzle solving.

This model is designed to run on Apple Silicon (M1/M2/M3/M4) using the MLX framework and uses test-time training (TTT) -- per-puzzle QLoRA fine-tuning -- to solve ARC puzzles.

Key Details

Property Value
Architecture Qwen3-4B (36 layers, 2560 hidden, 32 heads)
Vocabulary 16 tokens (digits 0-9, newline, user/assistant/pad/eos)
Quantization 4-bit (affine, group_size=64) via mlx-lm
Model Size 1.9 GB
Base Model sorokin/qwen3_4b_grids15_sft139 (Qwen3-4B fine-tuned on ~100K synthetic ARC grids)
Framework MLX (Apple Silicon)

What is ARC-AGI-2?

ARC-AGI-2 (Abstraction and Reasoning Corpus) is a benchmark for measuring artificial general intelligence. Each puzzle consists of:

  • Training pairs: 2-5 input/output grid examples demonstrating a transformation rule
  • Test input: A new grid where the model must predict the output
  • Grids are 2D arrays of colors (digits 0-9), up to 30x30

The challenge is that each puzzle has a unique, never-before-seen rule -- the model must infer the rule from just a few examples.

How It Works: Test-Time Training (TTT)

This model cannot solve ARC puzzles via simple inference. It requires per-puzzle fine-tuning:

  1. Augment the puzzle's training pairs (rotations, transpositions, color permutations) to create 8-16 training sequences
  2. QLoRA fine-tune the model on these augmented examples (5 epochs, rank 32-128)
  3. Generate multiple candidate outputs via sampling
  4. Vote on the most common valid grid
  5. Reset LoRA weights before the next puzzle

This approach was pioneered by the NVARC team (1st place ARC Prize 2025) and adapted here for MLX.

Grid Encoding

Grids are encoded as digit strings with newlines, wrapped in a chat-like format:

<user_token>\n{input_grid}<eos_token><assistant_token>\n{output_grid}<eos_token>

The 16-token vocabulary maps directly: 0-9 for digits, 10 for newline, 11 for user, 12 for assistant, 13 for pad, 15 for eos.

Usage with MLX

Requirements

pip install mlx-lm>=0.31.0

Quick Start

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx_lm
from mlx_lm.tuner.utils import linear_to_lora_layers, remove_lora_layers

# Load model
model, tokenizer = mlx_lm.load("mlx-community/Qwen3-4B-ARC-MLX-4bit")

# Apply QLoRA
n_layers = len(model.model.layers)
linear_to_lora_layers(model, num_layers=n_layers, config={
    'rank': 32, 'dropout': 0.0, 'scale': 10.0
})
mx.eval(model.parameters())

# Token IDs
USER, ASSISTANT, EOS, NEWLINE = 11, 12, 15, 10

def grid_to_tokens(grid):
    tokens = []
    for i, row in enumerate(grid):
        tokens.extend(int(c) for c in row)
        if i < len(grid) - 1:
            tokens.append(NEWLINE)
    return tokens

# Build training sequence from puzzle's training pairs
def build_sequence(train_pairs):
    tokens = []
    for pair in train_pairs:
        tokens += [USER, NEWLINE] + grid_to_tokens(pair["input"]) + [EOS]
        tokens += [ASSISTANT, NEWLINE] + grid_to_tokens(pair["output"]) + [EOS]
    return tokens

# Example puzzle
puzzle = {
    "train": [
        {"input": [[1, 2], [3, 4]], "output": [[4, 3], [2, 1]]},
        {"input": [[5, 6], [7, 8]], "output": [[8, 7], [6, 5]]},
    ],
    "test": [{"input": [[9, 0], [1, 2]]}]
}

# --- Test-Time Training ---
train_tokens = build_sequence(puzzle["train"])
labels = [-100] * len(train_tokens)

# Unmask assistant turns only
i = 0
while i < len(train_tokens):
    if train_tokens[i] == ASSISTANT:
        i += 2  # skip ASSISTANT + NEWLINE
        while i < len(train_tokens) and train_tokens[i] != EOS:
            labels[i] = train_tokens[i]
            i += 1
        if i < len(train_tokens):
            labels[i] = train_tokens[i]  # EOS
    i += 1

optimizer = optim.Adam(learning_rate=5e-5)
for epoch in range(5):
    inp = mx.array([train_tokens[:-1]])
    tgt = mx.array([labels[1:]])
    
    def loss_fn(model):
        out = model(inp)
        if isinstance(out, tuple): out = out[0]
        logits = out.reshape(-1, out.shape[-1])
        target = tgt.reshape(-1)
        mask = target != -100
        ce = nn.losses.cross_entropy(logits, target, reduction='none')
        return (ce * mask).sum() / mask.sum()
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)
    print(f"Epoch {epoch}: loss={loss.item():.4f}")

# --- Generate ---
prompt = build_sequence(puzzle["train"])
prompt += [USER, NEWLINE] + grid_to_tokens(puzzle["test"][0]["input"]) + [EOS]
prompt += [ASSISTANT, NEWLINE]

sampler = mlx_lm.sample_utils.make_sampler(temp=0.5)
response = mlx_lm.generate(model, tokenizer, prompt=prompt, max_tokens=200,
                           sampler=sampler, verbose=False)
print(f"Output: {response}")

# --- Reset for next puzzle ---
remove_lora_layers(model)
linear_to_lora_layers(model, num_layers=n_layers, config={
    'rank': 32, 'dropout': 0.0, 'scale': 10.0
})
mx.eval(model.parameters())

Results

ARC-AGI-2 Training Set (304 tasks, grids up to 10x10)

Metric Value
Tasks evaluated 304
Test instances 328
Correct 192
Accuracy 58.5%
Avg time per task 479s (~8 min)
Total compute time 40.5 hours

This is an unofficial score on the ARC-AGI-2 training split (grids up to 10x10). For reference, the ARC-AGI-2 Kaggle leaderboard top scores are around 12% on the hidden test set, which contains harder puzzles with larger grids (up to 30x30). Training set tasks are generally easier, so this score is not directly comparable to leaderboard scores.

Split First 152 tasks Last 152 tasks Combined
Correct / Total 94/163 98/165 192/328
Accuracy 57.7% 59.4% 58.5%

The consistent accuracy across both halves (run independently) validates the reproducibility of the results.

Configuration

  • LoRA rank: 32
  • Learning rate: 5e-5
  • TTT epochs: 5
  • Augmentations: 4 color permutations + 3 rotations + 2 transpositions
  • Samples per task: 3 (+ 1 greedy)
  • Temperature: 0.5
  • Hardware: Apple M3 Max (36GB)

Comparison with Original NVARC

Aspect NVARC (Kaggle) This MLX Port
Framework PyTorch + Unsloth MLX
Hardware 2x NVIDIA T4 (32GB) Apple Silicon (24-36GB)
Quantization 4-bit via bitsandbytes 4-bit via mlx-lm
LoRA rank 256 32
Decoding Turbo DFS (tree search) Sampling + voting
Speed per task ~540s ~480s

Limitations

  • Requires TTT: The model produces garbage without per-puzzle fine-tuning
  • Speed: Large grids (20x20+) with many training pairs can take over an hour on Apple Silicon
  • No DFS decoding: The original NVARC uses a turbo DFS tree search for better candidate generation; this port uses simpler sampling + voting
  • Large grids: Only evaluated on grids up to 10x10; larger grids are significantly slower and may have lower accuracy

Acknowledgments

  • NVARC team for the original ARC Prize 2025 winning solution
  • sorokin for the pre-trained Qwen3-4B ARC model
  • MLX team for the Apple Silicon ML framework
  • ARC Prize for the benchmark
Downloads last month
99
Safetensors
Model size
0.6B params
Tensor type
BF16
·
U32
·
MLX
Hardware compatibility
Log In to add your hardware

4-bit

Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support