| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import re |
| | from typing import Dict, List, Optional, Tuple, Any |
| |
|
| | from transformers import PreTrainedTokenizer |
| |
|
| |
|
| | class BinaryLLMTokenizer(PreTrainedTokenizer): |
| | model_input_names = ["input_ids", "attention_mask"] |
| |
|
| | TOKEN_RE = re.compile(r"^<U([0-9A-Fa-f]{4})>$") |
| |
|
| | def __init__( |
| | self, |
| | bos_token: str = "<BOS>", |
| | eos_token: str = "<EOS>", |
| | unk_token: str = "<UNK>", |
| | pad_token: Optional[str] = None, |
| | **kwargs: Any, |
| | ): |
| | |
| | self._base_vocab_size = 65536 |
| |
|
| | |
| | self._bos_id = 65536 |
| | self._eos_id = 65537 |
| |
|
| | |
| | self._unk_id = self._eos_id |
| |
|
| | self._bos_str = bos_token |
| | self._eos_str = eos_token |
| | self._unk_str = unk_token |
| | self._pad_str = pad_token |
| |
|
| | super().__init__( |
| | bos_token=bos_token, |
| | eos_token=eos_token, |
| | unk_token=unk_token, |
| | pad_token=pad_token, |
| | **kwargs, |
| | ) |
| |
|
| | |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | |
| | return 65538 |
| |
|
| | def get_vocab(self) -> Dict[str, int]: |
| | |
| | v = { |
| | self._bos_str: self._bos_id, |
| | self._eos_str: self._eos_id, |
| | self._unk_str: self._unk_id, |
| | } |
| | if self.pad_token is not None: |
| | v[self.pad_token] = self._convert_token_to_id(self.pad_token) |
| | return v |
| |
|
| | def _id_to_token_base(self, i: int) -> str: |
| | return f"<U{i:04X}>" |
| |
|
| | |
| |
|
| | def _encode_to_base65536_big_endian(self, text: str) -> List[int]: |
| | b = bytearray(text.encode("utf-8", errors="strict")) |
| | if len(b) == 0: |
| | return [0] |
| |
|
| | out: List[int] = [] |
| | i = 0 |
| | n = len(b) |
| |
|
| | while i + 1 < n: |
| | |
| | out.append((b[i] << 8) | b[i + 1]) |
| | i += 2 |
| |
|
| | if i < n: |
| | |
| | out.append(int(b[i])) |
| |
|
| | return out |
| |
|
| | def _decode_from_base65536_big_endian(self, ids: List[int]) -> str: |
| | bb = bytearray() |
| | for x in ids: |
| | xi = int(x) & 0xFFFFFFFF |
| | if 0 <= xi <= 255: |
| | bb.append(xi) |
| | else: |
| | bb.append((xi >> 8) & 0xFF) |
| | bb.append(xi & 0xFF) |
| | return bytes(bb).decode("utf-8", errors="replace") |
| |
|
| | |
| |
|
| | def _tokenize(self, text: str) -> List[str]: |
| | ids = self._encode_to_base65536_big_endian(text) |
| | return [self._id_to_token_base(i) for i in ids] |
| |
|
| | def _convert_token_to_id(self, token: str) -> int: |
| | if token == self._bos_str: |
| | return self._bos_id |
| | if token == self._eos_str: |
| | return self._eos_id |
| | if token == self._unk_str: |
| | return self._unk_id |
| |
|
| | if self.pad_token is not None and token == self.pad_token: |
| | |
| | if self.pad_token == self._eos_str: |
| | return self._eos_id |
| | return self._eos_id |
| |
|
| | m = self.TOKEN_RE.match(token) |
| | if m: |
| | return int(m.group(1), 16) |
| |
|
| | return self._unk_id |
| |
|
| | def _convert_id_to_token(self, index: int) -> str: |
| | if index == self._bos_id: |
| | return self._bos_str |
| | if index == self._eos_id: |
| | return self._eos_str |
| | if index == self._unk_id: |
| | return self._unk_str |
| |
|
| | if self.pad_token is not None and index == self.pad_token_id: |
| | return self.pad_token |
| |
|
| | if 0 <= index < self._base_vocab_size: |
| | return self._id_to_token_base(index) |
| |
|
| | return self._unk_str |
| |
|
| | def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| | ids: List[int] = [] |
| | for t in tokens: |
| | if t in (self._bos_str, self._eos_str, self._unk_str): |
| | continue |
| | if self.pad_token is not None and t == self.pad_token: |
| | continue |
| | m = self.TOKEN_RE.match(t) |
| | if m: |
| | ids.append(int(m.group(1), 16)) |
| | return self._decode_from_base65536_big_endian(ids) |
| |
|
| | def build_inputs_with_special_tokens( |
| | self, |
| | token_ids_0: List[int], |
| | token_ids_1: Optional[List[int]] = None, |
| | ) -> List[int]: |
| | |
| | |
| | if token_ids_1 is None: |
| | return [self._bos_id] + token_ids_0 + [self._eos_id] |
| | return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id] |
| |
|
| | def get_special_tokens_mask( |
| | self, |
| | token_ids_0: List[int], |
| | token_ids_1: Optional[List[int]] = None, |
| | already_has_special_tokens: bool = False, |
| | ) -> List[int]: |
| | pad_id = self.pad_token_id if self.pad_token is not None else -1 |
| |
|
| | if already_has_special_tokens: |
| | return [ |
| | 1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0 |
| | for t in token_ids_0 |
| | ] |
| |
|
| | if token_ids_1 is None: |
| | return [1] + [0] * len(token_ids_0) + [1] |
| | return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] |
| |
|
| | def create_token_type_ids_from_sequences( |
| | self, |
| | token_ids_0: List[int], |
| | token_ids_1: Optional[List[int]] = None, |
| | ) -> List[int]: |
| | if token_ids_1 is None: |
| | return [0] * (len(token_ids_0) + 2) |
| | return [0] * (len(token_ids_0) + len(token_ids_1) + 3) |
| |
|
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| | if not os.path.isdir(save_directory): |
| | os.makedirs(save_directory, exist_ok=True) |
| |
|
| | name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json" |
| | path = os.path.join(save_directory, name) |
| |
|
| | data = { |
| | "base_vocab_size": 65536, |
| | "vocab_size": 65538, |
| | "bos_token": self._bos_str, |
| | "bos_token_id": self._bos_id, |
| | "eos_token": self._eos_str, |
| | "eos_token_id": self._eos_id, |
| | "unk_token": self._unk_str, |
| | "unk_token_id": self._unk_id, |
| | "pad_token": self.pad_token, |
| | "pad_token_id": self.pad_token_id, |
| | "encoding": "utf-8", |
| | "radix": 65536, |
| | "endianness": "big", |
| | "odd_length_rule": "last_byte_as_single_digit_0_255", |
| | } |
| |
|
| | with open(path, "w", encoding="utf-8") as f: |
| | json.dump(data, f, ensure_ascii=False, indent=2) |
| |
|
| | return (path,) |
| |
|