MyanBERTa-BiLSTM-CRF-Joint

A joint Word Segmentation (WS) and Part-of-Speech (POS) Tagging model for the Myanmar (Burmese) language, built on top of MyanBERTa-Legal-Finetuned with an Asymmetric BiLSTM + Dual CRF head architecture.

Model Architecture

MyanBERTa-Legal (RoBERTa-based, 768-dim)
+ Position Embedding (64-dim, syllable position from sentence end)
β†’ Concatenated (832-dim)
β†’ Asymmetric BiLSTM:
    Forward LSTM:  832 β†’ 256
    Backward LSTM: 832 β†’ 512
β†’ Concatenated output (768-dim)
β†’ Dual Heads:
    Head 1 (WS):  Linear(768 β†’ 4)  + CRF  [B, I, E, S]
    Head 2 (POS): Linear(768 β†’ 68) + CRF  [B/I/E/S-UPOS]
  • Total Parameters: 112,858,104
  • Mixed Precision: FP16 (AMP)
  • Training Hardware: 2Γ— NVIDIA Tesla T4 (DataParallel)

Training Data

  • Dataset: finalanalyzedposxpos.conllu
  • Format: CoNLL-U, converted to syllable-level BIES tags
  • Size: 103,833 sentences | 3,124,305 syllables
  • Split: 80/10/10 (Train/Val/Test)
    • Train: 83,066 sentences
    • Val: 10,383 sentences
    • Test: 10,384 sentences

Evaluation Results (Test Set)

Task Precision Recall F1
Word Segmentation (WS) 0.9222 0.9379 0.9300
POS Tagging 0.8702 0.8965 0.8832
Combined (0.4Γ—WS + 0.6Γ—POS) β€” β€” 0.9019

POS Tagging β€” Per-Class F1 (Test Set)

POS Tag Precision Recall F1 Support
ADJ 0.7516 0.7925 0.7715 4,665
ADP 0.9351 0.9668 0.9507 25,367
ADV 0.7120 0.7646 0.7374 3,233
AUX 0.5961 0.6267 0.6110 1,776
CCONJ 0.8390 0.7859 0.8116 3,713
DET 0.8960 0.8870 0.8915 602
INTJ 0.5890 0.8600 0.6992 50
NOUN 0.8332 0.8651 0.8489 48,068
NUM 0.9370 0.9530 0.9449 4,869
PART 0.9118 0.9325 0.9220 44,738
PRON 0.9175 0.9318 0.9246 4,308
PROPN 0.6595 0.7156 0.6864 1,846
PUNCT 0.9803 0.9900 0.9852 17,331
SCONJ 0.7150 0.8154 0.7619 1,658
SYM 0.8269 0.8113 0.8190 53
VERB 0.8221 0.8497 0.8357 29,721
X 0.3845 0.4184 0.4007 521

Training Details

  • Base Model: sithu015/MyanBERTa-legal-finetuned
  • Tokenizer: MyanBERTa (RoBERTa-based, syllable-level)
  • Max Sequence Length: 300
  • Batch Size: 48 (total, 2 GPUs)
  • Epochs Trained: 12 (early stop, patience=3)
  • Best Val Combined F1: 0.9010 (Epoch 9)
  • Total Training Time: ~7.8 hours
  • Optimizer: AdamW
    • BERT layers: lr = 2e-5
    • BiLSTM / heads / CRF: lr = 1e-3
  • Scheduler: Linear warmup (10% steps)
  • Loss: CRF loss + 0.3 Γ— Cross-Entropy auxiliary loss (both heads)
  • Class Weighting: WeightedRandomSampler (4Γ— boost for rare POS tags: AUX, INTJ, SYM, PROPN, SCONJ, DET, X)

Training Progress (Validation Combined F1 per Epoch)

Epoch Loss WS F1 POS F1 Combined F1
1 73.6700 0.8871 0.7974 0.8333
2 25.0627 0.9055 0.8418 0.8673
3 18.7039 0.9180 0.8636 0.8853
4 15.6097 0.9213 0.8656 0.8879
5 13.4530 0.9249 0.8718 0.8930
6 11.9508 0.9268 0.8762 0.8964
7 10.8137 0.9283 0.8787 0.8985
8 9.6801 0.9287 0.8811 0.9001
9 8.9287 0.9293 0.8821 0.9010 βœ“ Best
10 8.2989 0.9290 0.8794 0.8992
11 10.0370 0.9231 0.8679 0.8900
12 12.8928 0.9208 0.8631 0.8862

Usage

This model uses a custom architecture and is not compatible with the standard transformers pipeline(). You need the JointSegPosModel class and a syllable segmentation function. Below is an example using the model files provided in this repository.

import torch
import json
from transformers import AutoTokenizer

# --- Define syllable_segment (example: split on space or use your own segmenter) ---
def syllable_segment(text):
    """Split text into syllable-level tokens. Replace with your actual syllable segmenter."""
    return text.split()

# Load label maps
with open("wsid2label.json") as f:
    wsid2label = {int(k): v for k, v in json.load(f).items()}
with open("posid2label.json") as f:
    posid2label = {int(k): v for k, v in json.load(f).items()}
with open("wslabel2id.json") as f:
    wslabel2id = json.load(f)
with open("poslabel2id.json") as f:
    poslabel2id = json.load(f)

tokenizer = AutoTokenizer.from_pretrained("UCSYNLP/MyanBERTa")

# Build model (JointSegPosModel class required from training code)
model = JointSegPosModel(
    "sithu015/MyanBERTa-legal-finetuned",
    num_ws_labels=len(wslabel2id),
    num_pos_labels=len(poslabel2id)
)
model.load_state_dict(torch.load("bestmodel.pt", map_location="cpu"))
model.eval()

# Tokenize and predict
syllables = syllable_segment("ကို၏ α€”α€¬α€™α€Šα€Ί မှာ α€€α€­α€―α€€α€­α€― α€–α€Όα€…α€Ία€žα€Šα€Ί")  # syllable-level list
encoding = tokenizer(
    syllables,
    is_split_into_words=True,
    return_tensors="pt",
    truncation=True,
    max_length=300,
    padding="max_length"
)

with torch.no_grad():
    ws_preds, pos_preds = model(encoding["input_ids"], encoding["attention_mask"])

Files

File Description
bestmodel.pt Trained model weights (PyTorch)
wslabel2id.json WS label β†’ ID mapping
wsid2label.json WS ID β†’ label mapping
poslabel2id.json POS label β†’ ID mapping
posid2label.json POS ID β†’ label mapping
modelmetadata.json Training metadata and results
config.json Model configuration

Citation

If you use this model, please cite:

@misc{sithu015-myanberta-bilstm-crf-joint-2026,
  author    = {Sithu Aung},
  title     = {MyanBERTa-BiLSTM-CRF-Joint: Joint Word Segmentation and POS Tagging for Myanmar},
  year      = {2026},
  publisher = {Hugging Face},
  url       = {https://huggingface.co/sithu015/MyanBERTa-BiLSTM-CRF-Joint}
}

License

Apache 2.0

Downloads last month
210
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for sithu015/MyanBERTa-BiLSTM-CRF-Joint

Finetuned
(1)
this model