Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import streamlit as st | |
| import os | |
| import asyncio | |
| import base64 | |
| import io | |
| import threading | |
| import traceback | |
| import time # For potential delays or timestamps if needed in future | |
| import atexit # For cleanup actions on exit | |
| from dotenv import load_dotenv | |
| # --- Import main libraries --- | |
| import cv2 | |
| import pyaudio | |
| import PIL.Image | |
| import mss # For screen capture | |
| from google import genai | |
| from google.genai import types | |
| # --- Configuration --- | |
| load_dotenv() # Load environment variables from .env file | |
| # Audio configuration | |
| FORMAT = pyaudio.paInt16 | |
| CHANNELS = 1 | |
| SEND_SAMPLE_RATE = 16000 | |
| RECEIVE_SAMPLE_RATE = 24000 # Gemini documentation recommendation | |
| CHUNK_SIZE = 1024 | |
| AUDIO_PLAYBACK_QUEUE_MAXSIZE = 50 # Buffer for audio chunks from Gemini for playback | |
| MEDIA_OUT_QUEUE_MAXSIZE = 20 # Buffer for audio/video frames to send to Gemini | |
| # Video configuration | |
| VIDEO_FPS_LIMIT = 1 # Send 1 frame per second to the API | |
| VIDEO_PREVIEW_RESIZE = (640, 480) # Size for Streamlit preview display | |
| VIDEO_API_RESIZE = (1024, 1024) # Max size for images sent to API (aspect ratio preserved) | |
| # Gemini model configuration | |
| MODEL_NAME = "models/gemini-2.0-flash-live-001" # VERIFY THIS MODEL NAME IS CORRECT FOR LIVE API | |
| DEFAULT_VIDEO_MODE = "camera" # Default video input mode ("camera", "screen", "none") | |
| # System Prompt for the Medical Assistant | |
| MEDICAL_ASSISTANT_SYSTEM_PROMPT = """You are an AI Medical Assistant. Your primary function is to analyze visual information from the user's camera or screen and respond via voice. | |
| Your responsibilities are: | |
| 1. **Visual Observation and Description:** Carefully examine the images or video feed. Describe relevant details you observe. | |
| 2. **General Information (Non-Diagnostic):** Provide general information related to what is visually presented, if applicable. You are not a diagnostic tool. | |
| 3. **Safety and Disclaimer (CRITICAL):** | |
| * You are an AI assistant, **NOT a medical doctor or a substitute for one.** | |
| * **DO NOT provide medical diagnoses, treatment advice, or interpret medical results (e.g., X-rays, scans, lab reports).** | |
| * When appropriate, and always if the user seems to be seeking diagnosis or treatment, explicitly state your limitations and **strongly advise the user to consult a qualified healthcare professional.** | |
| * If you see something that *appears* visually concerning (e.g., an unusual skin lesion, signs of injury), you may gently suggest it might be wise to have it looked at by a professional, without speculating on what it is. | |
| 4. **Tone:** Maintain a helpful, empathetic, and calm tone. | |
| 5. **Interaction:** After this initial instruction, you can make a brief acknowledgment of your role (e.g., "I'm ready to assist by looking at what you show me. Please remember to consult a doctor for medical advice."). Then, focus on responding to the user's visual input and questions. | |
| Example of a disclaimer you might use: "As an AI assistant, I can describe what I see, but I can't provide medical advice or diagnoses. For any health concerns, it's always best to speak with a doctor or other healthcare professional." | |
| """ | |
| # Initialize PyAudio instance globally | |
| pya = pyaudio.PyAudio() | |
| # Ensure PyAudio is terminated on exit | |
| def cleanup_pyaudio(): | |
| print("Terminating PyAudio instance.") | |
| pya.terminate() | |
| atexit.register(cleanup_pyaudio) | |
| # Initialize Streamlit session state variables | |
| def init_session_state(): | |
| defaults = { | |
| 'app_initialized': True, # To ensure this runs only once per session | |
| 'session_active': False, | |
| 'audio_loop_instance': None, | |
| 'chat_messages': [], | |
| 'current_frame_preview': None, | |
| 'video_mode_selection': DEFAULT_VIDEO_MODE, | |
| } | |
| for key, value in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| init_session_state() | |
| # Configure Streamlit page | |
| st.set_page_config(page_title="Live Medical Assistant", layout="wide") | |
| # Initialize Gemini client | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| st.error("GEMINI_API_KEY not found. Please set it in your environment variables or a .env file.") | |
| st.stop() | |
| # Use 'client' consistently for the genai.Client instance | |
| client = None # Define client in the global scope | |
| try: | |
| client = genai.Client( | |
| http_options={"api_version": "v1beta"}, # Required for live | |
| api_key=GEMINI_API_KEY, | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to initialize Gemini client: {e}") | |
| st.stop() | |
| # Gemini LiveConnectConfig | |
| LIVE_CONNECT_CONFIG = types.LiveConnectConfig( | |
| response_modalities=["audio", "text"], # Expect both audio and text responses | |
| speech_config=types.SpeechConfig( | |
| voice_config=types.VoiceConfig( | |
| prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Zephyr") # Changed voice to Zephyr | |
| ) | |
| ), | |
| ) | |
| class AudioVisualLoop: | |
| def __init__(self, video_mode_setting=DEFAULT_VIDEO_MODE): | |
| self.video_mode = video_mode_setting | |
| self.gemini_session = None | |
| self.async_event_loop = None # To store the loop for thread-safe calls | |
| self.audio_playback_queue = None # asyncio.Queue for audio from Gemini | |
| self.media_to_gemini_queue = None # asyncio.Queue for audio/video to Gemini | |
| self.is_running = True # Flag to control all async tasks | |
| self.mic_stream = None # PyAudio input stream | |
| self.playback_stream = None # PyAudio output stream | |
| self.camera_capture = None # OpenCV VideoCapture object | |
| async def send_text_input_to_gemini(self, user_text): | |
| if not user_text or not self.gemini_session or not self.is_running: | |
| print("Warning: Cannot send text. Session not active or no text.") | |
| return | |
| try: | |
| print(f"Sending text to Gemini: {user_text}") | |
| await self.gemini_session.send(input=user_text, end_of_turn=True) | |
| except Exception as e: | |
| st.error(f"Error sending text message to Gemini: {e}") | |
| print(f"Traceback for send_text_input_to_gemini: {traceback.format_exc()}") | |
| def _process_camera_frame(self): | |
| if not self.camera_capture or not self.camera_capture.isOpened(): | |
| print("Camera not available or not open.") | |
| return None | |
| ret, frame = self.camera_capture.read() | |
| if not ret: | |
| print("Failed to read frame from camera.") | |
| return None | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = PIL.Image.fromarray(frame_rgb) | |
| preview_img = img.copy() | |
| preview_img.thumbnail(VIDEO_PREVIEW_RESIZE) | |
| api_img = img.copy() | |
| api_img.thumbnail(VIDEO_API_RESIZE) # Preserves aspect ratio | |
| image_io = io.BytesIO() | |
| api_img.save(image_io, format="jpeg") | |
| image_bytes = image_io.getvalue() | |
| return { | |
| "preview": preview_img, | |
| "api_data": {"mime_type": "image/jpeg", "data": base64.b64encode(image_bytes).decode()} | |
| } | |
| async def stream_camera_frames(self): | |
| try: | |
| self.camera_capture = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 for default camera | |
| if not self.camera_capture.isOpened(): | |
| st.error("Could not open camera. Please check permissions and availability.") | |
| self.is_running = False | |
| return | |
| while self.is_running: | |
| frame_data = await asyncio.to_thread(self._process_camera_frame) | |
| if frame_data: | |
| st.session_state['current_frame_preview'] = frame_data["preview"] | |
| if self.media_to_gemini_queue.full(): | |
| await self.media_to_gemini_queue.get() # Make space | |
| await self.media_to_gemini_queue.put(frame_data["api_data"]) | |
| await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT) | |
| except Exception as e: | |
| # st.error(f"Camera streaming error: {e}") # Avoid st.error from non-main thread if problematic | |
| print(f"Camera streaming error: {e}\nTraceback: {traceback.format_exc()}") | |
| self.is_running = False | |
| finally: | |
| if self.camera_capture: | |
| await asyncio.to_thread(self.camera_capture.release) | |
| self.camera_capture = None | |
| print("Camera streaming task finished.") | |
| def _process_screen_frame(self): | |
| with mss.mss() as sct: | |
| monitor_index = 1 | |
| if len(sct.monitors) <= monitor_index: | |
| monitor_index = 0 | |
| monitor = sct.monitors[monitor_index] | |
| sct_img = sct.grab(monitor) | |
| img = PIL.Image.frombytes("RGB", (sct_img.width, sct_img.height), sct_img.rgb) | |
| preview_img = img.copy() | |
| preview_img.thumbnail(VIDEO_PREVIEW_RESIZE) | |
| api_img = img.copy() | |
| api_img.thumbnail(VIDEO_API_RESIZE) | |
| image_io = io.BytesIO() | |
| api_img.save(image_io, format="jpeg") | |
| image_bytes = image_io.getvalue() | |
| return { | |
| "preview": preview_img, | |
| "api_data": {"mime_type": "image/jpeg", "data": base64.b64encode(image_bytes).decode()} | |
| } | |
| async def stream_screen_frames(self): | |
| try: | |
| while self.is_running: | |
| frame_data = await asyncio.to_thread(self._process_screen_frame) | |
| if frame_data: | |
| st.session_state['current_frame_preview'] = frame_data["preview"] | |
| if self.media_to_gemini_queue.full(): | |
| await self.media_to_gemini_queue.get() | |
| await self.media_to_gemini_queue.put(frame_data["api_data"]) | |
| await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT) | |
| except Exception as e: | |
| # st.error(f"Screen capture error: {e}") | |
| print(f"Screen capture error: {e}\nTraceback: {traceback.format_exc()}") | |
| self.is_running = False | |
| finally: | |
| print("Screen streaming task finished.") | |
| async def stream_media_to_gemini(self): | |
| try: | |
| while self.is_running: | |
| if not self.gemini_session: | |
| await asyncio.sleep(0.1) | |
| continue | |
| try: | |
| media_chunk = await asyncio.wait_for(self.media_to_gemini_queue.get(), timeout=1.0) | |
| if media_chunk and self.gemini_session and self.is_running: | |
| await self.gemini_session.send(input=media_chunk) | |
| if media_chunk: # Avoid task_done on None if used as sentinel | |
| self.media_to_gemini_queue.task_done() | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| if self.is_running: | |
| print(f"Error in stream_media_to_gemini: {e}") | |
| await asyncio.sleep(0.1) | |
| except asyncio.CancelledError: | |
| print("stream_media_to_gemini task cancelled.") | |
| finally: | |
| print("Media streaming to Gemini task finished.") | |
| async def capture_microphone_audio(self): | |
| try: | |
| mic_info = await asyncio.to_thread(pya.get_default_input_device_info) | |
| self.mic_stream = await asyncio.to_thread( | |
| pya.open, | |
| format=FORMAT, channels=CHANNELS, rate=SEND_SAMPLE_RATE, | |
| input=True, input_device_index=mic_info["index"], | |
| frames_per_buffer=CHUNK_SIZE, | |
| ) | |
| print("Microphone stream opened.") | |
| while self.is_running: | |
| try: | |
| audio_data = await asyncio.to_thread(self.mic_stream.read, CHUNK_SIZE, exception_on_overflow=False) | |
| if self.media_to_gemini_queue.full(): | |
| await self.media_to_gemini_queue.get() | |
| await self.media_to_gemini_queue.put({"data": audio_data, "mime_type": "audio/pcm"}) | |
| except IOError as e: | |
| if e.errno == pyaudio.paInputOverflowed: # type: ignore | |
| print("Microphone Input overflowed. Skipping.") | |
| else: | |
| print(f"Microphone read IOError: {e}") | |
| self.is_running = False; break | |
| except Exception as e: | |
| print(f"Error in capture_microphone_audio: {e}") | |
| await asyncio.sleep(0.01) | |
| except Exception as e: | |
| # st.error(f"Failed to open microphone: {e}. Please check permissions.") | |
| print(f"Failed to open microphone: {e}\nTraceback: {traceback.format_exc()}") | |
| self.is_running = False | |
| finally: | |
| if self.mic_stream: | |
| await asyncio.to_thread(self.mic_stream.stop_stream) | |
| await asyncio.to_thread(self.mic_stream.close) | |
| self.mic_stream = None | |
| print("Microphone capture task finished.") | |
| async def process_gemini_responses(self): | |
| try: | |
| while self.is_running: | |
| if not self.gemini_session: | |
| await asyncio.sleep(0.1); continue | |
| try: | |
| turn_response = self.gemini_session.receive() | |
| async for chunk in turn_response: | |
| if not self.is_running: break | |
| if audio_data := chunk.data: | |
| if not self.audio_playback_queue.full(): | |
| self.audio_playback_queue.put_nowait(audio_data) | |
| else: | |
| print("Audio playback queue full, discarding data.") | |
| if text_response := chunk.text: | |
| # Schedule Streamlit update from the main thread if possible, | |
| # or use st.session_state and rely on rerun. | |
| st.session_state['chat_messages'] = st.session_state['chat_messages'] + [{"role": "assistant", "content": text_response}] | |
| # If immediate UI update is needed and safe: | |
| # st.experimental_rerun() # Use with caution from background threads | |
| except types.generation_types.StopCandidateException: | |
| print("Gemini response stream ended (StopCandidateException).") | |
| except Exception as e: | |
| if self.is_running: | |
| print(f"Error receiving from Gemini: {e}") | |
| await asyncio.sleep(0.1) | |
| except asyncio.CancelledError: | |
| print("process_gemini_responses task cancelled.") | |
| finally: | |
| print("Gemini response processing task finished.") | |
| async def play_gemini_audio(self): | |
| try: | |
| self.playback_stream = await asyncio.to_thread( | |
| pya.open, format=FORMAT, channels=CHANNELS, rate=RECEIVE_SAMPLE_RATE, output=True | |
| ) | |
| print("Audio playback stream opened.") | |
| while self.is_running: | |
| try: | |
| audio_chunk = await asyncio.wait_for(self.audio_playback_queue.get(), timeout=1.0) | |
| if audio_chunk: # Check if not None (sentinel) | |
| await asyncio.to_thread(self.playback_stream.write, audio_chunk) | |
| if audio_chunk: # Avoid task_done on None | |
| self.audio_playback_queue.task_done() | |
| except asyncio.TimeoutError: | |
| continue | |
| except Exception as e: | |
| print(f"Error playing audio: {e}") | |
| await asyncio.sleep(0.01) | |
| except Exception as e: | |
| # st.error(f"Failed to open audio playback stream: {e}") | |
| print(f"Failed to open audio playback stream: {e}\nTraceback: {traceback.format_exc()}") | |
| self.is_running = False | |
| finally: | |
| if self.playback_stream: | |
| await asyncio.to_thread(self.playback_stream.stop_stream) | |
| await asyncio.to_thread(self.playback_stream.close) | |
| self.playback_stream = None | |
| print("Audio playback task finished.") | |
| def signal_stop(self): | |
| print("Signal to stop AudioVisualLoop received.") | |
| self.is_running = False | |
| if self.media_to_gemini_queue: self.media_to_gemini_queue.put_nowait(None) | |
| if self.audio_playback_queue: self.audio_playback_queue.put_nowait(None) | |
| async def run_main_loop(self): | |
| self.async_event_loop = asyncio.get_running_loop() | |
| self.is_running = True | |
| # st.session_state['session_active'] = True # Set by caller in Streamlit thread | |
| print("AudioVisualLoop starting...") | |
| # Ensure client is not None before using (should be handled by global init) | |
| if client is None: | |
| print("Error: Gemini client is not initialized.") | |
| st.error("Critical Error: Gemini client failed to initialize. Cannot start session.") | |
| st.session_state['session_active'] = False | |
| return | |
| try: | |
| # Use the global 'client' instance | |
| async with client.aio.live.connect(model=MODEL_NAME, config=LIVE_CONNECT_CONFIG) as session: | |
| self.gemini_session = session | |
| print("Gemini session established.") | |
| try: | |
| print("Sending system prompt to Gemini...") | |
| await self.gemini_session.send(input=MEDICAL_ASSISTANT_SYSTEM_PROMPT, end_of_turn=False) | |
| print("System prompt sent successfully.") | |
| except Exception as e: | |
| # st.error(f"Failed to send system prompt to Gemini: {e}") | |
| print(f"Failed to send system prompt to Gemini: {e}\nTraceback: {traceback.format_exc()}") | |
| self.is_running = False | |
| return | |
| self.audio_playback_queue = asyncio.Queue(maxsize=AUDIO_PLAYBACK_QUEUE_MAXSIZE) | |
| self.media_to_gemini_queue = asyncio.Queue(maxsize=MEDIA_OUT_QUEUE_MAXSIZE) | |
| async with asyncio.TaskGroup() as tg: | |
| print("Creating async tasks...") | |
| tg.create_task(self.stream_media_to_gemini(), name="stream_media_to_gemini") | |
| tg.create_task(self.capture_microphone_audio(), name="capture_microphone_audio") | |
| if self.video_mode == "camera": | |
| tg.create_task(self.stream_camera_frames(), name="stream_camera_frames") | |
| elif self.video_mode == "screen": | |
| tg.create_task(self.stream_screen_frames(), name="stream_screen_frames") | |
| tg.create_task(self.process_gemini_responses(), name="process_gemini_responses") | |
| tg.create_task(self.play_gemini_audio(), name="play_gemini_audio") | |
| print("All async tasks created in TaskGroup.") | |
| print("TaskGroup finished execution.") | |
| except asyncio.CancelledError: | |
| print("AudioVisualLoop.run_main_loop() was cancelled.") | |
| except ExceptionGroup as eg: | |
| # st.error(f"An error occurred in one of the concurrent tasks: {eg.exceptions[0]}") | |
| print(f"ExceptionGroup caught in AudioVisualLoop: {eg}") | |
| for i, exc in enumerate(eg.exceptions): | |
| print(f" Task Exception {i+1}/{len(eg.exceptions)}: {type(exc).__name__}: {exc}") | |
| traceback.print_exception(type(exc), exc, exc.__traceback__) | |
| except Exception as e: # This catches the ConnectionClosedError | |
| # st.error(f"A critical error occurred in the main session loop: {e}") | |
| print(f"General Exception in AudioVisualLoop: {type(e).__name__}: {e}") | |
| traceback.print_exception(e) # Print full traceback for debugging | |
| finally: | |
| print("AudioVisualLoop.run_main_loop() finishing...") | |
| self.is_running = False | |
| # st.session_state['session_active'] = False # Set by caller in Streamlit thread | |
| self.gemini_session = None | |
| print("AudioVisualLoop finished.") | |
| # --- Streamlit UI --- | |
| def run_streamlit_app(): | |
| st.title("Live AI Medical Assistant") | |
| with st.sidebar: | |
| st.header("Session Control") | |
| video_mode_options = ["camera", "screen", "none"] | |
| current_video_mode = st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE) | |
| selected_video_mode = st.selectbox( | |
| "Video Source:", | |
| video_mode_options, | |
| index=video_mode_options.index(current_video_mode), | |
| disabled=st.session_state.get('session_active', False) | |
| ) | |
| if selected_video_mode != current_video_mode: # Update if changed and not active | |
| st.session_state['video_mode_selection'] = selected_video_mode | |
| if not st.session_state.get('session_active', False): | |
| if st.button("๐ Start Session", type="primary", use_container_width=True): | |
| st.session_state['session_active'] = True # Set active before starting thread | |
| st.session_state['chat_messages'] = [{ | |
| "role": "system", | |
| "content": ( | |
| "Medical Assistant is activating. The AI has been instructed on its role to visually assist you. " | |
| "Remember: This AI cannot provide medical diagnoses or replace consultation with a healthcare professional." | |
| ) | |
| }] | |
| st.session_state['current_frame_preview'] = None | |
| audio_loop_instance = AudioVisualLoop(video_mode_setting=st.session_state['video_mode_selection']) | |
| st.session_state['audio_loop_instance'] = audio_loop_instance | |
| threading.Thread(target=lambda: asyncio.run(audio_loop_instance.run_main_loop()), daemon=True).start() | |
| st.success("Session starting... Please wait for initialization.") | |
| time.sleep(1) | |
| st.rerun() | |
| else: | |
| if st.button("๐ Stop Session", type="secondary", use_container_width=True): | |
| if st.session_state.get('audio_loop_instance'): | |
| st.session_state['audio_loop_instance'].signal_stop() | |
| st.session_state['audio_loop_instance'] = None | |
| st.session_state['session_active'] = False # Set inactive | |
| st.warning("Session stopping... Please wait.") | |
| time.sleep(1) | |
| st.rerun() | |
| col_video, col_chat = st.columns([2, 3]) | |
| with col_video: | |
| st.subheader("Live Feed") | |
| if st.session_state.get('session_active', False) and st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE) != "none": | |
| if st.session_state.get('current_frame_preview') is not None: | |
| st.image(st.session_state['current_frame_preview'], caption="Live Video Feed", use_column_width=True) | |
| else: | |
| st.info("Waiting for video feed...") | |
| elif st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE) != "none": | |
| st.info("Video feed will appear here once the session starts.") | |
| else: | |
| st.info("Video input is disabled for this session.") | |
| with col_chat: | |
| st.subheader("Chat with Assistant") | |
| # Display chat messages from session state | |
| for msg in st.session_state.get('chat_messages', []): | |
| with st.chat_message(msg["role"]): | |
| st.write(msg["content"]) | |
| user_chat_input = st.chat_input( | |
| "Type your message or ask about the video...", | |
| key="user_chat_input_box", | |
| disabled=not st.session_state.get('session_active', False) | |
| ) | |
| if user_chat_input: | |
| # Append user message and rerun to display it immediately | |
| st.session_state['chat_messages'] = st.session_state.get('chat_messages', []) + [{"role": "user", "content": user_chat_input}] | |
| loop_instance = st.session_state.get('audio_loop_instance') | |
| if loop_instance and loop_instance.async_event_loop and loop_instance.gemini_session: | |
| if loop_instance.async_event_loop.is_running(): | |
| asyncio.run_coroutine_threadsafe( | |
| loop_instance.send_text_input_to_gemini(user_chat_input), | |
| loop_instance.async_event_loop | |
| ) | |
| else: | |
| st.error("Session event loop is not running. Cannot send message.") | |
| elif not loop_instance or not st.session_state.get('session_active', False): | |
| st.error("Session is not active. Please start a session to send messages.") | |
| else: | |
| st.warning("Session components not fully ready. Please wait a moment.") | |
| st.rerun() # Rerun to show user message and any quick AI text response | |
| if __name__ == "__main__": | |
| run_streamlit_app() |