Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| ECG AI7 - ECG Interpretation using Llama 3.2 11B Vision | |
| Gradio interface for Hugging Face Spaces | |
| """ | |
| import torch | |
| from transformers import MllamaForConditionalGeneration, AutoProcessor, TextStreamer, BitsAndBytesConfig | |
| from PIL import Image | |
| import gradio as gr | |
| # Model configuration | |
| MODEL_ID = "convaiinnovations/ECG-Instruct-Llama-3.2-11B-Vision" | |
| print(f"Loading model: {MODEL_ID}") | |
| print("Loading in 4-bit mode to fit in free tier memory (16GB)...") | |
| print("This may take a few minutes on first load...") | |
| # Configure 4-bit quantization properly using BitsAndBytesConfig | |
| # This is more stable than deprecated load_in_4bit parameter | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| # Load model and processor with 4-bit quantization to reduce memory significantly | |
| # This allows the 11B model to run on free tier (16GB GPU) | |
| model = MllamaForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| print("Model loaded successfully in 4-bit mode!") | |
| # Helper functions | |
| def _strip_assistant_prefix_safe(s: str) -> str: | |
| """Safely strip assistant prefix from generated text""" | |
| s = s.lstrip() | |
| # Only remove a leading role block if it literally starts the text | |
| for prefix in ("user", "assistant", "User", "Assistant"): | |
| if s.startswith(prefix): | |
| idx = s.find("\n\n") | |
| if idx != -1: | |
| return s[idx+2:].lstrip() | |
| idx = s.find("\n") | |
| if idx != -1: | |
| return s[idx+1:].lstrip() | |
| return s | |
| def generate_full_report(image_path: str, query: str, *, | |
| max_new_tokens: int = 1600, | |
| do_stream: bool = False, | |
| temperature: float = 0.0) -> str: | |
| """ | |
| Generate ECG interpretation report | |
| Args: | |
| image_path: local path to ECG image | |
| query: instruction string for the model | |
| max_new_tokens: maximum tokens to generate | |
| do_stream: whether to stream output (for terminal use) | |
| temperature: sampling temperature (0.0 = greedy) | |
| Returns: | |
| Full decoded interpretation report | |
| """ | |
| image = Image.open(image_path).convert("RGB") | |
| # Build single user turn: image + text | |
| messages = [ | |
| {"role": "user", "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": query} | |
| ]} | |
| ] | |
| # Create prompt compatible with processor | |
| input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=input_text, images=image, return_tensors="pt") | |
| # Move inputs to same device as model | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Setup streamer if requested | |
| streamer = TextStreamer(processor.tokenizer, skip_prompt=True) if do_stream else None | |
| # Generate | |
| with torch.no_grad(): | |
| out_ids = model.generate( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| do_sample=False if temperature == 0.0 else True, | |
| temperature=temperature, | |
| top_p=1.0, | |
| ) | |
| # Decode full generated text | |
| full_raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0] | |
| full_clean = _strip_assistant_prefix_safe(full_raw) | |
| return full_clean | |
| def translate_to_farsi(english_text: str, max_new_tokens: int = 1600) -> str: | |
| """Translate English text to Persian using the same model""" | |
| msgs = [ | |
| {"role": "user", "content": [ | |
| {"type": "text", | |
| "text": "فقط متن زیر را به فارسی روان ترجمه کن و فقط ترجمه را برگردان:\n\n" + english_text} | |
| ]} | |
| ] | |
| prompt = processor.apply_chat_template(msgs, add_generation_prompt=True) | |
| inputs = processor(text=prompt, return_tensors="pt") | |
| # Move to device | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=0.0, | |
| top_p=1.0 | |
| ) | |
| ans = processor.batch_decode(out, skip_special_tokens=True)[0] | |
| return _strip_assistant_prefix_safe(ans) | |
| # Gradio interface function | |
| def analyze_ecg_gradio(image, text_instruction="", language="Farsi"): | |
| """ | |
| Main function for Gradio interface | |
| Args: | |
| image: uploaded ECG image filepath (string path) | |
| text_instruction: optional clinical note / context | |
| language: output language (English or Farsi) | |
| Returns: | |
| Full AI-generated ECG interpretation report | |
| """ | |
| try: | |
| print(f"Received image: {image}") | |
| print(f"Text instruction: {text_instruction}") | |
| print(f"Language: {language}") | |
| # Build query | |
| query = "You are an expert cardiologist. " | |
| if text_instruction: | |
| query += f"Patient info: {text_instruction}. " | |
| query += "Write an in-depth diagnosis report from this ECG data, including the final diagnosis." | |
| # Generate report in English | |
| print("Generating report in English...") | |
| report = generate_full_report(image, query, max_new_tokens=1600, do_stream=False) | |
| # Translate to Farsi if requested | |
| if language == "Farsi": | |
| print("Translating to Farsi...") | |
| report = translate_to_farsi(report, max_new_tokens=1600) | |
| print("Report generated successfully!") | |
| return report | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return error_msg | |
| # Create minimal Gradio interface for API backend | |
| demo = gr.Interface( | |
| fn=analyze_ecg_gradio, | |
| inputs=[ | |
| gr.Image(type="filepath", label="ECG Image"), | |
| gr.Textbox(lines=2, label="Clinical Note (Optional)"), | |
| gr.Dropdown(choices=["English", "Farsi"], value="Farsi", label="Language"), | |
| ], | |
| outputs=gr.Textbox(label="ECG Report", lines=15), | |
| title="ECG AI7 Backend", | |
| description="ECG interpretation API powered by Llama 3.2 11B Vision", | |
| flagging_mode="never", | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |