Cerebellum — Android UI Action Predictor

LoRA adapter on top of google/gemma-4-E4B-it that predicts the next Android UI action given a screenshot and accessibility tree.

Architecture: The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model.


What It Does

Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next.


Input Format

The model uses a chat-style prompt (Gemma4 format). The user turn is structured as:

Task: {goal}

Step 1 (past): <|image|> -> {action_text}
Step 2 (past): <|image|> -> {action_text}
...
Current screen: <|image|>
{compressed_accessibility_tree}
[n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br  ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here)
Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible
Next action:

Inputs:

  • goal — natural language task description (e.g. "Open the settings app and enable dark mode")
  • history — up to 4 past (screenshot, action) pairs; can be empty
  • current screenshot — PIL image of the current screen, resized to 896px on the long edge
  • compressed_accessibility_tree — compact text representation of the UI element tree (see below)

Accessibility Tree Format

Each interactive element is one line:

[0 btn tl] Settings
[1 ed mc fc=focused] Search...
[2 btn sr tr] More options

Fields per element:

  • [n] — element index (used in action codes)
  • type: bt=Button, ed=EditText, tx=TextView, im=ImageView, ck=CheckBox, sw=Switch, rd=RadioButton, sp=Spinner, sc=ScrollView, ls=ListView/RecyclerView, bar=Toolbar, tab=TabLayout, dw=DrawerLayout, vw=other
  • zone: screen position of element center — row (t=top, m=mid, b=bottom) + col (l=left, c=center, r=right), e.g. tl=top-left, mc=mid-center
  • fc — element has keyboard focus (K action types here)
  • ed — element is editable (text input)
  • sr — element is scrollable
  • hd — element supports long-press
  • ds — element is disabled
  • ck/uc — checkbox checked/unchecked
  • sl — element is selected
  • pw — password field
  • ...above (N nodes, scroll up) / ...below (N nodes, scroll down) — off-screen content indicators

Tree Compression Rules

The raw Android accessibility tree is compressed before being passed to the model:

  1. Node filtering — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up)
  2. Off-screen filtering — nodes fully outside the screen bounds (x2<=0, y1>=screen_height, etc.) are excluded; replaced with ...above (N) / ...below (N) scroll indicators
  3. Sibling deduplication — identical sibling subtrees are rendered only once (handles repeated list items)
  4. Multi-window deduplication — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts)
  5. Largest-first ordering — when multiple roots exist, the most complete window (most lines) is rendered first
  6. Element indexing — only clickable=true AND enabled=true nodes get a numeric index [n]. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position
  7. Type abbreviation — class names are mapped to short tags (e.g. android.widget.Button → bt)
  8. Zone encoding — element center is bucketed into a 3×3 grid zone string (tl/tc/tr/ml/mc/mr/bl/bc/br)
  9. Label selection — text is preferred; falls back to content_desc; falls back to resource_id (last component after /)
  10. Hard truncation — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens

Output Format

A single action code (one forward pass, greedy decode):

Code Action Example
T{n} Tap element n T7
P{n} Long-press element n P3
K {text} Type text into focused field K hello world
U Scroll up U
D Scroll down D
L Scroll left L
R Scroll right R
B System back B
H Home button H
W Wait (screen loading) W
F Done (task complete) F
I Impossible (task cannot complete) I

Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS.


Inference-Time Error Recovery

The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. B4, W3, T some text). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt:

Next action:
'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again:

This zero-shot correction resolves the majority of format errors without additional training.


Performance (step 656)

Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions):

Metric Last 20 steps Last 50 steps All (102 steps)
Overall accuracy 95.0% 92.0% 88.2%
Element index accuracy 93.3% 88.6% 84.6%

Action type breakdown (last 20 steps):

Action Accuracy
tap (T) 93%
scroll (U/D/L/R) 100%
back (B) 100%
type (K) 100%
wait (W) 100%

Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL.


Training Process

Base model: google/gemma-4-E4B-it (4B MoE, 4-bit quantized during training via bitsandbytes)

LoRA config:

  • r=64, alpha=32, dropout=0.05
  • Target modules: all linear layers in the transformer

Training data: AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory.

Key training decisions:

  • No label smoothing — removed after identifying it softened action type gradients
  • accum_steps=1 — every sample is its own gradient update (maximum signal density)
  • lr=5e-5, cosine schedule
  • Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward
  • Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse)
  • Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance
  • Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches

Training infrastructure:

  • Single RTX 3060 12GB
  • ~100s/step (full image + tree encoding + gradient update)
  • Milestone checkpoints every ~100 steps via sentinel file

To replicate from scratch:

  1. Download AndroidControl dataset (GCS, 20 shards, ~47GB)
  2. Preprocess with scripts/preprocess_a11y.py to extract accessibility trees
  3. Train: py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current
  4. Resume: py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current
  5. Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training)

Loading the Adapter

from transformers import AutoProcessor
from peft import PeftModel
from transformers import Gemma4ForConditionalGeneration
import torch

base = Gemma4ForConditionalGeneration.from_pretrained(
    "google/gemma-4-E4B-it",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora")
processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora")
model.eval()

Roadmap

  • SFT on AndroidControl (~88-95% single-step accuracy)
  • Inference-time error recovery (format validator + correction blurb)
  • RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery
  • Error recovery fine-tuning on collected failure cases
Downloads last month
21
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dmitchelljackson/cerebellum-e4b-lora

Adapter
(70)
this model