Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| from difflib import get_close_matches | |
| from typing import Optional, Dict, Any | |
| import json | |
| import io | |
| from datasets import load_dataset # Import the datasets library | |
| # ------------------------------------------------- | |
| # Configuration | |
| # ------------------------------------------------- | |
| # Define insulin types and their durations and peak times | |
| INSULIN_TYPES = { | |
| "Rapid-Acting": {"onset": 0.25, "duration": 4, "peak_time": 1.0}, # Onset in hours, duration in hours, peak time in hours | |
| "Long-Acting": {"onset": 2, "duration": 24, "peak_time": 8}, | |
| } | |
| #Define basal rates | |
| DEFAULT_BASAL_RATES = { | |
| "00:00-06:00": 0.8, | |
| "06:00-12:00": 1.0, | |
| "12:00-18:00": 0.9, | |
| "18:00-24:00": 0.7 | |
| } | |
| # ------------------------------------------------- | |
| # Load Food Data from Hugging Face Dataset | |
| # ------------------------------------------------- | |
| def load_food_data(dataset_name="Anupam251272/food_nutrition"): | |
| try: | |
| dataset = load_dataset(dataset_name) | |
| food_data = dataset['train'].to_pandas() | |
| # Normalize column names to lowercase and remove spaces | |
| food_data.columns = [col.lower().replace(' ', '') for col in food_data.columns] | |
| # Remove unnamed columns | |
| food_data = food_data.loc[:, ~food_data.columns.str.contains('^unnamed')] # This line removes the columns | |
| # Normalize food_name column to lowercase: Crucial for matching | |
| if 'food_name' in food_data.columns: | |
| food_data['food_name'] = food_data['food_name'].str.lower() | |
| print("Unique Food Names in Dataset:") # ADDED | |
| print(food_data['food_name'].unique()) # ADDED | |
| else: | |
| print("Warning: 'food_name' column not found in dataset.") | |
| food_data = pd.DataFrame({ | |
| 'food_category': ['starch'], | |
| 'food_subcategory': ['bread'], | |
| 'food_name': ['white bread'], # lowercase default | |
| 'serving_description': ['servingsize'], | |
| 'serving_amount': [29], | |
| 'serving_unit': ['g'], | |
| 'carbohydrate_grams': [15], | |
| 'notes': ['default'] | |
| }) | |
| #Print first 5 rows to check columns and values | |
| print("First 5 rows of loaded data from Hugging Face Dataset:") | |
| print(food_data.head()) | |
| return food_data | |
| except Exception as e: | |
| print(f"Error loading Hugging Face Dataset: {e}") | |
| # Provide minimal default data in case of error | |
| food_data = pd.DataFrame({ | |
| 'food_category': ['starch'], | |
| 'food_subcategory': ['bread'], | |
| 'food_name': ['white bread'], # lowercase default | |
| 'serving_description': ['servingsize'], | |
| 'serving_amount': [29], | |
| 'serving_unit': ['g'], | |
| 'carbohydrate_grams': [15], | |
| 'notes': ['default'] | |
| }) | |
| return food_data | |
| # ------------------------------------------------- | |
| # Load Food Classification Model | |
| # ------------------------------------------------- | |
| try: | |
| processor = AutoImageProcessor.from_pretrained("therealcyberlord/vit-indian-food") | |
| model = AutoModelForImageClassification.from_pretrained( | |
| "therealcyberlord/vit-indian-food", | |
| torch_dtype=torch.float16, | |
| device_map="cpu", #This model will only use CPU! | |
| low_cpu_mem_usage=True # Force low memory usage, no matter the device | |
| ) | |
| model_loaded = True #Flag for error handling in other defs | |
| except Exception as e: | |
| print(f"Model Load Error", str(e)) | |
| model_loaded = False | |
| processor = None | |
| model = None | |
| def classify_food(image): | |
| """Classify food image using the pre-trained model""" | |
| print("classify_food function called") # Check if this function is even called | |
| try: | |
| if not model_loaded: | |
| print("Model not loaded, returning 'Unknown'") | |
| return "Unknown" | |
| print(f"Image type: {type(image)}") # Check the type of the image | |
| if isinstance(image, np.ndarray): | |
| print("Image is a numpy array, converting to PIL Image") | |
| image = Image.fromarray(image) | |
| print(f"Image mode: {image.mode}") # Check image mode (e.g., RGB, L) | |
| image = processor(images=image, return_tensors="pt") | |
| print(f"Processed image: {image}") # Print the output of the processor | |
| with torch.no_grad(): | |
| outputs = model(**image) | |
| predicted_idx = torch.argmax(outputs.logits, dim=-1).item() | |
| food_name = model.config.id2label.get(predicted_idx, "Unknown Food") | |
| print(f"Predicted food name: {food_name}") # Print the predicted food name | |
| return food_name.lower() # Convert classification to lowercase | |
| except Exception as e: | |
| print(f"Classify food error: {e}") # Print the full error message | |
| return "Unknown" # If an exception arises make sure to create a default case | |
| # ------------------------------------------------- | |
| # USDA API Integration - REMOVED for local HF Spaces deployment | |
| # ------------------------------------------------- | |
| def get_food_nutrition(food_name: str, food_data, portion_size: float = 1.0) -> Optional[Dict[str, Any]]: | |
| """Get carbohydrate content for the given food""" #No USDA anymore | |
| try: | |
| # First try the local CSV database | |
| food_name_lower = food_name.lower() # Ensure input is also lowercase | |
| food_names = food_data['food_name'].str.lower().tolist() #Already lowercased during load | |
| print(f"Searching for: {food_name_lower}") # Debugging: What are we searching for? | |
| matches = get_close_matches(food_name_lower, food_names, n=1, cutoff=0.5) | |
| if matches: | |
| # Use local database match | |
| matched_row = food_data[food_data['food_name'].str.lower() == matches[0]] | |
| if not matched_row.empty: | |
| row = matched_row.iloc[0] | |
| # Debugging: Print the entire row | |
| print(f"Matched row from CSV: {row}") | |
| # Explicitly check for column existence and valid data | |
| carb_col = 'carbohydrate_grams' | |
| amount_col = 'serving_amount' | |
| unit_col = 'serving_unit' | |
| if carb_col not in row or pd.isna(row[carb_col]): | |
| print(f"Warning: '{carb_col}' is missing or NaN in CSV") | |
| base_carbs = 0.0 | |
| else: | |
| base_carbs = row[carb_col] | |
| try: | |
| base_carbs = float(base_carbs) # Ensure it's a float | |
| except ValueError: | |
| print(f"Warning: '{carb_col}' is not a valid number in CSV") | |
| base_carbs = 0.0 | |
| if amount_col not in row or unit_col not in row or pd.isna(row[amount_col]) or pd.isna(row[unit_col]): | |
| serving_size = "Unknown" | |
| print(f"Warning: '{amount_col}' or '{unit_col}' is missing in CSV") | |
| else: | |
| serving_size = f"{row[amount_col]} {row[unit_col]}" | |
| adjusted_carbs = base_carbs * portion_size | |
| return { | |
| 'matched_food': row['food_name'], | |
| 'category': row['food_category'] if 'food_category' in row and not pd.isna(row['food_category']) else 'Unknown', | |
| 'subcategory': row['food_subcategory'] if 'food_subcategory' in row and not pd.isna(row['food_subcategory']) else 'Unknown', | |
| 'base_carbs': base_carbs, | |
| 'adjusted_carbs': adjusted_carbs, | |
| 'serving_size': serving_size, | |
| 'portion_multiplier': portion_size, | |
| 'notes': row['notes'] if 'notes' in row and not pd.isna(row['notes']) else '' | |
| } | |
| # If no match found in local database | |
| print(f"No match found in CSV for {food_name}") # Debugging line | |
| print(f"No nutrition information found for {food_name} in the local database.") # Debugging line | |
| return None | |
| except Exception as e: | |
| print(f"Error in get_food_nutrition: {e}") | |
| return None | |
| # ------------------------------------------------- | |
| # Insulin and Glucose Calculations | |
| # ------------------------------------------------- | |
| def get_basal_rate(current_time_hour, basal_rates): | |
| """Gets the appropriate basal rate for a given time of day.""" | |
| for interval, rate in basal_rates.items(): | |
| try: # add a try and except to handle values in intervals that do not have the format "start-end" | |
| parts = interval.split(":")[0].split("-") | |
| if len(parts) == 2: # Check if there are two parts (start and end) | |
| start_hour, end_hour = map(int, parts) | |
| if start_hour <= current_time_hour < end_hour or (start_hour <= current_time_hour and end_hour == 24): | |
| return rate | |
| except: | |
| print(f"Warning: Invalid interval format: {interval}. Skipping.") #Inform user of formatting issues | |
| return 0 # Default if no matching interval | |
| def insulin_activity(t, insulin_type, bolus_dose, bolus_duration=0): | |
| """Models insulin activity over time.""" | |
| insulin_data = INSULIN_TYPES.get(insulin_type) | |
| if not insulin_data: | |
| return 0 # Or raise an error | |
| # Simple exponential decay model (replace with a more sophisticated model) | |
| peak_time = insulin_data['peak_time'] # Time in hours at which insulin activity is at max level | |
| duration = insulin_data['duration'] # Total time for which insulin stays in effect | |
| if t < peak_time: | |
| activity = (bolus_dose * t / peak_time) * np.exp(1- t/peak_time) # rising activity | |
| elif t < duration: | |
| activity = bolus_dose * np.exp((peak_time - t) / (duration - peak_time)) # decaying activity | |
| else: | |
| activity = 0 | |
| if bolus_duration > 0: # Extended Bolus | |
| if 0 <= t <= bolus_duration: | |
| # Linear release of insulin over bolus_duration | |
| effective_dose = bolus_dose / bolus_duration | |
| duration = INSULIN_TYPES.get(insulin_type)['duration'] | |
| if t < duration: | |
| activity = effective_dose | |
| else: | |
| activity = 0 | |
| else: | |
| activity = 0 | |
| return activity | |
| def calculate_active_insulin(insulin_history, current_time): | |
| """Calculates remaining active insulin from previous doses.""" | |
| active_insulin = 0 | |
| for dose_time, dose_amount, insulin_type, bolus_duration in insulin_history: | |
| elapsed_time = current_time - dose_time | |
| remaining_activity = insulin_activity(elapsed_time, insulin_type, dose_amount, bolus_duration) | |
| active_insulin += remaining_activity | |
| return active_insulin | |
| def calculate_insulin_needs(carbs, glucose_current, glucose_target, tdd, weight, insulin_type="Rapid-Acting", override_correction_dose = None): | |
| """Calculate insulin needs for Type 1 diabetes""" | |
| if tdd <= 0: | |
| return { | |
| 'error': 'Total Daily Dose (TDD) must be greater than 0' | |
| } | |
| insulin_data = INSULIN_TYPES.get(insulin_type) | |
| if not insulin_data: | |
| return { | |
| 'error': "Invalid insulin type. Choose from" + ", ".join(INSULIN_TYPES.keys()) | |
| } | |
| # Refined calculations | |
| icr = (450 if weight <= 45 else 500) / tdd | |
| isf = 1700 / tdd | |
| # Calculate correction dose | |
| glucose_difference = glucose_current - glucose_target | |
| correction_dose = glucose_difference / isf | |
| if override_correction_dose is not None: # Check for None | |
| correction_dose = override_correction_dose | |
| # Calculate carb dose | |
| carb_dose = carbs / icr | |
| # Calculate total bolus | |
| total_bolus = max(0, carb_dose + correction_dose) | |
| # Calculate basal | |
| basal_dose = weight * 0.5 | |
| return { | |
| 'icr': round(icr, 2), | |
| 'isf': round(isf, 2), | |
| 'correction_dose': round(correction_dose, 2), | |
| 'carb_dose': round(carb_dose, 2), | |
| 'total_bolus': round(total_bolus, 2), | |
| 'basal_dose': round(basal_dose, 2), | |
| 'insulin_type': insulin_type, | |
| 'insulin_onset': insulin_data['onset'], | |
| 'insulin_duration': insulin_data['duration'], | |
| 'peak_time': insulin_data['peak_time'], | |
| } | |
| def create_detailed_report(nutrition_info, insulin_info, current_basal_rate): | |
| """Create a detailed report of carbs and insulin calculations""" | |
| carb_details = f""" | |
| FOOD DETAILS: | |
| ------------- | |
| Detected Food: {nutrition_info['matched_food']} | |
| Category: {nutrition_info['category']} | |
| Subcategory: {nutrition_info['subcategory']} | |
| CARBOHYDRATE INFORMATION: | |
| ------------------------ | |
| Standard Serving Size: {nutrition_info['serving_size']} | |
| Carbs per Serving: {nutrition_info['base_carbs']}g | |
| Portion Multiplier: {nutrition_info['portion_multiplier']}x | |
| Total Carbs: {nutrition_info['adjusted_carbs']}g | |
| Notes: {nutrition_info['notes']} | |
| """ | |
| insulin_details = f""" | |
| INSULIN CALCULATIONS: | |
| -------------------- | |
| ICR (Insulin to Carb Ratio): 1:{insulin_info['icr']} | |
| ISF (Insulin Sensitivity Factor): 1:{insulin_info['isf']} | |
| Insulin Type: {insulin_info['insulin_type']} | |
| Onset: {insulin_info['insulin_onset']} hours | |
| Duration: {insulin_info['insulin_duration']} hours | |
| Peak Time: {insulin_info['peak_time']} hours | |
| RECOMMENDED DOSES: | |
| ----------------- | |
| Correction Dose: {insulin_info['correction_dose']} units | |
| Carb Dose: {insulin_info['carb_dose']} units | |
| Total Bolus: {insulin_info['total_bolus']} units | |
| Daily Basal: {insulin_info['basal_dose']} units | |
| Current Basal Rate: {current_basal_rate} units/hour | |
| """ | |
| return carb_details, insulin_details | |
| # ------------------------------------------------- | |
| # Main Dashboard Function | |
| # ------------------------------------------------- | |
| def diabetes_dashboard(initial_glucose, food_image, stress_level, sleep_hours, time_hours, | |
| weight, tdd, target_glucose, exercise_duration, exercise_intensity, portion_size, insulin_type, | |
| override_correction_dose, extended_bolus_duration, basal_rates_input): | |
| """Main dashboard function""" | |
| try: | |
| # 0. Load Files | |
| food_data = load_food_data() #loads HF Datasets from the function | |
| # 1. Food Classification and Carb Calculation | |
| food_name = classify_food(food_image) # This line is now inside the function | |
| print(f"Classified food name: {food_name}") # Debugging: What is classified as? # Corrected indentation | |
| nutrition_info = get_food_nutrition(food_name, food_data, portion_size) # Changed to pass in data | |
| if not nutrition_info: | |
| # Try with generic categories if specific food not found | |
| generic_terms = food_name.split() | |
| for term in generic_terms: | |
| nutrition_info = get_food_nutrition(term, food_data, portion_size) # Changed to pass in data | |
| if nutrition_info: | |
| break | |
| if not nutrition_info: | |
| return ( | |
| f"Could not find nutrition information for: {food_name} in the local database", # Removed USDA ref | |
| "No insulin calculations available", | |
| None, | |
| None, | |
| None | |
| ) | |
| # 2. Insulin Calculations | |
| try: | |
| basal_rates_dict = json.loads(basal_rates_input) | |
| except: | |
| print("Basal rates JSON invalid, using default") | |
| basal_rates_dict = DEFAULT_BASAL_RATES | |
| insulin_info = calculate_insulin_needs( | |
| nutrition_info['adjusted_carbs'], | |
| initial_glucose, | |
| target_glucose, | |
| tdd, | |
| weight, | |
| insulin_type, | |
| override_correction_dose # Pass override | |
| ) | |
| if 'error' in insulin_info: | |
| return insulin_info['error'], None, None, None, None | |
| # 3. Create detailed reports | |
| current_basal_rate = get_basal_rate(12, basal_rates_dict) # Added basal rate to the function and report. | |
| carb_details, insulin_details = create_detailed_report(nutrition_info, insulin_info, current_basal_rate) | |
| # 4. Glucose Prediction | |
| hours = list(range(time_hours)) | |
| glucose_levels = [] | |
| current_glucose = initial_glucose | |
| insulin_history = [] # This will store all past doses for active insulin calculations | |
| # simulate that a dose has just been given to the patient at t=0 | |
| insulin_history.append((0, insulin_info['total_bolus'], insulin_info['insulin_type'], extended_bolus_duration)) # Pass bolus duration | |
| for t in hours: | |
| # Factor in carbs effect (peaks at 1-2 hours) | |
| carb_effect = nutrition_info['adjusted_carbs'] * 0.1 * np.exp(-(t - 1.5) ** 2 / 2) | |
| # Factor in insulin effect (peaks at 2-3 hours) | |
| # Original model: insulin_effect = insulin_info['total_bolus'] * 2 * np.exp(-(t-2.5)**2/2) | |
| # get effect based on amount of insulin still active from previous boluses | |
| active_insulin = calculate_active_insulin(insulin_history, t) | |
| insulin_effect = insulin_activity(t, insulin_type, active_insulin, extended_bolus_duration) # Pass bolus duration | |
| # Get the basal effect | |
| basal_rate = get_basal_rate(t, basal_rates_dict) | |
| basal_insulin_effect = basal_rate # Units per hour | |
| # Add stress effect | |
| stress_effect = stress_level * 2 | |
| # Add sleep effect | |
| sleep_effect = abs(8 - sleep_hours) * 5 | |
| # Add exercise effect | |
| exercise_effect = (exercise_duration / 60) * exercise_intensity * 2 | |
| # Calculate glucose with all factors | |
| glucose = (current_glucose + carb_effect - insulin_effect + | |
| stress_effect + sleep_effect + exercise_effect - basal_insulin_effect) | |
| glucose_levels.append(max(70, min(400, glucose))) | |
| current_glucose = glucose_levels[-1] | |
| # 5. Create visualization | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.plot(hours, glucose_levels, 'b-', label='Predicted Glucose') | |
| ax.axhline(y=target_glucose, color='g', linestyle='--', label='Target') | |
| ax.fill_between(hours, [70] * len(hours), [180] * len(hours), | |
| alpha=0.1, color='g', label='Target Range') | |
| ax.set_ylabel('Glucose (mg/dL)') | |
| ax.set_xlabel('Hours') | |
| ax.set_title('Predicted Blood Glucose Over Time') | |
| ax.legend() | |
| ax.grid(True) | |
| return ( | |
| carb_details, | |
| insulin_details, | |
| insulin_info['basal_dose'], | |
| insulin_info['total_bolus'], | |
| fig | |
| ) | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None, None, None | |
| # ------------------------------------------------- | |
| # Gradio Interface Setup | |
| # ------------------------------------------------- | |
| with gr.Blocks() as app: # using Blocks API to manually design the layout | |
| gr.Markdown("# Type 1 Diabetes Management Dashboard") | |
| with gr.Tab("Glucose & Meal"): | |
| with gr.Row(): | |
| initial_glucose = gr.Number(label="Current Blood Glucose (mg/dL)", value=120) | |
| food_image = gr.Image(label="Food Image", type="pil") # Now a file upload | |
| with gr.Row(): | |
| portion_size = gr.Slider(0.1, 3, step=0.1, label="Portion Size Multiplier", value=1.0) | |
| with gr.Tab("Insulin"): | |
| with gr.Column(): # Place inputs in a column layout | |
| insulin_type = gr.Dropdown(choices=list(INSULIN_TYPES.keys()), label="Insulin Type", value="Rapid-Acting") | |
| override_correction_dose = gr.Number(label="Override Correction Dose (Units)", value=None) | |
| extended_bolus_duration = gr.Number(label="Extended Bolus Duration (Hours)", value=0) | |
| with gr.Tab("Basal Settings"): | |
| with gr.Column(): | |
| basal_rates_input = gr.Textbox(label="Basal Rates (JSON)", lines=3, | |
| value="""{"00:00-06:00": 0.8, "06:00-12:00": 1.0, "12:00-18:00": 0.9, "18:00-24:00": 0.7}""") | |
| with gr.Tab("Other Factors"): | |
| with gr.Accordion("Factors affecting Glucose levels", open=False): # keep advanced options collapsed by default | |
| weight = gr.Number(label="Weight (kg)", value=70) | |
| tdd = gr.Number(label="Total Daily Dose (TDD) of insulin (units)", value=40) | |
| target_glucose = gr.Number(label="Target Blood Glucose (mg/dL)", value=100) | |
| stress_level = gr.Slider(1, 10, step=1, label="Stress Level (1-10)", value=1) | |
| sleep_hours = gr.Number(label="Sleep Hours", value=7) | |
| exercise_duration = gr.Number(label="Exercise Duration (minutes)", value=0) | |
| exercise_intensity = gr.Slider(1, 10, step=1, label="Exercise Intensity (1-10)", value=1) | |
| with gr.Row(): | |
| time_hours = gr.Slider(1, 24, step=1, label="Prediction Time (hours)", value=6) | |
| with gr.Row(): | |
| calculate_button = gr.Button("Calculate") | |
| with gr.Column(): | |
| carb_details_output = gr.Textbox(label="Carbohydrate Details", lines=5) | |
| insulin_details_output = gr.Textbox(label="Insulin Calculation Details", lines=5) | |
| basal_dose_output = gr.Number(label="Basal Insulin Dose (units/day)") | |
| bolus_dose_output = gr.Number(label="Bolus Insulin Dose (units)") | |
| glucose_plot_output = gr.Plot(label="Glucose Prediction") | |
| calculate_button.click( | |
| fn=diabetes_dashboard, | |
| inputs=[ | |
| initial_glucose, | |
| food_image, | |
| stress_level, | |
| sleep_hours, | |
| time_hours, | |
| weight, | |
| tdd, | |
| target_glucose, | |
| exercise_duration, | |
| exercise_intensity, | |
| portion_size, | |
| insulin_type, | |
| override_correction_dose, | |
| extended_bolus_duration, | |
| basal_rates_input, | |
| ], | |
| outputs=[ | |
| carb_details_output, | |
| insulin_details_output, | |
| basal_dose_output, | |
| bolus_dose_output, | |
| glucose_plot_output | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |