matthewliu0302 commited on
Commit
bfc097c
·
1 Parent(s): c7a5b5f

update chute_config instasll

Browse files
Files changed (2) hide show
  1. chute_config.yml +1 -1
  2. vocence_local_wrapper.py +133 -0
chute_config.yml CHANGED
@@ -4,7 +4,7 @@
4
  Image:
5
  from_base: parachutes/base-python:3.12.9
6
  run_command:
7
- - pip install torch torchaudio transformers accelerate huggingface_hub pyyaml soundfile
8
  set_workdir: /app
9
 
10
  NodeSelector:
 
4
  Image:
5
  from_base: parachutes/base-python:3.12.9
6
  run_command:
7
+ - pip install torch torchaudio transformers accelerate huggingface_hub pyyaml soundfile snac
8
  set_workdir: /app
9
 
10
  NodeSelector:
vocence_local_wrapper.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Local Vocence wrapper for testing miner.py without Chutes.
4
+
5
+ Run:
6
+ python vocence_local_wrapper.py
7
+
8
+ Then call:
9
+ GET http://127.0.0.1:8000/health
10
+ POST http://127.0.0.1:8000/speak
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import io
15
+ import wave
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import uvicorn
21
+ from fastapi import FastAPI, HTTPException, status
22
+ from fastapi.responses import Response
23
+ from pydantic import BaseModel, Field
24
+ from yaml import safe_load
25
+
26
+ from miner import Miner
27
+
28
+ VOCENCE_MAX_AUDIO_SECONDS = 30
29
+ VOCENCE_MAX_TEXT_LEN = 2000
30
+ VOCENCE_MAX_INSTRUCTION_LEN = 600
31
+
32
+
33
+ class VocenceSpeakRequest(BaseModel):
34
+ instruction: str = Field(..., min_length=1, max_length=VOCENCE_MAX_INSTRUCTION_LEN)
35
+ text: str = Field(..., min_length=1, max_length=VOCENCE_MAX_TEXT_LEN)
36
+
37
+
38
+ class VocenceHealthResponse(BaseModel):
39
+ status: str
40
+ model_loaded: bool
41
+ sample_rate: Optional[int] = None
42
+ adapter: Optional[str] = None
43
+ repo_path: str
44
+
45
+
46
+ def waveform_to_wav_bytes(waveform: np.ndarray, sample_rate: int) -> bytes:
47
+ if waveform.ndim != 1:
48
+ raise ValueError("waveform must be 1D mono")
49
+
50
+ if waveform.dtype != np.int16:
51
+ wf = np.asarray(waveform, dtype=np.float32)
52
+ wf = np.clip(wf, -1.0, 1.0)
53
+ wf = (wf * 32767.0).astype(np.int16)
54
+ else:
55
+ wf = waveform
56
+
57
+ buf = io.BytesIO()
58
+ with wave.open(buf, "wb") as wav:
59
+ wav.setnchannels(1)
60
+ wav.setsampwidth(2)
61
+ wav.setframerate(sample_rate)
62
+ wav.writeframes(wf.tobytes())
63
+ return buf.getvalue()
64
+
65
+
66
+ repo_path = Path(__file__).resolve().parent
67
+ app = FastAPI(title="Vocence Local Wrapper", version="0.1.0")
68
+
69
+
70
+ @app.on_event("startup")
71
+ async def startup_event() -> None:
72
+ app.state.status = "unknown"
73
+ app.state.sample_rate = None
74
+ app.state.adapter = None
75
+ app.state.tts_engine = None
76
+
77
+ try:
78
+ app.state.tts_engine = Miner(repo_path)
79
+ app.state.tts_engine.warmup()
80
+
81
+ vocence_yaml = repo_path / "vocence_config.yaml"
82
+ if vocence_yaml.exists():
83
+ with vocence_yaml.open("r", encoding="utf-8") as f:
84
+ cfg = safe_load(f) or {}
85
+ app.state.sample_rate = int(cfg.get("generation", {}).get("sample_rate", 24000))
86
+ app.state.adapter = str(cfg.get("runtime", {}).get("adapter", "unknown"))
87
+ else:
88
+ app.state.sample_rate = 24000
89
+ app.state.adapter = "unknown"
90
+
91
+ app.state.status = "healthy"
92
+ except Exception as exc:
93
+ app.state.status = f"startup_failed: {exc}"
94
+ app.state.tts_engine = None
95
+
96
+
97
+ @app.get("/health")
98
+ async def health() -> dict:
99
+ return VocenceHealthResponse(
100
+ status=getattr(app.state, "status", "unknown"),
101
+ model_loaded=getattr(app.state, "tts_engine", None) is not None,
102
+ sample_rate=getattr(app.state, "sample_rate", None),
103
+ adapter=getattr(app.state, "adapter", None),
104
+ repo_path=str(repo_path),
105
+ ).model_dump()
106
+
107
+
108
+ @app.post("/speak", response_class=Response)
109
+ async def speak(args: VocenceSpeakRequest):
110
+ engine = getattr(app.state, "tts_engine", None)
111
+ if engine is None:
112
+ raise HTTPException(
113
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
114
+ detail="TTS engine not loaded",
115
+ )
116
+
117
+ waveform, sample_rate = engine.generate_wav(instruction=args.instruction, text=args.text)
118
+ waveform = np.asarray(waveform)
119
+ if waveform.ndim != 1 or waveform.size == 0:
120
+ raise HTTPException(status_code=400, detail="invalid waveform")
121
+
122
+ duration_sec = float(waveform.shape[0]) / float(sample_rate)
123
+ if duration_sec <= 0 or duration_sec > VOCENCE_MAX_AUDIO_SECONDS:
124
+ raise HTTPException(status_code=400, detail="invalid duration")
125
+
126
+ return Response(
127
+ content=waveform_to_wav_bytes(waveform, sample_rate),
128
+ media_type="audio/wav",
129
+ )
130
+
131
+
132
+ if __name__ == "__main__":
133
+ uvicorn.run("vocence_local_wrapper:app", host="127.0.0.1", port=8000, reload=False)