Lots of Client Testing Changes
This commit is contained in:
@@ -12,6 +12,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from generator import load_csm_1b, Segment
|
||||
import uvicorn
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
# Select device
|
||||
if torch.cuda.is_available():
|
||||
@@ -48,6 +50,9 @@ class ConnectionManager:
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Silence detection parameters
|
||||
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
|
||||
SILENCE_DURATION_SEC = 1.0 # How long silence must persist to be considered "stopped talking"
|
||||
|
||||
# Helper function to convert audio data
|
||||
async def decode_audio_data(audio_data: str) -> torch.Tensor:
|
||||
@@ -92,6 +97,13 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
context_segments = [] # Store conversation context
|
||||
streaming_buffer = [] # Buffer for streaming audio chunks
|
||||
is_streaming = False
|
||||
|
||||
# Variables for silence detection
|
||||
last_active_time = time.time()
|
||||
is_silence = False
|
||||
energy_window = deque(maxlen=10) # For tracking recent audio energy
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -160,6 +172,114 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
"type": "context_updated",
|
||||
"message": "Context cleared"
|
||||
})
|
||||
|
||||
elif action == "stream_audio":
|
||||
try:
|
||||
speaker_id = request.get("speaker", 0)
|
||||
audio_data = request.get("audio", "")
|
||||
|
||||
# Convert received audio to tensor
|
||||
audio_chunk = await decode_audio_data(audio_data)
|
||||
|
||||
# Start streaming mode if not already started
|
||||
if not is_streaming:
|
||||
is_streaming = True
|
||||
streaming_buffer = []
|
||||
energy_window.clear()
|
||||
is_silence = False
|
||||
last_active_time = time.time()
|
||||
await websocket.send_json({
|
||||
"type": "streaming_status",
|
||||
"status": "started"
|
||||
})
|
||||
|
||||
# Calculate audio energy for silence detection
|
||||
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
|
||||
energy_window.append(chunk_energy)
|
||||
avg_energy = sum(energy_window) / len(energy_window)
|
||||
|
||||
# Check if audio is silent
|
||||
current_silence = avg_energy < SILENCE_THRESHOLD
|
||||
|
||||
# Track silence transition
|
||||
if not is_silence and current_silence:
|
||||
# Transition to silence
|
||||
is_silence = True
|
||||
last_active_time = time.time()
|
||||
elif is_silence and not current_silence:
|
||||
# User started talking again
|
||||
is_silence = False
|
||||
|
||||
# Add chunk to buffer regardless of silence state
|
||||
streaming_buffer.append(audio_chunk)
|
||||
|
||||
# Check if silence has persisted long enough to consider "stopped talking"
|
||||
silence_elapsed = time.time() - last_active_time
|
||||
|
||||
if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0:
|
||||
# User has stopped talking - process the collected audio
|
||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||
|
||||
# Process with speech-to-text (you would need to implement this)
|
||||
# For now, just use a placeholder text
|
||||
text = f"User audio from speaker {speaker_id}"
|
||||
|
||||
print(f"Detected end of speech, processing {len(streaming_buffer)} chunks")
|
||||
|
||||
# Add to conversation context
|
||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
|
||||
|
||||
# Generate response
|
||||
response_text = "This is a response to what you just said"
|
||||
audio_tensor = generator.generate(
|
||||
text=response_text,
|
||||
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
|
||||
context=context_segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
# Convert audio to base64 and send back to client
|
||||
audio_base64 = await encode_audio_data(audio_tensor)
|
||||
await websocket.send_json({
|
||||
"type": "audio_response",
|
||||
"audio": audio_base64
|
||||
})
|
||||
|
||||
# Clear buffer and reset silence detection
|
||||
streaming_buffer = []
|
||||
energy_window.clear()
|
||||
is_silence = False
|
||||
last_active_time = time.time()
|
||||
|
||||
# If buffer gets too large without silence, process it anyway
|
||||
# This prevents memory issues with very long streams
|
||||
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
|
||||
print("Buffer limit reached, processing audio")
|
||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||
text = f"Continued speech from speaker {speaker_id}"
|
||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
|
||||
streaming_buffer = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing streaming audio: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Error processing streaming audio: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "stop_streaming":
|
||||
is_streaming = False
|
||||
if streaming_buffer:
|
||||
# Process any remaining audio in the buffer
|
||||
full_audio = torch.cat(streaming_buffer, dim=0)
|
||||
text = f"Final streaming audio from speaker {request.get('speaker', 0)}"
|
||||
context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio))
|
||||
|
||||
streaming_buffer = []
|
||||
await websocket.send_json({
|
||||
"type": "streaming_status",
|
||||
"status": "stopped"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
Reference in New Issue
Block a user