barissozudogru's picture
bundle research_papers_mcp source, drop git+install
57272d3 verified
Raw
History Blame Contribute Delete
8.42 kB
"""Semantic Scholar source: search + per-paper citations/references/recs."""
import datetime
import logging
import os
from ._http import _http
from .base import Capability, PaperSource, RateLimitHint, paper_dict
logger = logging.getLogger(__name__)
_S2_BASE_URL = "https://api.semanticscholar.org"
_S2_SEARCH_URL = f"{_S2_BASE_URL}/graph/v1/paper/search"
_S2_PAPER_URL = f"{_S2_BASE_URL}/graph/v1/paper"
_S2_RECS_URL = f"{_S2_BASE_URL}/recommendations/v1/papers/forpaper"
_S2_FIELDS = (
"title,abstract,authors,year,publicationDate,"
"externalIds,citationCount,influentialCitationCount,url"
)
_S2_CITATION_FIELDS = (
"title,abstract,authors,year,publicationDate,"
"externalIds,citationCount,influentialCitationCount,url,contexts,intents,isInfluential"
)
def _s2_headers() -> dict:
api_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY")
return {"x-api-key": api_key} if api_key else {}
def _parse_s2_paper(paper: dict, source_label: str = "semantic_scholar") -> dict | None:
"""Parse a Semantic Scholar paper object into our standard dict format."""
paper_id = paper.get("paperId", "")
title = (paper.get("title") or "").strip()
if not paper_id or not title:
return None
author_names = [a.get("name", "") for a in paper.get("authors", []) if a.get("name")]
pub_date = None
if paper.get("publicationDate"):
try:
pub_date = datetime.datetime.strptime(paper["publicationDate"], "%Y-%m-%d")
except ValueError:
pass
if pub_date is None and paper.get("year"):
try:
pub_date = datetime.datetime(paper["year"], 1, 1)
except (ValueError, TypeError):
pass
ext_ids = paper.get("externalIds") or {}
return paper_dict(
title=title,
abstract=(paper.get("abstract") or "").strip() or None,
authors=", ".join(author_names) if author_names else None,
publication_date=pub_date,
source=source_label,
source_id=paper_id,
url=paper.get("url") or f"https://www.semanticscholar.org/paper/{paper_id}",
doi=ext_ids.get("DOI"),
topics=[],
citation_count=paper.get("citationCount"),
influential_citation_count=paper.get("influentialCitationCount"),
)
def fetch_semantic_scholar(
query: str = "machine learning",
max_results: int = 20,
fields_of_study: list[str] | None = None,
) -> list[dict]:
"""Fetch papers from Semantic Scholar Academic Graph API."""
params = {
"query": query,
"limit": min(max_results, 100),
"fields": _S2_FIELDS,
}
if fields_of_study:
params["fieldsOfStudy"] = ",".join(fields_of_study)
try:
resp = _http.get(
_S2_SEARCH_URL, params=params, headers=_s2_headers(), timeout=20,
)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning("Semantic Scholar search failed: %s", exc)
return []
papers = []
for paper in data.get("data", []):
try:
parsed = _parse_s2_paper(paper)
if parsed:
papers.append(parsed)
except Exception as exc:
logger.warning("Skipping malformed S2 paper: %s", exc)
return papers
def fetch_s2_by_id(identifier: str) -> dict | None:
"""Fetch a single Semantic Scholar paper by S2 ID, DOI, arXiv ID, etc."""
url = f"{_S2_PAPER_URL}/{identifier}"
params = {"fields": _S2_FIELDS}
try:
resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning("S2 get_by_id failed for %s: %s", identifier, exc)
return None
return _parse_s2_paper(data)
def fetch_s2_citations(
paper_identifier: str,
limit: int = 100,
) -> list[dict]:
"""Fetch papers that cite the given paper from Semantic Scholar.
Args:
paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc.
limit: Max citations to return (max 1000).
"""
url = f"{_S2_PAPER_URL}/{paper_identifier}/citations"
params = {
"fields": _S2_CITATION_FIELDS,
"limit": min(limit, 1000),
}
try:
resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning("S2 citations fetch failed for %s: %s", paper_identifier, exc)
return []
results = []
for item in data.get("data", []):
citing = item.get("citingPaper", {})
parsed = _parse_s2_paper(citing)
if parsed:
parsed["_is_influential"] = item.get("isInfluential", False)
parsed["_citation_contexts"] = item.get("contexts", [])
parsed["_citation_intents"] = item.get("intents", [])
results.append(parsed)
return results
def fetch_s2_references(
paper_identifier: str,
limit: int = 100,
) -> list[dict]:
"""Fetch papers referenced by the given paper from Semantic Scholar.
Args:
paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc.
limit: Max references to return (max 1000).
"""
url = f"{_S2_PAPER_URL}/{paper_identifier}/references"
params = {
"fields": _S2_CITATION_FIELDS,
"limit": min(limit, 1000),
}
try:
resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning("S2 references fetch failed for %s: %s", paper_identifier, exc)
return []
results = []
for item in data.get("data", []):
cited = item.get("citedPaper", {})
parsed = _parse_s2_paper(cited)
if parsed:
parsed["_is_influential"] = item.get("isInfluential", False)
parsed["_citation_contexts"] = item.get("contexts", [])
parsed["_citation_intents"] = item.get("intents", [])
results.append(parsed)
return results
def fetch_s2_recommendations(
paper_identifier: str,
limit: int = 20,
) -> list[dict]:
"""Fetch recommended papers from Semantic Scholar Recommendations API.
Uses SPECTER2 embeddings for semantic similarity.
Args:
paper_identifier: S2 paper ID.
limit: Max recommendations to return (max 500).
"""
url = f"{_S2_RECS_URL}/{paper_identifier}"
params = {
"fields": _S2_FIELDS,
"limit": min(limit, 500),
}
try:
resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
logger.warning("S2 recommendations failed for %s: %s", paper_identifier, exc)
return []
results = []
for paper in data.get("recommendedPapers", []):
try:
parsed = _parse_s2_paper(paper)
if parsed:
results.append(parsed)
except Exception as exc:
logger.warning("Skipping malformed S2 recommendation: %s", exc)
return results
class SemanticScholarSource(PaperSource):
"""Semantic Scholar Academic Graph + Recommendations API."""
name = "semantic_scholar"
def search(
self,
query: str,
max_results: int = 20,
filters: dict | None = None,
) -> list[dict]:
filters = filters or {}
return fetch_semantic_scholar(
query=query,
max_results=max_results,
fields_of_study=filters.get("fields_of_study"),
)
def get_by_id(self, identifier: str) -> dict | None:
return fetch_s2_by_id(identifier)
def supports(self, capability: Capability) -> bool:
return capability in ("citations", "references", "recs")
def rate_limit_hint(self) -> RateLimitHint:
if os.getenv("SEMANTIC_SCHOLAR_API_KEY"):
return RateLimitHint(
min_interval_seconds=0.01,
daily_quota=None,
notes="S2 with API key: high throughput, observe per-endpoint limits",
)
return RateLimitHint(
min_interval_seconds=1.0,
daily_quota=None,
notes="S2 public pool: ~1 req/s, frequent 429s, retries strongly recommended",
)