Spatial-aware / src /streamlit_app.py
noumanjavaid's picture
Update src/streamlit_app.py
76a264c verified
Raw
History Blame
25.3 kB
# -*- coding: utf-8 -*-
import streamlit as st
import os
import asyncio
import base64
import io
import threading
import traceback
import time # For potential delays or timestamps if needed in future
import atexit # For cleanup actions on exit
from dotenv import load_dotenv
# --- Import main libraries ---
import cv2
import pyaudio
import PIL.Image
import mss # For screen capture
from google import genai
from google.genai import types
# --- Configuration ---
load_dotenv() # Load environment variables from .env file
# Audio configuration
FORMAT = pyaudio.paInt16
CHANNELS = 1
SEND_SAMPLE_RATE = 16000
RECEIVE_SAMPLE_RATE = 24000 # Gemini documentation recommendation
CHUNK_SIZE = 1024
AUDIO_PLAYBACK_QUEUE_MAXSIZE = 50 # Buffer for audio chunks from Gemini for playback
MEDIA_OUT_QUEUE_MAXSIZE = 20 # Buffer for audio/video frames to send to Gemini
# Video configuration
VIDEO_FPS_LIMIT = 1 # Send 1 frame per second to the API
VIDEO_PREVIEW_RESIZE = (640, 480) # Size for Streamlit preview display
VIDEO_API_RESIZE = (1024, 1024) # Max size for images sent to API (aspect ratio preserved)
# Gemini model configuration
MODEL_NAME = "models/gemini-2.0-flash-live-001" # VERIFY THIS MODEL NAME IS CORRECT FOR LIVE API
DEFAULT_VIDEO_MODE = "camera" # Default video input mode ("camera", "screen", "none")
# System Prompt for the Medical Assistant
MEDICAL_ASSISTANT_SYSTEM_PROMPT = """You are an AI Medical Assistant. Your primary function is to analyze visual information from the user's camera or screen and respond via voice.
Your responsibilities are:
1. **Visual Observation and Description:** Carefully examine the images or video feed. Describe relevant details you observe.
2. **General Information (Non-Diagnostic):** Provide general information related to what is visually presented, if applicable. You are not a diagnostic tool.
3. **Safety and Disclaimer (CRITICAL):**
* You are an AI assistant, **NOT a medical doctor or a substitute for one.**
* **DO NOT provide medical diagnoses, treatment advice, or interpret medical results (e.g., X-rays, scans, lab reports).**
* When appropriate, and always if the user seems to be seeking diagnosis or treatment, explicitly state your limitations and **strongly advise the user to consult a qualified healthcare professional.**
* If you see something that *appears* visually concerning (e.g., an unusual skin lesion, signs of injury), you may gently suggest it might be wise to have it looked at by a professional, without speculating on what it is.
4. **Tone:** Maintain a helpful, empathetic, and calm tone.
5. **Interaction:** After this initial instruction, you can make a brief acknowledgment of your role (e.g., "I'm ready to assist by looking at what you show me. Please remember to consult a doctor for medical advice."). Then, focus on responding to the user's visual input and questions.
Example of a disclaimer you might use: "As an AI assistant, I can describe what I see, but I can't provide medical advice or diagnoses. For any health concerns, it's always best to speak with a doctor or other healthcare professional."
"""
# Initialize PyAudio instance globally
pya = pyaudio.PyAudio()
# Ensure PyAudio is terminated on exit
def cleanup_pyaudio():
print("Terminating PyAudio instance.")
pya.terminate()
atexit.register(cleanup_pyaudio)
# Initialize Streamlit session state variables
def init_session_state():
defaults = {
'app_initialized': True, # To ensure this runs only once per session
'session_active': False,
'audio_loop_instance': None,
'chat_messages': [],
'current_frame_preview': None,
'video_mode_selection': DEFAULT_VIDEO_MODE,
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
init_session_state()
# Configure Streamlit page
st.set_page_config(page_title="Live Medical Assistant", layout="wide")
# Initialize Gemini client
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
st.error("GEMINI_API_KEY not found. Please set it in your environment variables or a .env file.")
st.stop()
# Use 'client' consistently for the genai.Client instance
client = None # Define client in the global scope
try:
client = genai.Client(
http_options={"api_version": "v1beta"}, # Required for live
api_key=GEMINI_API_KEY,
)
except Exception as e:
st.error(f"Failed to initialize Gemini client: {e}")
st.stop()
# Gemini LiveConnectConfig
LIVE_CONNECT_CONFIG = types.LiveConnectConfig(
response_modalities=["audio", "text"], # Expect both audio and text responses
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="Zephyr") # Changed voice to Zephyr
)
),
)
class AudioVisualLoop:
def __init__(self, video_mode_setting=DEFAULT_VIDEO_MODE):
self.video_mode = video_mode_setting
self.gemini_session = None
self.async_event_loop = None # To store the loop for thread-safe calls
self.audio_playback_queue = None # asyncio.Queue for audio from Gemini
self.media_to_gemini_queue = None # asyncio.Queue for audio/video to Gemini
self.is_running = True # Flag to control all async tasks
self.mic_stream = None # PyAudio input stream
self.playback_stream = None # PyAudio output stream
self.camera_capture = None # OpenCV VideoCapture object
async def send_text_input_to_gemini(self, user_text):
if not user_text or not self.gemini_session or not self.is_running:
print("Warning: Cannot send text. Session not active or no text.")
return
try:
print(f"Sending text to Gemini: {user_text}")
await self.gemini_session.send(input=user_text, end_of_turn=True)
except Exception as e:
st.error(f"Error sending text message to Gemini: {e}")
print(f"Traceback for send_text_input_to_gemini: {traceback.format_exc()}")
def _process_camera_frame(self):
if not self.camera_capture or not self.camera_capture.isOpened():
print("Camera not available or not open.")
return None
ret, frame = self.camera_capture.read()
if not ret:
print("Failed to read frame from camera.")
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = PIL.Image.fromarray(frame_rgb)
preview_img = img.copy()
preview_img.thumbnail(VIDEO_PREVIEW_RESIZE)
api_img = img.copy()
api_img.thumbnail(VIDEO_API_RESIZE) # Preserves aspect ratio
image_io = io.BytesIO()
api_img.save(image_io, format="jpeg")
image_bytes = image_io.getvalue()
return {
"preview": preview_img,
"api_data": {"mime_type": "image/jpeg", "data": base64.b64encode(image_bytes).decode()}
}
async def stream_camera_frames(self):
try:
self.camera_capture = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 for default camera
if not self.camera_capture.isOpened():
st.error("Could not open camera. Please check permissions and availability.")
self.is_running = False
return
while self.is_running:
frame_data = await asyncio.to_thread(self._process_camera_frame)
if frame_data:
st.session_state['current_frame_preview'] = frame_data["preview"]
if self.media_to_gemini_queue.full():
await self.media_to_gemini_queue.get() # Make space
await self.media_to_gemini_queue.put(frame_data["api_data"])
await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT)
except Exception as e:
# st.error(f"Camera streaming error: {e}") # Avoid st.error from non-main thread if problematic
print(f"Camera streaming error: {e}\nTraceback: {traceback.format_exc()}")
self.is_running = False
finally:
if self.camera_capture:
await asyncio.to_thread(self.camera_capture.release)
self.camera_capture = None
print("Camera streaming task finished.")
def _process_screen_frame(self):
with mss.mss() as sct:
monitor_index = 1
if len(sct.monitors) <= monitor_index:
monitor_index = 0
monitor = sct.monitors[monitor_index]
sct_img = sct.grab(monitor)
img = PIL.Image.frombytes("RGB", (sct_img.width, sct_img.height), sct_img.rgb)
preview_img = img.copy()
preview_img.thumbnail(VIDEO_PREVIEW_RESIZE)
api_img = img.copy()
api_img.thumbnail(VIDEO_API_RESIZE)
image_io = io.BytesIO()
api_img.save(image_io, format="jpeg")
image_bytes = image_io.getvalue()
return {
"preview": preview_img,
"api_data": {"mime_type": "image/jpeg", "data": base64.b64encode(image_bytes).decode()}
}
async def stream_screen_frames(self):
try:
while self.is_running:
frame_data = await asyncio.to_thread(self._process_screen_frame)
if frame_data:
st.session_state['current_frame_preview'] = frame_data["preview"]
if self.media_to_gemini_queue.full():
await self.media_to_gemini_queue.get()
await self.media_to_gemini_queue.put(frame_data["api_data"])
await asyncio.sleep(1.0 / VIDEO_FPS_LIMIT)
except Exception as e:
# st.error(f"Screen capture error: {e}")
print(f"Screen capture error: {e}\nTraceback: {traceback.format_exc()}")
self.is_running = False
finally:
print("Screen streaming task finished.")
async def stream_media_to_gemini(self):
try:
while self.is_running:
if not self.gemini_session:
await asyncio.sleep(0.1)
continue
try:
media_chunk = await asyncio.wait_for(self.media_to_gemini_queue.get(), timeout=1.0)
if media_chunk and self.gemini_session and self.is_running:
await self.gemini_session.send(input=media_chunk)
if media_chunk: # Avoid task_done on None if used as sentinel
self.media_to_gemini_queue.task_done()
except asyncio.TimeoutError:
continue
except Exception as e:
if self.is_running:
print(f"Error in stream_media_to_gemini: {e}")
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("stream_media_to_gemini task cancelled.")
finally:
print("Media streaming to Gemini task finished.")
async def capture_microphone_audio(self):
try:
mic_info = await asyncio.to_thread(pya.get_default_input_device_info)
self.mic_stream = await asyncio.to_thread(
pya.open,
format=FORMAT, channels=CHANNELS, rate=SEND_SAMPLE_RATE,
input=True, input_device_index=mic_info["index"],
frames_per_buffer=CHUNK_SIZE,
)
print("Microphone stream opened.")
while self.is_running:
try:
audio_data = await asyncio.to_thread(self.mic_stream.read, CHUNK_SIZE, exception_on_overflow=False)
if self.media_to_gemini_queue.full():
await self.media_to_gemini_queue.get()
await self.media_to_gemini_queue.put({"data": audio_data, "mime_type": "audio/pcm"})
except IOError as e:
if e.errno == pyaudio.paInputOverflowed: # type: ignore
print("Microphone Input overflowed. Skipping.")
else:
print(f"Microphone read IOError: {e}")
self.is_running = False; break
except Exception as e:
print(f"Error in capture_microphone_audio: {e}")
await asyncio.sleep(0.01)
except Exception as e:
# st.error(f"Failed to open microphone: {e}. Please check permissions.")
print(f"Failed to open microphone: {e}\nTraceback: {traceback.format_exc()}")
self.is_running = False
finally:
if self.mic_stream:
await asyncio.to_thread(self.mic_stream.stop_stream)
await asyncio.to_thread(self.mic_stream.close)
self.mic_stream = None
print("Microphone capture task finished.")
async def process_gemini_responses(self):
try:
while self.is_running:
if not self.gemini_session:
await asyncio.sleep(0.1); continue
try:
turn_response = self.gemini_session.receive()
async for chunk in turn_response:
if not self.is_running: break
if audio_data := chunk.data:
if not self.audio_playback_queue.full():
self.audio_playback_queue.put_nowait(audio_data)
else:
print("Audio playback queue full, discarding data.")
if text_response := chunk.text:
# Schedule Streamlit update from the main thread if possible,
# or use st.session_state and rely on rerun.
st.session_state['chat_messages'] = st.session_state['chat_messages'] + [{"role": "assistant", "content": text_response}]
# If immediate UI update is needed and safe:
# st.experimental_rerun() # Use with caution from background threads
except types.generation_types.StopCandidateException:
print("Gemini response stream ended (StopCandidateException).")
except Exception as e:
if self.is_running:
print(f"Error receiving from Gemini: {e}")
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("process_gemini_responses task cancelled.")
finally:
print("Gemini response processing task finished.")
async def play_gemini_audio(self):
try:
self.playback_stream = await asyncio.to_thread(
pya.open, format=FORMAT, channels=CHANNELS, rate=RECEIVE_SAMPLE_RATE, output=True
)
print("Audio playback stream opened.")
while self.is_running:
try:
audio_chunk = await asyncio.wait_for(self.audio_playback_queue.get(), timeout=1.0)
if audio_chunk: # Check if not None (sentinel)
await asyncio.to_thread(self.playback_stream.write, audio_chunk)
if audio_chunk: # Avoid task_done on None
self.audio_playback_queue.task_done()
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"Error playing audio: {e}")
await asyncio.sleep(0.01)
except Exception as e:
# st.error(f"Failed to open audio playback stream: {e}")
print(f"Failed to open audio playback stream: {e}\nTraceback: {traceback.format_exc()}")
self.is_running = False
finally:
if self.playback_stream:
await asyncio.to_thread(self.playback_stream.stop_stream)
await asyncio.to_thread(self.playback_stream.close)
self.playback_stream = None
print("Audio playback task finished.")
def signal_stop(self):
print("Signal to stop AudioVisualLoop received.")
self.is_running = False
if self.media_to_gemini_queue: self.media_to_gemini_queue.put_nowait(None)
if self.audio_playback_queue: self.audio_playback_queue.put_nowait(None)
async def run_main_loop(self):
self.async_event_loop = asyncio.get_running_loop()
self.is_running = True
# st.session_state['session_active'] = True # Set by caller in Streamlit thread
print("AudioVisualLoop starting...")
# Ensure client is not None before using (should be handled by global init)
if client is None:
print("Error: Gemini client is not initialized.")
st.error("Critical Error: Gemini client failed to initialize. Cannot start session.")
st.session_state['session_active'] = False
return
try:
# Use the global 'client' instance
async with client.aio.live.connect(model=MODEL_NAME, config=LIVE_CONNECT_CONFIG) as session:
self.gemini_session = session
print("Gemini session established.")
try:
print("Sending system prompt to Gemini...")
await self.gemini_session.send(input=MEDICAL_ASSISTANT_SYSTEM_PROMPT, end_of_turn=False)
print("System prompt sent successfully.")
except Exception as e:
# st.error(f"Failed to send system prompt to Gemini: {e}")
print(f"Failed to send system prompt to Gemini: {e}\nTraceback: {traceback.format_exc()}")
self.is_running = False
return
self.audio_playback_queue = asyncio.Queue(maxsize=AUDIO_PLAYBACK_QUEUE_MAXSIZE)
self.media_to_gemini_queue = asyncio.Queue(maxsize=MEDIA_OUT_QUEUE_MAXSIZE)
async with asyncio.TaskGroup() as tg:
print("Creating async tasks...")
tg.create_task(self.stream_media_to_gemini(), name="stream_media_to_gemini")
tg.create_task(self.capture_microphone_audio(), name="capture_microphone_audio")
if self.video_mode == "camera":
tg.create_task(self.stream_camera_frames(), name="stream_camera_frames")
elif self.video_mode == "screen":
tg.create_task(self.stream_screen_frames(), name="stream_screen_frames")
tg.create_task(self.process_gemini_responses(), name="process_gemini_responses")
tg.create_task(self.play_gemini_audio(), name="play_gemini_audio")
print("All async tasks created in TaskGroup.")
print("TaskGroup finished execution.")
except asyncio.CancelledError:
print("AudioVisualLoop.run_main_loop() was cancelled.")
except ExceptionGroup as eg:
# st.error(f"An error occurred in one of the concurrent tasks: {eg.exceptions[0]}")
print(f"ExceptionGroup caught in AudioVisualLoop: {eg}")
for i, exc in enumerate(eg.exceptions):
print(f" Task Exception {i+1}/{len(eg.exceptions)}: {type(exc).__name__}: {exc}")
traceback.print_exception(type(exc), exc, exc.__traceback__)
except Exception as e: # This catches the ConnectionClosedError
# st.error(f"A critical error occurred in the main session loop: {e}")
print(f"General Exception in AudioVisualLoop: {type(e).__name__}: {e}")
traceback.print_exception(e) # Print full traceback for debugging
finally:
print("AudioVisualLoop.run_main_loop() finishing...")
self.is_running = False
# st.session_state['session_active'] = False # Set by caller in Streamlit thread
self.gemini_session = None
print("AudioVisualLoop finished.")
# --- Streamlit UI ---
def run_streamlit_app():
st.title("Live AI Medical Assistant")
with st.sidebar:
st.header("Session Control")
video_mode_options = ["camera", "screen", "none"]
current_video_mode = st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE)
selected_video_mode = st.selectbox(
"Video Source:",
video_mode_options,
index=video_mode_options.index(current_video_mode),
disabled=st.session_state.get('session_active', False)
)
if selected_video_mode != current_video_mode: # Update if changed and not active
st.session_state['video_mode_selection'] = selected_video_mode
if not st.session_state.get('session_active', False):
if st.button("๐Ÿš€ Start Session", type="primary", use_container_width=True):
st.session_state['session_active'] = True # Set active before starting thread
st.session_state['chat_messages'] = [{
"role": "system",
"content": (
"Medical Assistant is activating. The AI has been instructed on its role to visually assist you. "
"Remember: This AI cannot provide medical diagnoses or replace consultation with a healthcare professional."
)
}]
st.session_state['current_frame_preview'] = None
audio_loop_instance = AudioVisualLoop(video_mode_setting=st.session_state['video_mode_selection'])
st.session_state['audio_loop_instance'] = audio_loop_instance
threading.Thread(target=lambda: asyncio.run(audio_loop_instance.run_main_loop()), daemon=True).start()
st.success("Session starting... Please wait for initialization.")
time.sleep(1)
st.rerun()
else:
if st.button("๐Ÿ›‘ Stop Session", type="secondary", use_container_width=True):
if st.session_state.get('audio_loop_instance'):
st.session_state['audio_loop_instance'].signal_stop()
st.session_state['audio_loop_instance'] = None
st.session_state['session_active'] = False # Set inactive
st.warning("Session stopping... Please wait.")
time.sleep(1)
st.rerun()
col_video, col_chat = st.columns([2, 3])
with col_video:
st.subheader("Live Feed")
if st.session_state.get('session_active', False) and st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE) != "none":
if st.session_state.get('current_frame_preview') is not None:
st.image(st.session_state['current_frame_preview'], caption="Live Video Feed", use_column_width=True)
else:
st.info("Waiting for video feed...")
elif st.session_state.get('video_mode_selection', DEFAULT_VIDEO_MODE) != "none":
st.info("Video feed will appear here once the session starts.")
else:
st.info("Video input is disabled for this session.")
with col_chat:
st.subheader("Chat with Assistant")
# Display chat messages from session state
for msg in st.session_state.get('chat_messages', []):
with st.chat_message(msg["role"]):
st.write(msg["content"])
user_chat_input = st.chat_input(
"Type your message or ask about the video...",
key="user_chat_input_box",
disabled=not st.session_state.get('session_active', False)
)
if user_chat_input:
# Append user message and rerun to display it immediately
st.session_state['chat_messages'] = st.session_state.get('chat_messages', []) + [{"role": "user", "content": user_chat_input}]
loop_instance = st.session_state.get('audio_loop_instance')
if loop_instance and loop_instance.async_event_loop and loop_instance.gemini_session:
if loop_instance.async_event_loop.is_running():
asyncio.run_coroutine_threadsafe(
loop_instance.send_text_input_to_gemini(user_chat_input),
loop_instance.async_event_loop
)
else:
st.error("Session event loop is not running. Cannot send message.")
elif not loop_instance or not st.session_state.get('session_active', False):
st.error("Session is not active. Please start a session to send messages.")
else:
st.warning("Session components not fully ready. Please wait a moment.")
st.rerun() # Rerun to show user message and any quick AI text response
if __name__ == "__main__":
run_streamlit_app()