""" Reranker Adapter — supports BGE-Reranker-v2-m3 AND Jina-Reranker-v3 Auto-detects which model to load based on RERANKER_MODEL setting: - "BAAI/bge-reranker-v2-m3" → FlagReranker (pointwise cross-encoder) - "jinaai/jina-reranker-v3" → Jina v3 listwise reranker Jina v3 advantages over BGE for this project: - Listwise: sees ALL docs at once → better cross-doc comparison - 131K context window → reads full Jina-extracted articles (not just 512 chars) - +9.6% better on English news (BEIR 61.94 vs 56.51) - Better Arabic ranking (78.69 nDCG) - Same size (0.6B), same memory, same cost (free, self-hosted) Thread-safe lazy loading — model loads once on first rerank call. """ import logging import threading from typing import List, Dict, Any, Optional from src.core.config import settings from src.core.ports.reranker_port import RerankerPort logger = logging.getLogger(__name__) # ── Patch transformers compatibility issue ──────────────────────────────────── try: import transformers.utils.import_utils as _tui if not hasattr(_tui, "is_torch_fx_available"): _tui.is_torch_fx_available = lambda: False except Exception: pass # ── Try FlagEmbedding (for BGE) ─────────────────────────────────────────────── try: from FlagEmbedding import FlagReranker HAS_FLAG_RERANKER = True except ImportError: HAS_FLAG_RERANKER = False # ── Try sentence-transformers CrossEncoder (BGE fallback) ──────────────────── try: from sentence_transformers import CrossEncoder HAS_CROSS_ENCODER = True except ImportError: HAS_CROSS_ENCODER = False # ── Try transformers (for Jina v3) ──────────────────────────────────────────── try: import torch from transformers import AutoModel HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False logger.warning("transformers/torch not available — Jina v3 reranker disabled.") # ═══════════════════════════════════════════════════════════════════════════════ # JINA V3 RERANKER # ═══════════════════════════════════════════════════════════════════════════════ class JinaV3Reranker: """ Jina-Reranker-v3 self-hosted reranker. Key differences from BGE pointwise: - Listwise: processes all docs in one forward pass - 131K context window: reads full articles, not just first 512 chars - Built on Qwen3-0.6B backbone with causal self-attention - State-of-the-art BEIR: 61.94 nDCG@10 (vs BGE's 56.51) Scoring: uses sigmoid(logits) for normalized 0-1 scores. """ def __init__(self, model_name: str): self.model_name = model_name self._model = None self._lock = threading.Lock() self._load_failed = False self._device = "cpu" def _load(self): if self._model is not None or self._load_failed: return with self._lock: if self._model is not None or self._load_failed: return if not HAS_TRANSFORMERS: logger.error("transformers not installed — cannot load Jina v3") self._load_failed = True return try: logger.info(f"Loading Jina v3 reranker: {self.model_name}") self._device = "cuda" if torch.cuda.is_available() else "cpu" # Jina v3 uses AutoModel (NOT AutoModelForSequenceClassification) # It has a built-in .rerank() method that returns relevance_score directly from transformers import AutoModel self._model = AutoModel.from_pretrained( self.model_name, trust_remote_code=True, dtype="auto", ) self._model.eval() logger.info( f"✅ Jina v3 reranker loaded on {self._device} " f"(model={self.model_name})" ) except Exception as e: logger.error(f"Failed to load Jina v3 reranker: {e}", exc_info=True) self._load_failed = True def compute_scores( self, query: str, docs: List[str], max_length: int = 1024, ) -> List[float]: """ Score all (query, doc) pairs using Jina v3's built-in .rerank() method. Returns scores in original doc order (not sorted). """ if not docs: return [] self._load() if self._model is None: return [0.5] * len(docs) try: # Jina v3's .rerank() returns list of dicts: # [{"document": str, "relevance_score": float, "index": int}, ...] # Results are sorted by relevance_score descending — we need to # restore original order using the "index" field. results = self._model.rerank(query, docs) # Restore original order scores = [0.0] * len(docs) for r in results: original_idx = r["index"] scores[original_idx] = float(r["relevance_score"]) return scores except Exception as e: logger.error(f"Jina v3 rerank() failed: {e}") return [0.0] * len(docs) @property def is_loaded(self) -> bool: return self._model is not None # ═══════════════════════════════════════════════════════════════════════════════ # UNIFIED RERANKER ADAPTER # ═══════════════════════════════════════════════════════════════════════════════ class BgeRerankerAdapter(RerankerPort): """ Unified reranker adapter — auto-selects BGE or Jina v3 based on config. RERANKER_MODEL=jinaai/jina-reranker-v3 → Jina v3 (recommended) RERANKER_MODEL=BAAI/bge-reranker-v2-m3 → BGE (legacy) Both are self-hosted, free, ~0.6B parameters, ~1.2GB on disk. """ # Max content chars to send to reranker # Jina v3: 1024 tokens ≈ 4096 chars — reads much more than BGE's 512 chars MAX_CONTENT_CHARS_JINA = 4096 MAX_CONTENT_CHARS_BGE = 512 def __init__(self): self.model_name = settings.RERANKER_MODEL self._is_jina_v3 = "jina-reranker-v3" in self.model_name.lower() self._lock = threading.Lock() self._load_failed = False # Check if Jina API reranker is enabled (takes priority over self-hosted) self._jina_api = None if getattr(settings, 'JINA_RERANKER_ENABLED', False) and getattr(settings, 'JINA_API_KEY', ''): try: from src.infrastructure.adapters.jina_reranker_adapter import JinaRerankerAPIAdapter jina_key = settings.JINA_API_KEY if jina_key and jina_key not in ("", "your-jina-api-key-here"): self._jina_api = JinaRerankerAPIAdapter( api_key=jina_key, model=getattr(settings, 'JINA_RERANKER_MODEL', 'jina-reranker-v3'), timeout=getattr(settings, 'JINA_RERANKER_TIMEOUT', 5.0), ) logger.info("Reranker configured: Jina API (cloud, fast)") except Exception as e: logger.warning(f"Jina API reranker init failed: {e}") # Jina v3 self-hosted path if self._is_jina_v3 and not self._jina_api: self._jina = JinaV3Reranker(self.model_name) self._bge_model = None self._use_flag = False logger.info(f"Reranker configured: Jina v3 self-hosted ({self.model_name})") elif not self._jina_api: # BGE path self._jina = None self._bge_model = None self._use_flag = False logger.info(f"Reranker configured: BGE ({self.model_name})") else: self._jina = None self._bge_model = None self._use_flag = False def _load_bge(self): """Lazy-load BGE reranker (thread-safe).""" if self._bge_model is not None or self._load_failed: return with self._lock: if self._bge_model is not None or self._load_failed: return logger.info(f"Loading BGE reranker: {self.model_name}") try: if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower(): # Patch XLMRobertaTokenizer for older transformers versions try: from transformers import XLMRobertaTokenizer, PreTrainedTokenizer for method_name in [ "prepare_for_model", "build_inputs_with_special_tokens", "create_token_type_ids_from_sequences", "get_special_tokens_mask", "convert_tokens_to_string", ]: if not hasattr(XLMRobertaTokenizer, method_name): base_method = getattr(PreTrainedTokenizer, method_name, None) if base_method: setattr(XLMRobertaTokenizer, method_name, base_method) except Exception as patch_err: logger.debug(f"Tokenizer patch skipped: {patch_err}") self._bge_model = FlagReranker( self.model_name, use_fp16=True, normalize=True, trust_remote_code=True, ) self._use_flag = True logger.info(f"✅ BGE loaded via FlagReranker (fp16, multilingual)") elif HAS_CROSS_ENCODER: self._bge_model = CrossEncoder(self.model_name) self._use_flag = False logger.info(f"✅ BGE loaded via CrossEncoder (fallback)") else: logger.error("No BGE backend available (FlagEmbedding or sentence-transformers required)") self._load_failed = True except Exception as e: logger.error(f"Failed to load BGE reranker '{self.model_name}': {e}", exc_info=True) self._load_failed = True # ── Public interface ────────────────────────────────────────────────────── def rerank( self, query: str, docs: List[Dict[str, Any]], top_n: int = 5, ) -> List[Dict[str, Any]]: """ Rerank documents by relevance to query. Priority: Jina API (cloud) > Jina v3 self-hosted > BGE Jina v3 path: uses full article content (up to 4096 chars) BGE path: uses first 512 chars only Returns top_n docs sorted by rerank_score descending. """ if not docs: return [] # Priority: Jina API > Jina v3 self-hosted > BGE if self._jina_api and self._jina_api.is_available(): return self._jina_api.rerank(query, docs, top_n) elif self._is_jina_v3 and self._jina: return self._rerank_jina(query, docs, top_n) else: return self._rerank_bge(query, docs, top_n) def _rerank_jina( self, query: str, docs: List[Dict[str, Any]], top_n: int, ) -> List[Dict[str, Any]]: """Rerank using Jina v3 — reads full article content.""" # Ensure model is loaded self._jina._load() if self._jina._load_failed or not self._jina.is_loaded: logger.warning("Jina v3 unavailable — falling back to vector score ordering") return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] # Build content list — use full content up to 4096 chars # This is the key advantage: Jina reads 8x more content than BGE valid_docs = [] doc_texts = [] for doc in docs: content = doc.get("content", "").strip() if content: doc_texts.append(content[:self.MAX_CONTENT_CHARS_JINA]) valid_docs.append(doc) if not doc_texts: return [] try: scores = self._jina.compute_scores(query, doc_texts) for i, doc in enumerate(valid_docs): doc["rerank_score"] = scores[i] valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True) logger.info( f"[Reranker] Jina v3: {len(valid_docs)} docs → top {top_n} " f"(max_score={valid_docs[0]['rerank_score']:.3f})" ) return valid_docs[:top_n] except Exception as e: logger.error(f"Jina v3 reranking failed: {e} — falling back to vector score") return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] def _rerank_bge( self, query: str, docs: List[Dict[str, Any]], top_n: int, ) -> List[Dict[str, Any]]: """Rerank using BGE — reads first 512 chars only.""" if self._bge_model is None: self._load_bge() if self._bge_model is None: logger.warning("BGE unavailable — falling back to vector score ordering") return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] pairs = [] valid_docs = [] for doc in docs: content = doc.get("content", "").strip() if content: pairs.append([query, content[:self.MAX_CONTENT_CHARS_BGE]]) valid_docs.append(doc) if not pairs: return [] try: if self._use_flag: scores = self._bge_model.compute_score(pairs, batch_size=64) if isinstance(scores, float): scores = [scores] else: scores = self._bge_model.predict(pairs) if isinstance(scores, float): scores = [scores] for i, doc in enumerate(valid_docs): doc["rerank_score"] = float(scores[i]) valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True) logger.info( f"[Reranker] BGE: {len(valid_docs)} docs → top {top_n} " f"(max_score={valid_docs[0]['rerank_score']:.3f})" ) return valid_docs[:top_n] except Exception as e: logger.error(f"BGE reranking failed: {e} — falling back to vector score") return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n] @property def model_type(self) -> str: return "jina_v3" if self._is_jina_v3 else "bge"