Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import pipeline | |
| import torchaudio | |
| import os | |
| import re | |
| import numpy as np | |
| # ----------------------------- | |
| # 1) Model loading and utility functions | |
| # ----------------------------- | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Whisper model for Cantonese ASR | |
| MODEL_NAME = "alvanlii/whisper-small-cantonese" | |
| language = "zh" | |
| asr_pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME, | |
| chunk_length_s=30, # Adjust chunk size for memory handling | |
| device=device, | |
| generate_kwargs={ | |
| "no_repeat_ngram_size": 3, | |
| "repetition_penalty": 1.15, | |
| "temperature": 0.7, | |
| "top_p": 0.97, | |
| "top_k": 40, | |
| "max_new_tokens": 400, | |
| "do_sample": True | |
| } | |
| ) | |
| asr_pipe.model.config.forced_decoder_ids = asr_pipe.tokenizer.get_decoder_prompt_ids( | |
| language=language, task="transcribe" | |
| ) | |
| # Remove repeated sentences that are highly similar | |
| def remove_repeated_phrases(text): | |
| def is_similar(a, b): | |
| from difflib import SequenceMatcher | |
| return SequenceMatcher(None, a, b).ratio() > 0.9 | |
| sentences = re.split(r'(?<=[ใ๏ผ๏ผ])', text) | |
| cleaned_sentences = [] | |
| for sentence in sentences: | |
| if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()): | |
| cleaned_sentences.append(sentence.strip()) | |
| return " ".join(cleaned_sentences) | |
| # Remove punctuation from text | |
| def remove_punctuation(text): | |
| return re.sub(r'[^\w\s]', '', text) | |
| # Transcribe the audio using Whisper | |
| def transcribe_audio(audio_path): | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Convert multi-channel audio to mono if necessary | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| waveform = waveform.squeeze(0).numpy() | |
| duration = waveform.shape[0] / sample_rate | |
| # For audio longer than 60 seconds, process in overlapping chunks | |
| if duration > 60: | |
| chunk_size = sample_rate * 55 | |
| step_size = sample_rate * 50 | |
| results = [] | |
| for start in range(0, waveform.shape[0], step_size): | |
| chunk = waveform[start:start + chunk_size] | |
| if chunk.shape[0] == 0: | |
| break | |
| transcript = asr_pipe({"sampling_rate": sample_rate, "raw": chunk})["text"] | |
| results.append(remove_punctuation(transcript)) | |
| return remove_punctuation(remove_repeated_phrases(" ".join(results))) | |
| else: | |
| transcript = asr_pipe({"sampling_rate": sample_rate, "raw": waveform})["text"] | |
| return remove_punctuation(remove_repeated_phrases(transcript)) | |
| # Load sentiment analysis model | |
| sentiment_pipe = pipeline( | |
| "text-classification", | |
| model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced", | |
| device=device | |
| ) | |
| # Perform sentiment analysis in chunks (max 512 tokens each) | |
| def rate_quality(text): | |
| chunks = [text[i:i+512] for i in range(0, len(text), 512)] | |
| results = sentiment_pipe(chunks, batch_size=4) | |
| label_map = { | |
| "Very Negative": "Very Poor", | |
| "Negative": "Poor", | |
| "Neutral": "Neutral", | |
| "Positive": "Good", | |
| "Very Positive": "Very Good" | |
| } | |
| processed_results = [label_map.get(res["label"], "Unknown") for res in results] | |
| # Use majority voting to determine the final sentiment | |
| return max(set(processed_results), key=processed_results.count) | |
| # ----------------------------- | |
| # 2) Main Streamlit application | |
| # ----------------------------- | |
| def main(): | |
| st.set_page_config(page_title="Customer Service Analyzer", page_icon="๐๏ธ") | |
| # Custom CSS styling | |
| st.markdown(""" | |
| <style> | |
| .header { | |
| background: linear-gradient(90deg, #4B79A1, #283E51); | |
| border-radius: 10px; | |
| padding: 1.5rem; | |
| text-align: center; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| margin-bottom: 1.5rem; | |
| color: white; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown(""" | |
| <div class="header"> | |
| <h1 style='margin:0;'>๐๏ธ Customer Service Quality Analyzer</h1> | |
| <p>Evaluate the service quality with simple uploading!</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state to store results | |
| if "transcript" not in st.session_state: | |
| st.session_state["transcript"] = "" | |
| if "quality_rating" not in st.session_state: | |
| st.session_state["quality_rating"] = "" | |
| if "uploaded_filename" not in st.session_state: | |
| st.session_state["uploaded_filename"] = "" | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "๐ค Please upload your Cantonese customer service audio file", | |
| type=["wav", "mp3", "flac"] | |
| ) | |
| if uploaded_file is not None: | |
| # Display audio player | |
| st.audio(uploaded_file, format="audio/wav") | |
| # Only run the model again if a new file is uploaded | |
| if st.session_state["uploaded_filename"] != uploaded_file.name: | |
| st.session_state["uploaded_filename"] = uploaded_file.name | |
| # Save uploaded file to a temporary path | |
| temp_audio_path = "uploaded_audio.wav" | |
| with open(temp_audio_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # Process the audio | |
| with st.spinner('๐ Processing your audio, please wait...'): | |
| transcript = transcribe_audio(temp_audio_path) | |
| quality_rating = rate_quality(transcript) | |
| # Store results in session state | |
| st.session_state["transcript"] = transcript | |
| st.session_state["quality_rating"] = quality_rating | |
| # Remove the temporary file | |
| if os.path.exists(temp_audio_path): | |
| os.remove(temp_audio_path) | |
| # Display results if available | |
| if st.session_state["transcript"]: | |
| st.write("**Transcript:**", st.session_state["transcript"]) | |
| st.write("**Sentiment Analysis Result:**", st.session_state["quality_rating"]) | |
| # Prepare download content | |
| result_text = ( | |
| f"Transcript:\n{st.session_state['transcript']}\n\n" | |
| f"Sentiment Analysis Result: {st.session_state['quality_rating']}" | |
| ) | |
| # Download button for the analysis report | |
| st.download_button( | |
| label="๐ฅ Download Analysis Report", | |
| data=result_text, | |
| file_name="analysis_report.txt" | |
| ) | |
| st.markdown( | |
| "โIf you encounter any issues, please contact customer support: " | |
| "๐ง **example@hellotoby.com**" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |