diff --git a/Backend/server.py b/Backend/server.py index b9736b5..97b346b 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -9,7 +9,9 @@ import io import whisperx from io import BytesIO from typing import List, Dict, Any, Optional -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request +from fastapi.responses import HTMLResponse, FileResponse +from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from generator import load_csm_1b, Segment @@ -17,6 +19,8 @@ import uvicorn import time import gc from collections import deque +import socket +import requests # Select device if torch.cuda.is_available(): @@ -45,6 +49,32 @@ app.add_middleware( allow_headers=["*"], ) +# Define the base directory +base_dir = os.path.dirname(os.path.abspath(__file__)) + +# Mount a static files directory if you have any static assets like CSS or JS +static_dir = os.path.join(base_dir, "static") +os.makedirs(static_dir, exist_ok=True) # Create the directory if it doesn't exist +app.mount("/static", StaticFiles(directory=static_dir), name="static") + +# Define route to serve index.html as the main page +@app.get("/", response_class=HTMLResponse) +async def get_index(): + try: + with open(os.path.join(base_dir, "index.html"), "r") as f: + return HTMLResponse(content=f.read()) + except FileNotFoundError: + return HTMLResponse(content="

Error: index.html not found

") + +# Add a favicon endpoint (optional, but good to have) +@app.get("/favicon.ico") +async def get_favicon(): + favicon_path = os.path.join(static_dir, "favicon.ico") + if os.path.exists(favicon_path): + return FileResponse(favicon_path) + else: + return HTMLResponse(status_code=204) # No content + # Connection manager to handle multiple clients class ConnectionManager: def __init__(self): @@ -259,6 +289,7 @@ async def websocket_endpoint(websocket: WebSocket): energy_window.clear() is_silence = False last_active_time = time.time() + print(f"Streaming started with speaker ID: {speaker_id}") await websocket.send_json({ "type": "streaming_status", "status": "started" @@ -269,6 +300,13 @@ async def websocket_endpoint(websocket: WebSocket): energy_window.append(chunk_energy) avg_energy = sum(energy_window) / len(energy_window) + # Debug audio levels + if len(energy_window) >= 5: # Only start printing after we have enough samples + if avg_energy > SILENCE_THRESHOLD: + print(f"[AUDIO] Active sound detected - Energy: {avg_energy:.6f} (threshold: {SILENCE_THRESHOLD})") + else: + print(f"[AUDIO] Silence detected - Energy: {avg_energy:.6f} (threshold: {SILENCE_THRESHOLD})") + # Check if audio is silent current_silence = avg_energy < SILENCE_THRESHOLD @@ -277,33 +315,53 @@ async def websocket_endpoint(websocket: WebSocket): # Transition to silence is_silence = True last_active_time = time.time() + print("[STREAM] Transition to silence detected") elif is_silence and not current_silence: # User started talking again is_silence = False + print("[STREAM] User resumed speaking") # Add chunk to buffer regardless of silence state streaming_buffer.append(audio_chunk) + # Debug buffer size periodically + if len(streaming_buffer) % 10 == 0: + print(f"[BUFFER] Current size: {len(streaming_buffer)} chunks, ~{len(streaming_buffer)/5:.1f} seconds") + # Check if silence has persisted long enough to consider "stopped talking" silence_elapsed = time.time() - last_active_time if is_silence and silence_elapsed >= SILENCE_DURATION_SEC and len(streaming_buffer) > 0: # User has stopped talking - process the collected audio + print(f"[STREAM] Processing audio after {silence_elapsed:.2f}s of silence") + print(f"[STREAM] Processing {len(streaming_buffer)} audio chunks (~{len(streaming_buffer)/5:.1f} seconds)") + full_audio = torch.cat(streaming_buffer, dim=0) + # Log audio statistics + audio_duration = len(full_audio) / generator.sample_rate + audio_min = torch.min(full_audio).item() + audio_max = torch.max(full_audio).item() + audio_mean = torch.mean(full_audio).item() + print(f"[AUDIO] Processed audio - Duration: {audio_duration:.2f}s, Min: {audio_min:.4f}, Max: {audio_max:.4f}, Mean: {audio_mean:.4f}") + # Process with WhisperX speech-to-text + print("[ASR] Starting transcription with WhisperX...") transcribed_text = await transcribe_audio(full_audio) # Log the transcription - print(f"Transcribed text: '{transcribed_text}'") + print(f"[ASR] Transcribed text: '{transcribed_text}'") # Add to conversation context if transcribed_text: + print(f"[DIALOG] Adding user utterance to context: '{transcribed_text}'") user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) context_segments.append(user_segment) # Generate a contextual response + print("[DIALOG] Generating response...") response_text = await generate_response(transcribed_text, context_segments) + print(f"[DIALOG] Response text: '{response_text}'") # Send the transcribed text to client await websocket.send_json({ @@ -312,12 +370,14 @@ async def websocket_endpoint(websocket: WebSocket): }) # Generate audio for the response + print("[TTS] Generating speech for response...") audio_tensor = generator.generate( text=response_text, speaker=1 if speaker_id == 0 else 0, # Use opposite speaker context=context_segments, max_audio_length_ms=10_000, ) + print(f"[TTS] Generated audio length: {len(audio_tensor)/generator.sample_rate:.2f}s") # Add response to context ai_segment = Segment( @@ -326,15 +386,18 @@ async def websocket_endpoint(websocket: WebSocket): audio=audio_tensor ) context_segments.append(ai_segment) + print(f"[DIALOG] Context now has {len(context_segments)} segments") # Convert audio to base64 and send back to client audio_base64 = await encode_audio_data(audio_tensor) + print("[STREAM] Sending audio response to client") await websocket.send_json({ "type": "audio_response", "text": response_text, "audio": audio_base64 }) else: + print("[ASR] Transcription failed or returned empty text") # If transcription failed, send a generic response await websocket.send_json({ "type": "error", @@ -346,17 +409,20 @@ async def websocket_endpoint(websocket: WebSocket): energy_window.clear() is_silence = False last_active_time = time.time() + print("[STREAM] Buffer cleared, ready for next utterance") # If buffer gets too large without silence, process it anyway # This prevents memory issues with very long streams elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec - print("Buffer limit reached, processing audio") + print("[BUFFER] Maximum buffer size reached, processing audio") full_audio = torch.cat(streaming_buffer, dim=0) # Process with WhisperX speech-to-text + print("[ASR] Starting forced transcription of long audio...") transcribed_text = await transcribe_audio(full_audio) if transcribed_text: + print(f"[ASR] Transcribed long audio: '{transcribed_text}'") context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)) # Send the transcribed text to client @@ -364,11 +430,17 @@ async def websocket_endpoint(websocket: WebSocket): "type": "transcription", "text": transcribed_text + " (processing continued speech...)" }) + else: + print("[ASR] No transcription from long audio") streaming_buffer = [] + print("[BUFFER] Buffer cleared due to size limit") except Exception as e: - print(f"Error processing streaming audio: {str(e)}") + print(f"[ERROR] Processing streaming audio: {str(e)}") + # Print traceback for more detailed error information + import traceback + traceback.print_exc() await websocket.send_json({ "type": "error", "message": f"Error processing streaming audio: {str(e)}" @@ -412,6 +484,53 @@ async def websocket_endpoint(websocket: WebSocket): pass manager.disconnect(websocket) +# Add this function to get the public IP address +def get_public_ip(): + """Get the server's public IP address using an external service""" + try: + # Try multiple services in case one is down + services = [ + "https://api.ipify.org", + "https://ifconfig.me/ip", + "https://checkip.amazonaws.com", + ] + + for service in services: + try: + response = requests.get(service, timeout=3) + if response.status_code == 200: + return response.text.strip() + except: + continue + + # Fallback to socket if external services fail + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # Doesn't need to be reachable, just used to determine interface + s.connect(('8.8.8.8', 1)) + local_ip = s.getsockname()[0] + return local_ip + except: + return "localhost" + finally: + s.close() + except: + return "Could not determine IP address" +# Update the __main__ block if __name__ == "__main__": + public_ip = get_public_ip() + print(f"\n{'='*50}") + print(f"💬 Sesame AI Voice Chat Server") + print(f"{'='*50}") + print(f"📡 Server Information:") + print(f" - Public IP: http://{public_ip}:8000") + print(f" - Local URL: http://localhost:8000") + print(f" - WebSocket: ws://{public_ip}:8000/ws") + print(f"{'='*50}") + print(f"🌐 Connect from web browsers using: http://{public_ip}:8000") + print(f"🔧 Serving index.html from: {os.path.join(base_dir, 'index.html')}") + print(f"{'='*50}\n") + + # Start the server uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file