import torch import torch.nn as nn from torchvision import transforms, models import gradio as gr import json from PIL import Image import exifread from datetime import datetime import requests import numpy as np import soundfile as sf import os import time from transformers import pipeline as hf_pipeline import pytz # ======================================== # Setup Models # ======================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load classifier model = models.resnet18(weights=None) num_features = model.fc.in_features model.fc = nn.Linear(num_features, 4) model.load_state_dict(torch.load("location_classifier.pth", map_location=device)) model = model.to(device) model.eval() # Update these with your actual class names in the correct order class_labels = ["cafe", "gym", "library", "outdoor"] transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Load MusicGen print("Loading MusicGen model...") musicgen = hf_pipeline( "text-to-audio", model="facebook/musicgen-small", device=0 if torch.cuda.is_available() else -1, dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print("MusicGen loaded!") # Enable HEIC support try: from pillow_heif import register_heif_opener register_heif_opener() print("HEIC support enabled") except ImportError: print("HEIC support not available") # ======================================== # Constants # ======================================== DEFAULT_LAT = 40.4433 DEFAULT_LON = -79.9436 TIMEZONE = "America/New_York" # ======================================== # Helper Functions # ======================================== def extract_metadata(image_path): with open(image_path, 'rb') as f: tags = exifread.process_file(f) lat = lon = timestamp = None used_fallback = False # Try to extract GPS try: lat_values = tags['GPS GPSLatitude'].values lon_values = tags['GPS GPSLongitude'].values lat = float(lat_values[0].num)/lat_values[0].den + float(lat_values[1].num)/(lat_values[1].den*60) + float(lat_values[2].num)/(lat_values[2].den*3600) lon = float(lon_values[0].num)/lon_values[0].den + float(lon_values[1].num)/(lon_values[1].den*60) + float(lon_values[2].num)/(lon_values[2].den*3600) if str(tags.get('GPS GPSLatitudeRef'))=='S': lat = -lat if str(tags.get('GPS GPSLongitudeRef'))=='W': lon = -lon except: lat = DEFAULT_LAT lon = DEFAULT_LON used_fallback = True # Try to extract timestamp if 'EXIF DateTimeOriginal' in tags: try: timestamp = datetime.strptime(str(tags['EXIF DateTimeOriginal']), '%Y:%m:%d %H:%M:%S') eastern = pytz.timezone(TIMEZONE) timestamp = eastern.localize(timestamp) except: eastern = pytz.timezone(TIMEZONE) timestamp = datetime.now(eastern) used_fallback = True else: eastern = pytz.timezone(TIMEZONE) timestamp = datetime.now(eastern) used_fallback = True return {"lat": lat, "lon": lon, "time": timestamp, "used_fallback": used_fallback} def get_weather(lat=None, lon=None, timestamp=None): if lat is None or lon is None: lat = DEFAULT_LAT lon = DEFAULT_LON if timestamp is None: eastern = pytz.timezone(TIMEZONE) timestamp = datetime.now(eastern) if timestamp.tzinfo is not None: timestamp = timestamp.replace(tzinfo=None) date_str = timestamp.strftime("%Y-%m-%d") hour = timestamp.hour url = "https://archive-api.open-meteo.com/v1/archive" params = { "latitude": lat, "longitude": lon, "start_date": date_str, "end_date": date_str, "hourly": "temperature_2m,weathercode", "timezone": TIMEZONE } try: response = requests.get(url, params=params, timeout=5) data = response.json() if "hourly" in data: temps = data["hourly"]["temperature_2m"] codes = data["hourly"]["weathercode"] temp = temps[hour] if hour < len(temps) else temps[-1] code = codes[hour] if hour < len(codes) else codes[-1] weather_map = { 0: "Clear", 1: "Mainly Clear", 2: "Partly Cloudy", 3: "Overcast", 45: "Fog", 48: "Depositing Rime Fog", 51: "Light Drizzle", 53: "Moderate Drizzle", 55: "Dense Drizzle", 61: "Slight Rain", 63: "Moderate Rain", 65: "Heavy Rain", 71: "Slight Snow", 73: "Moderate Snow", 75: "Heavy Snow", 80: "Slight Rain Showers", 81: "Moderate Rain Showers", 82: "Violent Rain Showers", 95: "Thunderstorm", 96: "Thunderstorm with Slight Hail", 99: "Thunderstorm with Heavy Hail" } condition = weather_map.get(code, "Unknown") return condition, temp return None, None except Exception as e: print(f"Weather API error: {e}") return None, None def classify_mood(metadata): moods = [] # Get time of day first time_of_day = None if metadata.get("time"): hour = metadata["time"].hour if 5 <= hour < 12: moods.append("morning") time_of_day = "day" elif 12 <= hour < 18: moods.append("afternoon") time_of_day = "day" elif 18 <= hour < 22: moods.append("evening") time_of_day = "night" # Evening uses nighttime weather moods else: moods.append("night") time_of_day = "night" lat, lon = metadata.get("lat"), metadata.get("lon") timestamp = metadata.get("time") weather, temp = get_weather(lat, lon, timestamp) if weather: # Adjust weather moods based on time of day if time_of_day == "night": weather_map = { "Clear": ["clear", "starry"], "Mainly Clear": ["clear"], "Partly Cloudy": ["partly cloudy"], "Overcast": ["cloudy", "overcast"], "Fog": ["foggy", "mysterious"], "Depositing Rime Fog": ["foggy", "mysterious"], "Light Drizzle": ["rainy", "drizzle"], "Moderate Drizzle": ["rainy", "drizzle"], "Dense Drizzle": ["rainy", "drizzle"], "Slight Rain": ["rainy", "light rain"], "Moderate Rain": ["rainy", "moderate rain"], "Heavy Rain": ["rainy", "heavy rain"], "Slight Rain Showers": ["rainy", "showers"], "Moderate Rain Showers": ["rainy", "showers"], "Violent Rain Showers": ["rainy", "heavy showers"], "Slight Snow": ["snowy", "light snow"], "Moderate Snow": ["snowy", "moderate snow"], "Heavy Snow": ["snowy", "heavy snow"], "Thunderstorm": ["stormy", "thunder"], "Thunderstorm with Slight Hail": ["stormy", "hail"], "Thunderstorm with Heavy Hail": ["stormy", "hail"] } else: # Daytime weather_map = { "Clear": ["sunny", "bright"], "Mainly Clear": ["sunny", "light"], "Partly Cloudy": ["partly cloudy"], "Overcast": ["cloudy", "overcast"], "Fog": ["foggy", "hazy"], "Depositing Rime Fog": ["foggy", "hazy"], "Light Drizzle": ["rainy", "drizzle"], "Moderate Drizzle": ["rainy", "drizzle"], "Dense Drizzle": ["rainy", "drizzle"], "Slight Rain": ["rainy", "light rain"], "Moderate Rain": ["rainy", "moderate rain"], "Heavy Rain": ["rainy", "heavy rain"], "Slight Rain Showers": ["rainy", "showers"], "Moderate Rain Showers": ["rainy", "showers"], "Violent Rain Showers": ["rainy", "heavy showers"], "Slight Snow": ["snowy", "light snow"], "Moderate Snow": ["snowy", "moderate snow"], "Heavy Snow": ["snowy", "heavy snow"], "Thunderstorm": ["stormy", "thunder"], "Thunderstorm with Slight Hail": ["stormy", "hail"], "Thunderstorm with Heavy Hail": ["stormy", "hail"] } moods.extend(weather_map.get(weather, [weather.lower()])) if temp is not None: if temp < 10: moods.append("cold") elif temp > 25: moods.append("warm") else: moods.append("unknown weather") return moods def fusion(label, moods): mood_str = ", ".join(moods) return f"Ambient background music for a {label} setting with {mood_str} atmosphere" def generate_music(prompt, duration=5): result = musicgen( prompt, forward_params={ "do_sample": True, "max_new_tokens": int(duration * 50), "guidance_scale": 3.0 } ) audio_array = result[0]["audio"] if isinstance(result, list) else result["audio"] sample_rate = result[0]["sampling_rate"] if isinstance(result, list) else result["sampling_rate"] if isinstance(audio_array, torch.Tensor): audio_np = audio_array.cpu().numpy() else: audio_np = np.array(audio_array) audio_np = np.squeeze(audio_np) if audio_np.ndim != 1: audio_np = audio_np[0] if audio_np.shape[0] < audio_np.shape[1] else audio_np[:, 0] max_val = max(abs(audio_np.max()), abs(audio_np.min())) if max_val > 1.0: audio_np = audio_np / max_val output_path = f"music_{int(time.time())}.wav" sf.write(output_path, audio_np, sample_rate) return output_path # ======================================== # Main Pipeline # ======================================== def classify_and_generate_music(image): image_path = "temp.jpg" try: if image.mode != 'RGB': image = image.convert("RGB") image.save(image_path, "JPEG") except Exception as e: return ( f"Error processing image: {str(e)}", "", "", None, "" ) # Classification try: model_input = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(model_input) pred_idx = int(logits.argmax(dim=1).item()) pred_class = class_labels[pred_idx] if pred_idx < len(class_labels) else "unknown" except Exception as e: pred_class = "unknown" # Metadata + Mood try: metadata = extract_metadata(image_path) moods = classify_mood(metadata) used_fallback = metadata.get("used_fallback", False) except Exception as e: moods = ["unknown weather", "unknown time"] used_fallback = True # Fusion try: prompt = fusion(pred_class, moods) except Exception as e: prompt = f"Music for a {pred_class} with moods: {', '.join(moods)}" # Music Generation try: audio_output = generate_music(prompt) except Exception as e: audio_output = None # Info note info_note = "" if used_fallback: info_note = "No EXIF metadata found. Using current Pittsburgh time and CMU location for context." return ( f"Predicted setting: {pred_class}", f"Moods detected: {', '.join(moods)}", f"Fusion prompt: {prompt}", audio_output, info_note ) # ======================================== # Gradio Interface # ======================================== demo = gr.Interface( fn=classify_and_generate_music, inputs=gr.Image(type="pil", label="Upload image (JPEG, PNG, HEIC supported)"), outputs=[ gr.Textbox(label="Predicted Setting"), gr.Textbox(label="Detected Moods"), gr.Textbox(label="Fusion Prompt"), gr.Audio(type="filepath", label="Generated Music (2-5 min on CPU)"), gr.Textbox(label="Info", show_label=False) ], title="Image → Mood → Music Generator", description=( "Upload an image and this app will:\n" "1. Classify the scene (outdoor, gym, library, cafe)\n" "2. Extract EXIF metadata or use current Pittsburgh time/CMU location\n" "3. Query historical weather data\n" "4. Generate a contextual music prompt\n" "5. Create ambient music using MusicGen\n\n" "**Supported formats:** JPEG, PNG, HEIC\n\n" "Music generation takes 2-5 minutes on free CPU tier. Try the examples below!" ), examples=[ ["examples/example1.jpg"], ["examples/example2.jpg"], ["examples/example3.jpg"] ], cache_examples=False ) if __name__ == "__main__": demo.launch()