linyueqian commited on
Commit
a9d2a4d
·
verified ·
1 Parent(s): 513c9ed

fix: use hasattr for BatchEncoding compatibility

Browse files
Files changed (1) hide show
  1. tokenization_voxcpm2.py +2 -2
tokenization_voxcpm2.py CHANGED
@@ -57,7 +57,7 @@ class VoxCPM2Tokenizer(LlamaTokenizerFast):
57
 
58
  def __call__(self, text, *args, **kwargs):
59
  result = super().__call__(text, *args, **kwargs)
60
- if isinstance(result, dict) and "input_ids" in result:
61
  ids = result["input_ids"]
62
  if isinstance(ids, list) and ids and isinstance(ids[0], list):
63
  result["input_ids"] = [self._expand_ids(x) for x in ids]
@@ -65,7 +65,7 @@ class VoxCPM2Tokenizer(LlamaTokenizerFast):
65
  result["attention_mask"] = [
66
  [1] * len(x) for x in result["input_ids"]
67
  ]
68
- else:
69
  result["input_ids"] = self._expand_ids(ids)
70
  if "attention_mask" in result:
71
  result["attention_mask"] = [1] * len(result["input_ids"])
 
57
 
58
  def __call__(self, text, *args, **kwargs):
59
  result = super().__call__(text, *args, **kwargs)
60
+ if hasattr(result, "input_ids"):
61
  ids = result["input_ids"]
62
  if isinstance(ids, list) and ids and isinstance(ids[0], list):
63
  result["input_ids"] = [self._expand_ids(x) for x in ids]
 
65
  result["attention_mask"] = [
66
  [1] * len(x) for x in result["input_ids"]
67
  ]
68
+ elif isinstance(ids, list):
69
  result["input_ids"] = self._expand_ids(ids)
70
  if "attention_mask" in result:
71
  result["attention_mask"] = [1] * len(result["input_ids"])