Lots of Client Testing Changes

This commit is contained in:
2025-03-29 21:39:11 -04:00
parent b738423272
commit 131c8c9e78
2 changed files with 404 additions and 1 deletions

View File

@@ -12,6 +12,8 @@ from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from generator import load_csm_1b, Segment
import uvicorn
import time
from collections import deque
# Select device
if torch.cuda.is_available():
@@ -48,6 +50,9 @@ class ConnectionManager:
manager = ConnectionManager()
# Silence detection parameters
SILENCE_THRESHOLD = 0.01 # Adjust based on your audio normalization
SILENCE_DURATION_SEC = 1.0 # How long silence must persist to be considered "stopped talking"
# Helper function to convert audio data
async def decode_audio_data(audio_data: str) -> torch.Tensor:
@@ -92,6 +97,13 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
context_segments = [] # Store conversation context
streaming_buffer = [] # Buffer for streaming audio chunks
is_streaming = False
# Variables for silence detection
last_active_time = time.time()
is_silence = False
energy_window = deque(maxlen=10) # For tracking recent audio energy
try:
while True:
@@ -160,6 +172,114 @@ async def websocket_endpoint(websocket: WebSocket):
"type": "context_updated",
"message": "Context cleared"
})
elif action == "stream_audio":
try:
speaker_id = request.get("speaker", 0)
audio_data = request.get("audio", "")
# Convert received audio to tensor
audio_chunk = await decode_audio_data(audio_data)
# Start streaming mode if not already started
if not is_streaming:
is_streaming = True
streaming_buffer = []
energy_window.clear()
is_silence = False
last_active_time = time.time()
await websocket.send_json({
"type": "streaming_status",
"status": "started"
})
# Calculate audio energy for silence detection
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
energy_window.append(chunk_energy)
avg_energy = sum(energy_window) / len(energy_window)
# Check if audio is silent
current_silence = avg_energy < SILENCE_THRESHOLD
# Track silence transition
if not is_silence and current_silence:
# Transition to silence
is_silence = True
last_active_time = time.time()
elif is_silence and not current_silence:
# User started talking again
is_silence = False
# Add chunk to buffer regardless of silence state
streaming_buffer.append(audio_chunk)
# Check if silence has persisted long enough to consider "stopped talking"
silence_elapsed = time.time() - last_active_time
if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0:
# User has stopped talking - process the collected audio
full_audio = torch.cat(streaming_buffer, dim=0)
# Process with speech-to-text (you would need to implement this)
# For now, just use a placeholder text
text = f"User audio from speaker {speaker_id}"
print(f"Detected end of speech, processing {len(streaming_buffer)} chunks")
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
# Generate response
response_text = "This is a response to what you just said"
audio_tensor = generator.generate(
text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments,
max_audio_length_ms=10_000,
)
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"audio": audio_base64
})
# Clear buffer and reset silence detection
streaming_buffer = []
energy_window.clear()
is_silence = False
last_active_time = time.time()
# If buffer gets too large without silence, process it anyway
# This prevents memory issues with very long streams
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
print("Buffer limit reached, processing audio")
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Continued speech from speaker {speaker_id}"
context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio))
streaming_buffer = []
except Exception as e:
print(f"Error processing streaming audio: {str(e)}")
await websocket.send_json({
"type": "error",
"message": f"Error processing streaming audio: {str(e)}"
})
elif action == "stop_streaming":
is_streaming = False
if streaming_buffer:
# Process any remaining audio in the buffer
full_audio = torch.cat(streaming_buffer, dim=0)
text = f"Final streaming audio from speaker {request.get('speaker', 0)}"
context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio))
streaming_buffer = []
await websocket.send_json({
"type": "streaming_status",
"status": "stopped"
})
except WebSocketDisconnect:
manager.disconnect(websocket)