import torch import torchaudio import gradio as gr import spaces # Enables ZeroGPU on Hugging Face from demucs import pretrained from demucs.apply import apply_model from pyharp import * from audiotools import AudioSignal # Available Demucs models DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"] STEM_CHOICES = { "Vocals": 3, "Drums": 0, "Bass": 1, "Other": 2, "Instrumental (No Vocals)": "instrumental" } @spaces.GPU def separate_stem(audio_file_path: str, model_name: str, stem_choice: str): """ Separates an audio file into the chosen stem using a Demucs model. Ensures correct stem ordering and supports mono input. """ # Load Demucs model model = pretrained.get_model(model_name) model.to('cuda' if torch.cuda.is_available() else 'cpu') model.eval() # Load the audio file waveform, sr = torchaudio.load(audio_file_path) # Check if input is mono is_mono = waveform.shape[0] == 1 if is_mono: waveform = waveform.repeat(2, 1) # Convert mono to stereo for Demucs # Apply Demucs model with torch.no_grad(): stems_batch = apply_model( model, waveform.unsqueeze(0), overlap=0.2, shifts=1, split=True ) # stems shape: (batch, stems, channels, samples) stems = stems_batch[0] print(f"Model '{model_name}' extracted stems shape: {stems.shape}") if stem_choice == "Instrumental (No Vocals)": stem = stems[0] + stems[1] + stems[2] # Drums + Bass + Other else: stem_index = STEM_CHOICES[stem_choice] stem = stems[stem_index] # Convert back to mono if the input was originally mono if is_mono: stem = stem.mean(dim=0, keepdim=True) # Stereo → Mono # Convert to AudioSignal with float32 dtype stem_signal = AudioSignal(stem.cpu().numpy().astype('float32'), sample_rate=sr) return stem_signal def process_fn_stem(audio_file_path: str, demucs_model: str, stem_choice: str): """ PyHARP process function: - Separates the chosen stem using Demucs. - Saves the stem as a .wav file. """ stem_signal = separate_stem(audio_file_path, model_name=demucs_model, stem_choice=stem_choice) stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav") return stem_path, LabelList(labels=[]) # Define the model card model_card = ModelCard( name="Demucs Stem Separator", description="Uses Demucs to separate a music track into a selected stem.", author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach", tags=["demucs", "source-separation", "pyharp", "stems"] ) # Build Gradio interface with dropdowns for model and stem selection with gr.Blocks() as demo: gr.LoginButton() dropdown_model = gr.Dropdown( label="Select Demucs Model", choices=DEMUX_MODELS, value="mdx_extra_q" ) dropdown_stem = gr.Dropdown( label="Select Stem to Separate", choices=list(STEM_CHOICES.keys()), value="Vocals" ) app = build_endpoint( model_card=model_card, components=[dropdown_model, dropdown_stem], process_fn=process_fn_stem ) demo.queue() demo.launch(show_error=True)