PyTorch
gpt2
gpt2-10M-syllitok-eng / syllabic_pretokenizer.py
achille-fusco's picture
Upload folder using huggingface_hub
064b576 verified
Raw
History Blame
10.9 kB
# syllabic_pretokenizer.py
import re
from typing import List, Tuple, Optional
###############################################################################
# Configurable text preprocessor (spacing, lowercase) with alignment tracking
###############################################################################
class Preprocessor:
def __init__(
self,
lowercase: bool = False,
separate_apostrophes: bool = True,
separate_digits: bool = True,
separate_punctuation: bool = True,
):
self.lowercase = lowercase
self.separate_apostrophes = separate_apostrophes
self.separate_digits = separate_digits
self.separate_punctuation = separate_punctuation
# Precompiled regexes
self._apos_re = re.compile(r"[’'`]")
self._punct_re = re.compile(r"[^A-Za-z0-9\s’'`]")
self._digit_re = re.compile(r"\d")
def preprocess_with_alignment(self, line: str) -> Tuple[str, List[Optional[int]]]:
"""
Apply normalization equivalent to your previous Preprocessor, but also
return a map from each character in the preprocessed string back to the
raw string index. Inserted spaces get None in the map.
"""
raw = line
if self.lowercase:
raw = raw.lower()
out_chars: List[str] = []
out2raw: List[Optional[int]] = []
def emit(ch: str, raw_idx: Optional[int]):
out_chars.append(ch)
out2raw.append(raw_idx)
i = 0
n = len(raw)
while i < n:
c = raw[i]
# Decide if we should isolate this char with spaces
isolate = (
(self.separate_apostrophes and self._apos_re.match(c) is not None)
or (self.separate_punctuation and self._punct_re.match(c) is not None)
or (self.separate_digits and self._digit_re.match(c) is not None)
)
if isolate:
emit(" ", None)
emit(c, i)
emit(" ", None)
else:
emit(c, i)
i += 1
# Collapse whitespace to single spaces and strip, keeping alignment
pre, pre_map = _collapse_whitespace_with_map(out_chars, out2raw)
return pre, pre_map
def _collapse_whitespace_with_map(
chars: List[str], idx_map: List[Optional[int]]
) -> Tuple[str, List[Optional[int]]]:
"""
Collapse runs of whitespace to a single space and trim leading and trailing
whitespace, while preserving a per-char map back to the original.
"""
assert len(chars) == len(idx_map)
result_chars: List[str] = []
result_map: List[Optional[int]] = []
def is_space(ch: str) -> bool:
return ch.isspace()
# First pass: collapse runs to single spaces
prev_space = False
for ch, m in zip(chars, idx_map):
if is_space(ch):
if not prev_space:
result_chars.append(" ")
result_map.append(None) # inserted or collapsed space has no single origin
prev_space = True
# else skip extra spaces
else:
result_chars.append(ch)
result_map.append(m)
prev_space = False
# Strip leading space
if result_chars and result_chars[0] == " ":
result_chars.pop(0)
result_map.pop(0)
# Strip trailing space
if result_chars and result_chars[-1] == " ":
result_chars.pop()
result_map.pop()
return "".join(result_chars), result_map
###############################################################################
# Syllabifier (faithful to your logic)
###############################################################################
class Syllabifier:
def __init__(self):
# Case insensitive vowel matcher
self._vowel_re = re.compile(r"[aeiou]+", re.I)
self._consonant_clusters = {
"bl", "br", "cl", "cr", "dr", "fl", "fr", "gl", "gr", "pl", "pr",
"sc", "sk", "sl", "sm", "sn", "sp", "st", "sw", "tr", "tw", "th",
"ch", "sh", "ph", "wh", "sch", "str", "spr", "spl", "scr", "thr"
}
def syllabify_word(self, word: str) -> List[str]:
if len(word) <= 2:
return [word]
syllables: List[str] = []
i = 0
n = len(word)
while i < n:
syllable = ""
# Collect initial consonant cluster
consonant_start = i
while i < n and not self._vowel_re.match(word[i]):
i += 1
if i > consonant_start:
cluster = word[consonant_start:i]
found = False
# Prefer longer known clusters first
for known in sorted(self._consonant_clusters, key=len, reverse=True):
if cluster.startswith(known):
syllable += known
cluster_rest = cluster[len(known):]
if cluster_rest and syllables:
# attach leftover to previous syllable
syllables[-1] += cluster_rest
found = True
break
if not found:
# Split unknown cluster roughly in half
split = len(cluster) // 2
if syllables:
syllables[-1] += cluster[:split]
syllable += cluster[split:]
# Add vowel group
vowel_start = i
while i < n and self._vowel_re.match(word[i]):
i += 1
syllable += word[vowel_start:i]
# Trailing consonants
trailing_start = i
while i < n and not self._vowel_re.match(word[i]):
i += 1
if i > trailing_start:
consonants = word[trailing_start:i]
if i < n:
# One consonant stays with current syllable, rest go to next
syllable += consonants[0]
i = trailing_start + 1
else:
# End of word, keep all
syllable += consonants
if syllable:
syllables.append(syllable)
# Merge very short syllables
merged: List[str] = []
k = 0
while k < len(syllables):
cur = syllables[k]
if len(cur) == 1 and k < len(syllables) - 1:
merged.append(cur + syllables[k + 1])
k += 2
else:
merged.append(cur)
k += 1
return merged if merged else [word]
###############################################################################
# End to end helpers: preprocessing + syllabification + alignment
###############################################################################
def preprocess_and_segment_with_alignment(
text: str,
preprocessor: Preprocessor,
syllabifier: Optional[Syllabifier] = None,
) -> Tuple[str, List[Optional[int]]]:
"""
1) Apply Preprocessor with alignment tracking
2) Split by spaces into tokens
3) Syllabify each token
4) Rejoin syllables with single spaces
5) Return segmented text plus a map to raw indices
"""
pre, pre2raw = preprocessor.preprocess_with_alignment(text)
if syllabifier is None:
syllabifier = Syllabifier()
out_chars: List[str] = []
out_map: List[Optional[int]] = []
i = 0
n = len(pre)
while i < n:
# Skip spaces
while i < n and pre[i].isspace():
i += 1
if i >= n:
break
# Read one token
j = i
while j < n and not pre[j].isspace():
j += 1
token = pre[i:j]
# Map for this token
token_chars = pre[i:j]
token_maps = pre2raw[i:j]
# Syllabify this token
sylls = syllabifier.syllabify_word(token)
# --- Guard: ensure syllables cover the whole token exactly ---
total = sum(len(s) for s in sylls)
if total != len(token):
# Rebuild syllables by slicing the original token according to the
# predicted lengths, and force the last syllable to take any remainder.
rebuilt = []
pos2 = 0
for s in sylls[:-1]:
L = min(len(s), len(token) - pos2)
if L <= 0:
break
rebuilt.append(token[pos2:pos2 + L])
pos2 += L
# Last syllable gets the rest (can be empty if already exact)
if pos2 < len(token):
rebuilt.append(token[pos2:])
sylls = [s for s in rebuilt if s]
# --------------------------------------------------------------
# Emit token chars, inserting a single space between syllables
pos = 0
for s_idx, syl in enumerate(sylls):
L = len(syl)
# Emit the next L chars from token
for k in range(L):
out_chars.append(token_chars[pos + k])
out_map.append(token_maps[pos + k])
pos += L
# Add space between syllables, not after last
if s_idx < len(sylls) - 1:
out_chars.append(" ")
out_map.append(None)
# Add a space between tokens if there is more input
i = j
# Peek ahead to see if there is another token
while i < n and pre[i].isspace():
i += 1
if i < n:
out_chars.append(" ")
out_map.append(None)
# Final collapse (defensive) and strip
segmented, seg_map = _collapse_whitespace_with_map(out_chars, out_map)
return segmented, seg_map
def remap_offsets_to_raw(
offsets: List[Tuple[int, int]],
pre2raw: List[Optional[int]],
) -> List[Tuple[int, int]]:
"""
Translate offsets from the preprocessed or segmented string back to raw
string offsets. If an offset region consists only of inserted spaces,
return a degenerate (0, 0) span.
"""
mapped: List[Tuple[int, int]] = []
L = len(pre2raw)
for s, e in offsets:
s = max(0, min(s, L))
e = max(0, min(e, L))
rs = None
re_ = None
# start: first non-None in [s, e)
t = s
while t < e and rs is None:
if pre2raw[t] is not None:
rs = pre2raw[t]
t += 1
# end: last non-None in [s, e)
t = e - 1
while t >= s and re_ is None:
if pre2raw[t] is not None:
re_ = pre2raw[t] + 1 # exclusive end
t -= 1
if rs is None or re_ is None:
mapped.append((0, 0))
else:
mapped.append((rs, re_))
return mapped