Khmer Word Tokenizer β€” BiLSTM-CRF

A character-level BiLSTM + CRF model for Khmer (αž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžš) word segmentation.

Model Architecture

  • Type: Bidirectional LSTM + CRF (Conditional Random Field)
  • Input: Character sequence
  • Output: BIO tags (B=word begin, I=word inside, O=space/punct)
  • Embedding dim: 128
  • LSTM hidden dim: 256 (x2 bidirectional)
  • LSTM layers: 2
  • Vocab size: 364 characters

Training Data

Dataset Sentences
ye-kyaw-thu/khPOS (corpus-draft-ver-1.0) ~14,000
Asian-Language-Treebank/ALT-Parallel-Corpus ~20,000
phylypo/segmentation-crf-khmer (kh_data_1000) ~9,700

Usage

# ── Step 0: Install dependencies ────────────────────────────
!pip install -q pytorch-crf huggingface_hub

# ── Step 1: Imports ─────────────────────────────────────────
import torch
import torch.nn as nn
from torchcrf import CRF
import json
from huggingface_hub import hf_hub_download

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ── Step 2: Define the model class ──────────────────────────
class KhmerBiLSTMCRF(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim,
                 num_layers, num_tags, pad_idx, dropout=0.3):
        super().__init__()
        self.pad_idx   = pad_idx
        self.num_tags  = num_tags
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=embed_dim, hidden_size=hidden_dim,
            num_layers=num_layers, batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
        )
        self.dropout    = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim * 2, num_tags)
        self.crf        = CRF(num_tags, batch_first=True)

    def _get_emissions(self, chars, mask):
        emb     = self.dropout(self.embedding(chars))
        lengths = mask.sum(dim=1).cpu()
        packed  = nn.utils.rnn.pack_padded_sequence(
            emb, lengths, batch_first=True, enforce_sorted=False)
        out, _  = self.lstm(packed)
        out, _  = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        out     = self.dropout(out)
        return self.classifier(out)

    def forward(self, chars, tags, mask):
        return -self.crf(self._get_emissions(chars, mask), tags,
                         mask=mask, reduction='mean')

    @torch.no_grad()
    def decode(self, chars, mask):
        return self.crf.decode(self._get_emissions(chars, mask), mask=mask)


# ── Step 3: Define the tokenizer function ───────────────────
@torch.no_grad()
def tokenize_khmer(text, model, char2idx, idx2tag, device, max_len=512):
    text = text.strip()
    if not text:
        return []
    unk    = char2idx.get('<UNK>', 1)
    words  = []
    chunks = [text[i:i+max_len] for i in range(0, len(text), max_len)]
    for chunk in chunks:
        char_ids  = torch.tensor(
            [[char2idx.get(ch, unk) for ch in chunk]],
            dtype=torch.long, device=device)
        mask      = torch.ones(1, len(chunk), dtype=torch.bool, device=device)
        pred_tags = model.decode(char_ids, mask)[0]
        current   = ''
        for ch, tag_idx in zip(chunk, pred_tags):
            tag = idx2tag.get(tag_idx, 'B')
            if tag == 'B':
                if current: words.append(current)
                current = ch
            elif tag == 'I':
                current += ch
            else:
                if current: words.append(current); current = ''
                if ch.strip(): words.append(ch)
        if current: words.append(current)
    return [w for w in words if w.strip()]


# ── Step 4: Download and load ────────────────────────────────
model_path = hf_hub_download('phonsobon/khmer_word_tokenizer', 'khmer_tokenizer_full.pt')

ckpt = torch.load(model_path, map_location=DEVICE, weights_only=False)
cfg  = ckpt['cfg']

model = KhmerBiLSTMCRF(
    vocab_size = ckpt['vocab_size'],
    embed_dim  = cfg['EMBED_DIM'],
    hidden_dim = cfg['HIDDEN_DIM'],
    num_layers = cfg['NUM_LAYERS'],
    num_tags   = ckpt['num_tags'],
    pad_idx    = ckpt['char2idx']['<PAD>'],
    dropout    = 0.0,  
).to(DEVICE)

model.load_state_dict(ckpt['model_state'])
model.eval()

char2idx = ckpt['char2idx']
idx2tag  = ckpt['idx2tag']   # already {int: str} inside the .pt file

# ── Step 5: Tokenize ─────────────────────────────────────────
words = tokenize_khmer('αžαŸ’αž‰αž»αŸ†αžŸαŸ’αžšαž›αžΆαž‰αŸ‹αž”αŸ’αžšαž‘αŸαžŸαž€αž˜αŸ’αž–αž»αž‡αžΆ', model, char2idx, idx2tag, DEVICE)
print(words)

Citation

If you use this model, please cite the khPOS and ALT datasets.

Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support