| from audiocraft.models import MusicGen |
| import streamlit as st |
| import os |
| import torch |
| import torchaudio |
| from io import BytesIO |
|
|
|
|
|
|
| st.set_page_config( |
| page_icon=":musical_note:", |
| page_title="Music Gen" |
| ) |
|
|
| with open("style.css") as f: |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
| @st.cache_resource |
| def load_model(): |
| model = MusicGen.get_pretrained("facebook/musicgen-small") |
| return model |
|
|
| def generate_music_tensors(description, duration: int): |
| print("Description:", description) |
| print("Duration:", duration) |
| model = load_model() |
|
|
| model.set_generation_params( |
| use_sampling=True, |
| top_k=250, |
| duration=duration |
| ) |
|
|
| output = model.generate( |
| descriptions=[description], |
| progress=True, |
| return_tokens=True |
| ) |
| return output[0] |
|
|
| def save_audio_to_bytes(samples: torch.Tensor): |
| sample_rate = 32000 |
| assert samples.dim() == 2 or samples.dim() == 3 |
| samples = samples.detach().cpu() |
|
|
| if samples.dim() == 2: |
| samples = samples[None, ...] |
| |
| audio_buffer = BytesIO() |
| torchaudio.save(audio_buffer, samples[0], sample_rate=sample_rate, format="wav") |
| audio_buffer.seek(0) |
| return audio_buffer |
|
|
|
|
|
|
| video_background = """ |
| <style> |
| .video-container { |
| position: fixed; |
| top: 0; |
| left: 0; |
| width: 100%; |
| height: 100%; |
| overflow: hidden; |
| z-index: -1; |
| } |
| video { |
| position: absolute; |
| top: 50%; |
| left: 50%; |
| min-width: 100%; |
| min-height: 100%; |
| width: auto; |
| height: auto; |
| z-index: -1; |
| transform: translate(-50%, -50%); |
| background-size: cover; |
| } |
| </style> |
| |
| <div class="video-container"> |
| <video autoplay loop muted> |
| <source src="https://go.screenpal.com/watch/cZX2oynVXxQ" type="video/mp4"> |
| </video> |
| </div> |
| """ |
|
|
| st.markdown(video_background, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
| def main(): |
| st.title("Your Music") |
|
|
| with st.expander("See Explanation"): |
| st.write("This app uses Meta's Audiocraft Music Gen model to generate audio based on your description.") |
|
|
| text_area = st.text_area("Enter description") |
| time_slider = st.slider("Select time duration (seconds)", 2, 20, 5) |
|
|
| if text_area and time_slider: |
| st.json( |
| { |
| "Description": text_area, |
| "Selected duration": time_slider |
| } |
| ) |
| st.write("Generating your music... please wait.") |
|
|
|
|
|
|
| def main(): |
| st.title("Your Music") |
|
|
| with st.expander("See Explanation"): |
| st.write("This app uses Meta's Audiocraft Music Gen model to generate audio based on your description.") |
|
|
| text_area = st.text_area("Enter description") |
| time_slider = st.slider("Select time duration (seconds)", 2, 20, 5) |
|
|
| if text_area and time_slider: |
| st.json( |
| { |
| "Description": text_area, |
| "Selected duration": time_slider |
| } |
| ) |
| st.write("We will be back with your music... please enjoy doing the rest of your tasks while we come back in some time :)") |
| |
| st.subheader("Generated Music") |
| music_tensors = generate_music_tensors(text_area, time_slider) |
| |
| |
| audio_buffer = save_audio_to_bytes(music_tensors) |
| |
| |
| st.audio(audio_buffer, format="audio/wav") |
| |
| |
| st.download_button( |
| label="Download Audio", |
| data=audio_buffer, |
| file_name="generated_music.wav", |
| mime="audio/wav" |
| ) |
|
|
| if __name__ == "__main__": |
| main() |
|
|