Cerebellum — Android UI Action Predictor
LoRA adapter on top of google/gemma-4-E4B-it that predicts the next Android UI action given a screenshot and accessibility tree.
Architecture: The LLM (or orchestrating agent) issues high-level intent. Cerebellum executes it locally by grounding intent to a specific UI element and action — without screenshot round-trips to a remote model.
What It Does
Given a task goal, the current screen (screenshot + accessibility tree), and optional action history, the model outputs a single compact action code indicating what to do next.
Input Format
The model uses a chat-style prompt (Gemma4 format). The user turn is structured as:
Task: {goal}
Step 1 (past): <|image|> -> {action_text}
Step 2 (past): <|image|> -> {action_text}
...
Current screen: <|image|>
{compressed_accessibility_tree}
[n zone]=tap-target(top-to-bottom left-to-right) zone=tl/tc/tr/ml/mc/mr/bl/bc/br ed=text-input sr=scrollable fc=focused(use 'K your_text' to type here)
Actions: T{n}=tap element n, P{n}=long-press element n, K {text}=type text(space required), U/D/L/R=scroll(single token), B=back, H=home, W=wait, F=done, I=impossible
Next action:
Inputs:
goal— natural language task description (e.g. "Open the settings app and enable dark mode")history— up to 4 past (screenshot, action) pairs; can be emptycurrent screenshot— PIL image of the current screen, resized to 896px on the long edgecompressed_accessibility_tree— compact text representation of the UI element tree (see below)
Accessibility Tree Format
Each interactive element is one line:
[0 btn tl] Settings
[1 ed mc fc=focused] Search...
[2 btn sr tr] More options
Fields per element:
[n]— element index (used in action codes)- type:
bt=Button,ed=EditText,tx=TextView,im=ImageView,ck=CheckBox,sw=Switch,rd=RadioButton,sp=Spinner,sc=ScrollView,ls=ListView/RecyclerView,bar=Toolbar,tab=TabLayout,dw=DrawerLayout,vw=other - zone: screen position of element center — row (
t=top,m=mid,b=bottom) + col (l=left,c=center,r=right), e.g.tl=top-left,mc=mid-center fc— element has keyboard focus (K action types here)ed— element is editable (text input)sr— element is scrollablehd— element supports long-pressds— element is disabledck/uc— checkbox checked/uncheckedsl— element is selectedpw— password field...above (N nodes, scroll up)/...below (N nodes, scroll down)— off-screen content indicators
Tree Compression Rules
The raw Android accessibility tree is compressed before being passed to the model:
- Node filtering — nodes without text, content description, resource ID, clickability, or scrollability are collapsed (their children are promoted up)
- Off-screen filtering — nodes fully outside the screen bounds (
x2<=0,y1>=screen_height, etc.) are excluded; replaced with...above (N)/...below (N)scroll indicators - Sibling deduplication — identical sibling subtrees are rendered only once (handles repeated list items)
- Multi-window deduplication — Android can return multiple root windows; duplicate root blocks are dropped. Roots whose tappable element IDs are fully covered by a larger root are also dropped (handles split-screen / overlay artifacts)
- Largest-first ordering — when multiple roots exist, the most complete window (most lines) is rendered first
- Element indexing — only
clickable=true AND enabled=truenodes get a numeric index[n]. Non-clickable nodes are rendered without an index. Index order is top-to-bottom, left-to-right by element position - Type abbreviation — class names are mapped to short tags (e.g.
android.widget.Button→bt) - Zone encoding — element center is bucketed into a 3×3 grid zone string (
tl/tc/tr/ml/mc/mr/bl/bc/br) - Label selection —
textis preferred; falls back tocontent_desc; falls back toresource_id(last component after/) - Hard truncation — tree is truncated at 4000 characters before tokenization to prevent OOM on dense screens
Output Format
A single action code (one forward pass, greedy decode):
| Code | Action | Example |
|---|---|---|
T{n} |
Tap element n | T7 |
P{n} |
Long-press element n | P3 |
K {text} |
Type text into focused field | K hello world |
U |
Scroll up | U |
D |
Scroll down | D |
L |
Scroll left | L |
R |
Scroll right | R |
B |
System back | B |
H |
Home button | H |
W |
Wait (screen loading) | W |
F |
Done (task complete) | F |
I |
Impossible (task cannot complete) | I |
Single-token actions (U/D/L/R/B/H/W/F/I) self-terminate — no EOS token follows. T/P generate up to 5 tokens (letter + digits + EOS). K generates until EOS.
Inference-Time Error Recovery
The model occasionally produces malformed outputs (action letter fused with wrong content, e.g. B4, W3, T some text). A lightweight validator detects these and retries with a disambiguating correction blurb appended to the prompt:
Next action:
'B4' is not valid. Did you mean 'B' (back) or 'T4' (tap element 4)? Try again:
This zero-shot correction resolves the majority of format errors without additional training.
Performance (step 656)
Evaluated on AndroidControl dataset (accessibility tree format, single-step predictions):
| Metric | Last 20 steps | Last 50 steps | All (102 steps) |
|---|---|---|---|
| Overall accuracy | 95.0% | 92.0% | 88.2% |
| Element index accuracy | 93.3% | 88.6% | 84.6% |
Action type breakdown (last 20 steps):
| Action | Accuracy |
|---|---|
| tap (T) | 93% |
| scroll (U/D/L/R) | 100% |
| back (B) | 100% |
| type (K) | 100% |
| wait (W) | 100% |
Remaining errors are primarily element index off-by-one on tap targets — a known SFT ceiling, addressed by RL.
Training Process
Base model: google/gemma-4-E4B-it (4B MoE, 4-bit quantized during training via bitsandbytes)
LoRA config:
r=64,alpha=32,dropout=0.05- Target modules: all linear layers in the transformer
Training data: AndroidControl dataset (accessibility tree variant), ~20 shards from GCS. Each sample is a single (screenshot, a11y tree, goal, history) → action step from a real Android interaction trajectory.
Key training decisions:
- No label smoothing — removed after identifying it softened action type gradients
accum_steps=1— every sample is its own gradient update (maximum signal density)lr=5e-5, cosine schedule- Grammar-constrained loss: inference-time cap per action type (T/P: 5 tokens max, single-token actions: 1 token). Wrong action type predictions lose access to downstream element-index reward
- Type token weights: tap=4.0, long_press=4.0, type=8.0, scrolls=8.0 (upweighted to prevent collapse)
- Sample weights: rare actions (back/home/wait/done/impossible) upweighted 3× to prevent tap dominance
- Rolling window diversity quota (window=20): ensures each action type appears proportionally in recent batches
Training infrastructure:
- Single RTX 3060 12GB
- ~100s/step (full image + tree encoding + gradient update)
- Milestone checkpoints every ~100 steps via sentinel file
To replicate from scratch:
- Download AndroidControl dataset (GCS, 20 shards, ~47GB)
- Preprocess with
scripts/preprocess_a11y.pyto extract accessibility trees - Train:
py -3.11 -u scripts/train_autoregressive.py --out checkpoints/autoreg/current - Resume:
py -3.11 -u scripts/train_autoregressive.py --resume checkpoints/autoreg/current/step_XXXXXXX --out checkpoints/autoreg/current - Monitor: tail the log file for HIT/miss lines; ntfy.sh push notifications every 5 steps (topic: Cerebellum-Training)
Loading the Adapter
from transformers import AutoProcessor
from peft import PeftModel
from transformers import Gemma4ForConditionalGeneration
import torch
base = Gemma4ForConditionalGeneration.from_pretrained(
"google/gemma-4-E4B-it",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(base, "dmitchelljackson/cerebellum-e4b-lora")
processor = AutoProcessor.from_pretrained("dmitchelljackson/cerebellum-e4b-lora")
model.eval()
Roadmap
- SFT on AndroidControl (~88-95% single-step accuracy)
- Inference-time error recovery (format validator + correction blurb)
- RL fine-tuning (GRPO) on AndroidWorld tasks for multi-step accuracy and semantic recovery
- Error recovery fine-tuning on collected failure cases
- Downloads last month
- 21