Wiki-Test2 / tokenization_binaryllm.py
PhysiQuanty's picture
Duplicate from PhysiQuanty/Patent-Test-Radix-65536-AutoTokenizer_FineTune
efda231
#!/usr/bin/env python3
# tokenization_binaryllm.py
# ============================================================
# BinaryLLMTokenizer (AutoTokenizer compatible) — EXACTEMENT la même
# tokenisation/decodage que llmTalk (mode base=65536) + infer_tagged12/11:
#
# - Base: 65536
# - IDs radix: 0..65535
# - BOS: 65536
# - EOS: 65537
# - UNK: alias EOS (65537) (pas de nouveau token dans la base)
# - Encodage: UTF-8 bytes -> digits base65536 BIG-ENDIAN (chunks 2 bytes)
# * si longueur impaire: dernier byte encodé en valeur 0..255 (1 digit)
# - Décodage: digits -> bytes BIG-ENDIAN -> UTF-8 (errors="replace")
#
# Important:
# - build_inputs_with_special_tokens: [BOS] + seq + [EOS] (comme HF classique)
# - encode(..., add_special_tokens=False) renvoie UNIQUEMENT les digits base65536
# - encode(..., add_special_tokens=True) ajoute BOS/EOS via build_inputs...
#
# Ce fichier suffit pour `trust_remote_code=True` côté repo HF.
# ============================================================
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional, Tuple, Any
from transformers import PreTrainedTokenizer
class BinaryLLMTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
TOKEN_RE = re.compile(r"^<U([0-9A-Fa-f]{4})>$")
def __init__(
self,
bos_token: str = "<BOS>",
eos_token: str = "<EOS>",
unk_token: str = "<UNK>",
pad_token: Optional[str] = None,
**kwargs: Any,
):
# radix strict
self._base_vocab_size = 65536
# specials strict: base + 0/1
self._bos_id = 65536
self._eos_id = 65537
# UNK alias EOS (pas de token additionnel)
self._unk_id = self._eos_id
self._bos_str = bos_token
self._eos_str = eos_token
self._unk_str = unk_token
self._pad_str = pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
# ---------- vocab / ids ----------
@property
def vocab_size(self) -> int:
# 65536 + BOS + EOS
return 65538
def get_vocab(self) -> Dict[str, int]:
# IMPORTANT: ne jamais appeler self.unk_token_id ici (boucle)
v = {
self._bos_str: self._bos_id,
self._eos_str: self._eos_id,
self._unk_str: self._unk_id,
}
if self.pad_token is not None:
v[self.pad_token] = self._convert_token_to_id(self.pad_token)
return v
def _id_to_token_base(self, i: int) -> str:
return f"<U{i:04X}>"
# ---------- core encode/decode (même logique que infer_tagged / llmTalk base) ----------
def _encode_to_base65536_big_endian(self, text: str) -> List[int]:
b = bytearray(text.encode("utf-8", errors="strict"))
if len(b) == 0:
return [0]
out: List[int] = []
i = 0
n = len(b)
while i + 1 < n:
# 2 bytes -> 1 digit base65536 big-endian
out.append((b[i] << 8) | b[i + 1])
i += 2
if i < n:
# dernier byte seul -> digit 0..255
out.append(int(b[i]))
return out
def _decode_from_base65536_big_endian(self, ids: List[int]) -> str:
bb = bytearray()
for x in ids:
xi = int(x) & 0xFFFFFFFF
if 0 <= xi <= 255:
bb.append(xi)
else:
bb.append((xi >> 8) & 0xFF)
bb.append(xi & 0xFF)
return bytes(bb).decode("utf-8", errors="replace")
# ---------- HF tokenizer API overrides ----------
def _tokenize(self, text: str) -> List[str]:
ids = self._encode_to_base65536_big_endian(text)
return [self._id_to_token_base(i) for i in ids]
def _convert_token_to_id(self, token: str) -> int:
if token == self._bos_str:
return self._bos_id
if token == self._eos_str:
return self._eos_id
if token == self._unk_str:
return self._unk_id
if self.pad_token is not None and token == self.pad_token:
# pas de PAD dédié => alias EOS (compatible avec ton cadre)
if self.pad_token == self._eos_str:
return self._eos_id
return self._eos_id
m = self.TOKEN_RE.match(token)
if m:
return int(m.group(1), 16)
return self._unk_id
def _convert_id_to_token(self, index: int) -> str:
if index == self._bos_id:
return self._bos_str
if index == self._eos_id:
return self._eos_str
if index == self._unk_id:
return self._unk_str
if self.pad_token is not None and index == self.pad_token_id:
return self.pad_token
if 0 <= index < self._base_vocab_size:
return self._id_to_token_base(index)
return self._unk_str
def convert_tokens_to_string(self, tokens: List[str]) -> str:
ids: List[int] = []
for t in tokens:
if t in (self._bos_str, self._eos_str, self._unk_str):
continue
if self.pad_token is not None and t == self.pad_token:
continue
m = self.TOKEN_RE.match(t)
if m:
ids.append(int(m.group(1), 16))
return self._decode_from_base65536_big_endian(ids)
def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
# HF-style (simple): [BOS] seq [EOS]
# Pair: [BOS] seq0 [EOS] seq1 [EOS]
if token_ids_1 is None:
return [self._bos_id] + token_ids_0 + [self._eos_id]
return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id]
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
pad_id = self.pad_token_id if self.pad_token is not None else -1
if already_has_special_tokens:
return [
1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0
for t in token_ids_0
]
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 2)
return [0] * (len(token_ids_0) + len(token_ids_1) + 3)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json"
path = os.path.join(save_directory, name)
data = {
"base_vocab_size": 65536,
"vocab_size": 65538,
"bos_token": self._bos_str,
"bos_token_id": self._bos_id,
"eos_token": self._eos_str,
"eos_token_id": self._eos_id,
"unk_token": self._unk_str,
"unk_token_id": self._unk_id,
"pad_token": self.pad_token,
"pad_token_id": self.pad_token_id,
"encoding": "utf-8",
"radix": 65536,
"endianness": "big",
"odd_length_rule": "last_byte_as_single_digit_0_255",
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return (path,)