Spaces:
Sleeping
Sleeping
surbi karki commited on
Update recommender_core.py
Browse files- recommender_core.py +176 -166
recommender_core.py
CHANGED
|
@@ -1,166 +1,176 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import joblib
|
| 3 |
-
import os
|
| 4 |
-
import time
|
| 5 |
-
from sqlalchemy import create_engine
|
| 6 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
-
from urllib.parse import quote_plus
|
| 8 |
-
from text_utils import TextProcessor
|
| 9 |
-
from functools import lru_cache
|
| 10 |
-
|
| 11 |
-
# --- CONFIGURATION ---
|
| 12 |
-
# For cloud deployment (HF/Production), use DATABASE_URL.
|
| 13 |
-
# Fallback to local construction if not present.
|
| 14 |
-
DATABASE_URL = os.getenv("DATABASE_URL")
|
| 15 |
-
if not DATABASE_URL:
|
| 16 |
-
DB_USER = os.getenv("DB_USER", "postgres")
|
| 17 |
-
DB_PASSWORD = quote_plus(os.getenv("DB_PASSWORD", "subisu"))
|
| 18 |
-
DB_HOST = os.getenv("DB_HOST", "localhost")
|
| 19 |
-
DB_PORT = os.getenv("DB_PORT", "5432")
|
| 20 |
-
DB_NAME = os.getenv("DB_NAME", "ppd_project_db")
|
| 21 |
-
DB_URI = f'postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
| 22 |
-
else:
|
| 23 |
-
# Ensure URL is compatible with SQLAlchemy if it starts with postgres://
|
| 24 |
-
if DATABASE_URL.startswith("postgres://"):
|
| 25 |
-
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql+psycopg2://", 1)
|
| 26 |
-
elif "postgresql://" in DATABASE_URL and "+psycopg2" not in DATABASE_URL:
|
| 27 |
-
DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+psycopg2://", 1)
|
| 28 |
-
DB_URI = DATABASE_URL
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class RecommenderCore:
|
| 32 |
-
def __init__(self):
|
| 33 |
-
self.engine = create_engine(DB_URI)
|
| 34 |
-
self.vectorizer = None
|
| 35 |
-
self.tfidf_matrix = None
|
| 36 |
-
self.df = None
|
| 37 |
-
self.load_model()
|
| 38 |
-
|
| 39 |
-
def load_model(self):
|
| 40 |
-
try:
|
| 41 |
-
if os.path.exists('vectorizer.pkl') and os.path.exists('tfidf_matrix.pkl'):
|
| 42 |
-
self.vectorizer = joblib.load('vectorizer.pkl')
|
| 43 |
-
self.tfidf_matrix = joblib.load('tfidf_matrix.pkl')
|
| 44 |
-
print("💾 Model Loaded into Memory.")
|
| 45 |
-
|
| 46 |
-
self.df = pd.read_sql("SELECT * FROM articles WHERE status = 'Approved' ORDER BY article_id", self.engine)
|
| 47 |
-
self.df = self.df.reset_index(drop=True)
|
| 48 |
-
print(f"📚 Indexed {len(self.df)} articles.")
|
| 49 |
-
except Exception as e:
|
| 50 |
-
print(f"Load Error: {e}")
|
| 51 |
-
|
| 52 |
-
@lru_cache(maxsize=128)
|
| 53 |
-
def recommend_articles(self, symptoms_text, crisis_level, top_n=5):
|
| 54 |
-
"""Modular requirement: Main entry point with caching."""
|
| 55 |
-
if self.df is None or self.vectorizer is None:
|
| 56 |
-
return []
|
| 57 |
-
|
| 58 |
-
# 1. Preprocess user query
|
| 59 |
-
query_raw = symptoms_text
|
| 60 |
-
query_norm = TextProcessor.normalize(symptoms_text)
|
| 61 |
-
query_phased = TextProcessor.detect_phrases(query_norm)
|
| 62 |
-
|
| 63 |
-
# 2. Filter by Crisis Level (Safety First)
|
| 64 |
-
risk_map = {
|
| 65 |
-
"High": ["High", "Critical", "Moderate", "All"],
|
| 66 |
-
"Moderate": ["Moderate", "Low", "All"],
|
| 67 |
-
"Low": ["Low", "All"]
|
| 68 |
-
}
|
| 69 |
-
allowed = risk_map.get(crisis_level, ["All"])
|
| 70 |
-
|
| 71 |
-
# Determine the filtered subset
|
| 72 |
-
mask = self.df['risk_level'].apply(
|
| 73 |
-
lambda x: any(level.strip() in allowed for level in str(x).split(','))
|
| 74 |
-
)
|
| 75 |
-
filtered_df = self.df[mask].copy()
|
| 76 |
-
|
| 77 |
-
if filtered_df.empty: return []
|
| 78 |
-
|
| 79 |
-
# 3. Primary ML Scoring (Cosine Similarity)
|
| 80 |
-
user_vec = self.vectorizer.transform([query_phased])
|
| 81 |
-
all_cos_scores = cosine_similarity(user_vec, self.tfidf_matrix).flatten()
|
| 82 |
-
|
| 83 |
-
# 4. Final Ranking
|
| 84 |
-
# Correctly align scores using the original dataframe's index
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
score
|
| 146 |
-
|
| 147 |
-
#
|
| 148 |
-
if row['format_type'] == '
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import joblib
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from sqlalchemy import create_engine
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
+
from urllib.parse import quote_plus
|
| 8 |
+
from text_utils import TextProcessor
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
|
| 11 |
+
# --- CONFIGURATION ---
|
| 12 |
+
# For cloud deployment (HF/Production), use DATABASE_URL.
|
| 13 |
+
# Fallback to local construction if not present.
|
| 14 |
+
DATABASE_URL = os.getenv("DATABASE_URL")
|
| 15 |
+
if not DATABASE_URL:
|
| 16 |
+
DB_USER = os.getenv("DB_USER", "postgres")
|
| 17 |
+
DB_PASSWORD = quote_plus(os.getenv("DB_PASSWORD", "subisu"))
|
| 18 |
+
DB_HOST = os.getenv("DB_HOST", "localhost")
|
| 19 |
+
DB_PORT = os.getenv("DB_PORT", "5432")
|
| 20 |
+
DB_NAME = os.getenv("DB_NAME", "ppd_project_db")
|
| 21 |
+
DB_URI = f'postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
| 22 |
+
else:
|
| 23 |
+
# Ensure URL is compatible with SQLAlchemy if it starts with postgres://
|
| 24 |
+
if DATABASE_URL.startswith("postgres://"):
|
| 25 |
+
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql+psycopg2://", 1)
|
| 26 |
+
elif "postgresql://" in DATABASE_URL and "+psycopg2" not in DATABASE_URL:
|
| 27 |
+
DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+psycopg2://", 1)
|
| 28 |
+
DB_URI = DATABASE_URL
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RecommenderCore:
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.engine = create_engine(DB_URI)
|
| 34 |
+
self.vectorizer = None
|
| 35 |
+
self.tfidf_matrix = None
|
| 36 |
+
self.df = None
|
| 37 |
+
self.load_model()
|
| 38 |
+
|
| 39 |
+
def load_model(self):
|
| 40 |
+
try:
|
| 41 |
+
if os.path.exists('vectorizer.pkl') and os.path.exists('tfidf_matrix.pkl'):
|
| 42 |
+
self.vectorizer = joblib.load('vectorizer.pkl')
|
| 43 |
+
self.tfidf_matrix = joblib.load('tfidf_matrix.pkl')
|
| 44 |
+
print("💾 Model Loaded into Memory.")
|
| 45 |
+
|
| 46 |
+
self.df = pd.read_sql("SELECT * FROM articles WHERE status = 'Approved' ORDER BY article_id", self.engine)
|
| 47 |
+
self.df = self.df.reset_index(drop=True)
|
| 48 |
+
print(f"📚 Indexed {len(self.df)} articles.")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"Load Error: {e}")
|
| 51 |
+
|
| 52 |
+
@lru_cache(maxsize=128)
|
| 53 |
+
def recommend_articles(self, symptoms_text, crisis_level, top_n=5):
|
| 54 |
+
"""Modular requirement: Main entry point with caching."""
|
| 55 |
+
if self.df is None or self.vectorizer is None:
|
| 56 |
+
return []
|
| 57 |
+
|
| 58 |
+
# 1. Preprocess user query
|
| 59 |
+
query_raw = symptoms_text
|
| 60 |
+
query_norm = TextProcessor.normalize(symptoms_text)
|
| 61 |
+
query_phased = TextProcessor.detect_phrases(query_norm)
|
| 62 |
+
|
| 63 |
+
# 2. Filter by Crisis Level (Safety First)
|
| 64 |
+
risk_map = {
|
| 65 |
+
"High": ["High", "Critical", "Moderate", "All"],
|
| 66 |
+
"Moderate": ["Moderate", "Low", "All"],
|
| 67 |
+
"Low": ["Low", "All"]
|
| 68 |
+
}
|
| 69 |
+
allowed = risk_map.get(crisis_level, ["All"])
|
| 70 |
+
|
| 71 |
+
# Determine the filtered subset
|
| 72 |
+
mask = self.df['risk_level'].apply(
|
| 73 |
+
lambda x: any(level.strip() in allowed for level in str(x).split(','))
|
| 74 |
+
)
|
| 75 |
+
filtered_df = self.df[mask].copy()
|
| 76 |
+
|
| 77 |
+
if filtered_df.empty: return []
|
| 78 |
+
|
| 79 |
+
# 3. Primary ML Scoring (Cosine Similarity)
|
| 80 |
+
user_vec = self.vectorizer.transform([query_phased])
|
| 81 |
+
all_cos_scores = cosine_similarity(user_vec, self.tfidf_matrix).flatten()
|
| 82 |
+
|
| 83 |
+
# 4. Final Ranking
|
| 84 |
+
# Correctly align scores using the original dataframe's index
|
| 85 |
+
# SAFETY: Ensure we don't exceed the bounds of the scores array (mismatch protection)
|
| 86 |
+
max_idx = len(all_cos_scores)
|
| 87 |
+
cos_scores_for_filtered = []
|
| 88 |
+
for i in filtered_df.index:
|
| 89 |
+
if i < max_idx:
|
| 90 |
+
cos_scores_for_filtered.append(all_cos_scores[i])
|
| 91 |
+
else:
|
| 92 |
+
cos_scores_for_filtered.append(0.0)
|
| 93 |
+
|
| 94 |
+
filtered_df['cosine_score'] = cos_scores_for_filtered
|
| 95 |
+
|
| 96 |
+
# Apply the hybrid ranking engine
|
| 97 |
+
ranked_results = self.apply_ranking(filtered_df, query_raw)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Format for output
|
| 101 |
+
final_list = ranked_results.head(top_n).to_dict('records')
|
| 102 |
+
|
| 103 |
+
# 5. Live Fallback if needed
|
| 104 |
+
# Requirement: If results are too few, fetch fresh content
|
| 105 |
+
K = 3
|
| 106 |
+
if len(final_list) < K:
|
| 107 |
+
try:
|
| 108 |
+
from ingestion_service import IngestionService
|
| 109 |
+
service = IngestionService()
|
| 110 |
+
live_arts = service.fetch_from_pubmed(query_raw, limit=K)
|
| 111 |
+
for art in live_arts:
|
| 112 |
+
if len(final_list) >= top_n: break
|
| 113 |
+
final_list.append({
|
| 114 |
+
"article_id": -1,
|
| 115 |
+
"title": art['title'],
|
| 116 |
+
"category": "Live Fallback",
|
| 117 |
+
"format_type": "pubmed",
|
| 118 |
+
"external_url": art['url'],
|
| 119 |
+
"content": art['content'],
|
| 120 |
+
"risk_level": "All"
|
| 121 |
+
})
|
| 122 |
+
# Background ingestion (optional here, but requested in strategy)
|
| 123 |
+
if live_arts: service.store_articles(live_arts)
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Fallback error: {e}")
|
| 126 |
+
|
| 127 |
+
for item in final_list:
|
| 128 |
+
item['access_type'] = 'External Link' if item.get('format_type') == 'pubmed' else 'Direct Text'
|
| 129 |
+
if 'created_at' in item and item['created_at']:
|
| 130 |
+
item['created_at'] = str(item['created_at'])
|
| 131 |
+
|
| 132 |
+
return final_list
|
| 133 |
+
|
| 134 |
+
def apply_ranking(self, df, raw_query):
|
| 135 |
+
"""Modular requirement: Hybrid ranking engine."""
|
| 136 |
+
# Constants for weighting
|
| 137 |
+
SOURCE_WEIGHT = 1.15 # 15% boost for contributor articles
|
| 138 |
+
EXACT_MATCH_BOOST = 0.2
|
| 139 |
+
|
| 140 |
+
tokens = TextProcessor.normalize(raw_query).split()
|
| 141 |
+
|
| 142 |
+
now = pd.Timestamp.now()
|
| 143 |
+
|
| 144 |
+
def calculate_hybrid_score(row):
|
| 145 |
+
score = row['cosine_score']
|
| 146 |
+
|
| 147 |
+
# A. Source Weighting (Trusted Contributors)
|
| 148 |
+
if row['format_type'] == 'text':
|
| 149 |
+
score *= SOURCE_WEIGHT
|
| 150 |
+
|
| 151 |
+
# B. Exact Symptom Overlap Boost
|
| 152 |
+
# Check how many user tokens appear exactly in the normalized title
|
| 153 |
+
norm_title = TextProcessor.normalize(row['title'])
|
| 154 |
+
matches = sum(1 for t in tokens if t in norm_title)
|
| 155 |
+
score += (matches * EXACT_MATCH_BOOST)
|
| 156 |
+
|
| 157 |
+
# C. Recency Boost (PubMed only, newer is better)
|
| 158 |
+
if row['format_type'] == 'pubmed' and row['created_at']:
|
| 159 |
+
age_days = (now - pd.to_datetime(row['created_at'])).days
|
| 160 |
+
# Decaying boost: max 0.1 for brand new, goes to 0 over 365 days
|
| 161 |
+
recency_boost = max(0, 0.1 * (1 - (min(age_days, 365) / 365)))
|
| 162 |
+
score += recency_boost
|
| 163 |
+
|
| 164 |
+
return score
|
| 165 |
+
|
| 166 |
+
df['final_score'] = df.apply(calculate_hybrid_score, axis=1)
|
| 167 |
+
return df.sort_values(by='final_score', ascending=False)
|
| 168 |
+
|
| 169 |
+
def get_article_by_id(self, article_id):
|
| 170 |
+
"""Fetches a single article by its ID."""
|
| 171 |
+
if self.df is None: return None
|
| 172 |
+
article = self.df[self.df['article_id'] == article_id]
|
| 173 |
+
return article.iloc[0].to_dict() if not article.empty else None
|
| 174 |
+
|
| 175 |
+
# Singleton instance to be used by main.py
|
| 176 |
+
recommender = RecommenderCore()
|