"""Semantic Scholar source: search + per-paper citations/references/recs.""" import datetime import logging import os from ._http import _http from .base import Capability, PaperSource, RateLimitHint, paper_dict logger = logging.getLogger(__name__) _S2_BASE_URL = "https://api.semanticscholar.org" _S2_SEARCH_URL = f"{_S2_BASE_URL}/graph/v1/paper/search" _S2_PAPER_URL = f"{_S2_BASE_URL}/graph/v1/paper" _S2_RECS_URL = f"{_S2_BASE_URL}/recommendations/v1/papers/forpaper" _S2_FIELDS = ( "title,abstract,authors,year,publicationDate," "externalIds,citationCount,influentialCitationCount,url" ) _S2_CITATION_FIELDS = ( "title,abstract,authors,year,publicationDate," "externalIds,citationCount,influentialCitationCount,url,contexts,intents,isInfluential" ) def _s2_headers() -> dict: api_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY") return {"x-api-key": api_key} if api_key else {} def _parse_s2_paper(paper: dict, source_label: str = "semantic_scholar") -> dict | None: """Parse a Semantic Scholar paper object into our standard dict format.""" paper_id = paper.get("paperId", "") title = (paper.get("title") or "").strip() if not paper_id or not title: return None author_names = [a.get("name", "") for a in paper.get("authors", []) if a.get("name")] pub_date = None if paper.get("publicationDate"): try: pub_date = datetime.datetime.strptime(paper["publicationDate"], "%Y-%m-%d") except ValueError: pass if pub_date is None and paper.get("year"): try: pub_date = datetime.datetime(paper["year"], 1, 1) except (ValueError, TypeError): pass ext_ids = paper.get("externalIds") or {} return paper_dict( title=title, abstract=(paper.get("abstract") or "").strip() or None, authors=", ".join(author_names) if author_names else None, publication_date=pub_date, source=source_label, source_id=paper_id, url=paper.get("url") or f"https://www.semanticscholar.org/paper/{paper_id}", doi=ext_ids.get("DOI"), topics=[], citation_count=paper.get("citationCount"), influential_citation_count=paper.get("influentialCitationCount"), ) def fetch_semantic_scholar( query: str = "machine learning", max_results: int = 20, fields_of_study: list[str] | None = None, ) -> list[dict]: """Fetch papers from Semantic Scholar Academic Graph API.""" params = { "query": query, "limit": min(max_results, 100), "fields": _S2_FIELDS, } if fields_of_study: params["fieldsOfStudy"] = ",".join(fields_of_study) try: resp = _http.get( _S2_SEARCH_URL, params=params, headers=_s2_headers(), timeout=20, ) resp.raise_for_status() data = resp.json() except Exception as exc: logger.warning("Semantic Scholar search failed: %s", exc) return [] papers = [] for paper in data.get("data", []): try: parsed = _parse_s2_paper(paper) if parsed: papers.append(parsed) except Exception as exc: logger.warning("Skipping malformed S2 paper: %s", exc) return papers def fetch_s2_by_id(identifier: str) -> dict | None: """Fetch a single Semantic Scholar paper by S2 ID, DOI, arXiv ID, etc.""" url = f"{_S2_PAPER_URL}/{identifier}" params = {"fields": _S2_FIELDS} try: resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20) resp.raise_for_status() data = resp.json() except Exception as exc: logger.warning("S2 get_by_id failed for %s: %s", identifier, exc) return None return _parse_s2_paper(data) def fetch_s2_citations( paper_identifier: str, limit: int = 100, ) -> list[dict]: """Fetch papers that cite the given paper from Semantic Scholar. Args: paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc. limit: Max citations to return (max 1000). """ url = f"{_S2_PAPER_URL}/{paper_identifier}/citations" params = { "fields": _S2_CITATION_FIELDS, "limit": min(limit, 1000), } try: resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30) resp.raise_for_status() data = resp.json() except Exception as exc: logger.warning("S2 citations fetch failed for %s: %s", paper_identifier, exc) return [] results = [] for item in data.get("data", []): citing = item.get("citingPaper", {}) parsed = _parse_s2_paper(citing) if parsed: parsed["_is_influential"] = item.get("isInfluential", False) parsed["_citation_contexts"] = item.get("contexts", []) parsed["_citation_intents"] = item.get("intents", []) results.append(parsed) return results def fetch_s2_references( paper_identifier: str, limit: int = 100, ) -> list[dict]: """Fetch papers referenced by the given paper from Semantic Scholar. Args: paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc. limit: Max references to return (max 1000). """ url = f"{_S2_PAPER_URL}/{paper_identifier}/references" params = { "fields": _S2_CITATION_FIELDS, "limit": min(limit, 1000), } try: resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30) resp.raise_for_status() data = resp.json() except Exception as exc: logger.warning("S2 references fetch failed for %s: %s", paper_identifier, exc) return [] results = [] for item in data.get("data", []): cited = item.get("citedPaper", {}) parsed = _parse_s2_paper(cited) if parsed: parsed["_is_influential"] = item.get("isInfluential", False) parsed["_citation_contexts"] = item.get("contexts", []) parsed["_citation_intents"] = item.get("intents", []) results.append(parsed) return results def fetch_s2_recommendations( paper_identifier: str, limit: int = 20, ) -> list[dict]: """Fetch recommended papers from Semantic Scholar Recommendations API. Uses SPECTER2 embeddings for semantic similarity. Args: paper_identifier: S2 paper ID. limit: Max recommendations to return (max 500). """ url = f"{_S2_RECS_URL}/{paper_identifier}" params = { "fields": _S2_FIELDS, "limit": min(limit, 500), } try: resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20) resp.raise_for_status() data = resp.json() except Exception as exc: logger.warning("S2 recommendations failed for %s: %s", paper_identifier, exc) return [] results = [] for paper in data.get("recommendedPapers", []): try: parsed = _parse_s2_paper(paper) if parsed: results.append(parsed) except Exception as exc: logger.warning("Skipping malformed S2 recommendation: %s", exc) return results class SemanticScholarSource(PaperSource): """Semantic Scholar Academic Graph + Recommendations API.""" name = "semantic_scholar" def search( self, query: str, max_results: int = 20, filters: dict | None = None, ) -> list[dict]: filters = filters or {} return fetch_semantic_scholar( query=query, max_results=max_results, fields_of_study=filters.get("fields_of_study"), ) def get_by_id(self, identifier: str) -> dict | None: return fetch_s2_by_id(identifier) def supports(self, capability: Capability) -> bool: return capability in ("citations", "references", "recs") def rate_limit_hint(self) -> RateLimitHint: if os.getenv("SEMANTIC_SCHOLAR_API_KEY"): return RateLimitHint( min_interval_seconds=0.01, daily_quota=None, notes="S2 with API key: high throughput, observe per-endpoint limits", ) return RateLimitHint( min_interval_seconds=1.0, daily_quota=None, notes="S2 public pool: ~1 req/s, frequent 429s, retries strongly recommended", )