from pathlib import Path import random from string import ascii_letters from miditok import PerTok, TokSequence from transformers import AutoTokenizer, AutoModelForSeq2SeqLM generated_path = Path("generated") midi_tokenizer = PerTok(params="tokenizer2.json") _ = midi_tokenizer._create_base_vocabulary() # workaround, otherwise the preprocessing will fail # Define which model we want, download right tokenizer checkpoint = "JannikAhlers/groove_midi_2" t5_tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) def generate_groove(midi_filename: str, count: int=1) -> list[str]: midi_tokens = midi_tokenizer(midi_filename)[0] tokens_string = " ".join(midi_tokens.tokens[:512]) # limit length to 512, because the tokenizer can't handle longer inputs inputs = t5_tokenizer(tokens_string, return_tensors="pt").input_ids out_filenames = [] for i in range(count): outputs = model.generate(inputs, max_new_tokens=1000, do_sample=True, top_k=30, top_p=0.95) generated = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) generated = generated.split(" ") generated.pop() # delete the last event because it might be chopped off mid-token generated_seq = TokSequence(tokens=generated) # save file generated_miditok = midi_tokenizer([generated_seq], programs=[(10, True)]) out_filename = f"{str.join("", random.choices(ascii_letters, k=16))}.mid" generated_miditok.dump_midi(generated_path/out_filename) out_filenames.append(str(generated_path/out_filename)) return out_filenames