| import asyncio |
| import time |
| import uuid |
| import os |
| import json |
| from typing import Dict, List, Optional, Union, Any |
| from fastapi import FastAPI, HTTPException, Depends, Request, status, Body |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field, EmailStr |
| from slowapi import Limiter, _rate_limit_exceeded_handler |
| from slowapi.util import get_remote_address |
| from slowapi.errors import RateLimitExceeded |
| import uvicorn |
|
|
| from db_helper import MongoDBHelper |
| from deepinfra_client import DeepInfraClient |
| from hf_utils import HuggingFaceSpaceHelper |
|
|
| |
| hf_helper = HuggingFaceSpaceHelper() |
|
|
| |
| if hf_helper.is_in_space: |
| hf_helper.install_dependencies([ |
| "pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi", |
| "fake-useragent", "requests-ip-rotator", "pydantic[email]" |
| ]) |
|
|
| |
| app = FastAPI( |
| title="PyScoutAI API", |
| description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features", |
| version="1.0.0" |
| ) |
|
|
| |
| limiter = Limiter(key_func=get_remote_address) |
| app.state.limiter = limiter |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| security = HTTPBearer(auto_error=False) |
|
|
| |
| try: |
| db = MongoDBHelper(hf_helper.get_mongodb_uri()) |
| except Exception as e: |
| print(f"Warning: MongoDB connection failed: {e}") |
| print("API key authentication will not work!") |
| db = None |
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: Optional[str] = None |
| name: Optional[str] = None |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str |
| messages: List[Message] |
| temperature: Optional[float] = 0.7 |
| top_p: Optional[float] = 1.0 |
| n: Optional[int] = 1 |
| stream: Optional[bool] = False |
| max_tokens: Optional[int] = None |
| presence_penalty: Optional[float] = 0.0 |
| frequency_penalty: Optional[float] = 0.0 |
| user: Optional[str] = None |
|
|
| class CompletionRequest(BaseModel): |
| model: str |
| prompt: Union[str, List[str]] |
| temperature: Optional[float] = 0.7 |
| top_p: Optional[float] = 1.0 |
| n: Optional[int] = 1 |
| stream: Optional[bool] = False |
| max_tokens: Optional[int] = None |
| presence_penalty: Optional[float] = 0.0 |
| frequency_penalty: Optional[float] = 0.0 |
| user: Optional[str] = None |
|
|
| class UserCreate(BaseModel): |
| email: EmailStr |
| name: str |
| organization: Optional[str] = None |
|
|
| class APIKeyCreate(BaseModel): |
| name: str = "Default API Key" |
| user_id: str |
|
|
| class APIKeyResponse(BaseModel): |
| key: str |
| name: str |
| created_at: str |
|
|
| |
| clients: Dict[str, DeepInfraClient] = {} |
|
|
| |
| async def get_api_key( |
| request: Request, |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) |
| ) -> Optional[str]: |
| |
| if credentials: |
| return credentials.credentials |
| |
| |
| if "Authorization" in request.headers: |
| auth = request.headers["Authorization"] |
| if auth.startswith("Bearer "): |
| return auth.replace("Bearer ", "") |
| |
| if "x-api-key" in request.headers: |
| return request.headers["x-api-key"] |
| |
| |
| api_key = request.query_params.get("api_key") |
| if api_key: |
| return api_key |
| |
| |
| return None |
|
|
| |
| async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]: |
| if not api_key: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="API key is required", |
| headers={"WWW-Authenticate": "Bearer"} |
| ) |
| |
| |
| if not db: |
| return {"user_id": "development", "key": api_key} |
| |
| |
| if not api_key.startswith("PyScoutAI-"): |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key format", |
| headers={"WWW-Authenticate": "Bearer"} |
| ) |
| |
| |
| user_info = db.validate_api_key(api_key) |
| if not user_info: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API key", |
| headers={"WWW-Authenticate": "Bearer"} |
| ) |
| |
| |
| rate_limit = db.check_rate_limit(api_key) |
| if not rate_limit["allowed"]: |
| raise HTTPException( |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| detail=rate_limit["reason"] |
| ) |
| |
| return user_info |
|
|
| |
| def get_client(api_key: str) -> DeepInfraClient: |
| if api_key not in clients: |
| |
| clients[api_key] = DeepInfraClient( |
| use_random_user_agent=True, |
| use_proxy_rotation=True, |
| use_ip_rotation=True |
| ) |
| return clients[api_key] |
|
|
| @app.get("/") |
| async def root(): |
| metadata = hf_helper.get_hf_metadata() |
| return { |
| "message": "Welcome to PyScoutAI API", |
| "documentation": "/docs", |
| "environment": "Hugging Face Space" if hf_helper.is_in_space else "Local", |
| "endpoints": [ |
| "/v1/models", |
| "/v1/chat/completions", |
| "/v1/completions" |
| ], |
| **metadata |
| } |
|
|
| @app.get("/v1/models") |
| @limiter.limit("20/minute") |
| async def list_models( |
| request: Request, |
| user_info: Dict[str, Any] = Depends(get_user_info) |
| ): |
| api_key = user_info["key"] |
| client = get_client(api_key) |
| try: |
| models = await asyncio.to_thread(client.models.list) |
| |
| if db: |
| db.log_api_usage(api_key, "/v1/models", 0) |
| return models |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}") |
|
|
| @app.post("/v1/chat/completions") |
| @limiter.limit("60/minute") |
| async def create_chat_completion( |
| request: Request, |
| body: ChatCompletionRequest, |
| user_info: Dict[str, Any] = Depends(get_user_info) |
| ): |
| api_key = user_info["key"] |
| client = get_client(api_key) |
| |
| try: |
| |
| messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None] |
| |
| kwargs = { |
| "model": body.model, |
| "temperature": body.temperature, |
| "max_tokens": body.max_tokens, |
| "stream": body.stream, |
| "top_p": body.top_p, |
| "presence_penalty": body.presence_penalty, |
| "frequency_penalty": body.frequency_penalty, |
| } |
| |
| if body.stream: |
| async def generate_stream(): |
| response_stream = await asyncio.to_thread( |
| client.chat.create, |
| messages=messages, |
| **kwargs |
| ) |
| |
| total_tokens = 0 |
| for chunk in response_stream: |
| |
| if 'usage' in chunk and chunk['usage']: |
| total_tokens += chunk['usage'].get('total_tokens', 0) |
| |
| yield f"data: {json.dumps(chunk)}\n\n" |
| |
| |
| if db: |
| db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) |
| |
| yield "data: [DONE]\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream" |
| ) |
| else: |
| response = await asyncio.to_thread( |
| client.chat.create, |
| messages=messages, |
| **kwargs |
| ) |
| |
| |
| if db and 'usage' in response: |
| total_tokens = response['usage'].get('total_tokens', 0) |
| db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) |
| |
| return response |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}") |
|
|
| @app.post("/v1/completions") |
| @limiter.limit("60/minute") |
| async def create_completion( |
| request: Request, |
| body: CompletionRequest, |
| user_info: Dict[str, Any] = Depends(get_user_info) |
| ): |
| api_key = user_info["key"] |
| client = get_client(api_key) |
| |
| try: |
| |
| prompt = body.prompt |
| if isinstance(prompt, list): |
| prompt = prompt[0] |
| |
| kwargs = { |
| "model": body.model, |
| "temperature": body.temperature, |
| "max_tokens": body.max_tokens, |
| "stream": body.stream, |
| "top_p": body.top_p, |
| "presence_penalty": body.presence_penalty, |
| "frequency_penalty": body.frequency_penalty, |
| } |
| |
| if body.stream: |
| async def generate_stream(): |
| response_stream = await asyncio.to_thread( |
| client.completions.create, |
| prompt=prompt, |
| **kwargs |
| ) |
| |
| total_tokens = 0 |
| for chunk in response_stream: |
| if 'usage' in chunk and chunk['usage']: |
| total_tokens += chunk['usage'].get('total_tokens', 0) |
| |
| yield f"data: {json.dumps(chunk)}\n\n" |
| |
| |
| if db: |
| db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) |
| |
| yield "data: [DONE]\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream" |
| ) |
| else: |
| response = await asyncio.to_thread( |
| client.completions.create, |
| prompt=prompt, |
| **kwargs |
| ) |
| |
| |
| if db and 'usage' in response: |
| total_tokens = response['usage'].get('total_tokens', 0) |
| db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) |
| |
| return response |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}") |
|
|
| @app.get("/health") |
| async def health_check(): |
| status_info = {"api": "ok"} |
| |
| |
| if db: |
| try: |
| |
| db.api_keys_collection.find_one({}) |
| status_info["database"] = "ok" |
| except Exception as e: |
| status_info["database"] = f"error: {str(e)}" |
| else: |
| status_info["database"] = "not configured" |
| |
| |
| if hf_helper.is_in_space: |
| status_info["environment"] = "Hugging Face Space" |
| status_info["space_name"] = hf_helper.space_name |
| else: |
| status_info["environment"] = "Local" |
| |
| return status_info |
|
|
| |
| @app.post("/v1/api_keys", response_model=APIKeyResponse) |
| async def create_api_key(body: APIKeyCreate): |
| if not db: |
| raise HTTPException(status_code=500, detail="Database not configured") |
| |
| try: |
| api_key = db.generate_api_key(body.user_id, body.name) |
| key_data = db.validate_api_key(api_key) |
| return { |
| "key": api_key, |
| "name": key_data["name"], |
| "created_at": key_data["created_at"].isoformat() |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}") |
|
|
| @app.get("/v1/api_keys") |
| async def list_api_keys(user_id: str): |
| if not db: |
| raise HTTPException(status_code=500, detail="Database not configured") |
| |
| keys = db.get_user_api_keys(user_id) |
| for key in keys: |
| if "created_at" in key: |
| key["created_at"] = key["created_at"].isoformat() |
| if "last_used" in key and key["last_used"]: |
| key["last_used"] = key["last_used"].isoformat() |
| |
| return {"keys": keys} |
|
|
| @app.post("/v1/api_keys/revoke") |
| async def revoke_api_key(api_key: str): |
| if not db: |
| raise HTTPException(status_code=500, detail="Database not configured") |
| |
| success = db.revoke_api_key(api_key) |
| if not success: |
| raise HTTPException(status_code=404, detail="API key not found") |
| |
| return {"message": "API key revoked successfully"} |
|
|
| |
| @app.on_event("shutdown") |
| async def cleanup_clients(): |
| for client in clients.values(): |
| try: |
| if hasattr(client, 'ip_rotator') and client.ip_rotator: |
| client.ip_rotator.shutdown() |
| except: |
| pass |
|
|
| f __name__ == "__main__": |
| host = os.environ.get("HOST", "0.0.0.0") |
| port = int(os.environ.get("PORT", "7860")) |
| |
| print(f"Starting PyScoutAI API on http://{host}:{port}") |
| print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}") |
| |
| uvicorn.run( |
| app, |
| host=host, |
| port=port, |
| reload=not hf_helper.is_in_space |
| ) |
|
|