{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "370abc32", "metadata": { "id": "370abc32" }, "outputs": [], "source": [ "!unzip \"/content/songs.zip\" -d \"/content/uploaded_midis/\"" ] }, { "cell_type": "markdown", "id": "06f323a8", "metadata": { "id": "06f323a8" }, "source": [ "# Colab MIDI Continuation (Variant E 100M)\n", "\n", "This notebook is self-contained for Colab and targets the Pulse88 Variant E checkpoint family.\n", "\n", "It downloads model files and tokenizer from Hugging Face, then continues one MIDI file or a batch of MIDI files.\n", "\n", "What you provide:\n", "- a Hugging Face repo id or URL (defaults to https://huggingface.co/Chickaboo/Pulse88-E-85M-Alpha)\n", "- HF_TOKEN (or HUGGINGFACE_HUB_TOKEN) when using a private repo\n", "- INPUT_MIDI_SOURCE set to either:\n", " - a single MIDI file path (.mid or .midi), or\n", " - a folder containing MIDI files\n", "\n", "Input examples:\n", "- single file: /content/uploaded_midis/song.mid\n", "- folder: /content/uploaded_midis\n", "\n", "What this notebook does:\n", "1. Downloads the model bundle from Hugging Face (supports public and private repos).\n", "2. Loads Variant E (GDN plus sparse attention) and tokenizer.\n", "3. Resolves either single-file or folder input into a sorted MIDI list.\n", "4. Processes each song and outputs:\n", " - continuation MIDI\n", " - seed audio preview\n", " - continuation audio preview\n", " - seed-vs-continuation comparison PNG" ] }, { "cell_type": "code", "execution_count": null, "id": "60a5c7f8", "metadata": { "id": "60a5c7f8" }, "outputs": [], "source": [ "from pathlib import Path\n", "import importlib.util\n", "import subprocess\n", "import sys\n", "\n", "REQUIRED_MODULES = {\n", " \"pretty_midi\": \"pretty_midi>=0.2.10\",\n", " \"matplotlib\": \"matplotlib>=3.7.0\",\n", " \"safetensors\": \"safetensors>=0.4.0\",\n", " \"huggingface_hub\": \"huggingface_hub>=0.24.0\",\n", "}\n", "OPTIONAL_MODULES = {\n", " \"fla\": \"flash-linear-attention\",\n", "}\n", "\n", "missing_required = [\n", " spec\n", " for module_name, spec in REQUIRED_MODULES.items()\n", " if importlib.util.find_spec(module_name) is None\n", "]\n", "if missing_required:\n", " print(f\"Installing required package(s): {missing_required}\")\n", " subprocess.run(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"--quiet\",\n", " \"--disable-pip-version-check\",\n", " *missing_required,\n", " ],\n", " check=True,\n", " )\n", "\n", "missing_optional = [\n", " spec\n", " for module_name, spec in OPTIONAL_MODULES.items()\n", " if importlib.util.find_spec(module_name) is None\n", "]\n", "if missing_optional:\n", " print(f\"Attempting optional package install for GDN kernels: {missing_optional}\")\n", " subprocess.run(\n", " [\n", " sys.executable,\n", " \"-m\",\n", " \"pip\",\n", " \"install\",\n", " \"--quiet\",\n", " \"--disable-pip-version-check\",\n", " *missing_optional,\n", " ],\n", " check=False,\n", " )\n", " still_missing_optional = [\n", " module_name\n", " for module_name in OPTIONAL_MODULES.keys()\n", " if importlib.util.find_spec(module_name) is None\n", " ]\n", " if still_missing_optional:\n", " print(\n", " \"Optional GDN kernel package is still unavailable. \"\n", " \"The notebook will fall back to an approximation if strict GDN kernels cannot be loaded.\"\n", " )\n", "\n", "import torch\n", "\n", "if torch.cuda.is_available():\n", " try:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", " torch.backends.cudnn.allow_tf32 = True\n", " torch.set_float32_matmul_precision(\"high\")\n", " except Exception:\n", " pass\n", "\n", "PROJECT_ROOT = Path(\"/content/pulse88_colab_100m\")\n", "ASSET_DIR = PROJECT_ROOT / \"assets\"\n", "MODEL_DIR = PROJECT_ROOT / \"models\"\n", "TOKENIZER_DIR = PROJECT_ROOT / \"tokenizer\"\n", "OUTPUT_DIR = PROJECT_ROOT / \"outputs\"\n", "SEED_DIR = PROJECT_ROOT / \"seed\"\n", "for path in [PROJECT_ROOT, ASSET_DIR, MODEL_DIR, TOKENIZER_DIR, OUTPUT_DIR, SEED_DIR]:\n", " path.mkdir(parents=True, exist_ok=True)\n", "\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", "print(f\"Project root: {PROJECT_ROOT}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a5864f6f", "metadata": { "id": "a5864f6f" }, "outputs": [], "source": [ "from pathlib import Path\n", "import os\n", "\n", "INPUT_MIDI_SOURCE = Path(os.environ.get(\"INPUT_MIDI_SOURCE\", \"/content/uploaded_midis/songs\")).expanduser()\n", "MIDI_EXTENSIONS = {\".mid\", \".midi\"}\n", "RECURSIVE_SEARCH = False # Applies only when INPUT_MIDI_SOURCE is a directory.\n", "\n", "if not INPUT_MIDI_SOURCE.exists():\n", " raise FileNotFoundError(\n", " \"Set INPUT_MIDI_SOURCE to a MIDI file or a folder containing MIDI files. \"\n", " f\"Current value: {INPUT_MIDI_SOURCE}\"\n", " )\n", "\n", "if INPUT_MIDI_SOURCE.is_file():\n", " if INPUT_MIDI_SOURCE.suffix.lower() not in MIDI_EXTENSIONS:\n", " raise ValueError(f\"INPUT_MIDI_SOURCE is not a MIDI file: {INPUT_MIDI_SOURCE}\")\n", " MIDI_INPUT_FILES = [INPUT_MIDI_SOURCE]\n", " input_mode = \"file\"\n", "elif INPUT_MIDI_SOURCE.is_dir():\n", " input_mode = \"folder\"\n", " if RECURSIVE_SEARCH:\n", " candidates = [\n", " path\n", " for path in INPUT_MIDI_SOURCE.rglob(\"*\")\n", " if path.is_file() and path.suffix.lower() in MIDI_EXTENSIONS\n", " ]\n", " MIDI_INPUT_FILES = sorted(\n", " candidates,\n", " key=lambda p: str(p.relative_to(INPUT_MIDI_SOURCE)).lower(),\n", " )\n", " else:\n", " candidates = [\n", " path\n", " for path in INPUT_MIDI_SOURCE.iterdir()\n", " if path.is_file() and path.suffix.lower() in MIDI_EXTENSIONS\n", " ]\n", " MIDI_INPUT_FILES = sorted(candidates, key=lambda p: p.name.lower())\n", "else:\n", " raise ValueError(f\"Unsupported input path type: {INPUT_MIDI_SOURCE}\")\n", "\n", "if not MIDI_INPUT_FILES:\n", " raise FileNotFoundError(\n", " f\"No MIDI files found at {INPUT_MIDI_SOURCE}. Expected .mid or .midi files.\"\n", " )\n", "\n", "print(f\"Input mode: {input_mode}\")\n", "print(f\"Input source: {INPUT_MIDI_SOURCE}\")\n", "print(f\"Found {len(MIDI_INPUT_FILES)} MIDI file(s).\")\n", "for idx, path in enumerate(MIDI_INPUT_FILES[:10], start=1):\n", " print(f\" {idx:02d}. {path.name}\")\n", "if len(MIDI_INPUT_FILES) > 10:\n", " print(f\" ... ({len(MIDI_INPUT_FILES) - 10} more)\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e7073ba0", "metadata": { "id": "e7073ba0" }, "outputs": [], "source": [ "import os\n", "import shutil\n", "from getpass import getpass\n", "from pathlib import Path\n", "\n", "import torch\n", "from huggingface_hub import snapshot_download\n", "\n", "\n", "def _normalize_hf_repo_id(value: str) -> str:\n", " raw = str(value or \"\").strip()\n", " if not raw:\n", " return \"\"\n", "\n", " if \"huggingface.co/\" in raw:\n", " raw = raw.split(\"huggingface.co/\", 1)[1]\n", "\n", " raw = raw.split(\"?\", 1)[0].split(\"#\", 1)[0].strip(\"/\")\n", " parts = [p for p in raw.split(\"/\") if p]\n", " if not parts:\n", " return \"\"\n", "\n", " if parts[0] in {\"models\", \"datasets\", \"spaces\"} and len(parts) >= 3:\n", " return \"/\".join(parts[1:3])\n", " if len(parts) >= 2:\n", " return \"/\".join(parts[:2])\n", " return raw\n", "\n", "\n", "def _normalize_hf_checkpoint_subdir(value: str) -> str:\n", " raw = str(value or \"\").strip()\n", " if not raw:\n", " return \"\"\n", "\n", " # Accept full Hugging Face tree URLs, e.g.\n", " # https://huggingface.co///tree/main/step-500\n", " if \"huggingface.co/\" in raw and \"/tree/\" in raw:\n", " suffix = raw.split(\"/tree/\", 1)[1]\n", " parts = [p for p in suffix.split(\"/\") if p]\n", " if len(parts) >= 2:\n", " raw = \"/\".join(parts[1:])\n", "\n", " raw = raw.split(\"?\", 1)[0].split(\"#\", 1)[0].strip(\"/\")\n", " if not raw:\n", " return \"\"\n", "\n", " # If someone pastes just the final path token from a tree URL context,\n", " # this still works (e.g., \"step-500\").\n", " return str(raw)\n", "\n", "\n", "HF_REPO_ID_RAW = os.environ.get(\n", " \"HF_REPO_ID\",\n", " \"https://huggingface.co/Chickaboo/Pulse88-E-85M-Alpha\",\n", ").strip()\n", "HF_REPO_ID = _normalize_hf_repo_id(HF_REPO_ID_RAW)\n", "HF_REVISION = os.environ.get(\"HF_REVISION\", \"\").strip()\n", "HF_CHECKPOINT_DIR_RAW = os.environ.get(\"HF_CHECKPOINT_DIR\", \"step-10500\").strip()\n", "HF_CHECKPOINT_DIR = _normalize_hf_checkpoint_subdir(HF_CHECKPOINT_DIR_RAW)\n", "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\").strip() or os.environ.get(\"HUGGINGFACE_HUB_TOKEN\", \"\").strip()\n", "if not HF_TOKEN:\n", " try:\n", " HF_TOKEN = getpass(\"Enter HF token for private repo (leave blank for public): \" ).strip()\n", " except Exception:\n", " HF_TOKEN = \"\"\n", "\n", "if not HF_REPO_ID:\n", " raise ValueError(\"Set HF_REPO_ID to a valid Hugging Face repo id or URL.\")\n", "\n", "allow_patterns = [\"custom_tokenizer.json\", \"tokenizer.json\"]\n", "if HF_CHECKPOINT_DIR:\n", " allow_patterns.append(f\"{HF_CHECKPOINT_DIR}/**\")\n", "else:\n", " allow_patterns.append(\"**\")\n", "\n", "snapshot_kwargs = {\n", " \"repo_id\": HF_REPO_ID,\n", " \"local_dir\": str(ASSET_DIR),\n", " \"local_dir_use_symlinks\": False,\n", " \"allow_patterns\": allow_patterns,\n", "}\n", "if HF_REVISION:\n", " snapshot_kwargs[\"revision\"] = HF_REVISION\n", "if HF_TOKEN:\n", " snapshot_kwargs[\"token\"] = HF_TOKEN\n", "\n", "print(f\"HF repo: {HF_REPO_ID}\")\n", "print(f\"HF token provided: {'yes' if HF_TOKEN else 'no'}\")\n", "if HF_REVISION:\n", " print(f\"HF revision: {HF_REVISION}\")\n", "print(f\"HF checkpoint dir override: {HF_CHECKPOINT_DIR or '[auto]'}\")\n", "\n", "try:\n", " ASSET_ROOT = Path(snapshot_download(**snapshot_kwargs))\n", "except Exception as exc:\n", " if not HF_TOKEN:\n", " raise RuntimeError(\n", " \"Hugging Face download failed. If the repo is private, set HF_TOKEN (or HUGGINGFACE_HUB_TOKEN) and rerun this cell.\"\n", " ) from exc\n", " raise\n", "\n", "print(f\"Downloaded assets to: {ASSET_ROOT}\")\n", "\n", "MODEL_SEARCH_ROOT = ASSET_ROOT / HF_CHECKPOINT_DIR if HF_CHECKPOINT_DIR else ASSET_ROOT\n", "if HF_CHECKPOINT_DIR and not MODEL_SEARCH_ROOT.exists():\n", " raise FileNotFoundError(\n", " f\"Configured HF_CHECKPOINT_DIR was not found in snapshot: {HF_CHECKPOINT_DIR}\"\n", " )\n", "\n", "\n", "def _bundle_score(model_path: Path, state_path: Path | None, tokenizer_path: Path | None) -> int:\n", " score = 0\n", " model_name = model_path.name.lower()\n", " if \"latest\" in model_name:\n", " score += 120\n", " elif \"best\" in model_name:\n", " score += 100\n", " elif \"checkpoint\" in model_name or \"epoch\" in model_name:\n", " score += 60\n", " elif \"model\" in model_name:\n", " score += 30\n", " score += min(25, int(model_path.stat().st_size // 1_000_000))\n", "\n", " if state_path is not None:\n", " state_name = state_path.name.lower()\n", " if \"latest\" in state_name:\n", " score += 80\n", " elif \"best\" in state_name:\n", " score += 70\n", " elif \"state\" in state_name:\n", " score += 40\n", "\n", " if tokenizer_path is not None:\n", " score += 60 if tokenizer_path.name == \"custom_tokenizer.json\" else 30\n", "\n", " return score\n", "\n", "\n", "def _select_best_bundle(model_root: Path, tokenizer_root: Path) -> tuple[dict, list[dict]]:\n", " model_paths = sorted(path for path in model_root.rglob(\"*.safetensors\") if path.is_file())\n", " if not model_paths:\n", " raise FileNotFoundError(\n", " f\"No .safetensors files were found under model search root: {model_root}\"\n", " )\n", "\n", " bundles = []\n", " for model_path in model_paths:\n", " parent = model_path.parent\n", " stem = model_path.stem\n", "\n", " state_candidates = [\n", " parent / f\"{stem}_state.pt\",\n", " parent / f\"{stem.replace('_model', '')}_state.pt\" if stem.endswith(\"_model\") else parent / f\"{stem}_state.pt\",\n", " parent / \"latest_state.pt\",\n", " parent / \"best_state.pt\",\n", " ]\n", " state_candidates.extend(sorted(parent.glob(\"*_state.pt\")))\n", " state_path = next((candidate for candidate in state_candidates if candidate.exists()), None)\n", "\n", " tokenizer_candidates = [\n", " tokenizer_root / \"custom_tokenizer.json\",\n", " tokenizer_root / \"tokenizer.json\",\n", " parent / \"custom_tokenizer.json\",\n", " parent / \"tokenizer.json\",\n", " parent.parent / \"custom_tokenizer.json\",\n", " parent.parent / \"tokenizer.json\",\n", " ]\n", " tokenizer_path = next((candidate for candidate in tokenizer_candidates if candidate.exists()), None)\n", " if tokenizer_path is None:\n", " continue\n", "\n", " try:\n", " state_payload = torch.load(state_path, map_location=\"cpu\") if state_path is not None else {}\n", " except Exception:\n", " state_payload = {}\n", " if not isinstance(state_payload, dict):\n", " state_payload = {}\n", "\n", " model_payload = dict(state_payload.get(\"model_config\") or {})\n", " data_payload = dict(state_payload.get(\"data_config\") or {})\n", "\n", " score = _bundle_score(model_path, state_path, tokenizer_path)\n", " if str(data_payload.get(\"tokenization_strategy\", \"\")).strip().lower() == \"custom_delta\":\n", " score += 40\n", " if int(data_payload.get(\"vocab_size\", 0) or 0) == 374:\n", " score += 10\n", " if int(model_payload.get(\"d_model\", 0) or 0) > 0:\n", " score += 10\n", " if int(model_payload.get(\"n_layers\", 0) or 0) > 0:\n", " score += 10\n", "\n", " bundles.append(\n", " {\n", " \"score\": int(score),\n", " \"model_path\": model_path,\n", " \"state_path\": state_path,\n", " \"tokenizer_path\": tokenizer_path,\n", " \"summary\": {\n", " \"tokenization_strategy\": str(data_payload.get(\"tokenization_strategy\", \"\")),\n", " \"d_model\": int(model_payload.get(\"d_model\", 0) or 0),\n", " \"n_layers\": int(model_payload.get(\"n_layers\", 0) or 0),\n", " \"vocab_size\": int(data_payload.get(\"vocab_size\", 0) or 0),\n", " },\n", " }\n", " )\n", "\n", " if not bundles:\n", " raise FileNotFoundError(\n", " \"Could not find a complete model/state/tokenizer bundle in the Hugging Face snapshot.\"\n", " )\n", "\n", " bundles.sort(key=lambda item: item[\"score\"], reverse=True)\n", " return bundles[0], bundles[: min(5, len(bundles))]\n", "\n", "\n", "selected, top_candidates = _select_best_bundle(MODEL_SEARCH_ROOT, ASSET_ROOT)\n", "selected_model_path = selected[\"model_path\"]\n", "selected_state_path = selected[\"state_path\"]\n", "selected_tokenizer_path = selected[\"tokenizer_path\"]\n", "\n", "if selected_state_path is None:\n", " raise FileNotFoundError(\n", " \"Could not find a matching sidecar state file for the selected model bundle.\"\n", " )\n", "\n", "shutil.copy2(selected_model_path, MODEL_DIR / \"latest.safetensors\")\n", "shutil.copy2(selected_state_path, MODEL_DIR / \"latest_state.pt\")\n", "if selected_tokenizer_path.name == \"custom_tokenizer.json\":\n", " shutil.copy2(selected_tokenizer_path, TOKENIZER_DIR / \"custom_tokenizer.json\")\n", "else:\n", " shutil.copy2(selected_tokenizer_path, TOKENIZER_DIR / \"tokenizer.json\")\n", "\n", "print(\"Selected asset bundle:\")\n", "print(f\" model: {selected_model_path}\")\n", "print(f\" state: {selected_state_path}\")\n", "print(f\" tokenizer: {selected_tokenizer_path}\")\n", "print(f\" score: {selected['score']}\")\n", "print(f\" summary: {selected['summary']}\")\n", "\n", "if len(top_candidates) > 1:\n", " print(\"Top candidate scores:\")\n", " for item in top_candidates:\n", " model_rel = item[\"model_path\"].relative_to(ASSET_ROOT)\n", " state_rel = item[\"state_path\"].relative_to(ASSET_ROOT) if item[\"state_path\"] else \"missing\"\n", " tok_rel = item[\"tokenizer_path\"].relative_to(ASSET_ROOT) if item[\"tokenizer_path\"] else \"missing\"\n", " print(f\" score={item['score']:3d} model={model_rel} state={state_rel} tokenizer={tok_rel}\")\n", "\n", "print(\"Available model files:\")\n", "for path in sorted(MODEL_SEARCH_ROOT.rglob(\"*.safetensors\")):\n", " print(f\" {path.relative_to(ASSET_ROOT)}\")\n", "print(\"Available tokenizer files:\")\n", "for path in sorted(list(ASSET_ROOT.rglob(\"custom_tokenizer.json\")) + list(ASSET_ROOT.rglob(\"tokenizer.json\"))):\n", " print(f\" {path.relative_to(ASSET_ROOT)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3117799a", "metadata": { "id": "3117799a" }, "outputs": [], "source": [ "import json\n", "import math\n", "import wave\n", "import warnings\n", "from dataclasses import dataclass\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Sequence, Tuple\n", "\n", "import numpy as np\n", "import pretty_midi\n", "\n", "try:\n", " import matplotlib.pyplot as plt\n", "except Exception as exc: # pragma: no cover - optional dependency\n", " plt = None # type: ignore\n", " warnings.warn(f\"matplotlib import failed. Visualization will be disabled. Details: {exc}\")\n", "\n", "\n", "def midi_duration(midi_path: str | Path) -> float:\n", " \"\"\"Return MIDI duration in seconds.\"\"\"\n", "\n", " midi = pretty_midi.PrettyMIDI(str(midi_path))\n", " return float(midi.get_end_time())\n", "\n", "\n", "def render_midi_audio(\n", " midi_path: str | Path,\n", " wav_path: str | Path,\n", " *,\n", " sample_rate: int = 22050,\n", ") -> Path:\n", " \"\"\"Render a MIDI file to a small WAV preview using pretty_midi's synth.\"\"\"\n", "\n", " midi = pretty_midi.PrettyMIDI(str(midi_path))\n", " try:\n", " waveform = midi.synthesize(fs=int(sample_rate))\n", " except Exception:\n", " duration = max(0.5, float(midi.get_end_time()))\n", " waveform = np.zeros(int(duration * int(sample_rate)), dtype=np.float32)\n", "\n", " waveform = np.asarray(waveform, dtype=np.float32)\n", " if waveform.ndim == 2:\n", " waveform = waveform.mean(axis=1)\n", " if waveform.size == 0:\n", " waveform = np.zeros(int(sample_rate * 0.5), dtype=np.float32)\n", "\n", " peak = float(np.max(np.abs(waveform))) if waveform.size else 0.0\n", " if peak > 0:\n", " waveform = 0.95 * waveform / peak\n", "\n", " pcm16 = np.clip(waveform, -1.0, 1.0)\n", " pcm16 = (pcm16 * 32767.0).astype(np.int16)\n", "\n", " out_path = Path(wav_path)\n", " out_path.parent.mkdir(parents=True, exist_ok=True)\n", " with wave.open(str(out_path), \"wb\") as wav_file:\n", " wav_file.setnchannels(1)\n", " wav_file.setsampwidth(2)\n", " wav_file.setframerate(int(sample_rate))\n", " wav_file.writeframes(pcm16.tobytes())\n", " return out_path\n", "\n", "\n", "def _extract_note_events(midi_path: str | Path) -> List[Tuple[float, float, int, int]]:\n", " \"\"\"Extract note events `(start, end, pitch, velocity)` for piano range.\"\"\"\n", "\n", " midi = pretty_midi.PrettyMIDI(str(midi_path))\n", " events: List[Tuple[float, float, int, int]] = []\n", " for inst in midi.instruments:\n", " if inst.is_drum:\n", " continue\n", " for n in inst.notes:\n", " if 21 <= n.pitch <= 108:\n", " events.append((n.start, n.end, n.pitch, n.velocity))\n", " return events\n", "\n", "\n", "def visualize_pianoroll(\n", " midi_path: str | Path,\n", " title: str = \"\",\n", " save_path: Optional[str | Path] = None,\n", ") -> None:\n", " \"\"\"Render a pianoroll plot for one MIDI file.\"\"\"\n", "\n", " if plt is None: # pragma: no cover - optional dependency\n", " warnings.warn(\"visualize_pianoroll called but matplotlib is not installed; skipping visualization\")\n", " return\n", "\n", " events = _extract_note_events(midi_path)\n", " cmap = plt.get_cmap(\"viridis\")\n", "\n", " fig, ax = plt.subplots(figsize=(14, 5))\n", " for start, end, pitch, velocity in events:\n", " color = cmap(velocity / 127.0)\n", " ax.hlines(y=pitch, xmin=start, xmax=end, linewidth=2.0, color=color)\n", "\n", " ax.set_xlabel(\"Time (s)\")\n", " ax.set_ylabel(\"MIDI Pitch\")\n", " ax.set_ylim(21, 108)\n", " ax.set_title(title or f\"Piano Roll: {Path(midi_path).name}\")\n", " ax.grid(alpha=0.2)\n", " fig.tight_layout()\n", "\n", " if save_path is None:\n", " plt.show()\n", " else:\n", " Path(save_path).parent.mkdir(parents=True, exist_ok=True)\n", " fig.savefig(save_path, dpi=150)\n", " plt.close(fig)\n", "\n", "\n", "def compare_pianorolls(\n", " seed_path: str | Path,\n", " continuation_path: str | Path,\n", " save_path: Optional[str | Path] = None,\n", ") -> None:\n", " \"\"\"Render side-by-side timeline comparison for seed and continuation.\"\"\"\n", "\n", " if plt is None: # pragma: no cover - optional dependency\n", " warnings.warn(\"compare_pianorolls called but matplotlib is not installed; skipping visualization\")\n", " return\n", "\n", " seed_events = _extract_note_events(seed_path)\n", " cont_events = _extract_note_events(continuation_path)\n", " seed_cmap = plt.get_cmap(\"Blues\")\n", " cont_cmap = plt.get_cmap(\"Oranges\")\n", "\n", " seed_dur = midi_duration(seed_path)\n", "\n", " fig, ax = plt.subplots(figsize=(16, 5))\n", "\n", " for start, end, pitch, velocity in seed_events:\n", " color = seed_cmap(0.3 + 0.7 * velocity / 127.0)\n", " ax.hlines(y=pitch, xmin=start, xmax=end, linewidth=2.0, color=color)\n", "\n", " for start, end, pitch, velocity in cont_events:\n", " start += seed_dur\n", " end += seed_dur\n", " color = cont_cmap(0.3 + 0.7 * velocity / 127.0)\n", " ax.hlines(y=pitch, xmin=start, xmax=end, linewidth=2.0, color=color)\n", "\n", " ax.axvline(seed_dur, linestyle=\"--\", linewidth=2, color=\"black\", alpha=0.7)\n", " ax.set_xlabel(\"Time (s)\")\n", " ax.set_ylabel(\"MIDI Pitch\")\n", " ax.set_ylim(21, 108)\n", " ax.set_title(\"Seed | Continuation\")\n", " ax.grid(alpha=0.2)\n", " fig.tight_layout()\n", "\n", " if save_path is None:\n", " plt.show()\n", " else:\n", " Path(save_path).parent.mkdir(parents=True, exist_ok=True)\n", " fig.savefig(save_path, dpi=150)\n", " plt.close(fig)\n", "\n", "\n", "@dataclass(frozen=True)\n", "class _TokenSpec:\n", " delta_start: int = 0\n", " delta_end: int = 127\n", " pitch_start: int = 128\n", " pitch_end: int = 215\n", " duration_start: int = 216\n", " duration_end: int = 343\n", " velocity_start: int = 344\n", " velocity_end: int = 359\n", " pad_id: int = 360\n", " bos_id: int = 361\n", " eos_id: int = 362\n", " density_start: int = 363\n", " density_end: int = 366\n", " voices_start: int = 367\n", " voices_end: int = 370\n", " register_start: int = 371\n", " register_end: int = 373\n", " event_size: int = 4\n", "\n", " @property\n", " def vocab_size(self) -> int:\n", " return 374\n", "\n", "\n", "class CustomDeltaTokenizer:\n", " \"\"\"Unified frozen quad tokenizer with structural prefix context.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " *,\n", " default_velocity: int = 88,\n", " include_special_tokens: bool = False,\n", " include_structural_meta_tokens: bool = True,\n", " prepend_start_token: bool = True,\n", " density_quartiles: Optional[Tuple[float, float, float]] = None,\n", " ) -> None:\n", " self.spec = _TokenSpec()\n", " self.default_velocity = int(max(1, min(127, default_velocity)))\n", " self.include_special_tokens = bool(include_special_tokens)\n", " self.include_structural_meta_tokens = bool(include_structural_meta_tokens)\n", " self.prepend_start_token = bool(prepend_start_token)\n", "\n", " self._density_labels = (\"v_low\", \"low\", \"med\", \"high\")\n", " self._voices_labels = (\"mono\", \"poly_small\", \"poly_med\", \"poly_large\")\n", " self._register_labels = (\"bass\", \"mid\", \"treble\")\n", " self._density_quartiles = self._sanitize_density_quartiles(density_quartiles)\n", "\n", " self._density_token_to_label: Dict[int, str] = {\n", " int(self.spec.density_start + i): label\n", " for i, label in enumerate(self._density_labels)\n", " }\n", " self._voices_token_to_label: Dict[int, str] = {\n", " int(self.spec.voices_start + i): label\n", " for i, label in enumerate(self._voices_labels)\n", " }\n", " self._register_token_to_label: Dict[int, str] = {\n", " int(self.spec.register_start + i): label\n", " for i, label in enumerate(self._register_labels)\n", " }\n", "\n", " self._delta_min_positive_seconds = 1e-4\n", " self._delta_max_seconds = 8.0\n", " self._duration_min_seconds = 1.0 / 64.0\n", " self._duration_max_seconds = 8.0\n", " self._velocity_bin_count = int(self.spec.velocity_end - self.spec.velocity_start + 1)\n", "\n", " self._delta_edges = np.logspace(\n", " math.log10(self._delta_min_positive_seconds),\n", " math.log10(self._delta_max_seconds),\n", " num=128,\n", " ).astype(np.float64)\n", " self._duration_edges = np.logspace(\n", " math.log10(self._duration_min_seconds),\n", " math.log10(self._duration_max_seconds),\n", " num=129,\n", " ).astype(np.float64)\n", " self._delta_bins = np.concatenate(\n", " [\n", " np.asarray([0.0], dtype=np.float64),\n", " np.sqrt(self._delta_edges[:-1] * self._delta_edges[1:]),\n", " ],\n", " axis=0,\n", " )\n", " self._duration_bins = np.sqrt(self._duration_edges[:-1] * self._duration_edges[1:])\n", "\n", " @property\n", " def vocab_size(self) -> int:\n", " return int(self.spec.vocab_size)\n", "\n", " @property\n", " def event_size(self) -> int:\n", " return int(self.spec.event_size)\n", "\n", " @property\n", " def pad_id(self) -> int:\n", " return self.spec.pad_id\n", "\n", " @property\n", " def bos_id(self) -> int:\n", " return self.spec.bos_id\n", "\n", " @property\n", " def eos_id(self) -> int:\n", " return self.spec.eos_id\n", "\n", " @staticmethod\n", " def _sanitize_density_quartiles(\n", " quartiles: Optional[Tuple[float, float, float]],\n", " ) -> Tuple[float, float, float]:\n", " if quartiles is None:\n", " values = [1.0, 2.5, 5.0]\n", " else:\n", " raw = [float(v) for v in quartiles]\n", " if len(raw) != 3:\n", " raise ValueError(\"density_quartiles must contain exactly three values\")\n", " values = sorted(max(1e-4, float(v)) for v in raw)\n", "\n", " q1 = float(values[0])\n", " q2 = float(max(values[1], q1 + 1e-4))\n", " q3 = float(max(values[2], q2 + 1e-4))\n", " return (q1, q2, q3)\n", "\n", " def save(self, path: str) -> None:\n", " out_path = Path(path)\n", " out_path.parent.mkdir(parents=True, exist_ok=True)\n", " payload = {\n", " \"type\": \"CustomDeltaTokenizer\",\n", " \"version\": 3,\n", " \"spec_version\": 3,\n", " \"frozen\": True,\n", " \"vocab_size\": int(self.vocab_size),\n", " \"default_velocity\": int(self.default_velocity),\n", " \"event_size\": int(self.event_size),\n", " \"include_special_tokens\": bool(self.include_special_tokens),\n", " \"include_structural_meta_tokens\": bool(self.include_structural_meta_tokens),\n", " \"prepend_start_token\": bool(self.prepend_start_token),\n", " \"density_quartiles\": [float(v) for v in self._density_quartiles],\n", " \"token_ids\": {\n", " \"delta\": [self.spec.delta_start, self.spec.delta_end],\n", " \"pitch\": [self.spec.pitch_start, self.spec.pitch_end],\n", " \"duration\": [self.spec.duration_start, self.spec.duration_end],\n", " \"velocity\": [self.spec.velocity_start, self.spec.velocity_end],\n", " \"density\": [self.spec.density_start, self.spec.density_end],\n", " \"voices\": [self.spec.voices_start, self.spec.voices_end],\n", " \"register\": [self.spec.register_start, self.spec.register_end],\n", " \"pad\": self.spec.pad_id,\n", " \"bos\": self.spec.bos_id,\n", " \"eos\": self.spec.eos_id,\n", " },\n", " }\n", " out_path.write_text(json.dumps(payload, indent=2), encoding=\"utf-8\")\n", "\n", " @classmethod\n", " def load(cls, path: str) -> \"CustomDeltaTokenizer\":\n", " in_path = Path(path)\n", " if not in_path.exists():\n", " raise FileNotFoundError(f\"Tokenizer file not found: {in_path}\")\n", "\n", " payload = json.loads(in_path.read_text(encoding=\"utf-8\"))\n", " if str(payload.get(\"type\", \"\")) != \"CustomDeltaTokenizer\":\n", " raise ValueError(\"Unsupported tokenizer payload. Expected type='CustomDeltaTokenizer'.\")\n", "\n", " token_ids = payload.get(\"token_ids\")\n", " token_ids = token_ids if isinstance(token_ids, dict) else {}\n", " legacy_no_meta = (\n", " int(payload.get(\"vocab_size\", 0)) <= 171\n", " and \"density\" not in token_ids\n", " and \"voices\" not in token_ids\n", " and \"register\" not in token_ids\n", " )\n", "\n", " quartiles_payload = payload.get(\"density_quartiles\")\n", " quartiles: Optional[Tuple[float, float, float]] = None\n", " if isinstance(quartiles_payload, (list, tuple)) and len(quartiles_payload) == 3:\n", " quartiles = (\n", " float(quartiles_payload[0]),\n", " float(quartiles_payload[1]),\n", " float(quartiles_payload[2]),\n", " )\n", "\n", " return cls(\n", " default_velocity=int(payload.get(\"default_velocity\", 88)),\n", " include_special_tokens=bool(payload.get(\"include_special_tokens\", False)),\n", " include_structural_meta_tokens=bool(\n", " payload.get(\"include_structural_meta_tokens\", not legacy_no_meta)\n", " ),\n", " prepend_start_token=bool(payload.get(\"prepend_start_token\", not legacy_no_meta)),\n", " density_quartiles=quartiles,\n", " )\n", "\n", " @staticmethod\n", " def _estimate_piece_density(\n", " events: Sequence[Tuple[float, int, float, int]],\n", " ) -> float:\n", " if not events:\n", " return 0.0\n", " starts = [float(ev[0]) for ev in events]\n", " ends = [float(ev[0] + max(1e-4, float(ev[2]))) for ev in events]\n", " span = float(max(1e-3, max(ends) - min(starts)))\n", " return float(len(events) / span)\n", "\n", " @staticmethod\n", " def _estimate_polyphony(\n", " events: Sequence[Tuple[float, int, float, int]],\n", " ) -> Tuple[float, float]:\n", " if not events:\n", " return (1.0, 1.0)\n", "\n", " boundaries: List[Tuple[float, int]] = []\n", " for onset, _, duration, _ in events:\n", " start = float(max(0.0, onset))\n", " end = float(max(start + 1e-4, start + float(duration)))\n", " boundaries.append((start, +1))\n", " boundaries.append((end, -1))\n", "\n", " boundaries.sort(key=lambda item: (item[0], -item[1]))\n", " active = 0\n", " max_active = 0\n", " weighted_active = 0.0\n", " total_time = 0.0\n", " last_t = float(boundaries[0][0])\n", "\n", " for t, delta in boundaries:\n", " t_f = float(t)\n", " dt = float(max(0.0, t_f - last_t))\n", " if dt > 0.0:\n", " weighted_active += float(max(0, active)) * dt\n", " total_time += dt\n", " active += int(delta)\n", " max_active = max(max_active, active)\n", " last_t = t_f\n", "\n", " mean_active = float(weighted_active / total_time) if total_time > 0.0 else float(max_active)\n", " return (float(mean_active), float(max(1, max_active)))\n", "\n", " def _density_token(self, density: float) -> int:\n", " q1, q2, q3 = self._density_quartiles\n", " if float(density) <= q1:\n", " idx = 0\n", " elif float(density) <= q2:\n", " idx = 1\n", " elif float(density) <= q3:\n", " idx = 2\n", " else:\n", " idx = 3\n", " return int(self.spec.density_start + idx)\n", "\n", " def _voices_token(self, mean_polyphony: float, peak_polyphony: float) -> int:\n", " mean_v = float(max(1.0, mean_polyphony))\n", " peak_v = float(max(1.0, peak_polyphony))\n", " if peak_v <= 1.05 and mean_v < 1.20:\n", " idx = 0\n", " elif mean_v < 2.00 and peak_v <= 3.00:\n", " idx = 1\n", " elif mean_v < 3.50 and peak_v <= 6.00:\n", " idx = 2\n", " else:\n", " idx = 3\n", " return int(self.spec.voices_start + idx)\n", "\n", " def _register_token(self, median_pitch: float) -> int:\n", " pitch = float(median_pitch)\n", " if pitch < 48.0:\n", " idx = 0\n", " elif pitch <= 72.0:\n", " idx = 1\n", " else:\n", " idx = 2\n", " return int(self.spec.register_start + idx)\n", "\n", " def _derive_structural_meta_tokens(\n", " self,\n", " events: Sequence[Tuple[float, int, float, int]],\n", " ) -> Tuple[int, int, int]:\n", " if not events:\n", " return (\n", " int(self.spec.density_start),\n", " int(self.spec.voices_start),\n", " int(self.spec.register_start + 1),\n", " )\n", "\n", " density = self._estimate_piece_density(events)\n", " mean_polyphony, peak_polyphony = self._estimate_polyphony(events)\n", " pitches = np.asarray([float(ev[1]) for ev in events], dtype=np.float64)\n", " median_pitch = float(np.median(pitches)) if int(pitches.size) > 0 else 60.0\n", "\n", " return (\n", " int(self._density_token(density)),\n", " int(self._voices_token(mean_polyphony, peak_polyphony)),\n", " int(self._register_token(median_pitch)),\n", " )\n", "\n", " def _note_events(self, midi_path: Path) -> List[Tuple[float, int, float, int]]:\n", " midi = pretty_midi.PrettyMIDI(str(midi_path))\n", " events: List[Tuple[float, int, float, int]] = []\n", " for inst in midi.instruments:\n", " if inst.is_drum:\n", " continue\n", " for note in inst.notes:\n", " onset = float(max(0.0, note.start))\n", " duration = float(max(1e-4, note.end - note.start))\n", " pitch = int(note.pitch)\n", " velocity = int(max(0, min(127, int(note.velocity))))\n", " if pitch < 21 or pitch > 108:\n", " continue\n", " events.append((onset, pitch, duration, velocity))\n", " events.sort(key=lambda x: (x[0], x[1], x[2], x[3]))\n", " return events\n", "\n", " def _quantize_delta(self, delta_seconds: float) -> int:\n", " clamped = float(max(0.0, min(self._delta_max_seconds, float(delta_seconds))))\n", " if clamped <= 0.0:\n", " idx = 0\n", " else:\n", " positive = float(max(self._delta_min_positive_seconds, clamped))\n", " pos_idx = int(np.searchsorted(self._delta_edges, positive, side=\"right\") - 1)\n", " pos_idx = max(0, min(126, pos_idx))\n", " idx = 1 + int(pos_idx)\n", " return int(self.spec.delta_start + idx)\n", "\n", " def _quantize_duration(self, duration_seconds: float) -> int:\n", " clamped = float(max(self._duration_min_seconds, min(self._duration_max_seconds, float(duration_seconds))))\n", " idx = int(np.searchsorted(self._duration_edges, clamped, side=\"right\") - 1)\n", " idx = max(0, min(127, idx))\n", " return int(self.spec.duration_start + idx)\n", "\n", " def _quantize_pitch(self, pitch: int) -> int:\n", " pitch_i = int(max(21, min(108, pitch)))\n", " return int(self.spec.pitch_start + (pitch_i - 21))\n", "\n", " def _quantize_velocity(self, velocity: int) -> int:\n", " vel = int(max(0, min(127, int(velocity))))\n", " bin_idx = int((float(vel) / 128.0) * float(self._velocity_bin_count))\n", " bin_idx = max(0, min(self._velocity_bin_count - 1, bin_idx))\n", " return int(self.spec.velocity_start + bin_idx)\n", "\n", " def _dequantize_delta(self, token_id: int) -> float:\n", " idx = int(token_id) - self.spec.delta_start\n", " idx = max(0, min(int(self._delta_bins.shape[0]) - 1, idx))\n", " return float(self._delta_bins[idx])\n", "\n", " def _dequantize_duration(self, token_id: int) -> float:\n", " idx = int(token_id) - self.spec.duration_start\n", " idx = max(0, min(int(self._duration_bins.shape[0]) - 1, idx))\n", " return float(self._duration_bins[idx])\n", "\n", " def _dequantize_pitch(self, token_id: int) -> int:\n", " idx = int(token_id) - self.spec.pitch_start\n", " idx = max(0, min(87, idx))\n", " return int(21 + idx)\n", "\n", " def _dequantize_velocity(self, token_id: int) -> int:\n", " idx = int(token_id) - self.spec.velocity_start\n", " idx = max(0, min(self._velocity_bin_count - 1, idx))\n", " center = (float(idx) + 0.5) * (128.0 / float(self._velocity_bin_count))\n", " return int(max(0, min(127, int(center))))\n", "\n", " def _encode_event_tuples(\n", " self,\n", " events: Sequence[Tuple[float, int, float, int]],\n", " ) -> Tuple[List[int], List[float], List[float]]:\n", " event_list = list(events)\n", " token_ids: List[int] = []\n", " onset_times: List[float] = []\n", " durations: List[float] = []\n", " prev_onset = 0.0\n", "\n", " if self.include_structural_meta_tokens:\n", " density_tok, voices_tok, register_tok = self._derive_structural_meta_tokens(event_list)\n", " token_ids.extend([density_tok, voices_tok, register_tok])\n", " onset_times.extend([0.0, 0.0, 0.0])\n", " durations.extend([1e-4, 1e-4, 1e-4])\n", "\n", " if self.prepend_start_token or self.include_special_tokens:\n", " token_ids.append(self.spec.bos_id)\n", " onset_times.append(0.0)\n", " durations.append(1e-4)\n", "\n", " for onset, pitch, duration, velocity in event_list:\n", " delta = float(max(0.0, onset - prev_onset))\n", " prev_onset = onset\n", "\n", " d_tok = self._quantize_delta(delta)\n", " p_tok = self._quantize_pitch(pitch)\n", " u_tok = self._quantize_duration(duration)\n", " v_tok = self._quantize_velocity(velocity)\n", " token_ids.extend([d_tok, p_tok, u_tok, v_tok])\n", " onset_times.extend([float(onset), float(onset), float(onset), float(onset)])\n", " durations.extend([float(duration), float(duration), float(duration), float(duration)])\n", "\n", " if self.include_special_tokens:\n", " end_onset = float(onset_times[-1]) if onset_times else 0.0\n", " token_ids.append(self.spec.eos_id)\n", " onset_times.append(end_onset)\n", " durations.append(1e-4)\n", "\n", " if len(token_ids) != len(onset_times):\n", " raise AssertionError(\n", " \"CustomDeltaTokenizer alignment error: \"\n", " f\"len(ids)={len(token_ids)} len(onsets)={len(onset_times)}\"\n", " )\n", " return token_ids, onset_times, durations\n", "\n", " def encode(self, midi_path: Path) -> List[int]:\n", " token_ids, _, _ = self._encode_event_tuples(self._note_events(Path(midi_path)))\n", " return token_ids\n", "\n", " def encode_with_time_features(self, midi_path: Path) -> Tuple[List[int], List[float], List[float]]:\n", " return self._encode_event_tuples(self._note_events(Path(midi_path)))\n", "\n", " def decode_token_id_events(self, token_id: int) -> List[str]:\n", " token = int(token_id)\n", " if self.spec.delta_start <= token <= self.spec.delta_end:\n", " return [f\"Delta_{self._dequantize_delta(token):.6f}\"]\n", " if self.spec.pitch_start <= token <= self.spec.pitch_end:\n", " return [f\"Pitch_{self._dequantize_pitch(token)}\"]\n", " if self.spec.duration_start <= token <= self.spec.duration_end:\n", " return [f\"Duration_{self._dequantize_duration(token):.6f}\"]\n", " if self.spec.velocity_start <= token <= self.spec.velocity_end:\n", " return [f\"Velocity_{self._dequantize_velocity(token)}\"]\n", " if self.spec.density_start <= token <= self.spec.density_end:\n", " return [f\"Density_{self._density_token_to_label.get(token, 'v_low')}\"]\n", " if self.spec.voices_start <= token <= self.spec.voices_end:\n", " return [f\"Voices_{self._voices_token_to_label.get(token, 'mono')}\"]\n", " if self.spec.register_start <= token <= self.spec.register_end:\n", " return [f\"Register_{self._register_token_to_label.get(token, 'mid')}\"]\n", " if token == self.spec.pad_id:\n", " return [\"PAD_None\"]\n", " if token == self.spec.bos_id:\n", " return [\"BOS_None\"]\n", " if token == self.spec.eos_id:\n", " return [\"EOS_None\"]\n", " return []\n", "\n", " def decode(self, token_ids: Sequence[int], output_path: Path | str | None = None) -> Any:\n", " midi = pretty_midi.PrettyMIDI()\n", " piano = pretty_midi.Instrument(program=0)\n", "\n", " tokens = [int(t) for t in token_ids]\n", " i = 0\n", " onset = 0.0\n", "\n", " while i < len(tokens):\n", " tok = tokens[i]\n", " if tok == self.spec.eos_id:\n", " break\n", " if tok == self.spec.pad_id or tok == self.spec.bos_id or self.spec.density_start <= tok <= self.spec.register_end:\n", " i += 1\n", " continue\n", " if not (self.spec.delta_start <= tok <= self.spec.delta_end):\n", " i += 1\n", " continue\n", " if i + 3 >= len(tokens):\n", " break\n", "\n", " p_tok = tokens[i + 1]\n", " d_tok = tokens[i + 2]\n", " v_tok = tokens[i + 3]\n", " if not (self.spec.pitch_start <= p_tok <= self.spec.pitch_end):\n", " i += 1\n", " continue\n", " if not (self.spec.duration_start <= d_tok <= self.spec.duration_end):\n", " i += 1\n", " continue\n", " if not (self.spec.velocity_start <= v_tok <= self.spec.velocity_end):\n", " i += 1\n", " continue\n", "\n", " delta = self._dequantize_delta(tok)\n", " pitch = self._dequantize_pitch(p_tok)\n", " duration = self._dequantize_duration(d_tok)\n", " velocity = self._dequantize_velocity(v_tok)\n", "\n", " onset = float(max(0.0, onset + max(0.0, delta)))\n", " end = float(max(onset + 1e-4, onset + duration))\n", " piano.notes.append(\n", " pretty_midi.Note(\n", " velocity=int(velocity),\n", " pitch=int(pitch),\n", " start=float(onset),\n", " end=float(end),\n", " )\n", " )\n", " i += 4\n", "\n", " midi.instruments.append(piano)\n", "\n", " if output_path is not None:\n", " out_path = Path(output_path)\n", " out_path.parent.mkdir(parents=True, exist_ok=True)\n", " midi.write(str(out_path))\n", "\n", " return midi\n", "\n", " def verify_roundtrip(self, midi_path: Path) -> bool:\n", " try:\n", " ids = self.encode(Path(midi_path))\n", " _ = self.decode(ids)\n", " return len(ids) > 0\n", " except Exception:\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "id": "236d9bef", "metadata": { "id": "236d9bef" }, "outputs": [], "source": [ "import json\n", "import math\n", "import os\n", "import re\n", "import time\n", "import importlib\n", "import warnings\n", "from dataclasses import dataclass, fields\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Sequence, Tuple\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from safetensors import safe_open as safetensors_safe_open\n", "from safetensors.torch import load_file as safetensors_load_file\n", "\n", "_ipd = None\n", "try:\n", " _ipd = importlib.import_module(\"IPython.display\")\n", "except Exception:\n", " _ipd = None\n", "Audio = getattr(_ipd, \"Audio\", None)\n", "Image = getattr(_ipd, \"Image\", None)\n", "display = getattr(_ipd, \"display\", None)\n", "\n", "\n", "def _strip_dataparallel_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n", " out: Dict[str, torch.Tensor] = {}\n", " for key, value in state_dict.items():\n", " if key.startswith(\"module.\"):\n", " out[key[len(\"module.\"):]] = value\n", " else:\n", " out[key] = value\n", " return out\n", "\n", "\n", "def _rotate_half(x: torch.Tensor) -> torch.Tensor:\n", " half = x.shape[-1] // 2\n", " x1 = x[..., :half]\n", " x2 = x[..., half:]\n", " return torch.cat([-x2, x1], dim=-1)\n", "\n", "\n", "class RotaryEmbedding(nn.Module):\n", " def __init__(self, dim: int, base: float = 10000.0) -> None:\n", " super().__init__()\n", " if dim <= 0 or dim % 2 != 0:\n", " raise ValueError(\"RoPE dimension must be positive and even\")\n", " self.dim = int(dim)\n", " self.base = float(base)\n", " self._seq_len_cached = 0\n", " self.register_buffer(\"_cos_cached\", torch.empty(0), persistent=False)\n", " self.register_buffer(\"_sin_cached\", torch.empty(0), persistent=False)\n", "\n", " def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:\n", " inv_freq = 1.0 / (\n", " self.base\n", " ** (\n", " torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)\n", " / float(self.dim)\n", " )\n", " )\n", " positions = torch.arange(seq_len, device=device, dtype=torch.float32)\n", " freqs = torch.outer(positions, inv_freq)\n", " emb = torch.cat([freqs, freqs], dim=-1)\n", " self._cos_cached = torch.cos(emb).to(dtype=dtype)\n", " self._sin_cached = torch.sin(emb).to(dtype=dtype)\n", " self._seq_len_cached = int(seq_len)\n", "\n", " def _get_cos_sin(\n", " self,\n", " seq_len: int,\n", " device: torch.device,\n", " dtype: torch.dtype,\n", " offset: int = 0,\n", " ) -> Tuple[torch.Tensor, torch.Tensor]:\n", " required = int(seq_len + max(0, int(offset)))\n", " if (\n", " self._cos_cached.numel() == 0\n", " or self._sin_cached.numel() == 0\n", " or self._seq_len_cached < required\n", " or self._cos_cached.device != device\n", " or self._cos_cached.dtype != dtype\n", " ):\n", " self._build_cache(required, device=device, dtype=dtype)\n", " start = int(max(0, offset))\n", " end = start + int(seq_len)\n", " return self._cos_cached[start:end], self._sin_cached[start:end]\n", "\n", " def apply(self, q: torch.Tensor, k: torch.Tensor, offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:\n", " seq_len = int(q.shape[-2])\n", " cos, sin = self._get_cos_sin(seq_len=seq_len, device=q.device, dtype=q.dtype, offset=int(offset))\n", " cos = cos.view(1, 1, seq_len, self.dim)\n", " sin = sin.view(1, 1, seq_len, self.dim)\n", " return (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin)\n", "\n", "\n", "class GQABlock(nn.Module):\n", " def __init__(\n", " self,\n", " d_model: int,\n", " num_heads: int,\n", " num_kv_heads: Optional[int] = None,\n", " dropout: float = 0.1,\n", " rope_base: float = 10000.0,\n", " ) -> None:\n", " super().__init__()\n", " if d_model <= 0:\n", " raise ValueError(\"d_model must be > 0\")\n", " if num_heads <= 0:\n", " raise ValueError(\"num_heads must be > 0\")\n", " if d_model % num_heads != 0:\n", " raise ValueError(f\"d_model ({d_model}) must be divisible by num_heads ({num_heads})\")\n", "\n", " self.d_model = int(d_model)\n", " self.num_heads = int(num_heads)\n", " self.head_dim = self.d_model // self.num_heads\n", " if self.head_dim % 2 != 0:\n", " raise ValueError(\"head_dim must be even for RoPE\")\n", "\n", " kv = int(num_kv_heads) if num_kv_heads is not None else int(num_heads)\n", " kv = max(1, kv)\n", " if self.num_heads % kv != 0:\n", " raise ValueError(f\"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({kv})\")\n", " self.num_kv_heads = int(kv)\n", " self.group_size = self.num_heads // self.num_kv_heads\n", "\n", " self.q_proj = nn.Linear(self.d_model, self.num_heads * self.head_dim, bias=False)\n", " self.k_proj = nn.Linear(self.d_model, self.num_kv_heads * self.head_dim, bias=False)\n", " self.v_proj = nn.Linear(self.d_model, self.num_kv_heads * self.head_dim, bias=False)\n", " self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)\n", " self.out_dropout = nn.Dropout(float(dropout))\n", " self.rope = RotaryEmbedding(dim=self.head_dim, base=float(rope_base))\n", "\n", " def forward(self, x: torch.Tensor, position_offset: int = 0) -> torch.Tensor:\n", " batch_size, seq_len, dim = x.shape\n", " if dim != self.d_model:\n", " raise ValueError(f\"last dim must be {self.d_model}, got {dim}\")\n", "\n", " q = self.q_proj(x)\n", " k = self.k_proj(x)\n", " v = self.v_proj(x)\n", "\n", " q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n", " k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n", " v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)\n", "\n", " if self.group_size > 1:\n", " k = k.repeat_interleave(self.group_size, dim=1)\n", " v = v.repeat_interleave(self.group_size, dim=1)\n", "\n", " q, k = self.rope.apply(q, k, offset=int(max(0, position_offset)))\n", " attn_out = F.scaled_dot_product_attention(\n", " q,\n", " k,\n", " v,\n", " attn_mask=None,\n", " dropout_p=self.out_dropout.p if self.training else 0.0,\n", " is_causal=True,\n", " )\n", " out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)\n", " out = self.out_proj(out)\n", " out = self.out_dropout(out)\n", " return out\n", "\n", "\n", "@dataclass\n", "class SamplingDiagnostics:\n", " raw_top1_prob: torch.Tensor\n", " final_top1_prob: torch.Tensor\n", " candidate_count: torch.Tensor\n", "\n", "\n", "def _apply_repetition_penalty(\n", " logits: torch.Tensor,\n", " context_tokens: torch.Tensor,\n", " repetition_penalty: float,\n", " recent_window: int,\n", ") -> torch.Tensor:\n", " if repetition_penalty <= 1.0 or recent_window <= 0:\n", " return logits\n", " adjusted = logits.clone()\n", " recent = context_tokens[:, -min(recent_window, context_tokens.shape[1]):]\n", " for batch_idx in range(adjusted.shape[0]):\n", " token_ids = torch.unique(recent[batch_idx])\n", " token_logits = adjusted[batch_idx, token_ids]\n", " adjusted[batch_idx, token_ids] = torch.where(\n", " token_logits < 0,\n", " token_logits * repetition_penalty,\n", " token_logits / repetition_penalty,\n", " )\n", " return adjusted\n", "\n", "\n", "def _apply_topk_topp_filter(\n", " logits: torch.Tensor,\n", " top_k: Optional[int],\n", " top_p: Optional[float],\n", " min_tokens_to_keep: int,\n", ") -> torch.Tensor:\n", " batch_size, vocab_size = logits.shape\n", " keep_k = vocab_size\n", " if top_k is not None and top_k > 0:\n", " keep_k = min(max(int(top_k), int(min_tokens_to_keep)), vocab_size)\n", "\n", " topk_indices = torch.topk(logits, k=keep_k, dim=-1).indices\n", " candidate_mask = torch.zeros_like(logits, dtype=torch.bool)\n", " candidate_mask.scatter_(dim=-1, index=topk_indices, value=True)\n", "\n", " if top_p is not None and 0.0 < top_p < 1.0:\n", " sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)\n", " sorted_probs = torch.softmax(sorted_logits, dim=-1)\n", " cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n", " remove_mask = cumulative_probs > float(top_p)\n", " if min_tokens_to_keep > 0:\n", " remove_mask[..., : int(min_tokens_to_keep)] = False\n", " nucleus_keep = ~remove_mask\n", " nucleus_mask = torch.zeros_like(candidate_mask)\n", " nucleus_mask.scatter_(dim=-1, index=sorted_indices, src=nucleus_keep)\n", " candidate_mask = candidate_mask & nucleus_mask\n", "\n", " counts = candidate_mask.sum(dim=-1)\n", " if bool((counts < int(min_tokens_to_keep)).any()):\n", " top_idx = torch.topk(logits, k=int(min_tokens_to_keep), dim=-1).indices\n", " for row_idx in range(batch_size):\n", " if int(counts[row_idx].item()) < int(min_tokens_to_keep):\n", " candidate_mask[row_idx, top_idx[row_idx]] = True\n", "\n", " return logits.masked_fill(~candidate_mask, float(\"-inf\"))\n", "\n", "\n", "def sample_next_token(\n", " logits: torch.Tensor,\n", " context_tokens: torch.Tensor,\n", " temperature: float,\n", " top_p: Optional[float],\n", " top_k: Optional[int],\n", " repetition_penalty: float,\n", " recent_window: int,\n", " min_tokens_to_keep: int,\n", " top1_cap: Optional[float] = None,\n", ") -> Tuple[torch.Tensor, SamplingDiagnostics]:\n", " del top1_cap\n", " temperature = max(float(temperature), 0.1)\n", " min_tokens_to_keep = max(1, int(min_tokens_to_keep))\n", "\n", " penalized = _apply_repetition_penalty(\n", " logits=logits,\n", " context_tokens=context_tokens,\n", " repetition_penalty=float(repetition_penalty),\n", " recent_window=int(recent_window),\n", " )\n", " scaled = penalized / temperature\n", " raw_probs = torch.softmax(scaled, dim=-1)\n", " raw_top1_prob = raw_probs.max(dim=-1).values\n", "\n", " filtered_logits = _apply_topk_topp_filter(\n", " logits=scaled,\n", " top_k=top_k,\n", " top_p=top_p,\n", " min_tokens_to_keep=min_tokens_to_keep,\n", " )\n", " candidate_mask = torch.isfinite(filtered_logits)\n", " probs = torch.softmax(filtered_logits, dim=-1)\n", " probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)\n", "\n", " diagnostics = SamplingDiagnostics(\n", " raw_top1_prob=raw_top1_prob,\n", " final_top1_prob=probs.max(dim=-1).values,\n", " candidate_count=candidate_mask.sum(dim=-1),\n", " )\n", " next_token = torch.multinomial(probs, num_samples=1)\n", " return next_token, diagnostics\n", "\n", "\n", "GDN_AVAILABLE = False\n", "_GatedDeltaNet = None\n", "\n", "\n", "def _try_import_fla() -> bool:\n", " global GDN_AVAILABLE, _GatedDeltaNet\n", " try:\n", " fla_layers = importlib.import_module(\"fla.layers\")\n", " _GDN = getattr(fla_layers, \"GatedDeltaNet\")\n", " _GatedDeltaNet = _GDN\n", " GDN_AVAILABLE = True\n", " return True\n", " except Exception:\n", " _GatedDeltaNet = None\n", " GDN_AVAILABLE = False\n", " return False\n", "\n", "\n", "_try_import_fla()\n", "\n", "\n", "class _GatedDeltaFallback(nn.Module):\n", " def __init__(self, d_model: int, dropout: float = 0.1) -> None:\n", " super().__init__()\n", " self.mix = nn.Linear(d_model, d_model * 2, bias=False)\n", " self.out = nn.Linear(d_model, d_model, bias=False)\n", " self.dropout = nn.Dropout(float(dropout))\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " u, g = self.mix(x).chunk(2, dim=-1)\n", " y = F.silu(u) * torch.sigmoid(g)\n", " return self.dropout(self.out(y))\n", "\n", "\n", "class GatedDeltaNetBlock(nn.Module):\n", " def __init__(\n", " self,\n", " d_model: int,\n", " inner_dim: int = 320,\n", " num_heads: int = 5,\n", " dropout: float = 0.1,\n", " ) -> None:\n", " super().__init__()\n", " if inner_dim % num_heads != 0:\n", " raise ValueError(\"inner_dim must be divisible by num_heads\")\n", " self.d_model = int(d_model)\n", " self.inner_dim = int(inner_dim)\n", " self.num_heads = int(num_heads)\n", " self.head_dim = self.inner_dim // self.num_heads\n", "\n", " self.in_proj = nn.Identity() if self.inner_dim == self.d_model else nn.Linear(self.d_model, self.inner_dim, bias=False)\n", " self.out_proj = nn.Identity() if self.inner_dim == self.d_model else nn.Linear(self.inner_dim, self.d_model, bias=False)\n", "\n", " if GDN_AVAILABLE and _GatedDeltaNet is not None:\n", " self.core = _GatedDeltaNet(\n", " hidden_size=self.inner_dim,\n", " num_heads=self.num_heads,\n", " head_dim=self.head_dim,\n", " mode=\"chunk\",\n", " use_short_conv=True,\n", " )\n", " self.using_fallback = False\n", " else:\n", " self.core = _GatedDeltaFallback(self.inner_dim, dropout=dropout)\n", " self.using_fallback = True\n", " warnings.warn(\n", " \"flash-linear-attention GatedDeltaNet is unavailable; using fallback approximation block for GDN-based variants. Install `flash-linear-attention` for true GDN behavior.\"\n", " )\n", "\n", " self.post_dropout = nn.Dropout(float(dropout))\n", "\n", " def _run_core(self, x: torch.Tensor) -> torch.Tensor:\n", " if self.using_fallback:\n", " return self.core(x)\n", " out = self.core(x)\n", " return out[0] if isinstance(out, tuple) else out\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " if x.ndim != 3:\n", " raise ValueError(f\"GatedDeltaNetBlock expects (batch, seq, d_model), got {tuple(x.shape)}\")\n", " if int(x.shape[-1]) != self.d_model:\n", " raise ValueError(f\"Expected feature dim {self.d_model}, got {int(x.shape[-1])}\")\n", "\n", " y = self.in_proj(x)\n", " y = self._run_core(y)\n", " y = self.out_proj(y)\n", " y = self.post_dropout(y)\n", " return y\n", "\n", "\n", "def _resolve_divisible_heads(\n", " width: int,\n", " requested_heads: int,\n", " *,\n", " require_even_head_dim: bool = False,\n", ") -> int:\n", " w = int(max(1, width))\n", " heads = max(1, min(int(requested_heads), w))\n", " while heads > 1:\n", " if (w % heads) == 0:\n", " head_dim = w // heads\n", " if not require_even_head_dim or (head_dim % 2 == 0):\n", " return int(heads)\n", " heads -= 1\n", "\n", " if require_even_head_dim and (w % 2) != 0:\n", " return 1\n", " return 1\n", "\n", "\n", "def _resolve_hybrid_dims(d_model: int, gdn_ratio: float) -> Tuple[int, int]:\n", " d = int(max(4, d_model))\n", " ratio = float(min(0.9, max(0.1, gdn_ratio)))\n", " gdn_dim = int(round(float(d) * ratio))\n", " gdn_dim = max(1, min(d - 1, gdn_dim))\n", " gqa_dim = int(d - gdn_dim)\n", "\n", " if (gqa_dim % 2) != 0:\n", " if gdn_dim > 1:\n", " gdn_dim -= 1\n", " gqa_dim += 1\n", " else:\n", " gdn_dim += 1\n", " gqa_dim -= 1\n", "\n", " gdn_dim = max(1, gdn_dim)\n", " gqa_dim = max(2, gqa_dim)\n", " if (gqa_dim % 2) != 0:\n", " gqa_dim -= 1\n", " gdn_dim += 1\n", "\n", " return int(gdn_dim), int(gqa_dim)\n", "\n", "\n", "class ContinuousTimeEncoding(nn.Module):\n", " def __init__(self, d_model: int, max_time_seconds: float = 600.0) -> None:\n", " super().__init__()\n", " if d_model <= 0:\n", " raise ValueError(\"d_model must be > 0\")\n", "\n", " self.d_model = int(d_model)\n", " self.max_time_seconds = float(max_time_seconds)\n", " if self.max_time_seconds < 1.0:\n", " self.max_time_seconds = 1.0\n", "\n", " timescales = torch.tensor(\n", " [\n", " 0.05,\n", " 0.1,\n", " 0.2,\n", " 0.5,\n", " 1.0,\n", " 2.0,\n", " 4.0,\n", " 8.0,\n", " 16.0,\n", " 32.0,\n", " 64.0,\n", " 128.0,\n", " 256.0,\n", " 512.0,\n", " ],\n", " dtype=torch.float32,\n", " )\n", " self.register_buffer(\"timescales\", timescales)\n", " self.projection = nn.Linear(int(timescales.numel()) * 2, self.d_model)\n", " self.output_scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32))\n", " nn.init.xavier_uniform_(self.projection.weight, gain=0.1)\n", " nn.init.zeros_(self.projection.bias)\n", "\n", " def forward(self, onset_times: torch.Tensor) -> torch.Tensor:\n", " if onset_times.ndim != 2:\n", " raise ValueError(f\"onset_times must be shaped (batch, seq_len), got {tuple(onset_times.shape)}\")\n", "\n", " onset = onset_times.to(dtype=torch.float32)\n", " onset = torch.relu(onset)\n", " max_time = float(self.max_time_seconds)\n", " onset = onset - torch.relu(onset - max_time)\n", "\n", " timescales = torch.reshape(self.timescales, (1, 1, -1))\n", " scaled = onset.unsqueeze(-1) / timescales\n", " features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)\n", " encoded = self.projection(features)\n", " scale = torch.sigmoid(self.output_scale)\n", " return encoded * scale\n", "\n", "\n", "@dataclass\n", "class VariantEConfig:\n", " vocab_size: int = 374\n", " d_model: int = 1024\n", " n_layers: int = 14\n", " max_sequence_length: int = 2048\n", " dropout: float = 0.1\n", " attention_dropout: float = 0.1\n", " tie_embeddings: bool = True\n", " embedding_init_std: float = 0.02\n", " output_logit_scale: Optional[float] = None\n", " gdn_path_ratio: float = 0.5\n", " gdn_inner_dim: int = 512\n", " gdn_num_heads: int = 4\n", " gqa_num_heads: int = 8\n", " gqa_groups: int = 4\n", " attention_every_n_layers: int = 1\n", " full_attention: bool = False\n", " use_continuous_time: bool = True\n", " max_time_seconds: float = 1200.0\n", " use_v2_architecture: bool = True\n", "\n", "\n", "class _VariantEBlock(nn.Module):\n", " def __init__(self, cfg: VariantEConfig) -> None:\n", " super().__init__()\n", " d = int(cfg.d_model)\n", "\n", " self.gdn_dim, self.gqa_dim = _resolve_hybrid_dims(\n", " d_model=d,\n", " gdn_ratio=float(cfg.gdn_path_ratio),\n", " )\n", "\n", " self.norm_in = nn.LayerNorm(d)\n", " self.gdn_in_proj = nn.Linear(d, self.gdn_dim, bias=False)\n", " self.gqa_in_proj = nn.Linear(d, self.gqa_dim, bias=False)\n", "\n", " gdn_inner_dim = max(int(self.gdn_dim), int(cfg.gdn_inner_dim))\n", " gdn_heads = _resolve_divisible_heads(\n", " width=int(gdn_inner_dim),\n", " requested_heads=int(cfg.gdn_num_heads),\n", " require_even_head_dim=False,\n", " )\n", "\n", " self.norm_gdn = nn.LayerNorm(self.gdn_dim)\n", " self.gdn = GatedDeltaNetBlock(\n", " d_model=int(self.gdn_dim),\n", " inner_dim=int(gdn_inner_dim),\n", " num_heads=int(gdn_heads),\n", " dropout=float(cfg.dropout),\n", " )\n", "\n", " gqa_heads = _resolve_divisible_heads(\n", " width=int(self.gqa_dim),\n", " requested_heads=int(cfg.gqa_num_heads),\n", " require_even_head_dim=True,\n", " )\n", " gqa_groups = max(1, min(int(cfg.gqa_groups), int(gqa_heads)))\n", " while gqa_groups > 1 and (gqa_heads % gqa_groups) != 0:\n", " gqa_groups -= 1\n", " kv_heads = max(1, int(gqa_heads) // int(gqa_groups))\n", "\n", " self.norm_gqa = nn.LayerNorm(self.gqa_dim)\n", " self.gqa = GQABlock(\n", " d_model=int(self.gqa_dim),\n", " num_heads=int(gqa_heads),\n", " num_kv_heads=int(kv_heads),\n", " dropout=float(cfg.attention_dropout),\n", " )\n", "\n", " self.fuse = nn.Sequential(\n", " nn.LayerNorm(int(self.gdn_dim + self.gqa_dim)),\n", " nn.Linear(int(self.gdn_dim + self.gqa_dim), d, bias=False),\n", " nn.GELU(),\n", " nn.Dropout(float(cfg.dropout)),\n", " nn.Linear(d, d, bias=False),\n", " )\n", "\n", " def forward(self, x: torch.Tensor, position_offset: int) -> torch.Tensor:\n", " h = self.norm_in(x)\n", "\n", " gdn_state = self.gdn_in_proj(h)\n", " gdn_state = gdn_state + self.gdn(self.norm_gdn(gdn_state))\n", "\n", " gqa_state = self.gqa_in_proj(h)\n", " gqa_state = gqa_state + self.gqa(\n", " self.norm_gqa(gqa_state), position_offset=int(max(0, position_offset))\n", " )\n", "\n", " fused = self.fuse(torch.cat([gdn_state, gqa_state], dim=-1))\n", " return x + fused\n", "\n", "\n", "class VariantEModel(nn.Module):\n", " def __init__(self, config: Optional[VariantEConfig] = None) -> None:\n", " super().__init__()\n", " self.config = config or VariantEConfig()\n", " cfg = self.config\n", "\n", " self.vocab_size = int(cfg.vocab_size)\n", " self.d_model = int(cfg.d_model)\n", " self.max_sequence_length = int(cfg.max_sequence_length)\n", "\n", " self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)\n", " self.position_embedding = nn.Embedding(self.max_sequence_length, self.d_model)\n", " self.time_encoding = (\n", " ContinuousTimeEncoding(\n", " d_model=self.d_model,\n", " max_time_seconds=float(max(1.0, cfg.max_time_seconds)),\n", " )\n", " if bool(cfg.use_continuous_time)\n", " else None\n", " )\n", " self.dropout = nn.Dropout(float(cfg.dropout))\n", "\n", " n_layers = int(cfg.n_layers)\n", " self.layers = nn.ModuleList([_VariantEBlock(cfg) for _ in range(n_layers)])\n", "\n", " self.final_norm = nn.LayerNorm(self.d_model)\n", " self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False)\n", " if bool(cfg.tie_embeddings):\n", " self.lm_head.weight = self.token_embedding.weight\n", "\n", " self.output_logit_scale = (\n", " 1.0 / math.sqrt(float(self.d_model))\n", " if cfg.output_logit_scale is None\n", " else float(cfg.output_logit_scale)\n", " )\n", "\n", " self._reset_parameters()\n", " self.last_generation_stats: Dict[str, Any] = {}\n", "\n", " def _reset_parameters(self) -> None:\n", " std = float(max(1e-6, self.config.embedding_init_std))\n", " nn.init.normal_(self.token_embedding.weight, mean=0.0, std=std)\n", " nn.init.normal_(self.position_embedding.weight, mean=0.0, std=std)\n", " if self.lm_head.weight.data_ptr() != self.token_embedding.weight.data_ptr():\n", " nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)\n", "\n", " @staticmethod\n", " def _unwrap(model: Any) -> Any:\n", " if isinstance(model, torch.nn.DataParallel):\n", " return model.module\n", " return model\n", "\n", " def _prepare_generation_device(self) -> torch.device:\n", " current_device = next(self.parameters()).device\n", " if torch.cuda.is_available() and current_device.type == \"cuda\":\n", " target_device = torch.device(\"cuda:0\")\n", " else:\n", " target_device = current_device\n", "\n", " if current_device != target_device:\n", " self.to(target_device)\n", " return target_device\n", "\n", " @staticmethod\n", " def _to_seed_tensor(\n", " seed_tokens: Sequence[int] | torch.Tensor,\n", " *,\n", " device: torch.device,\n", " ) -> torch.Tensor:\n", " if isinstance(seed_tokens, torch.Tensor):\n", " if seed_tokens.ndim == 1:\n", " seed = seed_tokens.unsqueeze(0)\n", " elif seed_tokens.ndim == 2 and int(seed_tokens.shape[0]) == 1:\n", " seed = seed_tokens\n", " else:\n", " raise ValueError(\"seed tensor must be shape (seq,) or (1, seq)\")\n", " return seed.to(device=device, dtype=torch.long)\n", " vals = [int(t) for t in seed_tokens]\n", " if not vals:\n", " raise ValueError(\"seed_tokens cannot be empty\")\n", " return torch.tensor(vals, dtype=torch.long, device=device).unsqueeze(0)\n", "\n", " @staticmethod\n", " def _triplet_slot(index: int) -> int:\n", " return int(index % 4)\n", "\n", " @staticmethod\n", " def _allowed_ids_for_slot(slot: int, vocab_size: int) -> torch.Tensor:\n", " if slot == 0:\n", " return torch.arange(0, 128, dtype=torch.long)\n", " if slot == 1:\n", " return torch.arange(128, 216, dtype=torch.long)\n", " if slot == 2:\n", " return torch.arange(216, 344, dtype=torch.long)\n", " if slot == 3:\n", " return torch.arange(344, 360, dtype=torch.long)\n", " return torch.arange(0, vocab_size, dtype=torch.long)\n", "\n", " def _mask_logits_to_triplet_slot(self, logits: torch.Tensor, slot: int) -> torch.Tensor:\n", " mask = torch.full_like(logits, fill_value=-float(\"inf\"))\n", " allowed = self._allowed_ids_for_slot(slot, logits.shape[-1]).to(logits.device)\n", " mask[:, allowed] = logits[:, allowed]\n", " return mask\n", "\n", " @staticmethod\n", " def _delta_from_token_events(token_id: int, token_id_to_events: Any, default_step: float) -> float:\n", " if callable(token_id_to_events):\n", " try:\n", " events = token_id_to_events(int(token_id))\n", " if isinstance(events, str):\n", " events = [events]\n", " if isinstance(events, (list, tuple)):\n", " for ev in events:\n", " text = str(ev)\n", " if text.startswith(\"Delta_\"):\n", " return float(max(1e-4, float(text.split(\"_\", 1)[1])))\n", " except Exception:\n", " pass\n", " if 0 <= int(token_id) <= 127:\n", " if int(token_id) == 0:\n", " return 0.0\n", " bins = torch.logspace(\n", " math.log10(1e-4),\n", " math.log10(8.0),\n", " steps=127,\n", " dtype=torch.float32,\n", " )\n", " idx = max(0, min(126, int(token_id) - 1))\n", " return float(max(1e-4, bins[idx].item()))\n", " return float(max(1e-4, default_step))\n", "\n", " def forward(\n", " self,\n", " token_ids: torch.Tensor,\n", " onset_times: torch.Tensor,\n", " durations: Optional[torch.Tensor] = None,\n", " memory: Optional[Any] = None,\n", " return_memory: bool = False,\n", " position_offset: int = 0,\n", " ) -> Tuple[torch.Tensor, Optional[Any]] | torch.Tensor:\n", " del memory\n", "\n", " if token_ids.ndim != 2 or onset_times.ndim != 2:\n", " raise ValueError(\"token_ids and onset_times must be rank-2\")\n", " if token_ids.shape != onset_times.shape:\n", " raise ValueError(\"token_ids and onset_times must have same shape\")\n", " if durations is not None and durations.shape != token_ids.shape:\n", " raise ValueError(\"durations must match token_ids shape\")\n", "\n", " bsz, seq_len = token_ids.shape\n", " positions = torch.arange(\n", " int(max(0, position_offset)),\n", " int(max(0, position_offset)) + int(seq_len),\n", " device=token_ids.device,\n", " )\n", " positions = torch.clamp(positions, max=self.max_sequence_length - 1)\n", " positions = positions.unsqueeze(0).expand(bsz, -1)\n", "\n", " x = self.token_embedding(token_ids) + self.position_embedding(positions)\n", " if self.time_encoding is not None:\n", " x = x + self.time_encoding(onset_times)\n", " x = self.dropout(x)\n", "\n", " for layer in self.layers:\n", " x = layer(x, position_offset=int(max(0, position_offset)))\n", "\n", " logits = self.lm_head(self.final_norm(x)) * float(self.output_logit_scale)\n", " return (logits, None) if return_memory else logits\n", "\n", " @torch.no_grad()\n", " def generate(\n", " self,\n", " seed_tokens: Sequence[int] | torch.Tensor,\n", " max_new_tokens: int,\n", " temperature: float = 0.95,\n", " top_p: float = 0.92,\n", " top_k: int = 32,\n", " repetition_penalty: float = 1.15,\n", " repetition_window: int = 64,\n", " min_tokens_to_keep: int = 1,\n", " seed_onset_times: Sequence[float] | torch.Tensor | None = None,\n", " step_seconds: float = 0.1,\n", " token_id_to_events: Any = None,\n", " ) -> List[int]:\n", " self.eval()\n", " device = self._prepare_generation_device()\n", "\n", " tokens = self._to_seed_tensor(seed_tokens, device=device)\n", " if seed_onset_times is None:\n", " onsets = (\n", " torch.arange(tokens.shape[1], device=device, dtype=torch.float32)\n", " * float(max(1e-4, step_seconds))\n", " ).unsqueeze(0)\n", " else:\n", " if isinstance(seed_onset_times, torch.Tensor):\n", " on = seed_onset_times\n", " if on.ndim == 1:\n", " on = on.unsqueeze(0)\n", " onsets = on.to(device=device, dtype=torch.float32)\n", " else:\n", " onsets = torch.tensor([float(v) for v in seed_onset_times], dtype=torch.float32, device=device).unsqueeze(0)\n", "\n", " if onsets.shape != tokens.shape:\n", " raise ValueError(\"seed_onset_times shape must match seed token shape\")\n", "\n", " final_top1_probs: List[float] = []\n", " raw_top1_probs: List[float] = []\n", " candidate_counts: List[int] = []\n", "\n", " for _ in range(int(max_new_tokens)):\n", " context_tokens = tokens[:, -self.max_sequence_length:]\n", " context_onsets = onsets[:, -self.max_sequence_length:]\n", " context_offset = max(0, int(tokens.shape[1] - context_tokens.shape[1]))\n", "\n", " logits, _ = self.forward(\n", " token_ids=context_tokens,\n", " onset_times=context_onsets,\n", " memory=None,\n", " return_memory=True,\n", " position_offset=context_offset,\n", " )\n", "\n", " next_slot = self._triplet_slot(int(tokens.shape[1]))\n", " masked_logits = self._mask_logits_to_triplet_slot(logits[:, -1, :], next_slot)\n", " next_token, diagnostics = sample_next_token(\n", " logits=masked_logits,\n", " context_tokens=context_tokens,\n", " temperature=temperature,\n", " top_p=top_p,\n", " top_k=top_k,\n", " repetition_penalty=repetition_penalty,\n", " recent_window=repetition_window,\n", " min_tokens_to_keep=max(4, min_tokens_to_keep),\n", " )\n", "\n", " final_top1_probs.extend([float(v) for v in diagnostics.final_top1_prob.detach().cpu().tolist()])\n", " raw_top1_probs.extend([float(v) for v in diagnostics.raw_top1_prob.detach().cpu().tolist()])\n", " candidate_counts.extend([int(v) for v in diagnostics.candidate_count.detach().cpu().tolist()])\n", "\n", " tokens = torch.cat([tokens, next_token], dim=1)\n", " slot = self._triplet_slot(int(tokens.shape[1] - 1))\n", " tok = int(next_token.view(-1)[0].item())\n", " delta = float(max(1e-4, step_seconds))\n", " if slot == 0:\n", " delta = self._delta_from_token_events(tok, token_id_to_events, step_seconds)\n", " next_onset = onsets[:, -1:] + (delta if slot == 0 else 0.0)\n", " onsets = torch.cat([onsets, next_onset], dim=1)\n", "\n", " self.last_generation_stats = {\n", " \"steps\": int(max_new_tokens),\n", " \"mean_final_top1_prob\": float(sum(final_top1_probs) / max(1, len(final_top1_probs))),\n", " \"mean_raw_top1_prob\": float(sum(raw_top1_probs) / max(1, len(raw_top1_probs))),\n", " \"mean_candidate_count\": float(sum(candidate_counts) / max(1, len(candidate_counts))),\n", " }\n", " return [int(t) for t in tokens[0].tolist()]\n", "\n", "\n", "def _filtered_config_payload(payload: Dict[str, Any]) -> Dict[str, Any]:\n", " allowed = {field.name for field in fields(VariantEConfig)}\n", " return {key: value for key, value in dict(payload).items() if key in allowed}\n", "\n", "\n", "def _load_checkpoint_payload() -> Tuple[Dict[str, Any], Dict[str, Any]]:\n", " state_path = MODEL_DIR / \"latest_state.pt\"\n", " weights_path = MODEL_DIR / \"latest.safetensors\"\n", "\n", " if state_path.exists():\n", " state = torch.load(state_path, map_location=\"cpu\")\n", " if not isinstance(state, dict):\n", " raise RuntimeError(f\"Unsupported checkpoint state: {state_path}\")\n", " return dict(state.get(\"model_config\") or {}), dict(state.get(\"data_config\") or {})\n", "\n", " if weights_path.exists():\n", " with safetensors_safe_open(str(weights_path), framework=\"pt\", device=\"cpu\") as f:\n", " metadata = f.metadata() or {}\n", " model_raw = metadata.get(\"model_config\", \"{}\")\n", " data_raw = metadata.get(\"data_config\", \"{}\")\n", " model_payload = json.loads(model_raw) if isinstance(model_raw, str) and model_raw.strip() else {}\n", " data_payload = json.loads(data_raw) if isinstance(data_raw, str) and data_raw.strip() else {}\n", " return model_payload, data_payload\n", "\n", " raise FileNotFoundError(\"No checkpoint state or safetensors weights found in the model directory.\")\n", "\n", "\n", "def _infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n", " inferred: Dict[str, Any] = {}\n", "\n", " tok_emb = state_dict.get(\"token_embedding.weight\")\n", " if isinstance(tok_emb, torch.Tensor) and tok_emb.ndim == 2:\n", " inferred[\"vocab_size\"] = int(tok_emb.shape[0])\n", " inferred[\"d_model\"] = int(tok_emb.shape[1])\n", "\n", " pos_emb = state_dict.get(\"position_embedding.weight\")\n", " if isinstance(pos_emb, torch.Tensor) and pos_emb.ndim == 2:\n", " inferred[\"max_sequence_length\"] = int(pos_emb.shape[0])\n", "\n", " if \"time_encoding.output_scale\" in state_dict or \"time_encoding.projection.weight\" in state_dict:\n", " inferred[\"use_continuous_time\"] = True\n", "\n", " gdn_in = state_dict.get(\"layers.0.gdn_in_proj.weight\")\n", " if isinstance(gdn_in, torch.Tensor) and gdn_in.ndim == 2:\n", " inferred.setdefault(\"d_model\", int(gdn_in.shape[1]))\n", " if int(gdn_in.shape[1]) > 0:\n", " inferred[\"gdn_path_ratio\"] = float(gdn_in.shape[0] / gdn_in.shape[1])\n", "\n", " gdn_core_q = state_dict.get(\"layers.0.gdn.core.q_proj.weight\")\n", " if isinstance(gdn_core_q, torch.Tensor) and gdn_core_q.ndim == 2:\n", " inferred[\"gdn_inner_dim\"] = int(gdn_core_q.shape[0])\n", "\n", " gdn_a_log = state_dict.get(\"layers.0.gdn.core.A_log\")\n", " if isinstance(gdn_a_log, torch.Tensor):\n", " if gdn_a_log.ndim == 2:\n", " inferred[\"gdn_num_heads\"] = int(gdn_a_log.shape[0])\n", " inferred[\"gdn_inner_dim\"] = int(gdn_a_log.shape[0] * gdn_a_log.shape[1])\n", " elif gdn_a_log.ndim == 1:\n", " inferred[\"gdn_num_heads\"] = int(max(1, gdn_a_log.shape[0]))\n", "\n", " gqa_q = state_dict.get(\"layers.0.gqa.q_proj.weight\")\n", " gqa_k = state_dict.get(\"layers.0.gqa.k_proj.weight\")\n", " if (\n", " isinstance(gqa_q, torch.Tensor)\n", " and isinstance(gqa_k, torch.Tensor)\n", " and gqa_q.ndim == 2\n", " and gqa_k.ndim == 2\n", " and int(gqa_k.shape[0]) > 0\n", " ):\n", " inferred[\"gqa_groups\"] = int(max(1, int(gqa_q.shape[0] // gqa_k.shape[0])))\n", "\n", " layer_ids: List[int] = []\n", " attn_layer_ids: List[int] = []\n", " for key in state_dict.keys():\n", " if not key.startswith(\"layers.\"):\n", " continue\n", " parts = key.split(\".\")\n", " if len(parts) < 3 or not parts[1].isdigit():\n", " continue\n", " layer_id = int(parts[1])\n", " layer_ids.append(layer_id)\n", " if \".gqa.\" in key:\n", " attn_layer_ids.append(layer_id)\n", "\n", " if layer_ids:\n", " inferred[\"n_layers\"] = int(max(layer_ids) + 1)\n", "\n", " attn_unique = sorted(set(attn_layer_ids))\n", " if len(attn_unique) >= 2:\n", " diffs = [int(b - a) for a, b in zip(attn_unique[:-1], attn_unique[1:]) if int(b - a) > 0]\n", " if diffs:\n", " inferred[\"attention_every_n_layers\"] = int(max(1, min(diffs)))\n", "\n", " return inferred\n", "\n", "\n", "if \"MIDI_INPUT_FILES\" not in globals() or not MIDI_INPUT_FILES:\n", " raise RuntimeError(\"Run Cell 3 first so MIDI_INPUT_FILES is populated.\")\n", "\n", "model_payload, data_payload = _load_checkpoint_payload()\n", "weights_path = MODEL_DIR / \"latest.safetensors\"\n", "if not weights_path.exists():\n", " raise FileNotFoundError(f\"Model weights not found: {weights_path}\")\n", "\n", "state_dict = safetensors_load_file(str(weights_path), device=\"cpu\")\n", "state_dict = _strip_dataparallel_prefix(state_dict)\n", "inferred_payload = _infer_config_from_state_dict(state_dict)\n", "tokenizer_path = TOKENIZER_DIR / \"custom_tokenizer.json\"\n", "if not tokenizer_path.exists():\n", " tokenizer_path = TOKENIZER_DIR / \"tokenizer.json\"\n", "if not tokenizer_path.exists():\n", " raise FileNotFoundError(\"No tokenizer JSON found in the tokenizer directory.\")\n", "\n", "tokenizer = CustomDeltaTokenizer.load(str(tokenizer_path))\n", "cfg_payload = {**_filtered_config_payload(model_payload), **inferred_payload}\n", "if inferred_payload:\n", " print(f\"Inferred model config from weights: {inferred_payload}\")\n", "config = VariantEConfig(**cfg_payload)\n", "checkpoint_vocab_size = int(inferred_payload.get(\"vocab_size\", int(tokenizer.vocab_size)))\n", "if int(tokenizer.vocab_size) != checkpoint_vocab_size:\n", " raise RuntimeError(\n", " \"Tokenizer/checkpoint mismatch: \"\n", " f\"tokenizer vocab={int(tokenizer.vocab_size)} vs checkpoint vocab={checkpoint_vocab_size}. \"\n", " \"Use the tokenizer bundled with this checkpoint repo.\"\n", " )\n", "config = VariantEConfig(**{**config.__dict__, \"vocab_size\": checkpoint_vocab_size})\n", "model = VariantEModel(config)\n", "try:\n", " missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n", "except RuntimeError as exc:\n", " if not GDN_AVAILABLE:\n", " raise RuntimeError(\n", " \"Checkpoint/model mismatch detected and flash-linear-attention is unavailable. \"\n", " \"This Variant E checkpoint needs real GatedDeltaNet kernels. \"\n", " \"Install flash-linear-attention in Colab and rerun from Cell 2.\"\n", " ) from exc\n", " raise\n", "\n", "if missing_keys or unexpected_keys:\n", " missing_preview = \", \".join(missing_keys[:8]) if missing_keys else \"none\"\n", " unexpected_preview = \", \".join(unexpected_keys[:8]) if unexpected_keys else \"none\"\n", " if not GDN_AVAILABLE:\n", " raise RuntimeError(\n", " \"Checkpoint/model mismatch detected and flash-linear-attention is unavailable. \"\n", " \"This Variant E checkpoint needs real GatedDeltaNet kernels. \"\n", " \"Install flash-linear-attention in Colab and rerun from Cell 2. \"\n", " f\"missing={len(missing_keys)} ({missing_preview}) | \"\n", " f\"unexpected={len(unexpected_keys)} ({unexpected_preview})\"\n", " )\n", " raise RuntimeError(\n", " \"Checkpoint/model mismatch detected. \"\n", " f\"missing={len(missing_keys)} ({missing_preview}) | \"\n", " f\"unexpected={len(unexpected_keys)} ({unexpected_preview})\"\n", " )\n", "\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(DEVICE)\n", "model.eval()\n", "\n", "trained_seed_length = max(1, int(data_payload.get(\"seed_length\", 512) or 512))\n", "trained_continuation_length = int(data_payload.get(\"continuation_length\", 0) or 0)\n", "if trained_continuation_length <= 0:\n", " max_sequence_length = int(data_payload.get(\"max_sequence_length\", 0) or 0)\n", " trained_continuation_length = max(1, max_sequence_length - trained_seed_length) if max_sequence_length > 0 else 2048\n", "\n", "SEED_LENGTH = max(4, int(os.environ.get(\"IBP_SEED_LENGTH\", str(trained_seed_length))))\n", "MAX_NEW_TOKENS = max(1, int(os.environ.get(\"IBP_MAX_NEW_TOKENS\", str(trained_continuation_length))))\n", "TEMPERATURE = max(0.1, float(os.environ.get(\"IBP_TEMPERATURE\", \"0.90\")))\n", "TOP_P = min(1.0, max(0.0, float(os.environ.get(\"IBP_TOP_P\", \"0.95\"))))\n", "TOP_K = max(8, int(os.environ.get(\"IBP_TOP_K\", \"64\")))\n", "REPETITION_PENALTY = max(1.0, float(os.environ.get(\"IBP_REPETITION_PENALTY\", \"1.10\")))\n", "REPETITION_WINDOW = max(16, int(os.environ.get(\"IBP_REPETITION_WINDOW\", \"64\")))\n", "MIN_TOKENS_TO_KEEP = max(3, int(os.environ.get(\"IBP_MIN_TOKENS_TO_KEEP\", \"4\")))\n", "STEP_SECONDS = float(data_payload.get(\"time_feature_fallback_step_seconds\", 0.1) or 0.1)\n", "\n", "BATCH_OUTPUT_DIR = OUTPUT_DIR / \"batch_continuations\"\n", "BATCH_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "print(f\"Generation device: {DEVICE}\")\n", "print(f\"Tokenizer: {tokenizer_path.name} | vocab={tokenizer.vocab_size}\")\n", "print(f\"GDN kernel mode: {'flash-linear-attention' if GDN_AVAILABLE else 'fallback'}\")\n", "print(f\"Songs queued: {len(MIDI_INPUT_FILES)}\")\n", "print(f\"Seed length: {SEED_LENGTH} | Max new tokens: {MAX_NEW_TOKENS}\")\n", "\n", "\n", "def _safe_stem(path: Path) -> str:\n", " stem = re.sub(r\"[^A-Za-z0-9._-]+\", \"_\", path.stem).strip(\"._\")\n", " return stem or \"song\"\n", "\n", "\n", "def _trim_seed(token_ids: List[int], onset_times: List[float], seed_length: int, event_size: int) -> Tuple[List[int], List[float]]:\n", " take = min(len(token_ids), int(seed_length))\n", " aligned = take - (take % event_size)\n", " take = aligned if aligned > 0 else min(len(token_ids), event_size)\n", " return token_ids[-take:], onset_times[-take:]\n", "\n", "\n", "BATCH_RESULTS: List[Dict[str, Any]] = []\n", "total_files = len(MIDI_INPUT_FILES)\n", "batch_start = time.time()\n", "\n", "for index, midi_path in enumerate(MIDI_INPUT_FILES, start=1):\n", " song_start = time.time()\n", " safe_name = _safe_stem(midi_path)\n", " prefix = f\"{index:03d}_{safe_name}\"\n", "\n", " out_midi = BATCH_OUTPUT_DIR / f\"{prefix}_continuation.mid\"\n", " seed_audio = BATCH_OUTPUT_DIR / f\"{prefix}_seed.wav\"\n", " out_audio = BATCH_OUTPUT_DIR / f\"{prefix}_continuation.wav\"\n", " compare_png = BATCH_OUTPUT_DIR / f\"{prefix}_comparison.png\"\n", "\n", " print(f\"\\n[{index}/{total_files}] Starting: {midi_path.name}\")\n", " try:\n", " seed_tokens, seed_onsets, _ = tokenizer.encode_with_time_features(midi_path)\n", " if not seed_tokens:\n", " raise RuntimeError(\"Tokenizer produced no seed tokens for this MIDI file.\")\n", "\n", " seed_tokens, seed_onsets = _trim_seed(\n", " seed_tokens,\n", " seed_onsets,\n", " seed_length=SEED_LENGTH,\n", " event_size=max(1, int(getattr(tokenizer, \"event_size\", 4) or 4)),\n", " )\n", "\n", " with torch.inference_mode():\n", " generated_tokens = model.generate(\n", " seed_tokens=seed_tokens,\n", " max_new_tokens=MAX_NEW_TOKENS,\n", " temperature=TEMPERATURE,\n", " top_p=TOP_P,\n", " top_k=TOP_K,\n", " repetition_penalty=REPETITION_PENALTY,\n", " repetition_window=REPETITION_WINDOW,\n", " min_tokens_to_keep=MIN_TOKENS_TO_KEEP,\n", " seed_onset_times=seed_onsets,\n", " step_seconds=STEP_SECONDS,\n", " token_id_to_events=tokenizer.decode_token_id_events,\n", " )\n", "\n", " tokenizer.decode(generated_tokens, out_midi)\n", " render_midi_audio(midi_path, seed_audio)\n", " render_midi_audio(out_midi, out_audio)\n", " compare_pianorolls(midi_path, out_midi, save_path=compare_png)\n", "\n", " song_seconds = float(time.time() - song_start)\n", " stats = dict(getattr(model, \"last_generation_stats\", {}))\n", "\n", " print(f\"[{index}/{total_files}] Finished: {midi_path.name} ({song_seconds:.1f}s)\")\n", " print(f\" output midi: {out_midi}\")\n", " print(f\" seed audio: {seed_audio}\")\n", " print(f\" continuation audio: {out_audio}\")\n", " print(f\" comparison png: {compare_png}\")\n", " print(f\" generation stats: {stats}\")\n", "\n", " if display is not None and Audio is not None and Image is not None:\n", " display(Audio(filename=str(seed_audio)))\n", " display(Audio(filename=str(out_audio)))\n", " display(Image(filename=str(compare_png)))\n", " else:\n", " print(\" IPython display unavailable; skipping inline audio/image display.\")\n", "\n", " BATCH_RESULTS.append(\n", " {\n", " \"index\": index,\n", " \"input_midi\": str(midi_path),\n", " \"output_midi\": str(out_midi),\n", " \"seed_audio\": str(seed_audio),\n", " \"continuation_audio\": str(out_audio),\n", " \"comparison_png\": str(compare_png),\n", " \"status\": \"ok\",\n", " \"elapsed_seconds\": round(song_seconds, 3),\n", " \"generated_tokens\": int(len(generated_tokens)),\n", " \"stats\": stats,\n", " }\n", " )\n", "\n", " except Exception as exc:\n", " song_seconds = float(time.time() - song_start)\n", " print(f\"[{index}/{total_files}] FAILED: {midi_path.name} ({song_seconds:.1f}s)\")\n", " print(f\" error: {exc}\")\n", " BATCH_RESULTS.append(\n", " {\n", " \"index\": index,\n", " \"input_midi\": str(midi_path),\n", " \"status\": \"error\",\n", " \"elapsed_seconds\": round(song_seconds, 3),\n", " \"error\": str(exc),\n", " }\n", " )\n", "\n", "\n", "success_count = sum(1 for item in BATCH_RESULTS if item.get(\"status\") == \"ok\")\n", "fail_count = len(BATCH_RESULTS) - success_count\n", "elapsed_total = float(time.time() - batch_start)\n", "\n", "print(\"\\nBatch continuation complete.\")\n", "print(f\"Succeeded: {success_count} | Failed: {fail_count} | Total: {len(BATCH_RESULTS)}\")\n", "print(f\"Elapsed: {elapsed_total:.1f}s\")\n", "print(f\"Outputs: {BATCH_OUTPUT_DIR}\")\n", "\n", "if fail_count:\n", " print(\"Failed files:\")\n", " for item in BATCH_RESULTS:\n", " if item.get(\"status\") != \"ok\":\n", " print(f\" - {Path(item['input_midi']).name}: {item.get('error', 'unknown error')}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a1d6dc19", "metadata": { "id": "a1d6dc19" }, "outputs": [], "source": [ "import json\n", "import zipfile\n", "from pathlib import Path\n", "\n", "if \"BATCH_RESULTS\" not in globals():\n", " raise RuntimeError(\"Run Cell 6 first to generate continuations.\")\n", "\n", "batch_dir = OUTPUT_DIR / \"batch_continuations\"\n", "batch_dir.mkdir(parents=True, exist_ok=True)\n", "\n", "manifest_path = batch_dir / \"batch_manifest.json\"\n", "manifest_path.write_text(json.dumps(BATCH_RESULTS, indent=2), encoding=\"utf-8\")\n", "\n", "ok_items = [item for item in BATCH_RESULTS if item.get(\"status\") == \"ok\"]\n", "err_items = [item for item in BATCH_RESULTS if item.get(\"status\") != \"ok\"]\n", "\n", "print(f\"Manifest written: {manifest_path}\")\n", "print(f\"Successful songs: {len(ok_items)}\")\n", "print(f\"Failed songs: {len(err_items)}\")\n", "\n", "if err_items:\n", " print(\"Failures:\")\n", " for item in err_items:\n", " print(f\" - {Path(item['input_midi']).name}: {item.get('error', 'unknown error')}\")\n", "\n", "ZIP_OUTPUTS = False\n", "if ZIP_OUTPUTS:\n", " zip_path = OUTPUT_DIR / \"batch_continuations.zip\"\n", " with zipfile.ZipFile(zip_path, \"w\", compression=zipfile.ZIP_DEFLATED) as zf:\n", " for file_path in sorted(batch_dir.rglob(\"*\")):\n", " if file_path.is_file():\n", " zf.write(file_path, arcname=str(file_path.relative_to(batch_dir)))\n", " print(f\"Created zip: {zip_path}\")" ] } ], "metadata": { "language_info": { "name": "python" }, "colab": { "provenance": [], "gpuType": "T4" }, "accelerator": "GPU", "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "nbformat": 4, "nbformat_minor": 5 }