Spaces:
Sleeping
Sleeping
| """ | |
| Entity extraction module using Gemini AI with fallback methods | |
| """ | |
| import re | |
| import logging | |
| from typing import List, Optional | |
| import google.generativeai as genai | |
| from services.appconfig import GEMINI_API_KEY, COMMON_TECH_ENTITIES, MAX_ENTITIES | |
| logger = logging.getLogger(__name__) | |
| class EntityExtractor: | |
| """Extract entities from text using Gemini AI or fallback methods""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| """ | |
| Initialize EntityExtractor | |
| Args: | |
| api_key (str, optional): Gemini API key | |
| """ | |
| self.api_key = api_key or GEMINI_API_KEY | |
| self.model = None | |
| self._setup_gemini() | |
| def _setup_gemini(self) -> None: | |
| """Setup Gemini API""" | |
| if not self.api_key: | |
| logger.warning("No Gemini API key provided, using fallback method") | |
| return | |
| try: | |
| genai.configure(api_key=self.api_key) | |
| self.model = genai.GenerativeModel('gemini-2.0-flash-exp') | |
| logger.info("Gemini API initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini API: {e}") | |
| self.model = None | |
| def extract_with_gemini(self, text: str) -> List[str]: | |
| """ | |
| Extract entities using Gemini AI | |
| Args: | |
| text (str): Input text | |
| Returns: | |
| List[str]: List of extracted entities | |
| """ | |
| if not self.model: | |
| raise Exception("Gemini model not available") | |
| prompt = """ | |
| Extract company names, product names, software names, tool names, and brand names from this text. | |
| Only return names that would have recognizable logos (like Microsoft, Adobe, React, etc.). | |
| Return as a simple list, one name per line, no bullet points or numbers. | |
| Avoid generic terms like "cloud" or "database". | |
| Text: {text} | |
| """.format(text=text) | |
| try: | |
| response = self.model.generate_content(prompt) | |
| if not response.text: | |
| return [] | |
| entities = [ | |
| line.strip() | |
| for line in response.text.strip().split('\n') | |
| if line.strip() and not line.strip().startswith('-') and len(line.strip()) > 1 | |
| ] | |
| # Filter out common words that aren't entities | |
| filtered_entities = [] | |
| for entity in entities: | |
| if self._is_valid_entity(entity): | |
| filtered_entities.append(entity) | |
| logger.info(f"Gemini extracted {len(filtered_entities)} entities") | |
| return filtered_entities[:MAX_ENTITIES] | |
| except Exception as e: | |
| logger.error(f"Gemini extraction failed: {e}") | |
| raise | |
| def extract_with_fallback(self, text: str) -> List[str]: | |
| """ | |
| Extract entities using fallback pattern matching | |
| Args: | |
| text (str): Input text | |
| Returns: | |
| List[str]: List of extracted entities | |
| """ | |
| entities = [] | |
| # Find common tech entities | |
| for tech_entity in COMMON_TECH_ENTITIES: | |
| if tech_entity.lower() in text.lower(): | |
| entities.append(tech_entity) | |
| # Find capitalized words (likely proper nouns) | |
| cap_words = re.findall(r'\b[A-Z][a-zA-Z]{2,}\b', text) | |
| for word in cap_words: | |
| if self._is_valid_entity(word) and word not in entities: | |
| entities.append(word) | |
| # Find words with specific patterns (e.g., Node.js, C++) | |
| pattern_words = re.findall(r'\b[A-Z][a-zA-Z]*\.[a-zA-Z]+\b', text) | |
| for word in pattern_words: | |
| if word not in entities: | |
| entities.append(word) | |
| # Remove duplicates while preserving order | |
| unique_entities = [] | |
| seen = set() | |
| for entity in entities: | |
| if entity.lower() not in seen: | |
| seen.add(entity.lower()) | |
| unique_entities.append(entity) | |
| logger.info(f"Fallback extracted {len(unique_entities)} entities") | |
| return unique_entities[:MAX_ENTITIES] | |
| def _is_valid_entity(self, entity: str) -> bool: | |
| """ | |
| Check if entity is valid for logo extraction | |
| Args: | |
| entity (str): Entity name | |
| Returns: | |
| bool: True if valid entity | |
| """ | |
| # Filter out common words that aren't brand names | |
| invalid_words = { | |
| 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', | |
| 'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before', | |
| 'after', 'above', 'below', 'between', 'among'} | |
| # 'cloud', 'database', | |
| # 'server', 'client', 'user', 'admin', 'data', 'system', 'network', | |
| # 'security', 'management', 'development', 'application', 'platform', | |
| # 'service', 'solution', 'technology', 'software', 'hardware', 'tool' | |
| # } | |
| entity_lower = entity.lower() | |
| # Check length | |
| if len(entity) < 2 or len(entity) > 50: | |
| return False | |
| # Check if it's a common invalid word | |
| if entity_lower in invalid_words: | |
| return False | |
| # Must contain at least one letter | |
| if not re.search(r'[a-zA-Z]', entity): | |
| return False | |
| return True | |
| def extract_entities(self, text: str) -> List[str]: | |
| """ | |
| Extract entities from text using available methods | |
| Args: | |
| text (str): Input text | |
| Returns: | |
| List[str]: List of extracted entities | |
| """ | |
| if not text or not text.strip(): | |
| return [] | |
| logger.info("Starting entity extraction...") | |
| # Try Gemini first | |
| if self.model: | |
| try: | |
| entities = self.extract_with_gemini(text) | |
| if entities: | |
| logger.info(f"Successfully extracted {len(entities)} entities with Gemini") | |
| return entities | |
| except Exception as e: | |
| logger.warning(f"Gemini extraction failed, using fallback: {e}") | |
| # Use fallback method | |
| entities = self.extract_with_fallback(text) | |
| logger.info(f"Extracted {len(entities)} entities using fallback method") | |
| return entities |