Server API and Webpage update

This commit is contained in:
2025-03-29 22:27:44 -04:00
parent da6038f2b2
commit e08f7a2c1c

View File

@@ -9,7 +9,9 @@ import io
import whisperx import whisperx
from io import BytesIO from io import BytesIO
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from generator import load_csm_1b, Segment from generator import load_csm_1b, Segment
@@ -17,6 +19,8 @@ import uvicorn
import time import time
import gc import gc
from collections import deque from collections import deque
import socket
import requests
# Select device # Select device
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -45,6 +49,32 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Define the base directory
base_dir = os.path.dirname(os.path.abspath(__file__))
# Mount a static files directory if you have any static assets like CSS or JS
static_dir = os.path.join(base_dir, "static")
os.makedirs(static_dir, exist_ok=True) # Create the directory if it doesn't exist
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Define route to serve index.html as the main page
@app.get("/", response_class=HTMLResponse)
async def get_index():
try:
with open(os.path.join(base_dir, "index.html"), "r") as f:
return HTMLResponse(content=f.read())
except FileNotFoundError:
return HTMLResponse(content="<html><body><h1>Error: index.html not found</h1></body></html>")
# Add a favicon endpoint (optional, but good to have)
@app.get("/favicon.ico")
async def get_favicon():
favicon_path = os.path.join(static_dir, "favicon.ico")
if os.path.exists(favicon_path):
return FileResponse(favicon_path)
else:
return HTMLResponse(status_code=204) # No content
# Connection manager to handle multiple clients # Connection manager to handle multiple clients
class ConnectionManager: class ConnectionManager:
def __init__(self): def __init__(self):
@@ -259,6 +289,7 @@ async def websocket_endpoint(websocket: WebSocket):
energy_window.clear() energy_window.clear()
is_silence = False is_silence = False
last_active_time = time.time() last_active_time = time.time()
print(f"Streaming started with speaker ID: {speaker_id}")
await websocket.send_json({ await websocket.send_json({
"type": "streaming_status", "type": "streaming_status",
"status": "started" "status": "started"
@@ -269,6 +300,13 @@ async def websocket_endpoint(websocket: WebSocket):
energy_window.append(chunk_energy) energy_window.append(chunk_energy)
avg_energy = sum(energy_window) / len(energy_window) avg_energy = sum(energy_window) / len(energy_window)
# Debug audio levels
if len(energy_window) >= 5: # Only start printing after we have enough samples
if avg_energy > SILENCE_THRESHOLD:
print(f"[AUDIO] Active sound detected - Energy: {avg_energy:.6f} (threshold: {SILENCE_THRESHOLD})")
else:
print(f"[AUDIO] Silence detected - Energy: {avg_energy:.6f} (threshold: {SILENCE_THRESHOLD})")
# Check if audio is silent # Check if audio is silent
current_silence = avg_energy < SILENCE_THRESHOLD current_silence = avg_energy < SILENCE_THRESHOLD
@@ -277,33 +315,53 @@ async def websocket_endpoint(websocket: WebSocket):
# Transition to silence # Transition to silence
is_silence = True is_silence = True
last_active_time = time.time() last_active_time = time.time()
print("[STREAM] Transition to silence detected")
elif is_silence and not current_silence: elif is_silence and not current_silence:
# User started talking again # User started talking again
is_silence = False is_silence = False
print("[STREAM] User resumed speaking")
# Add chunk to buffer regardless of silence state # Add chunk to buffer regardless of silence state
streaming_buffer.append(audio_chunk) streaming_buffer.append(audio_chunk)
# Debug buffer size periodically
if len(streaming_buffer) % 10 == 0:
print(f"[BUFFER] Current size: {len(streaming_buffer)} chunks, ~{len(streaming_buffer)/5:.1f} seconds")
# Check if silence has persisted long enough to consider "stopped talking" # Check if silence has persisted long enough to consider "stopped talking"
silence_elapsed = time.time() - last_active_time silence_elapsed = time.time() - last_active_time
if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0: if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0:
# User has stopped talking - process the collected audio # User has stopped talking - process the collected audio
print(f"[STREAM] Processing audio after {silence_elapsed:.2f}s of silence")
print(f"[STREAM] Processing {len(streaming_buffer)} audio chunks (~{len(streaming_buffer)/5:.1f} seconds)")
full_audio = torch.cat(streaming_buffer, dim=0) full_audio = torch.cat(streaming_buffer, dim=0)
# Log audio statistics
audio_duration = len(full_audio) / generator.sample_rate
audio_min = torch.min(full_audio).item()
audio_max = torch.max(full_audio).item()
audio_mean = torch.mean(full_audio).item()
print(f"[AUDIO] Processed audio - Duration: {audio_duration:.2f}s, Min: {audio_min:.4f}, Max: {audio_max:.4f}, Mean: {audio_mean:.4f}")
# Process with WhisperX speech-to-text # Process with WhisperX speech-to-text
print("[ASR] Starting transcription with WhisperX...")
transcribed_text = await transcribe_audio(full_audio) transcribed_text = await transcribe_audio(full_audio)
# Log the transcription # Log the transcription
print(f"Transcribed text: '{transcribed_text}'") print(f"[ASR] Transcribed text: '{transcribed_text}'")
# Add to conversation context # Add to conversation context
if transcribed_text: if transcribed_text:
print(f"[DIALOG] Adding user utterance to context: '{transcribed_text}'")
user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)
context_segments.append(user_segment) context_segments.append(user_segment)
# Generate a contextual response # Generate a contextual response
print("[DIALOG] Generating response...")
response_text = await generate_response(transcribed_text, context_segments) response_text = await generate_response(transcribed_text, context_segments)
print(f"[DIALOG] Response text: '{response_text}'")
# Send the transcribed text to client # Send the transcribed text to client
await websocket.send_json({ await websocket.send_json({
@@ -312,12 +370,14 @@ async def websocket_endpoint(websocket: WebSocket):
}) })
# Generate audio for the response # Generate audio for the response
print("[TTS] Generating speech for response...")
audio_tensor = generator.generate( audio_tensor = generator.generate(
text=response_text, text=response_text,
speaker=1 if speaker_id == 0 else 0, # Use opposite speaker speaker=1 if speaker_id == 0 else 0, # Use opposite speaker
context=context_segments, context=context_segments,
max_audio_length_ms=10_000, max_audio_length_ms=10_000,
) )
print(f"[TTS] Generated audio length: {len(audio_tensor)/generator.sample_rate:.2f}s")
# Add response to context # Add response to context
ai_segment = Segment( ai_segment = Segment(
@@ -326,15 +386,18 @@ async def websocket_endpoint(websocket: WebSocket):
audio=audio_tensor audio=audio_tensor
) )
context_segments.append(ai_segment) context_segments.append(ai_segment)
print(f"[DIALOG] Context now has {len(context_segments)} segments")
# Convert audio to base64 and send back to client # Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor) audio_base64 = await encode_audio_data(audio_tensor)
print("[STREAM] Sending audio response to client")
await websocket.send_json({ await websocket.send_json({
"type": "audio_response", "type": "audio_response",
"text": response_text, "text": response_text,
"audio": audio_base64 "audio": audio_base64
}) })
else: else:
print("[ASR] Transcription failed or returned empty text")
# If transcription failed, send a generic response # If transcription failed, send a generic response
await websocket.send_json({ await websocket.send_json({
"type": "error", "type": "error",
@@ -346,17 +409,20 @@ async def websocket_endpoint(websocket: WebSocket):
energy_window.clear() energy_window.clear()
is_silence = False is_silence = False
last_active_time = time.time() last_active_time = time.time()
print("[STREAM] Buffer cleared, ready for next utterance")
# If buffer gets too large without silence, process it anyway # If buffer gets too large without silence, process it anyway
# This prevents memory issues with very long streams # This prevents memory issues with very long streams
elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec
print("Buffer limit reached, processing audio") print("[BUFFER] Maximum buffer size reached, processing audio")
full_audio = torch.cat(streaming_buffer, dim=0) full_audio = torch.cat(streaming_buffer, dim=0)
# Process with WhisperX speech-to-text # Process with WhisperX speech-to-text
print("[ASR] Starting forced transcription of long audio...")
transcribed_text = await transcribe_audio(full_audio) transcribed_text = await transcribe_audio(full_audio)
if transcribed_text: if transcribed_text:
print(f"[ASR] Transcribed long audio: '{transcribed_text}'")
context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)) context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio))
# Send the transcribed text to client # Send the transcribed text to client
@@ -364,11 +430,17 @@ async def websocket_endpoint(websocket: WebSocket):
"type": "transcription", "type": "transcription",
"text": transcribed_text + " (processing continued speech...)" "text": transcribed_text + " (processing continued speech...)"
}) })
else:
print("[ASR] No transcription from long audio")
streaming_buffer = [] streaming_buffer = []
print("[BUFFER] Buffer cleared due to size limit")
except Exception as e: except Exception as e:
print(f"Error processing streaming audio: {str(e)}") print(f"[ERROR] Processing streaming audio: {str(e)}")
# Print traceback for more detailed error information
import traceback
traceback.print_exc()
await websocket.send_json({ await websocket.send_json({
"type": "error", "type": "error",
"message": f"Error processing streaming audio: {str(e)}" "message": f"Error processing streaming audio: {str(e)}"
@@ -412,6 +484,53 @@ async def websocket_endpoint(websocket: WebSocket):
pass pass
manager.disconnect(websocket) manager.disconnect(websocket)
# Add this function to get the public IP address
def get_public_ip():
"""Get the server's public IP address using an external service"""
try:
# Try multiple services in case one is down
services = [
"https://api.ipify.org",
"https://ifconfig.me/ip",
"https://checkip.amazonaws.com",
]
for service in services:
try:
response = requests.get(service, timeout=3)
if response.status_code == 200:
return response.text.strip()
except:
continue
# Fallback to socket if external services fail
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# Doesn't need to be reachable, just used to determine interface
s.connect(('8.8.8.8', 1))
local_ip = s.getsockname()[0]
return local_ip
except:
return "localhost"
finally:
s.close()
except:
return "Could not determine IP address"
# Update the __main__ block
if __name__ == "__main__": if __name__ == "__main__":
public_ip = get_public_ip()
print(f"\n{'='*50}")
print(f"💬 Sesame AI Voice Chat Server")
print(f"{'='*50}")
print(f"📡 Server Information:")
print(f" - Public IP: http://{public_ip}:8000")
print(f" - Local URL: http://localhost:8000")
print(f" - WebSocket: ws://{public_ip}:8000/ws")
print(f"{'='*50}")
print(f"🌐 Connect from web browsers using: http://{public_ip}:8000")
print(f"🔧 Serving index.html from: {os.path.join(base_dir, 'index.html')}")
print(f"{'='*50}\n")
# Start the server
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)