"""Research Papers MCP Server. Exposes 10 tools for research paper discovery, citation analysis, and trend detection. Aggregates papers from arXiv, PubMed, Semantic Scholar, and OpenAlex into a local SQLite cache for fast analytics. Install: pip install research-papers-mcp Run: research-papers-mcp """ import asyncio import json import logging from fastmcp import FastMCP from sqlalchemy import func from . import core, sources from .db import Paper, get_session, upsert_papers logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") logger = logging.getLogger(__name__) _MAX_INPUT_LENGTH = 1000 def _valid_sources() -> set[str]: """Derived live from the source registry; new registrations are picked up automatically.""" return set(sources.REGISTRY.keys()) mcp = FastMCP( name="research-papers", instructions=( "Research paper discovery and analysis server. " "Searches arXiv, PubMed, Semantic Scholar, and OpenAlex. " "Provides real citation data via Semantic Scholar, trend detection " "over the local cache, paper similarity (SPECTER2 embeddings with " "TF-IDF fallback), author profiles, literature reviews, and BibTeX export. " "Use search_papers first to populate the local cache, then " "use analytical tools on cached results. Use get_cache_stats " "to check cache size before running analytical tools. Call " "list_sources to discover which sources expose which capabilities." ), ) def _clamp_str(value: str) -> str: return value[:_MAX_INPUT_LENGTH] _SOURCE_CONCEPT_PROVIDERS = {"openalex"} def _enrich_and_upsert(papers: list[dict], query: str, max_results: int) -> dict: """Enrich papers with topics/scores, upsert to DB, return cached results. Source-supplied concepts (OpenAlex today, MeSH/categories elsewhere) keep the top slots; keyword-derived topics fill the tail so we never lose the authoritative signal. Runs entirely in a worker thread to avoid blocking the event loop. """ for p in papers: # Confidence-scored entries: OpenAlex concept assignments when present, # else embeddings (if the extra is installed), else keyword taxonomy. p["topics"] = core.enrich_topic_entries( existing=p.get("topics") or [], title=p.get("title", ""), abstract=p.get("abstract"), concepts=p.get("concepts"), source=p.get("source"), ) p["impact_score"] = core.compute_impact_score( p.get("citation_count"), p.get("influential_citation_count"), p.get("publication_date"), ) new_count = upsert_papers(papers) session = get_session() try: cached = core.search_cached(session, query, limit=max_results) finally: session.close() return {"total_results": len(cached), "new_papers_cached": new_count, "papers": cached} # --------------------------------------------------------------------------- # Tool 1: search_papers # --------------------------------------------------------------------------- @mcp.tool async def search_papers( query: str, sources_filter: list[str] | None = None, max_results: int = 20, ) -> str: """Search for research papers across arXiv, PubMed, Semantic Scholar, and OpenAlex. Fetches papers from external APIs and caches them locally for subsequent analysis. Returns deduplicated results with paper IDs that can be used with other tools. Args: query: Search query (keywords or natural language). sources_filter: Which sources to search. Options: "arxiv", "pubmed", "semantic_scholar", "openalex". Defaults to all four. Call list_sources to see the current registry. max_results: Maximum results per source (default 20, max 50). """ query = _clamp_str(query) max_results = max(min(max_results, 50), 1) # Validate sources against the live registry. valid = _valid_sources() if sources_filter: invalid = [s for s in sources_filter if s not in valid] if invalid: return json.dumps({ "error": f"Invalid sources: {invalid}. Valid options: {sorted(valid)}" }) sources_filter = [s for s in sources_filter if s in valid] # Fetch from external APIs (in worker thread) papers = await asyncio.to_thread( sources.federated_search, query=query, sources=sources_filter, max_results=max_results, ) # Enrich, cache, and re-query (in worker thread) result = await asyncio.to_thread(_enrich_and_upsert, papers, query, max_results) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 2: search_cached_papers # --------------------------------------------------------------------------- @mcp.tool async def search_cached_papers( query: str, source: str | None = None, year_from: int | None = None, year_to: int | None = None, min_citations: int | None = None, max_results: int = 20, ) -> str: """Search the local paper cache without hitting external APIs. Fast local search over previously cached papers by keyword matching on title, abstract, and topics. Use this to re-query papers you have already fetched with search_papers. Args: query: Search keywords. source: Filter by source ("arxiv", "pubmed", "semantic_scholar", "openalex"). year_from: Include papers published on or after this year. year_to: Include papers published on or before this year. min_citations: Minimum citation count filter. max_results: Maximum results to return (default 20, max 50). """ query = _clamp_str(query) max_results = max(min(max_results, 50), 1) valid = _valid_sources() if source and source not in valid: return json.dumps({ "error": f"Invalid source: {source}. Valid options: {sorted(valid)}" }) def _search(): session = get_session() try: return core.search_cached( session, query, source=source, year_from=year_from, year_to=year_to, min_citations=min_citations, limit=max_results, ) finally: session.close() cached = await asyncio.to_thread(_search) return json.dumps({"total_results": len(cached), "papers": cached}, default=str) # --------------------------------------------------------------------------- # Tool 3: get_paper_details # --------------------------------------------------------------------------- @mcp.tool async def get_paper_details(paper_id: int) -> str: """Get full details for a specific paper from the local cache. Use the numeric ID returned by search_papers or other tools. Args: paper_id: The paper's numeric ID in the local database. """ def _get(): session = get_session() try: paper = session.query(Paper).filter(Paper.id == paper_id).first() if not paper: return {"error": f"Paper {paper_id} not found in local cache"} return paper.to_dict() finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 4: get_paper_citations # --------------------------------------------------------------------------- @mcp.tool async def get_paper_citations(paper_id: int) -> str: """Get real citation data for a paper from Semantic Scholar. Returns papers that cite this paper and papers it references, fetched from the Semantic Scholar Academic Graph API. Includes citation context, intents, and influence indicators. Requires the paper to have a resolvable Semantic Scholar identifier (S2 paper ID, DOI, or arXiv ID). Args: paper_id: The paper's numeric ID in the local database. """ def _get(): session = get_session() try: return core.get_citations(session, paper_id) finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 5: get_trending_topics # --------------------------------------------------------------------------- @mcp.tool async def get_trending_topics(window_days: int = 30) -> str: """Detect emerging and declining research topics in your cached corpus. Compares topic frequency in a recent window against a baseline period of the same length. Operates only on locally cached papers, so results reflect your search history rather than the full publication landscape. Run broad searches with search_papers first to build a representative sample. Args: window_days: Size of the analysis window in days (default 30). """ window_days = max(window_days, 1) def _get(): session = get_session() try: return core.get_trends(session, window_days) finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 6: find_similar_papers # --------------------------------------------------------------------------- @mcp.tool async def find_similar_papers(paper_id: int, top_n: int = 10) -> str: """Find papers similar to a given paper. Uses the Semantic Scholar Recommendations API (SPECTER2 embeddings) when available for high-quality semantic similarity. Falls back to local TF-IDF cosine similarity when the paper lacks a Semantic Scholar identifier or the API is unavailable. Args: paper_id: The paper's numeric ID in the local database. top_n: Number of similar papers to return (default 10, max 50). """ top_n = max(min(top_n, 50), 1) def _get(): session = get_session() try: return core.find_similar(session, paper_id, top_n) finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 7: get_author_profile # --------------------------------------------------------------------------- @mcp.tool async def get_author_profile(author_name: str) -> str: """Get a researcher's publication profile and analytics. Searches the local cache for papers by the given author and returns publication frequency, top topics, and collaborators. Use full last name for best results (e.g. "Vaswani" not "V"). Args: author_name: Author last name or full name (e.g. "LeCun", "Yann LeCun"). """ author_name = _clamp_str(author_name) if len(author_name.strip()) < 2: return json.dumps({"error": "Author name must be at least 2 characters"}) def _get(): session = get_session() try: return core.get_author_profile(session, author_name) finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 8: generate_literature_review # --------------------------------------------------------------------------- @mcp.tool async def generate_literature_review(topic: str) -> str: """Generate a structured literature review for a research topic. Analyzes papers in the local cache matching the topic, groups them by subtopic, identifies consensus and debate areas, and produces a structured review. Run search_papers first to populate the cache with relevant papers. Args: topic: Research topic to review (e.g. "federated learning"). """ topic = _clamp_str(topic) def _get(): session = get_session() try: return core.generate_review(session, topic) finally: session.close() result = await asyncio.to_thread(_get) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 9: export_bibtex # --------------------------------------------------------------------------- @mcp.tool async def export_bibtex( paper_ids: list[int] | None = None, query: str | None = None, limit: int = 50, ) -> str: """Export papers as BibTeX entries for use in LaTeX documents. Provide specific paper IDs or a search query to select papers. Returns properly formatted BibTeX with citation keys, DOIs, arXiv eprint IDs, and abstracts. Args: paper_ids: List of paper IDs to export. Takes priority over query. query: Search query to find papers to export. limit: Maximum papers to export when using query (default 50). """ if query: query = _clamp_str(query) limit = max(min(limit, 200), 1) def _export(): session = get_session() try: return core.export_bibtex(session, paper_ids=paper_ids, query=query, limit=limit) finally: session.close() result = await asyncio.to_thread(_export) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 10: get_cache_stats # --------------------------------------------------------------------------- def _compute_stats() -> dict: session = get_session() try: total = session.query(Paper).count() by_source = {} for source, count in ( session.query(Paper.source, func.count(Paper.id)) .group_by(Paper.source) .all() ): by_source[source] = count date_range = session.query( func.min(Paper.publication_date), func.max(Paper.publication_date), ).first() return { "total_papers": total, "papers_by_source": by_source, "earliest": date_range[0].isoformat() if date_range and date_range[0] else None, "latest": date_range[1].isoformat() if date_range and date_range[1] else None, } finally: session.close() @mcp.tool async def get_cache_stats() -> str: """Get statistics about the local paper cache. Returns total paper count, breakdown by source, and date range of cached papers. Use this to check whether the cache has enough data before running analytical tools like get_trending_topics. """ result = await asyncio.to_thread(_compute_stats) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Tool 11: list_sources # --------------------------------------------------------------------------- _CAPABILITY_KEYS = ("citations", "references", "recs", "full_text") def _describe_sources() -> dict: items = [] for name, src in sources.REGISTRY.items(): hint = src.rate_limit_hint() items.append({ "name": name, "capabilities": {cap: src.supports(cap) for cap in _CAPABILITY_KEYS}, "rate_limit": { "min_interval_seconds": hint.min_interval_seconds, "daily_quota": hint.daily_quota, "notes": hint.notes, }, }) items.sort(key=lambda item: item["name"]) return {"sources": items} @mcp.tool async def list_sources() -> str: """List registered paper sources, their capabilities, and rate-limit hints. Use this to discover which sources can supply citations, references, recommendations, or full text without reading the source code. Capabilities map to the `supports()` contract on each `PaperSource`. """ result = await asyncio.to_thread(_describe_sources) return json.dumps(result, default=str) # --------------------------------------------------------------------------- # Resources # --------------------------------------------------------------------------- @mcp.resource("papers://sources") def get_sources() -> str: """Registered paper sources with capabilities and rate-limit hints.""" return json.dumps(_describe_sources(), default=str) @mcp.resource("papers://stats") def get_stats() -> str: """Database statistics: total papers, by source, date range.""" return json.dumps(_compute_stats(), default=str) @mcp.resource("papers://fields") def get_fields() -> str: """List all research fields/topics in the cached corpus with paper counts.""" session = get_session() try: rows = session.query(Paper.topics_json).all() topic_counts: dict[str, int] = {} for (topics_json,) in rows: try: topics = json.loads(topics_json or "[]") except (ValueError, TypeError): continue for name in core.topics_to_names(topics): key = name.lower() topic_counts[key] = topic_counts.get(key, 0) + 1 sorted_topics = sorted(topic_counts.items(), key=lambda x: x[1], reverse=True) return json.dumps({"fields": [{"name": t, "count": c} for t, c in sorted_topics]}) finally: session.close() # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(): import argparse import os import sys if len(sys.argv) > 1 and sys.argv[1] == "refresh-concepts": from .topics import refresh_concepts count = refresh_concepts(mailto=os.getenv("OPENALEX_MAILTO") or None) print(f"Refreshed OpenAlex concept vocabulary: {count} concepts") return parser = argparse.ArgumentParser(description="Research Papers MCP Server") parser.add_argument( "--transport", choices=["stdio", "sse", "streamable-http"], default="stdio", help="Transport protocol (default: stdio)", ) parser.add_argument("--host", default="127.0.0.1", help="Host for HTTP transports") parser.add_argument("--port", type=int, default=8080, help="Port for HTTP transports") args = parser.parse_args() if args.transport == "stdio": mcp.run(transport="stdio") elif args.transport == "sse": mcp.run(transport="sse", host=args.host, port=args.port) else: mcp.run(transport="streamable-http", host=args.host, port=args.port) if __name__ == "__main__": main()