PyTorch
gpt2
gpt2-10M-syllitok-eng / tokenizer.py
achille-fusco's picture
Update tokenizer.py
6966297 verified
Raw
History Blame
4.32 kB
# tokenizer.py
from typing import List, Tuple, Optional, Union, Dict, Any
import os, json
from transformers import PreTrainedTokenizerFast
from .syllabic_pretokenizer import (
Preprocessor,
preprocess_and_segment_with_alignment,
remap_offsets_to_raw,
)
class SyllabicTokenizerWrapper(PreTrainedTokenizerFast):
"""
A HF-compatible tokenizer that FIRST applies your syllabic segmentation,
then delegates to the underlying fast tokenizer from tokenizer.json.
Required files in the same directory:
- tokenizer.json, tokenizer_config.json, special_tokens_map.json
- preprocess_config.json (with the Preprocessor flags)
"""
slow_tokenizer_class = None # required by HF when no slow version exists
def __init__(self, *args, **kwargs):
# Ensure we load the fast tokenizer directly (no slow->fast conversion).
name_or_path = kwargs.get("name_or_path") or (args[0] if args and isinstance(args[0], str) else None)
if "tokenizer_file" not in kwargs and name_or_path:
tf = os.path.join(name_or_path, "tokenizer.json")
if not os.path.isfile(tf):
raise FileNotFoundError(f"Expected tokenizer.json at {tf}")
kwargs["tokenizer_file"] = tf
super().__init__(*args, **kwargs)
# Resolve the directory where the artifacts live
hf_dir = kwargs.get("name_or_path", getattr(self, "name_or_path", None)) \
or os.path.dirname(getattr(self, "tokenizer_file", "")) or "."
revision = kwargs.get("revision", None)
# Load preprocessing flags saved during training
cfg_path = os.path.join(hf_dir, "preprocess_config.json", revision)
if not os.path.exists(cfg_path):
raise FileNotFoundError(
f"Missing preprocess_config.json in {hf_dir}. "
f"Did you save it during tokenizer training?"
)
with open(cfg_path, "r", encoding="utf-8") as f:
self.pre_cfg = json.load(f)
self.preprocessor = Preprocessor(**self.pre_cfg)
'''
cfg = {"lowercase": True, "space_punct": True}
ppath = _get_repo_file(repo_id_or_path, "paradigms.json", revision)
self.paradigms, self.paradigms_meta = _load_paradigms_any(ppath)
cpath = _get_repo_file(repo_id_or_path, "preprocess_config.json", revision)
cfg_path_exists = os.path.exists(cpath) # when local path returned
with open(cpath, "r", encoding="utf-8") as f:
cfg.update(json.load(f))
self.segmenter = ParadigmFinderSegmenter(
paradigms=self.paradigms,
lowercase=cfg.get("lowercase", True),
space_punct=cfg.get("space_punct", True),
)'''
# --- core segmentation helpers ---
def _segment_one(self, text: str) -> Tuple[str, List[Optional[int]]]:
return preprocess_and_segment_with_alignment(text, self.preprocessor)
# --- public API overrides ---
def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, Any]:
"""
Segments -> calls the fast tokenizer (super) with segmented text.
"""
want_offset = kwargs.pop("return_offset_mapping", False)
if isinstance(text, str):
seg, seg_map = self._segment_one(text)
enc = super().__call__(seg, **kwargs)
return enc
elif isinstance(text, (list, tuple)):
segs = []
for t in text:
seg, maps = self._segment_one(t)
segs.append(seg)
enc = super().__call__(segs, **kwargs)
return enc
else:
raise TypeError("text must be str or List[str]")
def tokenize(self, text: Union[str, List[str]], **kwargs):
"""
Also intercept manual .tokenize() to ensure segmentation happens first.
"""
if isinstance(text, str):
seg, _ = self._segment_one(text)
return super().tokenize(seg, **kwargs)
elif isinstance(text, list):
out: List[str] = []
for t in text:
seg, _ = self._segment_one(t)
out.extend(super().tokenize(seg, **kwargs))
return out
else:
raise TypeError("tokenize() expects str or List[str]")