Demo Update 10
This commit is contained in:
@@ -93,15 +93,49 @@ def load_speech_models():
|
|||||||
# Load Whisper model for speech recognition
|
# Load Whisper model for speech recognition
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading speech recognition model on {device}...")
|
logger.info(f"Loading speech recognition model on {device}...")
|
||||||
speech_recognizer = pipeline("automatic-speech-recognition",
|
|
||||||
|
# Try with newer API first
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||||
|
|
||||||
|
model_id = "openai/whisper-small"
|
||||||
|
|
||||||
|
# Load model and processor
|
||||||
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||||
|
device_map=device,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Create pipeline with specific parameters
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=model,
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
|
feature_extractor=processor.feature_extractor,
|
||||||
|
max_new_tokens=128,
|
||||||
|
chunk_length_s=30,
|
||||||
|
batch_size=16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as api_error:
|
||||||
|
logger.warning(f"Newer API loading failed: {api_error}, trying simpler approach")
|
||||||
|
|
||||||
|
# Fallback to simpler API
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
model="openai/whisper-small",
|
model="openai/whisper-small",
|
||||||
device=device)
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Speech recognition model loaded successfully")
|
logger.info("Speech recognition model loaded successfully")
|
||||||
|
return generator, speech_recognizer
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading speech recognition model: {e}")
|
logger.error(f"Error loading speech recognition model: {e}")
|
||||||
speech_recognizer = None
|
return generator, None
|
||||||
|
|
||||||
return generator, speech_recognizer
|
|
||||||
|
|
||||||
# Unpack both models
|
# Unpack both models
|
||||||
generator, speech_recognizer = load_speech_models()
|
generator, speech_recognizer = load_speech_models()
|
||||||
@@ -308,8 +342,27 @@ def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
|
|||||||
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
|
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
|
||||||
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
|
|
||||||
# Perform speech recognition - using input_features instead of inputs
|
# Perform speech recognition - handle the warning differently
|
||||||
result = speech_recognizer(temp_path, input_features=None) # input_features=None forces use of the correct parameter name
|
# Just pass the path without any additional parameters
|
||||||
|
try:
|
||||||
|
# First try - use default parameters
|
||||||
|
result = speech_recognizer(temp_path)
|
||||||
|
transcription = result["text"]
|
||||||
|
except Exception as whisper_error:
|
||||||
|
logger.warning(f"First transcription attempt failed: {whisper_error}")
|
||||||
|
# Try with explicit parameters for older versions of transformers
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
# Load audio as numpy array
|
||||||
|
audio_np, sr = sf.read(temp_path)
|
||||||
|
if sr != 16000:
|
||||||
|
# Whisper expects 16kHz audio
|
||||||
|
from scipy import signal
|
||||||
|
audio_np = signal.resample(audio_np, int(len(audio_np) * 16000 / sr))
|
||||||
|
|
||||||
|
# Try with numpy array directly
|
||||||
|
result = speech_recognizer(audio_np)
|
||||||
transcription = result["text"]
|
transcription = result["text"]
|
||||||
|
|
||||||
# Clean up temp file
|
# Clean up temp file
|
||||||
@@ -320,6 +373,7 @@ def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
|
|||||||
if not transcription or transcription.isspace():
|
if not transcription or transcription.isspace():
|
||||||
return "I didn't detect any speech. Could you please try again?"
|
return "I didn't detect any speech. Could you please try again?"
|
||||||
|
|
||||||
|
logger.info(f"Transcription successful: '{transcription}'")
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user