Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import asyncio | |
| import time | |
| import traceback | |
| import json | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| from contextlib import asynccontextmanager | |
| from huggingface_hub import hf_hub_download | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ---------- CPU optimizations ---------- | |
| def optimize_for_cpu(): | |
| """Apply CPU-specific optimizations (optional).""" | |
| os.environ['OMP_NUM_THREADS'] = str(os.cpu_count()) | |
| os.environ['KMP_BLOCKTIME'] = '1' | |
| os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0' | |
| try: | |
| import psutil | |
| p = psutil.Process() | |
| p.nice(-5) | |
| logger.debug("Set process to higher priority") | |
| except: | |
| pass | |
| optimize_for_cpu() | |
| # ---------- Queue management ---------- | |
| class QueueStatus: | |
| def __init__(self, max_concurrent: int = 1): | |
| self.max_concurrent = max_concurrent | |
| self.active_tasks = 0 | |
| self.pending_queue = [] | |
| self._lock = asyncio.Lock() | |
| async def acquire(self): | |
| async with self._lock: | |
| if self.active_tasks < self.max_concurrent: | |
| self.active_tasks += 1 | |
| return True, 0 | |
| else: | |
| position = len(self.pending_queue) + 1 | |
| future = asyncio.Future() | |
| self.pending_queue.append(future) | |
| return False, position | |
| async def release(self): | |
| async with self._lock: | |
| self.active_tasks -= 1 | |
| if self.pending_queue: | |
| future = self.pending_queue.pop(0) | |
| future.set_result(True) | |
| self.active_tasks += 1 | |
| def get_status(self): | |
| return { | |
| "active": self.active_tasks, | |
| "queued": len(self.pending_queue), | |
| "max_concurrent": self.max_concurrent | |
| } | |
| queue_status = QueueStatus(max_concurrent=1) | |
| # ---------- The model class with local GGUF model ---------- | |
| class MixtralFreeModel: | |
| def __init__(self, model_path: str = None): | |
| self.model_name = "ministral-3.3b" | |
| self.max_tokens = 512 | |
| self.temperature = 0.7 | |
| if model_path is None: | |
| model_path = os.environ.get("GGUF_MODEL_PATH", None) | |
| if model_path and os.path.exists(model_path): | |
| gguf_file = model_path | |
| logger.info(f"Using provided model path: {gguf_file}") | |
| else: | |
| local_path = "/app/models/Ministral-3-3B-Instruct-2512-Q4_K_M.gguf" | |
| if os.path.exists(local_path): | |
| gguf_file = local_path | |
| logger.info(f"Using local model file: {local_path}") | |
| else: | |
| logger.info("Downloading Ministral-3.3B model from Hugging Face Hub...") | |
| gguf_file = hf_hub_download( | |
| repo_id="mistralai/Ministral-3-3B-Instruct-2512-GGUF", | |
| filename="Ministral-3-3B-Instruct-2512-Q4_K_M.gguf" | |
| ) | |
| logger.info(f"Downloaded model to: {gguf_file}") | |
| logger.info(f"Loading GGUF model from {gguf_file}...") | |
| start_time = time.time() | |
| try: | |
| self.llm = Llama( | |
| model_path=gguf_file, | |
| n_ctx=4096, | |
| n_batch=512, | |
| n_gpu_layers=0, | |
| n_threads=os.cpu_count(), | |
| n_threads_batch=os.cpu_count(), | |
| use_mlock=True, | |
| use_mmap=True, | |
| low_vram=False, | |
| verbose=False, | |
| seed=42, | |
| ) | |
| load_time = time.time() - start_time | |
| logger.info(f"GGUF model loaded successfully in {load_time:.2f}s") | |
| except Exception as e: | |
| logger.error(f"Failed to load GGUF model: {e}") | |
| raise | |
| async def warm_up(self) -> None: | |
| logger.info("Warming up model with test inference...") | |
| start_time = time.time() | |
| try: | |
| await self._generate_completion("Hello", max_tokens=10, temperature=0.1) | |
| warm_up_time = time.time() - start_time | |
| logger.info(f"Model warm-up completed in {warm_up_time:.2f}s") | |
| except Exception as e: | |
| logger.warning(f"Model warm-up failed: {e}") | |
| async def _generate_completion(self, prompt: str, max_tokens: int = None, temperature: float = None) -> str: | |
| if max_tokens is None: | |
| max_tokens = self.max_tokens | |
| if temperature is None: | |
| temperature = 0.3 | |
| def _blocking(): | |
| start = time.time() | |
| response = self.llm.create_completion( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.95, | |
| stop=[], | |
| echo=False, | |
| stream=False | |
| ) | |
| elapsed = time.time() - start | |
| logger.debug(f"Blocking completion took {elapsed:.2f}s") | |
| return response['choices'][0]['text'].strip() | |
| return await asyncio.to_thread(_blocking) | |
| async def generate_response(self, question: str, context: str = "") -> str: | |
| is_guide_request = any(phrase in question.lower() for phrase in | |
| ["guide", "create a guide", "make a guide", "step by step", "tutorial"]) | |
| if is_guide_request: | |
| system_prompt = f"""You are an assistant that creates structured guides. | |
| When asked to create a guide, you MUST respond with ONLY a valid JSON object. | |
| Do not include any additional text, explanations, markdown, or code fences. | |
| The JSON object must contain the keys "action" and "summary". | |
| Format: | |
| {{"action": "generate_guide", "summary": "Brief summary of the task"}} | |
| Conversation context: | |
| {context} | |
| Now produce the JSON object for the user's request:""" | |
| else: | |
| system_prompt = f"""You are a helpful, accurate, and context-aware assistant. Use the conversation history below to provide a relevant and useful answer to the question. | |
| IMPORTANT: | |
| - Answer in the same language as the question | |
| - Be concise but comprehensive | |
| - Use the conversation context when relevant | |
| - If the context doesn't contain relevant information, use your general knowledge | |
| Conversation history: | |
| {context} | |
| Provide a helpful response""" | |
| prompt = f"<s>[INST] {system_prompt}\n\nNow handle this user request: {question} [/INST]" | |
| try: | |
| response_text = await self._generate_completion(prompt, max_tokens=512) | |
| if is_guide_request: | |
| import re | |
| match = re.search(r'\{[^{}]*"action"\s*:\s*"generate_guide"[^{}]*\}', response_text, re.DOTALL) | |
| if match: | |
| return match.group(0) | |
| else: | |
| logger.warning("Model did not return valid JSON for guide request. Using fallback.") | |
| return json.dumps({ | |
| "action": "generate_guide", | |
| "summary": "Create a guide based on the conversation.", | |
| "sections": ["Overview", "Prerequisites", "Step-by-Step Instructions", "Tools & Assets", "Flow"] | |
| }) | |
| return response_text | |
| except Exception as e: | |
| logger.error(f"Error in generation: {str(e)}") | |
| return "I apologize, but I'm having trouble responding right now." | |
| def clean_question(self, question: str) -> str: | |
| prefixes = ['!bot', '!ai', '@bot', 'bot,', '!ai_search'] | |
| if not question or not question.strip(): | |
| return question | |
| question_lower = question.lower().strip() | |
| original_question = question.strip() | |
| for prefix in prefixes: | |
| if question_lower.startswith(prefix.lower()): | |
| cleaned = original_question[len(prefix):].lstrip(' ,!:@') | |
| return cleaned | |
| return original_question | |
| async def compress_input(self, text: str, max_tokens: int = 500) -> str: | |
| if len(text.split()) < max_tokens: | |
| return text | |
| logger.info(f"Compressing input of {len(text.split())} words...") | |
| start = time.time() | |
| prompt = f"<s>[INST] Summarize the following text into a concise, structured form (bullet points or key-value pairs) keeping all essential details. Use at most {max_tokens} tokens.\n\nText:\n{text}\n\nSummary: [/INST]" | |
| summary = await self._generate_completion(prompt, max_tokens=max_tokens, temperature=0.5) | |
| elapsed = time.time() - start | |
| logger.info(f"Compression completed in {elapsed:.2f}s") | |
| return summary | |
| async def generate_efficient_section(self, section_type: str, context: str, max_tokens: int = 300) -> str: | |
| logger.info(f"Generating efficient representation for '{section_type}'...") | |
| start = time.time() | |
| system = f"You are an expert task guide writer. Generate content for the section \"{section_type}\" in an efficient language format.\nUse a structured format like:\n- Key point 1: details\n- Key point 2: details\nOr use JSON if appropriate. Keep it concise and use at most {max_tokens} tokens." | |
| prompt = f"<s>[INST] {system}\n\nContext: {context}\nGenerate the efficient language for {section_type} section. [/INST]" | |
| efficient = await self._generate_completion(prompt, max_tokens=max_tokens) | |
| elapsed = time.time() - start | |
| logger.info(f"Efficient section generation took {elapsed:.2f}s") | |
| return efficient | |
| async def expand_efficient_to_natural(self, efficient_text: str, section_type: str, max_tokens: int = 300) -> str: | |
| logger.info(f"Expanding efficient language to natural text for section '{section_type}'...") | |
| start = time.time() | |
| system = f"""You are an expert task guide writer. | |
| Expand the efficient language into a **short but helpful** section titled "{section_type}". | |
| STRICT RULES: | |
| - Maximum 120 words total. | |
| - Use markdown subheadings (###) and bullet points. | |
| - No long paragraphs – break into 3-5 bullet points or short phrases. | |
| - Skip introductions, conclusions, and fluff. | |
| - Keep the tone professional and clear. | |
| Efficient language: | |
| {efficient_text} | |
| Write the {section_type} section now:""" | |
| prompt = f"<s>[INST] {system}\n\nEfficient language:\n{efficient_text}\n\nWrite the full {section_type} section now. [/INST]" | |
| expanded = await self._generate_completion(prompt, max_tokens=max_tokens) | |
| elapsed = time.time() - start | |
| logger.info(f"Expansion took {elapsed:.2f}s") | |
| return expanded | |
| async def generate_flow_diagram(self, context: str) -> str: | |
| prompt = f"""[INST] You are an expert at creating Mermaid flowcharts for task guides. | |
| STRICT RULES: | |
| - Output ONLY a Mermaid diagram | |
| - MUST be inside a markdown code block with ```mermaid | |
| - Use "flowchart TD" | |
| - No explanations, no extra text | |
| Context: | |
| {context} | |
| Example format: | |
| ```mermaid | |
| flowchart TD | |
| A[Start] --> B[Step 1] | |
| B --> C{{Decision}} | |
| C -->|Yes| D[Step 2] | |
| C -->|No| E[Step 3] | |
| D --> F[End] | |
| E --> F | |
| Now generate the diagram. [/INST]""" | |
| try: | |
| response = await self._generate_completion(prompt, max_tokens=512, temperature=0.2) | |
| response = response.strip() | |
| if response.startswith("```mermaid") and response.endswith("```"): | |
| return response | |
| if "flowchart" in response or "graph" in response: | |
| return f"```mermaid\n{response}\n```" | |
| logger.warning("Invalid Mermaid output, using fallback diagram.") | |
| return """```mermaid | |
| flowchart TD | |
| A[Start] --> B[Follow the steps above] | |
| B --> C[Complete task] | |
| C --> D[End]""" | |
| except Exception as e: | |
| logger.error(f"Flow diagram generation failed: {e}") | |
| return """```mermaid | |
| flowchart TD | |
| A[Start] --> B[Error generating diagram] | |
| B --> C[Try again] | |
| C --> D[End] | |
| ```""" | |
| async def generate_section(self, section_type: str, context: str, compress_input: bool = True) -> str: | |
| total_start = time.time() | |
| if section_type.lower() == "flow": | |
| return await self.generate_flow_diagram(context) | |
| logger.info(f"Starting section generation for '{section_type}' (compress_input={compress_input})") | |
| if compress_input and len(context.split()) > 1500: | |
| context = await self.compress_input(context, max_tokens=1000) | |
| else: | |
| logger.info(f"Input context size OK: {len(context.split())} words") | |
| efficient = await self.generate_efficient_section(section_type, context) | |
| expanded = await self.expand_efficient_to_natural(efficient, section_type) | |
| total_time = time.time() - total_start | |
| logger.info(f"Total section generation time: {total_time:.2f}s") | |
| return expanded | |
| # ---------- Global model variable ---------- | |
| model = None | |
| # ---------- Lifespan context manager ---------- | |
| async def lifespan(app: FastAPI): | |
| global model | |
| try: | |
| logger.info("Starting lifespan startup...") | |
| start_total = time.time() | |
| model = MixtralFreeModel() | |
| await model.warm_up() | |
| total_time = time.time() - start_total | |
| logger.info(f"Model initialized and warmed up successfully in {total_time:.2f}s") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize model: {e}") | |
| model = None | |
| yield | |
| logger.info("Shutting down, releasing model resources.") | |
| model = None | |
| logger.info("Shutdown complete.") | |
| # ---------- FastAPI app ---------- | |
| app = FastAPI( | |
| title="Free AI Response API", | |
| description="Uses local GGUF model with queue management", | |
| version="1.0", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| question: str | |
| context: str = "" | |
| class ChatResponse(BaseModel): | |
| response: str | |
| class GenerateSectionRequest(BaseModel): | |
| section_type: str | |
| context: str = "" # legacy, optional | |
| compressed_context: str = None # new field (skip efficient phase) | |
| compress_input: bool = True | |
| class GenerateSectionResponse(BaseModel): | |
| content: str | |
| class CompressQueryRequest(BaseModel): | |
| prompt: str | |
| class CompressQueryResponse(BaseModel): | |
| compressed: str | |
| # ---------- Endpoints ---------- | |
| async def root(): | |
| return {"message": "Free AI Response API is running. Use POST /chat, POST /generate-section, or POST /compress-query."} | |
| async def get_queue_status(): | |
| return queue_status.get_status() | |
| async def chat(request: ChatRequest): | |
| queue_start = time.time() | |
| can_process, queue_position = await queue_status.acquire() | |
| queue_wait = time.time() - queue_start | |
| if not can_process: | |
| logger.info(f"Request queued at position {queue_position}") | |
| return {"status": "queued", "queue_position": queue_position} | |
| logger.info(f"Request started processing after queue wait {queue_wait:.3f}s") | |
| req_start = time.time() | |
| try: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| cleaned_question = model.clean_question(request.question) | |
| response_text = await model.generate_response(cleaned_question, request.context) | |
| total_time = time.time() - req_start | |
| logger.info(f"Chat request completed in {total_time:.2f}s (queue wait {queue_wait:.3f}s)") | |
| return ChatResponse(response=response_text) | |
| except Exception as e: | |
| logger.error(f"Error processing request: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| finally: | |
| await queue_status.release() | |
| async def generate_section_endpoint(request: GenerateSectionRequest): | |
| queue_start = time.time() | |
| can_process, queue_position = await queue_status.acquire() | |
| queue_wait = time.time() - queue_start | |
| if not can_process: | |
| return {"status": "queued", "queue_position": queue_position} | |
| logger.info(f"Section generation started after queue wait {queue_wait:.3f}s") | |
| try: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| # SPECIAL CASE: Flow section -> generate Mermaid diagram | |
| if request.section_type.lower() == "flow": | |
| # For Flow, we ignore compressed_context and always generate a diagram | |
| # But we can optionally use compressed_context as additional context | |
| if request.compressed_context: | |
| context = request.compressed_context | |
| else: | |
| context = request.context | |
| diagram = await model.generate_flow_diagram(context) | |
| total_time = time.time() - queue_start | |
| logger.info(f"Flow diagram generated in {total_time:.2f}s") | |
| return GenerateSectionResponse(content=diagram) | |
| # Normal sections: use compressed_context if provided, else efficient+expand | |
| if request.compressed_context: | |
| efficient_repr = request.compressed_context | |
| logger.info(f"Using provided compressed context for section '{request.section_type}'") | |
| else: | |
| context_to_use = request.context | |
| if request.compress_input and len(context_to_use.split()) > 1500: | |
| logger.info("Input context large, compressing...") | |
| context_to_use = await model.compress_input(context_to_use, max_tokens=1000) | |
| efficient_repr = await model.generate_efficient_section(request.section_type, context_to_use) | |
| expanded = await model.expand_efficient_to_natural(efficient_repr, request.section_type) | |
| total_time = time.time() - queue_start | |
| logger.info(f"Generate-section request completed in {total_time:.2f}s (queue wait {queue_wait:.3f}s)") | |
| return GenerateSectionResponse(content=expanded) | |
| except Exception as e: | |
| logger.error(f"Error generating section: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| finally: | |
| await queue_status.release() | |
| async def compress_query_endpoint(request: CompressQueryRequest): | |
| queue_start = time.time() | |
| can_process, queue_position = await queue_status.acquire() | |
| queue_wait = time.time() - queue_start | |
| if not can_process: | |
| return {"status": "queued", "queue_position": queue_position} | |
| logger.info(f"Compress-query started after queue wait {queue_wait:.3f}s") | |
| try: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| # Use generate_efficient_section with a special context to compress the user prompt | |
| compressed = await model.generate_efficient_section( | |
| section_type="QueryCompression", | |
| context=f"User request: {request.prompt}\nProduce a dense, efficient representation (bullet points or key-value pairs) of the user's intent, steps, and requirements. Keep under 300 tokens." | |
| ) | |
| total_time = time.time() - queue_start | |
| logger.info(f"Compress-query completed in {total_time:.2f}s") | |
| return CompressQueryResponse(compressed=compressed) | |
| except Exception as e: | |
| logger.error(f"Error compressing query: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| finally: | |
| await queue_status.release() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |