Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,30 +3,43 @@ import torch
|
|
| 3 |
import torchaudio
|
| 4 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
| 5 |
|
|
|
|
| 6 |
model_id = "Wiam/baby-cry-classification-finetuned-babycry-v4"
|
| 7 |
model = AutoModelForAudioClassification.from_pretrained(model_id)
|
| 8 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
| 9 |
|
| 10 |
def classify_baby_cry(audio_file):
|
|
|
|
| 11 |
waveform, sample_rate = torchaudio.load(audio_file)
|
| 12 |
-
|
| 13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if waveform.shape[0] > 1:
|
| 15 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 16 |
-
|
|
|
|
| 17 |
inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt")
|
|
|
|
|
|
|
| 18 |
with torch.no_grad():
|
| 19 |
outputs = model(**inputs)
|
| 20 |
probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)
|
| 21 |
|
|
|
|
| 22 |
labels = model.config.id2label
|
| 23 |
results = {labels[i]: float(probs[i]) for i in range(len(labels))}
|
| 24 |
return results
|
| 25 |
|
|
|
|
| 26 |
gr.Interface(
|
| 27 |
fn=classify_baby_cry,
|
| 28 |
inputs=gr.Audio(type="filepath"),
|
| 29 |
outputs=gr.Label(num_top_classes=3),
|
| 30 |
title="讝讬讛讜讬 讘讻讬 转讬谞讜拽讜转",
|
| 31 |
-
description="讛诪注专讻转 诪讗讝讬谞讛 诇拽讜讘抓 拽讜诇 讜诪讞讝讬专讛
|
| 32 |
).launch()
|
|
|
|
| 3 |
import torchaudio
|
| 4 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
| 5 |
|
| 6 |
+
# 讛讙讚专转 讛诪讜讚诇
|
| 7 |
model_id = "Wiam/baby-cry-classification-finetuned-babycry-v4"
|
| 8 |
model = AutoModelForAudioClassification.from_pretrained(model_id)
|
| 9 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
| 10 |
|
| 11 |
def classify_baby_cry(audio_file):
|
| 12 |
+
# 讟注谉 讗转 拽讜讘抓 讛住讗讜谞讚
|
| 13 |
waveform, sample_rate = torchaudio.load(audio_file)
|
| 14 |
+
|
| 15 |
+
# 讛诪专讛 诇志16kHz 讻驻讬 砖讛诪讜讚诇 讚讜专砖
|
| 16 |
+
if sample_rate != 16000:
|
| 17 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
|
| 18 |
+
waveform = resampler(waveform)
|
| 19 |
+
sample_rate = 16000
|
| 20 |
+
|
| 21 |
+
# 讗诐 讛讗讜讚讬讜 讘住讟专讗讜, 谞讗讞讚 诇注专讜抓 讗讞讚
|
| 22 |
if waveform.shape[0] > 1:
|
| 23 |
waveform = waveform.mean(dim=0, keepdim=True)
|
| 24 |
+
|
| 25 |
+
# 讛诪专转 讛讗讜讚讬讜 诇讻谞讬住转 诪讜讚诇
|
| 26 |
inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt")
|
| 27 |
+
|
| 28 |
+
# 讞讬讝讜讬
|
| 29 |
with torch.no_grad():
|
| 30 |
outputs = model(**inputs)
|
| 31 |
probs = torch.nn.functional.softmax(outputs.logits[0], dim=0)
|
| 32 |
|
| 33 |
+
# 诪讬驻讜讬 转讜爪讗讜转
|
| 34 |
labels = model.config.id2label
|
| 35 |
results = {labels[i]: float(probs[i]) for i in range(len(labels))}
|
| 36 |
return results
|
| 37 |
|
| 38 |
+
# 诪诪砖拽 讙专驻讬拽讛
|
| 39 |
gr.Interface(
|
| 40 |
fn=classify_baby_cry,
|
| 41 |
inputs=gr.Audio(type="filepath"),
|
| 42 |
outputs=gr.Label(num_top_classes=3),
|
| 43 |
title="讝讬讛讜讬 讘讻讬 转讬谞讜拽讜转",
|
| 44 |
+
description="讛诪注专讻转 诪讗讝讬谞讛 诇拽讜讘抓 拽讜诇 砖诇 转讬谞讜拽 讜诪讞讝讬专讛 讗转 住讜讙 讛讘讻讬 诇驻讬 讛诪讜讚诇"
|
| 45 |
).launch()
|