import gradio as gr import torch from transformers import AutoTokenizer, AutoConfig from configuration_stacked import ImpressoConfig from modeling_stacked import ExtendedMultitaskTimeModelForTokenClassification import numpy as np # Define the model name MODEL_NAME = "impresso-project/ner-stacked-bert-multilingual-v1.1.0" print("Loading tokenizer...") ner_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) print("Loading model configuration...") config = ImpressoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True) print("Loading model...") model = ExtendedMultitaskTimeModelForTokenClassification.from_pretrained( MODEL_NAME, config=config, trust_remote_code=True, ) model.eval() # Load label maps label_map = config.label_map print(f"Label map: {label_map}") # Create reverse label maps for decoding id2label_coarse = {v: k for k, v in label_map["NE-COARSE-LIT"].items()} id2label_fine = {v: k for k, v in label_map["NE-FINE-COMP"].items()} print(f"Model loaded successfully!") print(f"Using model: {MODEL_NAME}") print(f"Using tokenizer: {MODEL_NAME}") def format_entities_as_html(entities): excluded_keys = {"start", "end", "index"} # Keys to exclude from the output html_output = "
" for entity in entities: html_output += ( "
" # Each entity in a separate div ) # Dynamically add all fields except the excluded ones for key, value in entity.items(): if key not in excluded_keys: if isinstance(value, float): # Format score if it's a float html_output += ( f"{key.capitalize()}: {value:.2f}
" ) else: html_output += f"{key.capitalize()}: {value}
" html_output += "
" html_output += "
" return html_output def predict_entities(text): """Run NER prediction on text and return entities in the expected format.""" # Tokenize input inputs = ner_tokenizer( text, return_tensors="pt", truncation=True, max_length=512, return_offsets_mapping=True, ) offset_mapping = inputs.pop("offset_mapping")[0] # Run inference with torch.no_grad(): outputs = model(**inputs) # Get predictions for both tasks logits_coarse = outputs.logits["NE-COARSE-LIT"] logits_fine = outputs.logits["NE-FINE-COMP"] predictions_coarse = torch.argmax(logits_coarse, dim=-1)[0] predictions_fine = torch.argmax(logits_fine, dim=-1)[0] # Get scores (confidence) scores_coarse = torch.softmax(logits_coarse, dim=-1)[0] scores_fine = torch.softmax(logits_fine, dim=-1)[0] # Convert predictions to labels entities = [] current_entity = None for idx, (pred_coarse, pred_fine, score_c, score_f) in enumerate( zip(predictions_coarse, predictions_fine, scores_coarse, scores_fine) ): label_coarse = id2label_coarse[pred_coarse.item()] label_fine = id2label_fine[pred_fine.item()] # Skip special tokens and 'O' tags if idx == 0 or idx >= len(offset_mapping): continue start, end = offset_mapping[idx] if start == end: # Skip special tokens continue # Process coarse-grained entities (loc, pers, org) if label_coarse.startswith("B-"): # Save previous entity if current_entity: entities.append(current_entity) entity_type = label_coarse[2:] # Remove 'B-' prefix current_entity = { "type": entity_type, "surface": text[start:end], "lOffset": start.item(), "rOffset": end.item(), "confidence_ner": float(score_c[pred_coarse].item() * 100), } # Add fine-grained info if available if not label_fine.startswith("O"): fine_type = label_fine[2:] if label_fine.startswith(("B-", "I-")) else label_fine fine_parts = fine_type.split(".") if len(fine_parts) == 2: attr_name = fine_parts[1] # e.g., 'name', 'title', 'function' current_entity[attr_name] = text[start:end] elif label_coarse.startswith("I-") and current_entity: # Continue current entity entity_type = label_coarse[2:] if entity_type == current_entity["type"]: current_entity["rOffset"] = end.item() current_entity["surface"] = text[current_entity["lOffset"]:end] # Extend fine-grained attributes if not label_fine.startswith("O"): fine_type = label_fine[2:] if label_fine.startswith(("B-", "I-")) else label_fine fine_parts = fine_type.split(".") if len(fine_parts) == 2: attr_name = fine_parts[1] if attr_name in current_entity: current_entity[attr_name] = text[current_entity["lOffset"]:end] else: current_entity[attr_name] = text[start:end] else: # End of entity if current_entity: entities.append(current_entity) current_entity = None # Don't forget the last entity if current_entity: entities.append(current_entity) return entities # Function to process the sentence and extract entities def extract_entities(sentence): if not sentence or not sentence.strip(): return {"text": "", "entities": []} results = predict_entities(sentence) # Debugging the result format print(f"NER results: {results}") entities = [] seen_spans = set() # Track the spans we have already added to avoid overlaps for entity in results: entity_span = (entity["lOffset"], entity["rOffset"]) # Only add non-overlapping entities if entity_span not in seen_spans: seen_spans.add(entity_span) label = f"{entity['type']}" if "title" in entity: label += f" - Title: {entity['title']}" if "name" in entity: label += f" - Name: {entity['name']}" if "function" in entity: label += f" - Function: {entity['function']}" entity["entity"] = label entity["start"] = entity["lOffset"] entity["end"] = entity["rOffset"] entities.append(entity) print(f"Entities: {entities}") return {"text": sentence, "entities": entities} # Create Gradio interface def ner_app_interface(): input_sentence = gr.Textbox( lines=5, label="Input Sentence", placeholder="Enter a sentence for NER:" ) output_entities = gr.HTML(label="Extracted Entities") # Interface definition interface = gr.Interface( fn=extract_entities, inputs=input_sentence, outputs=[gr.HighlightedText(label="Text with mentions")], # outputs=output_entities, title="Named Entity Recognition", description="Enter a sentence to extract named entities using the NER model from the Impresso project.", examples=[ [ "Des chercheurs de l'Université de Cambridge ont développé une nouvelle technique de calcul quantique qui promet d'augmenter exponentiellement les vitesses de calcul." ], [ "Le rapport complet sur ces découvertes a été publié dans la prestigieuse revue 'Nature Physics'. (Reuters)" ], ["In the year 1789, the Estates-General was convened in France."], [ "The event was held at the Palace of Versailles, a symbol of French monarchy." ], [ "At Versailles, Marie Antoinette, the Queen of France, was involved in discussions." ], [ "Maximilien Robespierre, a leading member of the National Assembly, also participated." ], [ "Jean-Jacques Rousseau, the famous philosopher, was a significant figure in the debate." ], [ "Another important participant was Charles de Talleyrand, the Bishop of Autun." ], [ "Meanwhile, across the Atlantic, George Washington, the first President of the United States, was shaping policies." ], [ "Thomas Jefferson, the nation's Secretary of State, played a key role in drafting policies for the new American government." ], ], live=False, ) interface.launch(share=True) # Run the app if __name__ == "__main__": ner_app_interface()