Space2 / app.py
Hakim18's picture
Upload 10 files
11133f2 verified
Raw
History Blame Contribute Delete
7.75 kB
import os
import random
import torch
import pandas as pd
from flask import Flask, render_template, request, jsonify
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import uvicorn
import nest_asyncio
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware
# ==============================
# CONFIG
# ==============================
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
CSV_DATA = "dataset_2026.csv"
EMB_FILE = "embeddings_questions.pt"
TOP_K_RECOMMANDATIONS = 5
# ==============================
# FLASK APP
# ==============================
app = Flask(
__name__,
template_folder=os.path.join(BASE_DIR, "templates"),
static_folder=os.path.join(BASE_DIR, "static")
)
# ==============================
# MODEL
# ==============================
print("🔄 Chargement du modèle...")
try:
model = SentenceTransformer(
"OrdalieTech/Solon-embeddings-mini-beta-1.1",
device="cpu",
trust_remote_code=True
)
print("✓ Modèle principal chargé")
except Exception as e:
print("⚠️ Modèle principal échoué:", e)
model = SentenceTransformer(
"paraphrase-multilingual-MiniLM-L12-v2",
device="cpu"
)
print("✓ Modèle fallback chargé")
# ==============================
# GLOBAL CACHE (IMPORTANT FIX)
# ==============================
df = None
embeddings = None
# ==============================
# DATA LOADING (ROBUST FIX)
# ==============================
def load_data():
global df
try:
# auto-detect separator (FIX IMPORTANT)
df = pd.read_csv(CSV_DATA, sep=None, engine="python")
# normalize column names (VERY IMPORTANT FIX)
df.columns = df.columns.str.strip()
print(f"✓ Données chargées: {len(df)} lignes")
print("📌 Colonnes:", df.columns.tolist())
return df
except FileNotFoundError:
print("❌ Dataset introuvable → création...")
df = pd.DataFrame({
"Question": ["Bonjour", "Comment ça va?", "Qu'est-ce que c'est?"],
"Response": [
"Bonjour! Comment puis-je vous aider?",
"Je vais bien, merci!",
"C'est une application Q/A"
],
"Intent": ["salutation", "conversation", "information"]
})
df.to_csv(CSV_DATA, index=False)
return df
# ==============================
# EMBEDDINGS (CACHE FIX)
# ==============================
def load_embeddings():
global embeddings, df
if embeddings is not None:
return embeddings
if os.path.exists(EMB_FILE):
print("📂 Chargement embeddings...")
embeddings = torch.load(EMB_FILE, map_location="cpu")
return embeddings
print("🔨 Création embeddings...")
questions = df["Question"].astype(str).tolist()
embeddings = model.encode(
questions,
convert_to_tensor=True,
normalize_embeddings=True,
show_progress_bar=True
)
torch.save(embeddings, EMB_FILE)
print("✓ Embeddings sauvegardés")
return embeddings
# ==============================
# UTILS
# ==============================
def enrich_message(text):
prefixes = [
"Bonne question 🙂",
"Voici la réponse :",
"Intéressant !",
"D'après mes données :",
"Réponse :",
"🤖"
]
return f"{random.choice(prefixes)} {text}"
def get_column(df, name):
"""
SAFE column getter (fixes Intent/intent/spacing issues)
"""
for col in df.columns:
if col.lower() == name.lower():
return col
raise KeyError(f"Column '{name}' not found. Available: {df.columns.tolist()}")
# ==============================
# CORE LOGIC (FIXED)
# ==============================
def process_question(question):
global df, embeddings
if not question or not question.strip():
return {
"response": "Veuillez poser une question valide.",
"confidence": 0,
"matched": "—",
"intent": "Invalid",
"recs": []
}
try:
df = load_data()
embeddings = load_embeddings()
q_col = get_column(df, "Question")
r_col = get_column(df, "Response")
i_col = get_column(df, "Intent")
emb_q = model.encode(
question,
convert_to_tensor=True,
normalize_embeddings=True
)
scores = util.pytorch_cos_sim(emb_q, embeddings)[0]
best_idx = torch.argmax(scores).item()
confidence = int(scores[best_idx].item() * 100)
# LOW CONFIDENCE
if confidence < 40:
return {
"response": "Désolé, je n'ai pas trouvé de réponse.",
"confidence": confidence,
"matched": "—",
"intent": "Not found",
"recs": []
}
# MEDIUM CONFIDENCE (suggestions)
if confidence < 80:
k = min(TOP_K_RECOMMANDATIONS + 1, len(scores))
top_indices = torch.topk(scores, k).indices.tolist()
recs = [
df[q_col].iloc[i]
for i in top_indices
if i != best_idx
][:TOP_K_RECOMMANDATIONS]
return {
"response": "Je ne suis pas sûr. Voulez-vous dire :",
"confidence": confidence,
"matched": df[q_col].iloc[best_idx],
"intent": df[i_col].iloc[best_idx],
"recs": recs
}
# HIGH CONFIDENCE (final answer)
answer = df[r_col].iloc[best_idx]
intent = df[i_col].iloc[best_idx]
return {
"response": enrich_message(answer),
"confidence": confidence,
"matched": df[q_col].iloc[best_idx],
"intent": intent,
"recs": []
}
except Exception as e:
print("❌ Erreur:", e)
return {
"response": "Erreur technique.",
"confidence": 0,
"matched": "—",
"intent": "Error",
"recs": []
}
# ==============================
# FLASK ROUTES
# ==============================
@app.route("/")
def index():
return render_template("index.html")
@app.route("/ask", methods=["POST"])
def ask():
try:
data = request.get_json()
question = data.get("question", "")
return jsonify(process_question(question))
except Exception as e:
print(e)
return jsonify({"response": "Erreur serveur"})
# ==============================
# GRADIO
# ==============================
def gradio_chat(message, history):
return process_question(message)["response"]
iface = gr.ChatInterface(
fn=gradio_chat,
title="AskLaQ Assistant",
description="Posez vos questions"
)
# ==============================
# FASTAPI WRAPPER
# ==============================
fastapi_app = FastAPI(title="AskLaQ API")
fastapi_app.mount("/", WSGIMiddleware(app))
fastapi_app = gr.mount_gradio_app(fastapi_app, iface, path="/chat")
# ==============================
# MAIN
# ==============================
if __name__ == "__main__":
nest_asyncio.apply()
print("=" * 60)
print("🚀 ASKLAQ SYSTEM (ROBUST VERSION)")
print("=" * 60)
uvicorn.run(
fastapi_app,
host="0.0.0.0",
port=7860,
log_level="info"
)