"""arXiv source: Atom feed search + id_list lookup.""" import datetime import logging import re import time from typing import Any import feedparser import requests from ._http import _http from .base import Capability, PaperSource, RateLimitHint, paper_dict logger = logging.getLogger(__name__) _ARXIV_BASE_URL = "https://export.arxiv.org/api/query" _ARXIV_CATEGORY_TOPICS: dict[str, list[str]] = { "cs.AI": ["artificial intelligence"], "cs.LG": ["machine learning"], "cs.CL": ["natural language processing"], "cs.CV": ["computer vision"], "cs.RO": ["robotics"], "cs.CR": ["cybersecurity"], "cs.DC": ["distributed systems"], "cs.NE": ["neural networks"], "stat.ML": ["machine learning", "statistics"], "q-bio": ["bioinformatics"], } def _parse_arxiv_date(date_str: str) -> datetime.datetime | None: for fmt in ("%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S%z"): try: return datetime.datetime.strptime(date_str, fmt).replace(tzinfo=None) except ValueError: continue return None def _entry_to_paper(entry: Any, default_topics: list[str] | None = None) -> dict | None: """Convert a feedparser entry into the standard paper dict.""" try: authors = ", ".join(a.get("name", "") for a in entry.get("authors", [])) arxiv_id_match = re.search(r"abs/([^v]+)", entry.id) arxiv_id = arxiv_id_match.group(1) if arxiv_id_match else entry.id entry_cats = [ t.get("term", "") for t in entry.get("tags", []) if t.get("term") ] topics = list(default_topics or []) combined_topics = list(dict.fromkeys(topics + entry_cats)) return paper_dict( title=entry.title.replace("\n", " ").strip(), abstract=(entry.summary or "").replace("\n", " ").strip() or None, authors=authors, publication_date=_parse_arxiv_date(entry.get("published", "")), source="arxiv", source_id=arxiv_id, url=entry.link, doi=f"10.48550/arXiv.{arxiv_id}", topics=combined_topics, ) except Exception as exc: logger.warning("Skipping malformed arXiv entry: %s", exc) return None def fetch_arxiv( query: str = "", category: str | None = None, max_results: int = 20, ) -> list[dict]: """Fetch papers from arXiv API. Searches all categories when query is provided.""" if query and not category: search_query = f"all:{query}" elif query and category: search_query = f"(cat:{category}) AND all:{query}" elif category: search_query = f"cat:{category}" else: search_query = "cat:cs.AI" params = { "search_query": search_query, "sortBy": "submittedDate", "sortOrder": "descending", "max_results": max_results, } try: resp = _http.get(_ARXIV_BASE_URL, params=params, timeout=15) resp.raise_for_status() except requests.RequestException as exc: logger.warning("arXiv fetch failed: %s", exc) return [] feed = feedparser.parse(resp.text) default_topics = _ARXIV_CATEGORY_TOPICS.get(category, [category] if category else []) papers = [] for entry in feed.entries: parsed = _entry_to_paper(entry, default_topics=default_topics) if parsed: papers.append(parsed) return papers def fetch_arxiv_multi( query: str = "", categories: list[str] | None = None, max_per_category: int = 10, ) -> list[dict]: """Fetch from multiple arXiv categories, deduplicating.""" cats = categories or ["cs.AI", "cs.LG", "cs.CL"] seen = set() results = [] for cat in cats: papers = fetch_arxiv(query=query, category=cat, max_results=max_per_category) for p in papers: if p["source_id"] not in seen: seen.add(p["source_id"]) results.append(p) time.sleep(3) # arXiv rate limit return results def fetch_arxiv_by_id(arxiv_id: str) -> dict | None: """Fetch a single arXiv paper by its identifier (e.g. '2301.12345').""" params = {"id_list": arxiv_id, "max_results": 1} try: resp = _http.get(_ARXIV_BASE_URL, params=params, timeout=15) resp.raise_for_status() except requests.RequestException as exc: logger.warning("arXiv get_by_id failed for %s: %s", arxiv_id, exc) return None feed = feedparser.parse(resp.text) if not feed.entries: return None return _entry_to_paper(feed.entries[0]) class ArxivSource(PaperSource): """arXiv preprint server. Search + id lookup only.""" name = "arxiv" def search( self, query: str, max_results: int = 20, filters: dict | None = None, ) -> list[dict]: filters = filters or {} return fetch_arxiv( query=query, category=filters.get("category"), max_results=max_results, ) def get_by_id(self, identifier: str) -> dict | None: return fetch_arxiv_by_id(identifier) def supports(self, capability: Capability) -> bool: # arXiv has no first-party citation graph or recommendations API. return False def rate_limit_hint(self) -> RateLimitHint: return RateLimitHint( min_interval_seconds=3.0, daily_quota=None, notes="arXiv API guidance: max 1 request per 3 seconds; use OAI-PMH for bulk", )