#============================================================================================ # https://huggingface.co/spaces/projectlosangeles/Orpheus-Karaoke #============================================================================================ print('=' * 70) print('Orpheus Karaoke Gradio App') print('=' * 70) print('Loading core Orpheus Karaoke modules...') import os import copy import time as reqtime import datetime from pytz import timezone print('=' * 70) print('Loading main Orpheus Karaoke modules...') os.environ['USE_FLASH_ATTENTION'] = '1' import torch torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.backends.cuda.enable_flash_sdp(True) from huggingface_hub import hf_hub_download import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from x_transformer_2_3_1 import * import random from transformers import AutoModelForCausalLM, AutoTokenizer import tqdm print('=' * 70) print('Loading aux Orpheus Karaoke modules...') import matplotlib.pyplot as plt import gradio as gr import spaces print('=' * 70) print('PyTorch version:', torch.__version__) print('=' * 70) print('Done!') print('Enjoy! :)') print('=' * 70) #================================================================================== MODEL_CHECKPOINT = 'Orpheus_Music_Transformer_Karaoke_Fine_Tuned_Model_2068_steps_0.9833_loss_0.7328_acc.pth' SOUNDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2' #================================================================================== print('=' * 70) print('Instantiating Orpehus model...') device_type = 'cuda' dtype = 'bfloat16' ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) SEQ_LEN = 1668 PAD_IDX = 18819 model = TransformerWrapper(num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True ) ) model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX) print('=' * 70) print('Loading model checkpoint...') model_checkpoint = hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer', filename=MODEL_CHECKPOINT) model.load_state_dict(torch.load(model_checkpoint, map_location=device_type, weights_only=True)) model = torch.compile(model, mode='max-autotune') model.to(device_type) model.eval() print('=' * 70) print('Done!') print('=' * 70) print('Model will use', dtype, 'precision...') print('=' * 70) #================================================================================== print('=' * 70) print('Instantiating Karaoke Lyrics model...') model_path = "asigalov61/Karaoke-Lyrics-Qwen3-0.6B" lyr_model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto" ) lyr_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) print('=' * 70) print('Done!') print('=' * 70) #================================================================================== def load_midi(input_midi): raw_score = TMIDIX.midi2single_track_ms_score(input_midi) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True) if escore_notes and escore_notes[0]: escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True) escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes) escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, min_notes_gap=0) #======================================================= # FINAL PROCESSING #======================================================= melody_chords = [] chord = [18816, 0] #======================================================= # MAIN PROCESSING CYCLE #======================================================= pe = escore_notes[0] first_chord = True for i, e in enumerate(escore_notes): delta_time = max(0, min(255, e[1] - pe[1])) if delta_time != 0: if first_chord: # Durations dur = 255 # Patches pat = 128 # Pitches ptc = 127 # Velocities # Calculating octo-velocity vel = 127 velocity = round(vel / 15)-1 #======================================================= # FINAL NOTE SEQ #======================================================= # Writing final note pat_ptc = (128 * pat) + ptc dur_vel = (8 * dur) + velocity chord.extend([pat_ptc+256, dur_vel+16768]) # 18816 first_chord = False #=============================================================================== melody_chords.append(chord) chord = [] chord.append(delta_time) #======================================================= # Durations dur = max(1, min(255, e[2])) # Patches pat = max(0, min(128, e[6])) # Pitches ptc = max(1, min(127, e[4])) # Velocities # Calculating octo-velocity vel = max(8, min(127, e[5])) velocity = round(vel / 15)-1 #======================================================= # FINAL NOTE SEQ #======================================================= # Writing final note pat_ptc = (128 * pat) + ptc dur_vel = (8 * dur) + velocity chord.extend([pat_ptc+256, dur_vel+16768]) # 18816 #===================================================================================== pe = e print('Done!') print('=' * 70) print('Score hss', len(melody_chords), 'chords') print('=' * 70) return melody_chords else: return None #================================================================================== @spaces.GPU def Generate_Karaoke(input_midi, words_generation_bias, drum_marker_pitch, generate_lyrics, model_temperature, model_sampling_top_p ): #=============================================================================== print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) print('=' * 70) print('Requested settings:') print('=' * 70) if input_midi is not None: fn = os.path.basename(input_midi) fn1 = fn.split('.')[0] print('Input MIDI file name:', fn) print('Words generation bias', words_generation_bias) print('Drum marker pitch:', drum_marker_pitch) print('Fill-in lyrics:', generate_lyrics) print('Model temperature:', model_temperature) print('Model top k:', model_sampling_top_p) print('=' * 70) #================================================================== def generate_lyrics(chords): inp_seq = [] for i, c in enumerate(tqdm.tqdm(chords)): inp_seq.extend(c) x = torch.LongTensor(inp_seq).cuda() with ctx: out = model.generate_biased(x, 1, temperature=model_temperature, filter_logits_fn=top_p, filter_kwargs={'thres': model_sampling_top_p}, logit_bias={16767: words_generation_bias}, return_prime=False, eos_token=18818, verbose=False ) y = out.tolist()[0] if y == 16767: inp_seq.append(16767) x = torch.LongTensor(inp_seq).cuda() with ctx: out = model.generate(x, 1, temperature=model_temperature, filter_logits_fn=top_p, filter_kwargs={'thres': model_sampling_top_p}, return_prime=False, eos_token=18818, verbose=False ) y = out.tolist()[0] inp_seq.append(y) return inp_seq #================================================================== def generate_lyrics_words(words_lens_list): prompt = 'Lyrics template: ' + ' '.join(['_' * c for c in words_lens_list]) messages = [ {"role": "system", "content": "Please fill in the words in the following song lyrics template and guess song title. Thank you."}, {"role": "user", "content": prompt} ] chat_text = lyr_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) model_inputs = lyr_tokenizer([chat_text], return_tensors="pt").to(lyr_model.device) num_batches = 1 generated_ids = lyr_model.generate( **model_inputs, max_new_tokens=1024, do_sample=True, temperature=model_temperature, top_p=model_sampling_top_p, num_return_sequences=num_batches, repetition_penalty=1.05 ) output_tokens = [ output_ids[len(input_ids):] for input_ids, output_ids in zip([model_inputs.input_ids] * num_batches, generated_ids) ] responses = lyr_tokenizer.batch_decode(output_tokens, skip_special_tokens=True) final_responses = [] for r in responses: final_responses.append(r.split('\n\n')[-1].strip()) title, lyrics = final_responses[0].splitlines() return title, lyrics #================================================================== if input_midi is not None: print('Loading MIDI...') chords = load_midi(input_midi.name) if chords is not None: print('Sample score chord', chords[0]) #================================================================== print('=' * 70) print('Generating Karaoke...') #================================================================== output_seq = generate_lyrics(chords) #================================================================== words_counts_list = [] pitch = 60 patch = 0 for ss in output_seq: if 256 <= ss < 16768: patch = (ss-256) // 128 pitch = (ss-256) % 128 if 16768 <= ss < 18816: dur = ((ss-16768) // 8) * 16 if pitch == 127 and patch == 128 and dur // 16 < 248: words_counts_list.append(max(1, min(15, dur // 16 // 8))) elif pitch == 127 and patch == 128 and dur // 16 >= 248: continue #================================================================== print('=' * 70) print('Done!') print('=' * 70) print('Output seq len', len(output_seq)) print('=' * 70) #=============================================================================== print('Rendering results...') print('=' * 70) #=============================================================================== def ntw(n): return ["one","two","three","four","five","six","seven","eight", "nine","ten","eleven","twelve","thirteen","fourteen","fifteen"][n-1] #=============================================================================== words = [] if generate_lyrics: print('Generating lyrics words...') gen_title, gen_lyrics = generate_lyrics_words(words_counts_list) gen_text = gen_title.title() + '\n\n' gen_text += gen_lyrics words = gen_lyrics.split(' ')[1:] print('Done!') print('=' * 70) #================================================================================== song_f = [] text_f = 'Lyrics template: ' time = 0 dur = 1 vel = 90 pitch = 60 channel = 0 patch = 0 patches = [-1] * 16 channels = [0] * 16 channels[9] = 1 widx = 0 for ss in output_seq: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 16768: patch = (ss-256) // 128 if patch < 128: if patch not in patches: if 0 in channels: cha = channels.index(0) channels[cha] = 1 else: cha = 15 patches[cha] = patch channel = patches.index(patch) else: channel = patches.index(patch) if patch == 128: channel = 9 pitch = (ss-256) % 128 if 16768 <= ss < 18816: dur = ((ss-16768) // 8) * 16 vel = (((ss-16768) % 8)+1) * 15 if pitch == 127 and patch == 128 and dur // 16 < 248: if generate_lyrics: if widx < len(words): song_f.append(['text_event', time, words[widx]]) if drum_marker_pitch > 26: song_f.append(['note', time, 128, 9, drum_marker_pitch, 127, 128]) widx += 1 else: song_f.append(['text_event', time, ntw(max(1, min(15, dur // 16 // 8)))]) if drum_marker_pitch > 26: song_f.append(['note', time, 128, 9, drum_marker_pitch, 127, 128]) text_f += ntw(max(1, min(15, dur // 16 // 8))) + ' ' elif pitch == 127 and patch == 128 and dur // 16 >= 248: continue else: song_f.append(['note', time, dur, channel, pitch, vel, patch]) #================================================================================== text_f = text_f.strip() #================================================================================== if generate_lyrics: text_f = gen_text #================================================================================== output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) fn1 = "Orpheus-Karaoke-Composition" detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'Orpheus Karaoke', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=SOUNDFONT_PATH, sample_rate=16000, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_lyrics = text_f output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True ) print('Output lyrics:', output_lyrics[:128]) print('=' * 70) #======================================================== else: return None, None, None, None print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_audio, output_plot, output_lyrics, output_midi else: return None, None, None, None #================================================================================== PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) #================================================================================== with gr.Blocks() as demo: #================================================================================== gr.Markdown("