trysem commited on
Commit
4d2241c
·
verified ·
1 Parent(s): 439306b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -34
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import subprocess
2
  import sys
3
  import os
4
- import shutil
5
  from huggingface_hub import hf_hub_download
6
 
7
- repo_id = "Praha-Labs/PrahaTTS-ML"
8
-
9
- # --- 1. PRE-FLIGHT: INSTALLS AND MEMORY PATCHING ---
10
  def pre_flight_setup():
11
  try:
12
  import chatterbox
@@ -16,49 +14,24 @@ def pre_flight_setup():
16
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"])
17
  subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"])
18
 
19
- print("Downloading and applying Indic patch...")
20
- config_path = hf_hub_download(repo_id=repo_id, filename="config_indic.py")
21
- shutil.copy(config_path, "config_indic.py")
22
-
23
- # CRITICAL FIX: Purge chatterbox from Python's memory cache!
24
- # This forces the library to completely reload *after* our patch is applied.
25
- modules_to_remove = [mod for mod in sys.modules if mod.startswith("chatterbox")]
26
- for mod in modules_to_remove:
27
- del sys.modules[mod]
28
-
29
- import config_indic
30
- if hasattr(config_indic, 'apply_config'):
31
- try:
32
- config_indic.apply_config()
33
- except TypeError:
34
- pass
35
-
36
- # Run the setup immediately
37
  pre_flight_setup()
38
 
39
- # --- 2. NOW WE IMPORT CHATTERBOX (CLEAN RELOAD) ---
40
  import gradio as gr
41
  import torch
 
42
  import torchaudio as ta
43
  from peft import PeftModel
44
- import tempfile
45
 
46
  from chatterbox.tts import ChatterboxTTS
47
  from chatterbox.models.tokenizers import EnTokenizer
48
 
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
50
 
51
  def load_model():
52
  print(f"Loading base Chatterbox model on {device}...")
53
  model = ChatterboxTTS.from_pretrained(device=device)
54
-
55
- import config_indic
56
- # Sometimes the config requires the model object itself to resize layers
57
- if hasattr(config_indic, 'apply_config'):
58
- try:
59
- config_indic.apply_config(model)
60
- except TypeError:
61
- pass
62
 
63
  print("Applying custom Indic tokenizer...")
64
  try:
@@ -67,10 +40,32 @@ def load_model():
67
  except Exception as e:
68
  print(f"Error during tokenizer inject: {e}")
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  print("Loading LoRA adapter weights...")
71
  try:
72
- # We wrap the ENTIRE model now, ensuring any saved embedding layers are swapped
73
- model = PeftModel.from_pretrained(model, repo_id)
 
 
74
  print("LoRA adapter loaded successfully.")
75
  except Exception as e:
76
  print(f"Failed to load PEFT adapter: {e}")
 
1
  import subprocess
2
  import sys
3
  import os
4
+ import tempfile
5
  from huggingface_hub import hf_hub_download
6
 
7
+ # --- 1. PRE-FLIGHT: BYPASS BUILD ISOLATION ---
 
 
8
  def pre_flight_setup():
9
  try:
10
  import chatterbox
 
14
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-build-isolation", "pkuseg==0.0.25"])
15
  subprocess.check_call([sys.executable, "-m", "pip", "install", "chatterbox-tts>=0.1.7"])
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  pre_flight_setup()
18
 
19
+ # --- 2. MAIN APPLICATION ---
20
  import gradio as gr
21
  import torch
22
+ import torch.nn as nn
23
  import torchaudio as ta
24
  from peft import PeftModel
 
25
 
26
  from chatterbox.tts import ChatterboxTTS
27
  from chatterbox.models.tokenizers import EnTokenizer
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ repo_id = "Praha-Labs/PrahaTTS-ML"
31
 
32
  def load_model():
33
  print(f"Loading base Chatterbox model on {device}...")
34
  model = ChatterboxTTS.from_pretrained(device=device)
 
 
 
 
 
 
 
 
35
 
36
  print("Applying custom Indic tokenizer...")
37
  try:
 
40
  except Exception as e:
41
  print(f"Error during tokenizer inject: {e}")
42
 
43
+ # --- CRITICAL FIX: MANUALLY RESIZE PYTORCH EMBEDDINGS ---
44
+ # We must resize the base model's vocabulary layers to match the new
45
+ # Malayalam vocab size (2573) before loading the adapter weights.
46
+ vocab_size = 2573
47
+ print(f"Resizing base model embeddings to handle vocab size of {vocab_size}...")
48
+
49
+ target_layer = model.t3 if hasattr(model, 't3') else model
50
+
51
+ if hasattr(target_layer, 'text_emb'):
52
+ embed_dim = target_layer.text_emb.embedding_dim
53
+ target_layer.text_emb = nn.Embedding(vocab_size, embed_dim)
54
+
55
+ if hasattr(target_layer, 'text_head'):
56
+ in_features = target_layer.text_head.in_features
57
+ has_bias = target_layer.text_head.bias is not None
58
+ target_layer.text_head = nn.Linear(in_features, vocab_size, bias=has_bias)
59
+
60
+ # Send resized layers to the correct device
61
+ target_layer.to(device)
62
+
63
  print("Loading LoRA adapter weights...")
64
  try:
65
+ if hasattr(model, 't3'):
66
+ model.t3 = PeftModel.from_pretrained(model.t3, repo_id)
67
+ else:
68
+ model = PeftModel.from_pretrained(model, repo_id)
69
  print("LoRA adapter loaded successfully.")
70
  except Exception as e:
71
  print(f"Failed to load PEFT adapter: {e}")