| """
|
| Panini Tokenizer V3 - Morphology-Aware Sanskrit Tokenizer
|
| HuggingFace PreTrainedTokenizer compatible.
|
| """
|
|
|
| import json
|
| import os
|
| from typing import Dict, List, Optional, Tuple, Union
|
| from collections import OrderedDict
|
|
|
|
|
| try:
|
| from transformers import PreTrainedTokenizer
|
| from transformers.tokenization_utils_base import AddedToken
|
| HAS_TRANSFORMERS = True
|
| except ImportError:
|
| HAS_TRANSFORMERS = False
|
| PreTrainedTokenizer = object
|
|
|
| from .analyzer import VidyutAnalyzer, MorphParse
|
| from .splitter import SamasaSplitter, CompoundSplit
|
|
|
|
|
| class PaniniTokenizerV3(PreTrainedTokenizer if HAS_TRANSFORMERS else object):
|
| """
|
| Morphology-aware Sanskrit tokenizer using Vidyut.
|
|
|
| Pipeline:
|
| 1. Vidyut analysis → extract morphological structure
|
| 2. Compound splitting → split at samāsa boundaries
|
| 3. Vibhakti separation → separate inflection from stem
|
| 4. Dynamic vocab → Kosha-backed vocabulary
|
| """
|
|
|
|
|
| vocab_files_names = {"vocab_file": "vocab.json"}
|
| model_input_names = ["input_ids", "attention_mask"]
|
|
|
| def __init__(
|
| self,
|
| vocab_file: Optional[str] = None,
|
| unk_token: str = "<unk>",
|
| bos_token: str = "<s>",
|
| eos_token: str = "</s>",
|
| pad_token: str = "<pad>",
|
| sep_token: str = "<sep>",
|
| cls_token: str = "<cls>",
|
| mask_token: str = "<mask>",
|
| add_prefix_space: bool = True,
|
| freeze_vocab: bool = False,
|
| **kwargs
|
| ):
|
|
|
| self.add_prefix_space = add_prefix_space
|
| self.freeze_vocab = freeze_vocab
|
|
|
|
|
| self.analyzer = VidyutAnalyzer(preload_cache=True)
|
| self.splitter = SamasaSplitter(self.analyzer)
|
|
|
|
|
| self._vocab: Dict[str, int] = {}
|
| self._id_to_token: Dict[int, str] = {}
|
|
|
|
|
| if vocab_file and os.path.exists(vocab_file):
|
| self._load_vocab(vocab_file)
|
| else:
|
| self._build_initial_vocab()
|
|
|
|
|
| if HAS_TRANSFORMERS:
|
| super().__init__(
|
| unk_token=unk_token,
|
| bos_token=bos_token,
|
| eos_token=eos_token,
|
| pad_token=pad_token,
|
| sep_token=sep_token,
|
| cls_token=cls_token,
|
| mask_token=mask_token,
|
| add_prefix_space=add_prefix_space,
|
| **kwargs
|
| )
|
|
|
| def _build_initial_vocab(self):
|
| """Build initial vocabulary with special tokens and common morphemes."""
|
|
|
| special = ["<unk>", "<s>", "</s>", "<pad>", "<sep>", "<cls>", "<mask>", "▁"]
|
| for i, tok in enumerate(special):
|
| self._vocab[tok] = i
|
| self._id_to_token[i] = tok
|
|
|
|
|
| vibhaktis = [
|
| "H", "m", "am", "At", "Aya", "asya", "e", "O", "ayoH",
|
| "AH", "An", "eByo", "EH", "ezu", "ena", "ABym",
|
| "A", "AyAH", "AyAm", "ayA", "Ani", "AnAm",
|
| "sya", "ya", "aH", "iH", "uH",
|
| ]
|
|
|
|
|
| pratyayas = [
|
| "tvA", "ya", "ta", "tavat", "at", "Ana", "tum",
|
| "ti", "ana", "aka", "in", "tf", "tva", "tA",
|
| "maya", "vat", "mat", "ika", "Iya",
|
| ]
|
|
|
|
|
| upasargas = [
|
| "pra", "parA", "apa", "sam", "anu", "ava", "nis", "nir",
|
| "vi", "A", "ni", "aDi", "api", "ati", "su", "ut", "ud",
|
| "aBi", "prati", "pari", "upa", "dur", "dus",
|
| ]
|
|
|
|
|
| next_id = len(self._vocab)
|
| for morpheme_list in [vibhaktis, pratyayas, upasargas]:
|
| for m in morpheme_list:
|
| if m not in self._vocab:
|
| self._vocab[m] = next_id
|
| self._id_to_token[next_id] = m
|
| next_id += 1
|
|
|
| spaced = "▁" + m
|
| if spaced not in self._vocab:
|
| self._vocab[spaced] = next_id
|
| self._id_to_token[next_id] = spaced
|
| next_id += 1
|
|
|
| print(f" PaniniTokenizerV3: Initial vocab size = {len(self._vocab)}")
|
|
|
| def _load_vocab(self, vocab_file: str):
|
| """Load vocabulary from JSON file."""
|
| with open(vocab_file, "r", encoding="utf-8") as f:
|
| self._vocab = json.load(f)
|
| self._id_to_token = {v: k for k, v in self._vocab.items()}
|
| print(f" PaniniTokenizerV3: Loaded vocab size = {len(self._vocab)}")
|
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| """Save vocabulary to directory."""
|
| if not os.path.isdir(save_directory):
|
| os.makedirs(save_directory, exist_ok=True)
|
|
|
| vocab_file = os.path.join(
|
| save_directory,
|
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
|
| )
|
|
|
| with open(vocab_file, "w", encoding="utf-8") as f:
|
| json.dump(self._vocab, f, ensure_ascii=False, indent=2)
|
|
|
| return (vocab_file,)
|
|
|
| def save_pretrained(self, save_directory: str, **kwargs):
|
| """
|
| Save the tokenizer to a directory (HuggingFace compatible).
|
| Creates: vocab.json, tokenizer_config.json, special_tokens_map.json
|
| """
|
| os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
| vocab_file = os.path.join(save_directory, "vocab.json")
|
| with open(vocab_file, "w", encoding="utf-8") as f:
|
| json.dump(self._vocab, f, ensure_ascii=False, indent=2)
|
|
|
|
|
| config = {
|
| "tokenizer_class": "PaniniTokenizerV3",
|
| "vocab_size": len(self._vocab),
|
| "unk_token": "<unk>",
|
| "bos_token": "<s>",
|
| "eos_token": "</s>",
|
| "pad_token": "<pad>",
|
| "sep_token": "<sep>",
|
| "cls_token": "<cls>",
|
| "mask_token": "<mask>",
|
| "add_prefix_space": self.add_prefix_space,
|
| "freeze_vocab": self.freeze_vocab,
|
| }
|
| config_file = os.path.join(save_directory, "tokenizer_config.json")
|
| with open(config_file, "w", encoding="utf-8") as f:
|
| json.dump(config, f, ensure_ascii=False, indent=2)
|
|
|
|
|
| special_tokens = {
|
| "unk_token": "<unk>",
|
| "bos_token": "<s>",
|
| "eos_token": "</s>",
|
| "pad_token": "<pad>",
|
| "sep_token": "<sep>",
|
| "cls_token": "<cls>",
|
| "mask_token": "<mask>",
|
| }
|
| special_file = os.path.join(save_directory, "special_tokens_map.json")
|
| with open(special_file, "w", encoding="utf-8") as f:
|
| json.dump(special_tokens, f, ensure_ascii=False, indent=2)
|
|
|
| print(f"✅ Saved PaniniTokenizerV3 to {save_directory}/")
|
| print(f" vocab.json: {len(self._vocab)} tokens")
|
| return save_directory
|
|
|
| @classmethod
|
| def from_pretrained(cls, pretrained_path: str, **kwargs):
|
| """
|
| Load a tokenizer from a directory (HuggingFace compatible).
|
| """
|
| vocab_file = os.path.join(pretrained_path, "vocab.json")
|
| config_file = os.path.join(pretrained_path, "tokenizer_config.json")
|
|
|
|
|
| config = {}
|
| if os.path.exists(config_file):
|
| with open(config_file, "r", encoding="utf-8") as f:
|
| config = json.load(f)
|
|
|
|
|
| tokenizer = cls(
|
| vocab_file=vocab_file,
|
| freeze_vocab=config.get("freeze_vocab", True),
|
| add_prefix_space=config.get("add_prefix_space", True),
|
| **kwargs
|
| )
|
|
|
| print(f"✅ Loaded PaniniTokenizerV3 from {pretrained_path}/")
|
| print(f" vocab.json: {len(tokenizer._vocab)} tokens")
|
| return tokenizer
|
|
|
| @property
|
| def vocab_size(self) -> int:
|
| return len(self._vocab)
|
|
|
| def get_vocab(self) -> Dict[str, int]:
|
| return dict(self._vocab)
|
|
|
| def _add_to_vocab(self, token: str) -> int:
|
| """Dynamically add a token to vocabulary."""
|
| if token in self._vocab:
|
| return self._vocab[token]
|
|
|
| new_id = len(self._vocab)
|
| self._vocab[token] = new_id
|
| self._id_to_token[new_id] = token
|
| return new_id
|
|
|
| def _convert_token_to_id(self, token: str) -> int:
|
| """Convert token to ID, adding to vocab if needed (dynamic vocab)."""
|
| if token in self._vocab:
|
| return self._vocab[token]
|
|
|
|
|
| if self.freeze_vocab:
|
| return self._vocab.get("<unk>", 0)
|
|
|
|
|
| return self._add_to_vocab(token)
|
|
|
| def _convert_id_to_token(self, index: int) -> str:
|
| """Convert ID to token."""
|
| return self._id_to_token.get(index, self.unk_token)
|
|
|
| def _tokenize_word(self, word: str) -> List[str]:
|
| """
|
| Tokenize a single word using morphological analysis.
|
|
|
| New Grammar-Safe Pipeline (Rule A, B, C):
|
| 1. Parse with Vidyut (Collapse spines)
|
| 2. Iterative Samasa Splitting
|
| 3. No SP fallback for valid stems
|
| """
|
| if not word:
|
| return []
|
|
|
|
|
|
|
| if self.analyzer._is_verb_form(word):
|
| return ["▁" + word]
|
|
|
|
|
| parse = self.analyzer.get_best_parse(word)
|
| stem = parse.token_form()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| final_tokens = []
|
|
|
|
|
|
|
| current_components = [stem]
|
|
|
|
|
| def merge_known_compounds(parts):
|
| """Merge adjacent parts that together form a known compound."""
|
| merged = []
|
| i = 0
|
| while i < len(parts):
|
| if i + 1 < len(parts):
|
|
|
| left = parts[i]
|
| right = parts[i + 1]
|
|
|
| if left.endswith('A'):
|
| candidate = left[:-1] + 'a' + right
|
| else:
|
| candidate = left + right
|
|
|
|
|
|
|
| candidates = [candidate]
|
| if left.endswith('A') and not right.startswith(('a', 'A', 'i', 'I', 'u', 'U', 'e', 'E', 'o', 'O')):
|
|
|
| candidates.append(left + 'A' + right)
|
| if self.analyzer._in_kosha(candidate):
|
| merged.append(candidate)
|
| i += 2
|
| continue
|
|
|
| atman_candidate = left[:-1] + 'an' if left.endswith('A') else left + 'an'
|
| if right.endswith('A'):
|
| atman_full = atman_candidate + right[:-1] + 'a'
|
| else:
|
| atman_full = atman_candidate
|
| if len(atman_candidate) > 3 and self.analyzer._in_kosha(atman_candidate):
|
| merged.append(atman_candidate)
|
|
|
| merged.append(right)
|
| i += 2
|
| continue
|
| merged.append(parts[i])
|
| i += 1
|
| return merged
|
|
|
|
|
| MAX_PASSES = 6
|
| for _ in range(MAX_PASSES):
|
| new_components = []
|
| changed = False
|
|
|
|
|
| for comp in current_components:
|
|
|
| split_res = self.splitter.split(comp)
|
| if split_res.is_compound and len(split_res.components) > 1:
|
| new_components.extend(split_res.components)
|
| changed = True
|
| else:
|
|
|
|
|
|
|
|
|
| if (len(comp) > 3 and
|
| comp[0] not in 'aAiIuUeEoO' and
|
| not self.splitter._is_valid_stem(comp)):
|
| restored = 'A' + comp
|
| restored_res = self.splitter.split(restored)
|
| if restored_res.is_compound and len(restored_res.components) > 1:
|
|
|
| new_components.extend(restored_res.components)
|
| changed = True
|
| continue
|
| new_components.append(comp)
|
|
|
|
|
| merged_components = merge_known_compounds(new_components)
|
| if len(merged_components) != len(new_components):
|
| changed = True
|
|
|
| if not changed:
|
| break
|
| current_components = merged_components
|
|
|
|
|
| for i, comp in enumerate(current_components):
|
|
|
|
|
|
|
|
|
| prefix = "▁" if i == 0 else ""
|
|
|
| if self.analyzer._in_kosha(comp):
|
|
|
| final_tokens.append(prefix + comp)
|
| else:
|
|
|
|
|
| final_tokens.append(prefix + comp)
|
|
|
|
|
|
|
| if parse.vibhakti and final_tokens:
|
| last_token = final_tokens[-1].lstrip('▁')
|
|
|
| if not last_token.endswith(parse.vibhakti):
|
| final_tokens.append(parse.vibhakti)
|
|
|
| return final_tokens
|
|
|
| def tokenize(self, text: str, **kwargs) -> List[str]:
|
| """
|
| Tokenize text into morphological tokens.
|
|
|
| This is the main entry point for tokenization.
|
| """
|
| if not text:
|
| return []
|
|
|
|
|
| words = text.split()
|
|
|
| all_tokens = []
|
| for i, word in enumerate(words):
|
| word_tokens = self._tokenize_word(word)
|
| all_tokens.extend(word_tokens)
|
|
|
| return all_tokens
|
|
|
| def _encode_impl(self, text: str) -> List[int]:
|
| """Internal encode implementation."""
|
| tokens = self.tokenize(text)
|
| return [self._convert_token_to_id(t) for t in tokens]
|
|
|
| def encode(
|
| self,
|
| text: Union[str, List[str]],
|
| add_special_tokens: bool = True,
|
| **kwargs
|
| ) -> List[int]:
|
| """Encode text to token IDs."""
|
| if isinstance(text, list):
|
| text = " ".join(text)
|
|
|
| ids = self._encode_impl(text)
|
|
|
| if add_special_tokens:
|
| bos_id = self._vocab.get("<s>", 1)
|
| eos_id = self._vocab.get("</s>", 2)
|
| ids = [bos_id] + ids + [eos_id]
|
|
|
| return ids
|
|
|
| def decode(
|
| self,
|
| token_ids: List[int],
|
| skip_special_tokens: bool = True,
|
| **kwargs
|
| ) -> str:
|
| """Decode token IDs back to text."""
|
| special_ids = {0, 1, 2, 3, 4, 5, 6}
|
|
|
| tokens = []
|
| for tid in token_ids:
|
| if skip_special_tokens and tid in special_ids:
|
| continue
|
| token = self._convert_id_to_token(tid)
|
| tokens.append(token)
|
|
|
|
|
| text = ""
|
| for t in tokens:
|
| if t.startswith("▁"):
|
| text += " " + t[1:]
|
| else:
|
| text += t
|
|
|
| return text.strip()
|
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| """Convert token list back to string."""
|
| text = ""
|
| for t in tokens:
|
| if t.startswith("▁"):
|
| text += " " + t[1:]
|
| else:
|
| text += t
|
| return text.strip()
|
|
|
|
|
|
|
| def create_tokenizer(vocab_path: Optional[str] = None) -> PaniniTokenizerV3:
|
| """Create a PaniniTokenizerV3 instance."""
|
| return PaniniTokenizerV3(vocab_file=vocab_path)
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| print("\n" + "="*60)
|
| print(" Testing PaniniTokenizerV3")
|
| print("="*60)
|
|
|
| tokenizer = PaniniTokenizerV3()
|
|
|
| test_cases = [
|
| "rAmaH gacCati",
|
| "hfdpadmagataM paramAtma",
|
| "sopAdhikapratyagAtmAbhAsabhedAbhedavicAraH",
|
| ]
|
|
|
| for text in test_cases:
|
| tokens = tokenizer.tokenize(text)
|
| ids = tokenizer.encode(text, add_special_tokens=False)
|
| decoded = tokenizer.decode(ids)
|
|
|
| print(f"\n Input: {text}")
|
| print(f" Tokens: {tokens}")
|
| print(f" IDs: {ids[:10]}..." if len(ids) > 10 else f" IDs: {ids}")
|
| print(f" Decoded: {decoded}")
|
|
|