Spaces:
Sleeping
Sleeping
| """SQLite database layer for paper caching. | |
| Provides a lightweight persistence layer using SQLAlchemy with SQLite. | |
| Papers are cached locally so repeated queries are fast and the corpus | |
| grows over time for trend detection and citation analysis. | |
| """ | |
| import json | |
| import os | |
| import threading | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from sqlalchemy import ( | |
| Column, | |
| DateTime, | |
| Float, | |
| Integer, | |
| String, | |
| Text, | |
| UniqueConstraint, | |
| create_engine, | |
| event, | |
| tuple_, | |
| ) | |
| from sqlalchemy.orm import Session, declarative_base, sessionmaker | |
| Base = declarative_base() | |
| _CACHE_DIR = Path(os.environ.get( | |
| "RESEARCH_MCP_CACHE_DIR", | |
| os.path.expanduser("~/.research-papers-mcp"), | |
| )) | |
| class Paper(Base): | |
| __tablename__ = "papers" | |
| __table_args__ = ( | |
| UniqueConstraint("source", "source_id", name="uq_source_paper"), | |
| ) | |
| id = Column(Integer, primary_key=True) | |
| title = Column(String, nullable=False) | |
| abstract = Column(Text) | |
| authors = Column(String) | |
| publication_date = Column(DateTime) | |
| source = Column(String, nullable=False) | |
| source_id = Column(String, nullable=False) | |
| url = Column(String) | |
| doi = Column(String) | |
| topics_json = Column(Text, default="[]") | |
| citation_count = Column(Integer) | |
| influential_citation_count = Column(Integer) | |
| impact_score = Column(Float) | |
| created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) | |
| updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), | |
| onupdate=lambda: datetime.now(timezone.utc)) | |
| def topics(self): | |
| """Topic names only, normalised from either on-disk schema shape.""" | |
| from .topics.schema import to_names | |
| try: | |
| raw = json.loads(self.topics_json or "[]") | |
| except (json.JSONDecodeError, TypeError): | |
| return [] | |
| return to_names(raw) | |
| def topic_entries(self): | |
| """Full topic entries: ``[{"name", "confidence", "source"}, ...]``.""" | |
| from .topics.schema import to_entries | |
| try: | |
| raw = json.loads(self.topics_json or "[]") | |
| except (json.JSONDecodeError, TypeError): | |
| return [] | |
| return to_entries(raw) | |
| def topics(self, value): | |
| self.topics_json = json.dumps(value or []) | |
| def to_dict(self, compact: bool = False): | |
| d = { | |
| "id": self.id, | |
| "title": self.title, | |
| "authors": self.authors, | |
| "publication_date": ( | |
| self.publication_date.isoformat() if self.publication_date else None | |
| ), | |
| "source": self.source, | |
| "url": self.url, | |
| "topics": self.topics, | |
| "citation_count": self.citation_count, | |
| "impact_score": self.impact_score, | |
| } | |
| if not compact: | |
| d["abstract"] = self.abstract | |
| d["source_id"] = self.source_id | |
| d["doi"] = self.doi | |
| d["influential_citation_count"] = self.influential_citation_count | |
| d["updated_at"] = ( | |
| self.updated_at.isoformat() if self.updated_at else None | |
| ) | |
| return d | |
| _engine = None | |
| _SessionLocal = None | |
| _init_lock = threading.Lock() | |
| def _enable_wal(dbapi_conn, connection_record): | |
| cursor = dbapi_conn.cursor() | |
| cursor.execute("PRAGMA journal_mode=WAL") | |
| cursor.close() | |
| def get_engine(): | |
| global _engine | |
| if _engine is None: | |
| with _init_lock: | |
| if _engine is None: | |
| _CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| db_path = _CACHE_DIR / "papers.db" | |
| _engine = create_engine( | |
| f"sqlite:///{db_path}", | |
| connect_args={"check_same_thread": False}, | |
| ) | |
| event.listen(_engine, "connect", _enable_wal) | |
| Base.metadata.create_all(_engine) | |
| return _engine | |
| def get_session() -> Session: | |
| global _SessionLocal | |
| if _SessionLocal is None: | |
| with _init_lock: | |
| if _SessionLocal is None: | |
| _SessionLocal = sessionmaker(bind=get_engine()) | |
| return _SessionLocal() | |
| def upsert_papers(papers: list[dict]) -> int: | |
| """Insert papers, skipping duplicates. Returns count of new papers.""" | |
| if not papers: | |
| return 0 | |
| session = get_session() | |
| new_count = 0 | |
| try: | |
| # Batch-fetch existing papers to avoid N+1 queries | |
| keys = [(p["source"], p["source_id"]) for p in papers] | |
| existing_rows = ( | |
| session.query(Paper) | |
| .filter( | |
| tuple_(Paper.source, Paper.source_id).in_(keys) | |
| ) | |
| .all() | |
| ) | |
| existing_map = {(r.source, r.source_id): r for r in existing_rows} | |
| for p in papers: | |
| key = (p["source"], p["source_id"]) | |
| existing = existing_map.get(key) | |
| if existing: | |
| updated = False | |
| if p.get("citation_count") is not None: | |
| existing.citation_count = p["citation_count"] | |
| updated = True | |
| if p.get("influential_citation_count") is not None: | |
| existing.influential_citation_count = p["influential_citation_count"] | |
| updated = True | |
| if p.get("impact_score") is not None: | |
| existing.impact_score = p["impact_score"] | |
| updated = True | |
| if p.get("abstract") and not existing.abstract: | |
| existing.abstract = p["abstract"] | |
| updated = True | |
| if updated: | |
| existing.updated_at = datetime.now(timezone.utc) | |
| continue | |
| paper = Paper( | |
| title=p["title"], | |
| abstract=p.get("abstract"), | |
| authors=p.get("authors"), | |
| publication_date=p.get("publication_date"), | |
| source=p["source"], | |
| source_id=p["source_id"], | |
| url=p.get("url"), | |
| doi=p.get("doi"), | |
| citation_count=p.get("citation_count"), | |
| influential_citation_count=p.get("influential_citation_count"), | |
| impact_score=p.get("impact_score"), | |
| ) | |
| paper.topics = p.get("topics", []) | |
| session.add(paper) | |
| new_count += 1 | |
| session.commit() | |
| except Exception: | |
| session.rollback() | |
| raise | |
| finally: | |
| session.close() | |
| return new_count | |