import os from llama_cpp import Llama from huggingface_hub import hf_hub_download import re import time import logging import threading from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GGUFModelPipeline: def __init__(self, model_path_or_repo, filename=None, cache_dir=None, timeout=300): # Resolve cache dir for Spaces (default to /tmp/huggingface) cache_dir = cache_dir or os.environ.get("HF_HOME", "/tmp/huggingface") os.makedirs(cache_dir, exist_ok=True) # Set timeout for model operations self.timeout = timeout # If filename is provided, treat model_path_or_repo as HuggingFace repo_id if filename is not None: try: logger.info(f"Downloading model from {model_path_or_repo}/{filename}") local_path = hf_hub_download( repo_id=model_path_or_repo, filename=filename, cache_dir=cache_dir, resume_download=True, local_files_only=False, ) logger.info(f"Model downloaded successfully to {local_path}") except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError(f"Model download failed: {str(e)}") else: local_path = model_path_or_repo if not os.path.exists(local_path): raise FileNotFoundError(f"Model path does not exist: {local_path}") # Check file size to ensure it's reasonable file_size = os.path.getsize(local_path) / (1024 * 1024) # MB logger.info(f"Model file size: {file_size:.2f} MB") if file_size > 5000: # 5GB limit logger.warning(f"Model file is very large ({file_size:.2f} MB), may cause memory issues") load_start = time.time() # Performance tuning and CPU-friendly defaults for Spaces try: cpu_count = os.cpu_count() or 2 n_threads = min(cpu_count, 4) # Limit threads for HF Spaces (max 4 cores) n_batch = int(os.environ.get("GGUF_N_BATCH", "32")) # Reduced batch size for memory # Memory-optimized settings for Hugging Face Spaces self.model = Llama( model_path=local_path, n_ctx=2048, # Further reduced for HF Spaces memory limits n_threads=n_threads, n_batch=n_batch, n_gpu_layers=0, # CPU-only on Spaces by default logits_all=False, embedding=False, use_mmap=True, use_mlock=False, seed=0, verbose=False, # Reduce logging ) except Exception as e: logger.error(f"Failed to initialize GGUF model: {e}") raise RuntimeError(f"Failed to initialize GGUF model via llama.cpp: {e}") load_time = time.time() - load_start logger.info(f"[GGUF] Model initialized in {load_time:.2f}s from {local_path} (threads={n_threads}, batch={n_batch})") def _strip_special_tokens(self, text: str) -> str: # Remove common chat/control tokens that may leak from templates patterns = [ r"<\|assistant\|>", r"<\|user\|>", r"<\|system\|>", r"<\|end\|>", r"<\|endoftext\|>", r"", r"" ] for p in patterns: text = re.sub(p, "", text, flags=re.IGNORECASE) return text.strip() def _generate_with_timeout(self, prompt, max_tokens=512, temperature=0.5, top_p=0.95, timeout=500): """Generate text with timeout using threading. FIXED FOR RECURSION CRASH.""" # Calculate allowed max tokens based on context window prompt_tokens = len(prompt.split()) # Rough estimate n_ctx = 2048 # Context window size set during model init allowed_max_tokens = n_ctx - prompt_tokens if allowed_max_tokens <= 0: raise ValueError(f"Prompt too long ({prompt_tokens} tokens) for context window ({n_ctx})") if max_tokens > allowed_max_tokens: logger.warning(f"Reducing max_tokens from {max_tokens} to {allowed_max_tokens}") max_tokens = allowed_max_tokens def _generate(): try: # --- CRITICAL FIX: Use the most basic, direct call possible --- # Remove ALL complex parameters, especially 'stop' output = self.model( prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, # stop parameter is REMOVED entirely to prevent token-matching loops echo=False, # Do not return the prompt stream=False, # Get a single, complete response ) return output except Exception as e: raise RuntimeError(f"Direct model call failed: {str(e)}") from e with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(_generate) try: output = future.result(timeout=timeout) return output except FutureTimeoutError: future.cancel() raise TimeoutError(f"Generation timed out after {timeout} seconds") def generate(self, prompt, max_tokens=512, temperature=0.5, top_p=0.95): t0 = time.time() try: # Call the fixed _generate_with_timeout directly output = self._generate_with_timeout(prompt, max_tokens, temperature, top_p, timeout=500) dt = time.time() - t0 # Extract and clean the text text = output["choices"][0]["text"].strip() text = self._strip_special_tokens(text) approx_words = len(text.split()) logger.info(f"[GGUF] generate: {dt:.2f}s, ~{approx_words} words") return text except TimeoutError as e: logger.error(f"Generation timed out: {e}") raise e except Exception as e: logger.error(f"Generation failed: {e}") raise RuntimeError(f"Text generation failed: {str(e)}") def generate_full_summary(self, prompt, max_tokens=512, max_loops=2): def is_complete(text): # Check for all required sections and that it ends with a full sentence required_sections = [ 'Clinical Assessment', 'Key Trends & Changes', 'Plan & Suggested Actions', 'Direct Guidance for Physician' ] for section in required_sections: if section not in text: return False # Ends with a full sentence (ends with . ! or ?) return bool(re.search(r'[.!?][\s\n]*$', text)) full_output = "" current_prompt = prompt total_start = time.time() try: for loop_idx in range(max_loops): loop_start = time.time() output = self.generate(current_prompt, max_tokens=max_tokens) # Remove prompt from output if repeated if output.startswith(prompt): output = output[len(prompt):].strip() full_output += output loop_time = time.time() - loop_start logger.info(f"[GGUF] loop {loop_idx+1}/{max_loops}: {loop_time:.2f}s, cumulative {time.time()-total_start:.2f}s, length={len(full_output)} chars") # Only continue if required sections are missing required_present = all(s in full_output for s in ['Clinical Assessment','Key Trends & Changes','Plan & Suggested Actions','Direct Guidance for Physician']) if required_present: break # Prepare the next prompt to continue current_prompt = prompt + "\n" + full_output + "\nContinue the summary in markdown format:" total_time = time.time() - total_start logger.info(f"[GGUF] generate_full_summary total: {total_time:.2f}s") return full_output.strip() except Exception as e: logger.error(f"Full summary generation failed: {e}") # Return partial output if available if full_output.strip(): logger.warning("Returning partial summary due to generation error") return full_output.strip() raise RuntimeError(f"Summary generation failed: {str(e)}") # Fallback function for when GGUF model fails def create_fallback_pipeline(): """Create a simple text-based fallback when GGUF model fails""" class FallbackPipeline: def __init__(self): self.name = "fallback_text" def generate(self, prompt, **kwargs): # Simple template-based response sections = [ "## Clinical Assessment\nBased on the provided information, this appears to be a medical case requiring clinical review.", "## Key Trends & Changes\nPlease review the patient data for any significant changes or trends.", "## Plan & Suggested Actions\nConsider consulting with a healthcare provider for proper medical assessment.", "## Direct Guidance for Physician\nThis summary was generated using a fallback method. Please review all patient data thoroughly." ] return "\n\n".join(sections) def generate_full_summary(self, prompt, **kwargs): return self.generate(prompt, **kwargs) return FallbackPipeline()