Token Classification
Transformers
French
German
ocr_qa_assessment
ocr
bloomfilter
unigram
impresso
quality-assessment
v1.0.6
custom_code
Instructions to use impresso-project/ocr-quality-assessor-unigram-light with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use impresso-project/ocr-quality-assessor-unigram-light with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="impresso-project/ocr-quality-assessor-unigram-light", trust_remote_code=True)# Load model directly from transformers import AutoModelForTokenClassification model = AutoModelForTokenClassification.from_pretrained("impresso-project/ocr-quality-assessor-unigram-light", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| import logging | |
| import floret | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from .configuration_ocrqa import ImpressoConfig | |
| logger = logging.getLogger(__name__) | |
| from pybloomfilter import BloomFilter | |
| from transformers import pipeline | |
| import unicodedata | |
| from typing import Optional | |
| QUOTES_PUNCT = "„•<>!\"#%&'’" | |
| ASCII_PUNCT = "()*,./:;?" | |
| BRACKETS_SPECIAL = "[]\\~_{}" | |
| UNICODE_PUNCT = "\xa1\xab\xb7\xbb\xbf" | |
| DASH_CARET = "—^`" | |
| SPECIAL_SYMBOLS = "¦§£=" | |
| HYPHEN = "-" | |
| DIGITS = "0123456789" | |
| NORMALIZATION_TABLE = str.maketrans( | |
| { | |
| char: " " | |
| for char in ( | |
| QUOTES_PUNCT | |
| + ASCII_PUNCT | |
| + BRACKETS_SPECIAL | |
| + UNICODE_PUNCT | |
| + DASH_CARET | |
| + SPECIAL_SYMBOLS | |
| + HYPHEN | |
| ) | |
| } | |
| | {char: "0" for char in DIGITS} | |
| ) | |
| def normalize_text(s: str, unicode_normalize: Optional[str] = "NFKC") -> str: | |
| """Normalize text by replacing punctuation with spaces and digits with '0'.""" | |
| if unicode_normalize: | |
| s = unicodedata.normalize(unicode_normalize, s).lower() | |
| return s.translate(NORMALIZATION_TABLE) | |
| def filter_text(text: str, bloom_filter: BloomFilter): | |
| knowns = set() | |
| unknowns = set() | |
| # Normalize and tokenize text | |
| normalized_text = normalize_text(text) | |
| tokens = normalized_text.split() | |
| # Check tokens against the bloom filter | |
| for token in tokens: | |
| if token in bloom_filter: | |
| # print(f"'{token}' is in the bloom filter.") | |
| knowns.add(token) | |
| else: | |
| # print(f"'{token}' is NOT in the bloom filter.") | |
| unknowns.add(token) | |
| result = {"known": knowns, "unknown": unknowns} | |
| return result | |
| class QAAssessmentModel(PreTrainedModel): | |
| config_class = ImpressoConfig | |
| def get_bloomfilter(self, model_id: str, filename: str): | |
| return BloomFilter.open(hf_hub_download(repo_id=model_id, filename=filename)) | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # Dummy for device checking | |
| self.dummy_param = nn.Parameter(torch.zeros(1)) | |
| bin_filenames = {"en": self.config.config.filename["en"], | |
| "de": self.config.config.filename["de"], | |
| "fr": self.config.config.filename["fr"], | |
| "other": self.config.config.filename["other"]} | |
| self.ocrqa_assessors = {} | |
| # model_filename = self.config.config.model[lang] | |
| for lang in bin_filenames.keys(): | |
| model_filename = self.config.config.filename[lang] | |
| print(f"Loading model for {lang}: {model_filename}") | |
| # if not os.path.exists(model_filename): | |
| # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...") | |
| self.ocrqa_assessors[lang] = self.get_bloomfilter(model_id=self.config.config._name_or_path, | |
| filename=model_filename) | |
| # print(self.ocrqa_assessors) | |
| self.lang_pipeline = pipeline("langident", | |
| model="impresso-project/language-identifier", | |
| trust_remote_code=True, | |
| device="cpu") | |
| def forward(self, input_ids, **kwargs): | |
| if isinstance(input_ids, str): | |
| # If the input is a single string, make it a list for floret | |
| texts = [input_ids] | |
| elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): | |
| texts = input_ids | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(input_ids)}") | |
| predictions, probabilities = [], [] | |
| for text in texts: | |
| langs = self.lang_pipeline(input_ids) | |
| # [{'label': 'fr', 'confidence': 99.87}] | |
| if len(langs) > 0: | |
| # print(f"Detected languages: {langs}") | |
| lang = langs['language'] | |
| logger.info(f"Detected language: {lang}") | |
| else: | |
| lang = "other" | |
| logger.warning("Language detection failed, using 'other' as default.") | |
| if lang not in self.ocrqa_assessors: | |
| logger.warning(f"Language '{lang}' not found in bin_filename, using 'other' as default.") | |
| lang = "other" | |
| # Process the text using the selected filter | |
| result = filter_text(text, self.ocrqa_assessors[lang]) | |
| known_count = len(result["known"]) | |
| unknown_count = len(result["unknown"]) | |
| # Compute quality score percentage | |
| score = (known_count / (known_count + unknown_count + 0.000001)) # * 100 | |
| predictions.append(score) | |
| return predictions | |
| def device(self): | |
| return next(self.parameters()).device | |
| def from_pretrained(cls, *args, **kwargs): | |
| # print("Ignoring weights and using custom initialization.") | |
| # Manually create the config | |
| config = ImpressoConfig(**kwargs) | |
| # Pass the manually created config to the class | |
| model = cls(config) | |
| return model | |