Malayalam Scene Text OCR β€” PARSeq-style

A PARSeq-inspired OCR model (ViT encoder + Transformer decoder) for Malayalam scene text recognition.

Available Checkpoints

Checkpoint Description Word Acc (real val)
parseq_best.pth Pretrained on 950k synthetic images 97.64% (synthetic)
parseq_finetuned_best.pth Finetuned v1 (100 epochs, basic) 84.15%
parseq_finetuned_v2.pth Finetuned v2 (150 epochs, label smoothing + cosine restart) 91.46%

Use parseq_finetuned_v2.pth for best results on real Malayalam scene text.


Benchmark Results

Evaluated on 82 real Malayalam scene text images (val_ split, IndicVignesh dataset). All scores normalized for Malayalam Unicode variants (Chillu characters, ZWJ/ZWNJ).

Model Word Acc Char Acc
GPT-5.4 34.15% 63.17%
Claude Sonnet 4.5 35.37% 56.18%
Claude Sonnet 4.6 84.15% 93.57%
Gemini 3 Flash Preview 85.37% 94.76%
Claude Opus 4.6 86.59% 95.41%
PARSeq v2 (Ours) 91.46% 97.18%

Our 25M parameter specialized model beats all frontier VLMs including Claude Opus 4.6, runs locally at zero inference cost.


Model Architecture

  • Encoder: Vision Transformer (ViT) β€” patch size 4Γ—8 on 32Γ—128 images β†’ 128 patches
  • Decoder: Autoregressive Transformer decoder with causal masking
  • Parameters: ~25M
  • Vocab: 99 tokens (95 Malayalam characters + [PAD], [BOS], [EOS], [UNK])
  • Input: 32Γ—128 RGB images (auto-resized)

Quick Start

Install

pip install torch torchvision huggingface_hub Pillow

Define model class

import math
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    def __init__(self, img_h=32, img_w=128, patch_h=4, patch_w=8, in_chans=3, embed_dim=384):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        return self.norm(self.proj(x).flatten(2).transpose(1, 2))

class SinusoidalPE(nn.Module):
    def __init__(self, embed_dim, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe  = torch.zeros(max_len, embed_dim)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

class ViTEncoder(nn.Module):
    def __init__(self, img_h=32, img_w=128, patch_h=4, patch_w=8, in_chans=3,
                 embed_dim=384, depth=6, num_heads=6, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_h, img_w, patch_h, patch_w, in_chans, embed_dim)
        self.pos_enc     = SinusoidalPE(embed_dim, max_len=512, dropout=dropout)
        encoder_layer    = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                             dim_feedforward=int(embed_dim*mlp_ratio), dropout=dropout,
                             activation='gelu', batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm    = nn.LayerNorm(embed_dim)
    def forward(self, x):
        return self.norm(self.encoder(self.pos_enc(self.patch_embed(x))))

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=384, depth=6, num_heads=6,
                 mlp_ratio=4.0, dropout=0.1, max_label_len=26, pad_idx=0):
        super().__init__()
        self.pad_idx     = pad_idx
        self.token_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.pos_enc     = SinusoidalPE(embed_dim, max_len=max_label_len+2, dropout=dropout)
        decoder_layer    = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads,
                             dim_feedforward=int(embed_dim*mlp_ratio), dropout=dropout,
                             activation='gelu', batch_first=True, norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=depth)
        self.norm    = nn.LayerNorm(embed_dim)
        self.head    = nn.Linear(embed_dim, vocab_size)
    def _causal_mask(self, sz, device):
        return torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
    def forward(self, tgt_inp, memory):
        B, T = tgt_inp.shape
        x = self.pos_enc(self.token_embed(tgt_inp))
        x = self.decoder(tgt=x, memory=memory,
                         tgt_mask=self._causal_mask(T, tgt_inp.device),
                         tgt_key_padding_mask=(tgt_inp == self.pad_idx))
        return self.head(self.norm(x))

class PARSeqOCR(nn.Module):
    def __init__(self, vocab_size, img_h=32, img_w=128, patch_h=4, patch_w=8,
                 embed_dim=384, enc_depth=6, dec_depth=6, num_heads=6,
                 mlp_ratio=4.0, dropout=0.1, max_label_len=25, pad_idx=0):
        super().__init__()
        self.max_label_len = max_label_len
        self.pad_idx       = pad_idx
        self.encoder = ViTEncoder(img_h, img_w, patch_h, patch_w, 3, embed_dim,
                                  enc_depth, num_heads, mlp_ratio, dropout)
        self.decoder = TransformerDecoder(vocab_size, embed_dim, dec_depth, num_heads,
                                          mlp_ratio, dropout, max_label_len, pad_idx)

    def forward(self, images, tgt_inp):
        return self.decoder(tgt_inp, self.encoder(images))

    @torch.no_grad()
    def greedy_decode(self, images, bos_idx, eos_idx, max_len=None):
        """Fast greedy decoding β€” good for batches."""
        self.eval()
        max_len   = max_len or self.max_label_len
        B, device = images.size(0), images.device
        memory    = self.encoder(images)
        generated = torch.full((B, 1), bos_idx, dtype=torch.long, device=device)
        finished  = torch.zeros(B, dtype=torch.bool, device=device)
        for _ in range(max_len):
            next_token = self.decoder(generated, memory)[:, -1, :].argmax(-1)
            next_token = torch.where(finished, torch.full_like(next_token, self.pad_idx), next_token)
            generated  = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
            finished   = finished | (next_token == eos_idx)
            if finished.all(): break
        preds = []
        for seq in generated.tolist():
            seq = seq[1:]
            if eos_idx in seq: seq = seq[:seq.index(eos_idx)]
            preds.append(seq)
        return preds

    @torch.no_grad()
    def beam_decode(self, images, bos_idx, eos_idx, beam_size=5, max_len=None):
        """Beam search decoding β€” slightly more accurate, slower."""
        self.eval()
        max_len = max_len or self.max_label_len
        device  = images.device
        B       = images.size(0)
        memory  = self.encoder(images)
        all_preds = []
        for b in range(B):
            mem   = memory[b:b+1]
            beams = [(0.0, [bos_idx])]
            completed = []
            for _ in range(max_len):
                new_beams = []
                for score, tokens in beams:
                    if tokens[-1] == eos_idx:
                        completed.append((score, tokens))
                        continue
                    seq      = torch.tensor([tokens], dtype=torch.long, device=device)
                    logits   = self.decoder(seq, mem)
                    log_prob = torch.log_softmax(logits[0, -1, :], dim=-1)
                    topk_scores, topk_ids = log_prob.topk(beam_size)
                    for s, t in zip(topk_scores.tolist(), topk_ids.tolist()):
                        new_beams.append((score + s, tokens + [t]))
                new_beams.sort(key=lambda x: x[0], reverse=True)
                beams = new_beams[:beam_size]
                if all(t[-1] == eos_idx for _, t in beams):
                    completed.extend(beams)
                    break
            completed.extend(beams)
            completed.sort(key=lambda x: x[0] / max(len(x[1]), 1), reverse=True)
            best = completed[0][1][1:]
            if eos_idx in best: best = best[:best.index(eos_idx)]
            all_preds.append(best)
        return all_preds

Load model and run inference

import json, torch
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download

REPO   = 'magles/malayalam-ocr-parseq'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download files
ckpt_path     = hf_hub_download(REPO, 'parseq_finetuned_v2.pth')
char2idx_path = hf_hub_download(REPO, 'char2idx.json')
idx2char_path = hf_hub_download(REPO, 'idx2char.json')

# Load vocab
with open(char2idx_path, encoding='utf-8') as f:
    char2idx = json.load(f)
with open(idx2char_path, encoding='utf-8') as f:
    idx2char = {int(k): v for k, v in json.load(f).items()}

# Load model
ckpt  = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
model = PARSeqOCR(vocab_size=len(char2idx))
model.load_state_dict(ckpt['model'])
model = model.to(DEVICE)
model.eval()
print(f"Model loaded β€” vocab: {len(char2idx)}")

# Preprocess
transform = T.Compose([
    T.Resize((32, 128)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

bos_idx = char2idx['[BOS]']
eos_idx = char2idx['[EOS]']

def predict(image_path, use_beam=True, beam_size=5):
    img    = Image.open(image_path).convert('RGB')
    tensor = transform(img).unsqueeze(0).to(DEVICE)
    if use_beam:
        indices = model.beam_decode(tensor, bos_idx, eos_idx, beam_size=beam_size)[0]
    else:
        indices = model.greedy_decode(tensor, bos_idx, eos_idx)[0]
    return ''.join(idx2char.get(i, '') for i in indices)

# Single image
print(predict('your_image.jpg'))

# Batch (greedy is faster for batches)
def predict_batch(image_paths):
    imgs = torch.stack([transform(Image.open(p).convert('RGB'))
                        for p in image_paths]).to(DEVICE)
    all_seqs = model.greedy_decode(imgs, bos_idx, eos_idx)
    return [''.join(idx2char.get(i, '') for i in seq) for seq in all_seqs]

Greedy vs Beam Search

# Greedy β€” fast, good for batches, ~84% word accuracy on v1 / 91% on v2
indices = model.greedy_decode(tensor, bos_idx, eos_idx)[0]

# Beam search β€” slightly more accurate, slower (processes one image at a time internally)
# beam_size=5 is the sweet spot β€” larger values give no further improvement
indices = model.beam_decode(tensor, bos_idx, eos_idx, beam_size=5)[0]

Malayalam Unicode Normalization

Some ground truth labels use different Unicode encodings for the same visual character (e.g. Chillu characters: ΰ΅Ύ = U+0D7E vs ളࡍ‍ = U+0D33 + U+0D4D + U+200D). Normalize before comparing predictions:

CHILLU_MAP = {
    '\u0d7a': '\u0d23\u0d4d',
    '\u0d7b': '\u0d28\u0d4d',
    '\u0d7c': '\u0d30\u0d4d',
    '\u0d7d': '\u0d32\u0d4d',
    '\u0d7e': '\u0d33\u0d4d',
    '\u0d7f': '\u0d15\u0d4d',
}

def normalize_malayalam(text):
    text = text.strip().replace('\u200c', '').replace('\u200d', '')
    for chillu, base in CHILLU_MAP.items():
        text = text.replace(chillu, base)
    return text

# Compare
normalize_malayalam(pred) == normalize_malayalam(gt)

Training Details

Pretraining (parseq_best.pth)

  • Dataset: 950,000 synthetic Malayalam scene text images
  • Epochs: 20 | Batch: 512 | LR: 1e-3 (OneCycleLR) | AMP: fp16
  • Hardware: RTX 5090 | Time: ~3 hours

Finetuning v1 (parseq_finetuned_best.pth)

  • Dataset: 915 real images (IndicVignesh finetune_ split)
  • Epochs: 100 | Batch: 32 | LR: 1e-4
  • Val: 82 real images (val_ split)

Finetuning v2 (parseq_finetuned_v2.pth) ← recommended

  • Dataset: 915 real images (same as v1)
  • Epochs: 150 | Batch: 32 | LR: 1e-4
  • Label smoothing: 0.1 β€” prevents overconfidence, improves generalization
  • Scheduler: CosineAnnealingWarmRestarts (T_0=50) β€” 3 LR restart cycles
  • Encoder frozen for first 10 epochs, then unfrozen
  • Best epoch: 121

Citation

@inproceedings{bautista2022parseq,
  title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
  author={Bautista, Darwin and Atienza, Rowel},
  booktitle={European Conference on Computer Vision (ECCV)},
  year={2022}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support