Spaces:
Sleeping
Sleeping
| """Semantic Scholar source: search + per-paper citations/references/recs.""" | |
| import datetime | |
| import logging | |
| import os | |
| from ._http import _http | |
| from .base import Capability, PaperSource, RateLimitHint, paper_dict | |
| logger = logging.getLogger(__name__) | |
| _S2_BASE_URL = "https://api.semanticscholar.org" | |
| _S2_SEARCH_URL = f"{_S2_BASE_URL}/graph/v1/paper/search" | |
| _S2_PAPER_URL = f"{_S2_BASE_URL}/graph/v1/paper" | |
| _S2_RECS_URL = f"{_S2_BASE_URL}/recommendations/v1/papers/forpaper" | |
| _S2_FIELDS = ( | |
| "title,abstract,authors,year,publicationDate," | |
| "externalIds,citationCount,influentialCitationCount,url" | |
| ) | |
| _S2_CITATION_FIELDS = ( | |
| "title,abstract,authors,year,publicationDate," | |
| "externalIds,citationCount,influentialCitationCount,url,contexts,intents,isInfluential" | |
| ) | |
| def _s2_headers() -> dict: | |
| api_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY") | |
| return {"x-api-key": api_key} if api_key else {} | |
| def _parse_s2_paper(paper: dict, source_label: str = "semantic_scholar") -> dict | None: | |
| """Parse a Semantic Scholar paper object into our standard dict format.""" | |
| paper_id = paper.get("paperId", "") | |
| title = (paper.get("title") or "").strip() | |
| if not paper_id or not title: | |
| return None | |
| author_names = [a.get("name", "") for a in paper.get("authors", []) if a.get("name")] | |
| pub_date = None | |
| if paper.get("publicationDate"): | |
| try: | |
| pub_date = datetime.datetime.strptime(paper["publicationDate"], "%Y-%m-%d") | |
| except ValueError: | |
| pass | |
| if pub_date is None and paper.get("year"): | |
| try: | |
| pub_date = datetime.datetime(paper["year"], 1, 1) | |
| except (ValueError, TypeError): | |
| pass | |
| ext_ids = paper.get("externalIds") or {} | |
| return paper_dict( | |
| title=title, | |
| abstract=(paper.get("abstract") or "").strip() or None, | |
| authors=", ".join(author_names) if author_names else None, | |
| publication_date=pub_date, | |
| source=source_label, | |
| source_id=paper_id, | |
| url=paper.get("url") or f"https://www.semanticscholar.org/paper/{paper_id}", | |
| doi=ext_ids.get("DOI"), | |
| topics=[], | |
| citation_count=paper.get("citationCount"), | |
| influential_citation_count=paper.get("influentialCitationCount"), | |
| ) | |
| def fetch_semantic_scholar( | |
| query: str = "machine learning", | |
| max_results: int = 20, | |
| fields_of_study: list[str] | None = None, | |
| ) -> list[dict]: | |
| """Fetch papers from Semantic Scholar Academic Graph API.""" | |
| params = { | |
| "query": query, | |
| "limit": min(max_results, 100), | |
| "fields": _S2_FIELDS, | |
| } | |
| if fields_of_study: | |
| params["fieldsOfStudy"] = ",".join(fields_of_study) | |
| try: | |
| resp = _http.get( | |
| _S2_SEARCH_URL, params=params, headers=_s2_headers(), timeout=20, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception as exc: | |
| logger.warning("Semantic Scholar search failed: %s", exc) | |
| return [] | |
| papers = [] | |
| for paper in data.get("data", []): | |
| try: | |
| parsed = _parse_s2_paper(paper) | |
| if parsed: | |
| papers.append(parsed) | |
| except Exception as exc: | |
| logger.warning("Skipping malformed S2 paper: %s", exc) | |
| return papers | |
| def fetch_s2_by_id(identifier: str) -> dict | None: | |
| """Fetch a single Semantic Scholar paper by S2 ID, DOI, arXiv ID, etc.""" | |
| url = f"{_S2_PAPER_URL}/{identifier}" | |
| params = {"fields": _S2_FIELDS} | |
| try: | |
| resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception as exc: | |
| logger.warning("S2 get_by_id failed for %s: %s", identifier, exc) | |
| return None | |
| return _parse_s2_paper(data) | |
| def fetch_s2_citations( | |
| paper_identifier: str, | |
| limit: int = 100, | |
| ) -> list[dict]: | |
| """Fetch papers that cite the given paper from Semantic Scholar. | |
| Args: | |
| paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc. | |
| limit: Max citations to return (max 1000). | |
| """ | |
| url = f"{_S2_PAPER_URL}/{paper_identifier}/citations" | |
| params = { | |
| "fields": _S2_CITATION_FIELDS, | |
| "limit": min(limit, 1000), | |
| } | |
| try: | |
| resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception as exc: | |
| logger.warning("S2 citations fetch failed for %s: %s", paper_identifier, exc) | |
| return [] | |
| results = [] | |
| for item in data.get("data", []): | |
| citing = item.get("citingPaper", {}) | |
| parsed = _parse_s2_paper(citing) | |
| if parsed: | |
| parsed["_is_influential"] = item.get("isInfluential", False) | |
| parsed["_citation_contexts"] = item.get("contexts", []) | |
| parsed["_citation_intents"] = item.get("intents", []) | |
| results.append(parsed) | |
| return results | |
| def fetch_s2_references( | |
| paper_identifier: str, | |
| limit: int = 100, | |
| ) -> list[dict]: | |
| """Fetch papers referenced by the given paper from Semantic Scholar. | |
| Args: | |
| paper_identifier: S2 paper ID, DOI (DOI:xxx), arXiv ID (ARXIV:xxx), etc. | |
| limit: Max references to return (max 1000). | |
| """ | |
| url = f"{_S2_PAPER_URL}/{paper_identifier}/references" | |
| params = { | |
| "fields": _S2_CITATION_FIELDS, | |
| "limit": min(limit, 1000), | |
| } | |
| try: | |
| resp = _http.get(url, params=params, headers=_s2_headers(), timeout=30) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception as exc: | |
| logger.warning("S2 references fetch failed for %s: %s", paper_identifier, exc) | |
| return [] | |
| results = [] | |
| for item in data.get("data", []): | |
| cited = item.get("citedPaper", {}) | |
| parsed = _parse_s2_paper(cited) | |
| if parsed: | |
| parsed["_is_influential"] = item.get("isInfluential", False) | |
| parsed["_citation_contexts"] = item.get("contexts", []) | |
| parsed["_citation_intents"] = item.get("intents", []) | |
| results.append(parsed) | |
| return results | |
| def fetch_s2_recommendations( | |
| paper_identifier: str, | |
| limit: int = 20, | |
| ) -> list[dict]: | |
| """Fetch recommended papers from Semantic Scholar Recommendations API. | |
| Uses SPECTER2 embeddings for semantic similarity. | |
| Args: | |
| paper_identifier: S2 paper ID. | |
| limit: Max recommendations to return (max 500). | |
| """ | |
| url = f"{_S2_RECS_URL}/{paper_identifier}" | |
| params = { | |
| "fields": _S2_FIELDS, | |
| "limit": min(limit, 500), | |
| } | |
| try: | |
| resp = _http.get(url, params=params, headers=_s2_headers(), timeout=20) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| except Exception as exc: | |
| logger.warning("S2 recommendations failed for %s: %s", paper_identifier, exc) | |
| return [] | |
| results = [] | |
| for paper in data.get("recommendedPapers", []): | |
| try: | |
| parsed = _parse_s2_paper(paper) | |
| if parsed: | |
| results.append(parsed) | |
| except Exception as exc: | |
| logger.warning("Skipping malformed S2 recommendation: %s", exc) | |
| return results | |
| class SemanticScholarSource(PaperSource): | |
| """Semantic Scholar Academic Graph + Recommendations API.""" | |
| name = "semantic_scholar" | |
| def search( | |
| self, | |
| query: str, | |
| max_results: int = 20, | |
| filters: dict | None = None, | |
| ) -> list[dict]: | |
| filters = filters or {} | |
| return fetch_semantic_scholar( | |
| query=query, | |
| max_results=max_results, | |
| fields_of_study=filters.get("fields_of_study"), | |
| ) | |
| def get_by_id(self, identifier: str) -> dict | None: | |
| return fetch_s2_by_id(identifier) | |
| def supports(self, capability: Capability) -> bool: | |
| return capability in ("citations", "references", "recs") | |
| def rate_limit_hint(self) -> RateLimitHint: | |
| if os.getenv("SEMANTIC_SCHOLAR_API_KEY"): | |
| return RateLimitHint( | |
| min_interval_seconds=0.01, | |
| daily_quota=None, | |
| notes="S2 with API key: high throughput, observe per-endpoint limits", | |
| ) | |
| return RateLimitHint( | |
| min_interval_seconds=1.0, | |
| daily_quota=None, | |
| notes="S2 public pool: ~1 req/s, frequent 429s, retries strongly recommended", | |
| ) | |