diff --git a/.gitignore b/.gitignore index 1170717..e06d006 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,4 @@ dist .yarn/build-state.yml .yarn/install-state.gz .pnp.* +Backend/test.py diff --git a/Backend/index.html b/Backend/index.html index 309364f..f4ff6a0 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -10,60 +10,113 @@ max-width: 800px; margin: 0 auto; padding: 20px; + background-color: #f9f9f9; } .conversation { - border: 1px solid #ccc; - border-radius: 8px; - padding: 15px; - height: 300px; + border: 1px solid #ddd; + border-radius: 12px; + padding: 20px; + height: 400px; overflow-y: auto; - margin-bottom: 15px; + margin-bottom: 20px; + background-color: white; + box-shadow: 0 2px 10px rgba(0,0,0,0.05); } .message { - margin-bottom: 10px; - padding: 8px; - border-radius: 8px; + margin-bottom: 15px; + padding: 12px; + border-radius: 12px; + max-width: 80%; + line-height: 1.4; } .user { background-color: #e3f2fd; text-align: right; + margin-left: auto; + border-bottom-right-radius: 4px; } .ai { background-color: #f1f1f1; + margin-right: auto; + border-bottom-left-radius: 4px; + } + .system { + background-color: #f8f9fa; + font-style: italic; + text-align: center; + font-size: 0.9em; + color: #666; + padding: 8px; + margin: 10px auto; + max-width: 90%; } .controls { display: flex; - flex-direction: column; - gap: 10px; - } - .input-row { - display: flex; - gap: 10px; - } - input[type="text"] { - flex-grow: 1; - padding: 8px; - border-radius: 4px; - border: 1px solid #ccc; + gap: 15px; + justify-content: center; + align-items: center; } button { - padding: 8px 16px; - border-radius: 4px; + padding: 12px 24px; + border-radius: 24px; border: none; background-color: #4CAF50; color: white; cursor: pointer; + font-weight: bold; + transition: all 0.2s ease; + box-shadow: 0 2px 5px rgba(0,0,0,0.1); } button:hover { background-color: #45a049; + box-shadow: 0 4px 8px rgba(0,0,0,0.15); } .recording { background-color: #f44336; + animation: pulse 1.5s infinite; + } + .processing { + background-color: #FFA500; } select { - padding: 8px; - border-radius: 4px; - border: 1px solid #ccc; + padding: 10px; + border-radius: 24px; + border: 1px solid #ddd; + background-color: white; + } + .transcript { + font-style: italic; + color: #666; + margin-top: 5px; + } + @keyframes pulse { + 0% { opacity: 1; } + 50% { opacity: 0.7; } + 100% { opacity: 1; } + } + .status-indicator { + display: flex; + align-items: center; + justify-content: center; + margin-top: 10px; + gap: 5px; + } + .status-dot { + width: 10px; + height: 10px; + border-radius: 50%; + background-color: #ccc; + } + .status-dot.active { + background-color: #4CAF50; + } + .status-text { + font-size: 0.9em; + color: #666; + } + audio { + width: 100%; + margin-top: 5px; } @@ -72,30 +125,25 @@
-
- - - -
- -
- - -
+ + + +
+ +
+
+
Not connected
\ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index bfdc590..b9736b5 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -5,6 +5,8 @@ import asyncio import torch import torchaudio import numpy as np +import io +import whisperx from io import BytesIO from typing import List, Dict, Any, Optional from fastapi import FastAPI, WebSocket, WebSocketDisconnect @@ -13,6 +15,7 @@ from pydantic import BaseModel from generator import load_csm_1b, Segment import uvicorn import time +import gc from collections import deque # Select device @@ -25,6 +28,12 @@ print(f"Using device: {device}") # Initialize the model generator = load_csm_1b(device=device) +# Initialize WhisperX for ASR +print("Loading WhisperX model...") +# Use a smaller model for faster response times +asr_model = whisperx.load_model("medium", device, compute_type="float16") +print("WhisperX model loaded!") + app = FastAPI() # Add CORS middleware to allow cross-origin requests @@ -93,6 +102,68 @@ async def encode_audio_data(audio_tensor: torch.Tensor) -> str: return f"data:audio/wav;base64,{audio_base64}" +async def transcribe_audio(audio_tensor: torch.Tensor) -> str: + """Transcribe audio using WhisperX""" + try: + # Save the tensor to a temporary file + temp_file = BytesIO() + torchaudio.save(temp_file, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") + temp_file.seek(0) + + # Create a temporary file on disk (WhisperX requires a file path) + temp_path = "temp_audio.wav" + with open(temp_path, "wb") as f: + f.write(temp_file.read()) + + # Load and transcribe the audio + audio = whisperx.load_audio(temp_path) + result = asr_model.transcribe(audio, batch_size=16) + + # Clean up + os.remove(temp_path) + + # Get the transcription text + if result["segments"] and len(result["segments"]) > 0: + # Combine all segments + transcription = " ".join([segment["text"] for segment in result["segments"]]) + print(f"Transcription: {transcription}") + return transcription.strip() + else: + return "" + except Exception as e: + print(f"Error in transcription: {str(e)}") + return "" + + +async def generate_response(text: str, conversation_history: List[Segment]) -> str: + """Generate a contextual response based on the transcribed text""" + # Simple response logic - can be replaced with a more sophisticated LLM in the future + responses = { + "hello": "Hello there! How are you doing today?", + "how are you": "I'm doing well, thanks for asking! How about you?", + "what is your name": "I'm Sesame, your voice assistant. How can I help you?", + "bye": "Goodbye! It was nice chatting with you.", + "thank you": "You're welcome! Is there anything else I can help with?", + "weather": "I don't have real-time weather data, but I hope it's nice where you are!", + "help": "I can chat with you using natural voice. Just speak normally and I'll respond.", + } + + text_lower = text.lower() + + # Check for matching keywords + for key, response in responses.items(): + if key in text_lower: + return response + + # Default responses based on text length + if not text: + return "I didn't catch that. Could you please repeat?" + elif len(text) < 10: + return "Thanks for your message. Could you elaborate a bit more?" + else: + return f"I understand you said '{text}'. That's interesting! Can you tell me more about that?" + + @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) @@ -220,30 +291,55 @@ async def websocket_endpoint(websocket: WebSocket): # User has stopped talking - process the collected audio full_audio = torch.cat(streaming_buffer, dim=0) - # Process with speech-to-text (you would need to implement this) - # For now, just use a placeholder text - text = f"User audio from speaker {speaker_id}" + # Process with WhisperX speech-to-text + transcribed_text = await transcribe_audio(full_audio) - print(f"Detected end of speech, processing {len(streaming_buffer)} chunks") + # Log the transcription + print(f"Transcribed text: '{transcribed_text}'") # Add to conversation context - context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio)) - - # Generate response - response_text = "This is a response to what you just said" - audio_tensor = generator.generate( - text=response_text, - speaker=1 if speaker_id == 0 else 0, # Use opposite speaker - context=context_segments, - max_audio_length_ms=10_000, - ) - - # Convert audio to base64 and send back to client - audio_base64 = await encode_audio_data(audio_tensor) - await websocket.send_json({ - "type": "audio_response", - "audio": audio_base64 - }) + if transcribed_text: + user_segment = Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio) + context_segments.append(user_segment) + + # Generate a contextual response + response_text = await generate_response(transcribed_text, context_segments) + + # Send the transcribed text to client + await websocket.send_json({ + "type": "transcription", + "text": transcribed_text + }) + + # Generate audio for the response + audio_tensor = generator.generate( + text=response_text, + speaker=1 if speaker_id == 0 else 0, # Use opposite speaker + context=context_segments, + max_audio_length_ms=10_000, + ) + + # Add response to context + ai_segment = Segment( + text=response_text, + speaker=1 if speaker_id == 0 else 0, + audio=audio_tensor + ) + context_segments.append(ai_segment) + + # Convert audio to base64 and send back to client + audio_base64 = await encode_audio_data(audio_tensor) + await websocket.send_json({ + "type": "audio_response", + "text": response_text, + "audio": audio_base64 + }) + else: + # If transcription failed, send a generic response + await websocket.send_json({ + "type": "error", + "message": "Sorry, I couldn't understand what you said. Could you try again?" + }) # Clear buffer and reset silence detection streaming_buffer = [] @@ -256,8 +352,19 @@ async def websocket_endpoint(websocket: WebSocket): elif len(streaming_buffer) >= 30: # ~6 seconds of audio at 5 chunks/sec print("Buffer limit reached, processing audio") full_audio = torch.cat(streaming_buffer, dim=0) - text = f"Continued speech from speaker {speaker_id}" - context_segments.append(Segment(text=text, speaker=speaker_id, audio=full_audio)) + + # Process with WhisperX speech-to-text + transcribed_text = await transcribe_audio(full_audio) + + if transcribed_text: + context_segments.append(Segment(text=transcribed_text, speaker=speaker_id, audio=full_audio)) + + # Send the transcribed text to client + await websocket.send_json({ + "type": "transcription", + "text": transcribed_text + " (processing continued speech...)" + }) + streaming_buffer = [] except Exception as e: @@ -269,11 +376,21 @@ async def websocket_endpoint(websocket: WebSocket): elif action == "stop_streaming": is_streaming = False - if streaming_buffer: + if streaming_buffer and len(streaming_buffer) > 5: # Only process if there's meaningful audio # Process any remaining audio in the buffer full_audio = torch.cat(streaming_buffer, dim=0) - text = f"Final streaming audio from speaker {request.get('speaker', 0)}" - context_segments.append(Segment(text=text, speaker=request.get("speaker", 0), audio=full_audio)) + + # Process with WhisperX speech-to-text + transcribed_text = await transcribe_audio(full_audio) + + if transcribed_text: + context_segments.append(Segment(text=transcribed_text, speaker=request.get("speaker", 0), audio=full_audio)) + + # Send the transcribed text to client + await websocket.send_json({ + "type": "transcription", + "text": transcribed_text + }) streaming_buffer = [] await websocket.send_json({ @@ -286,12 +403,15 @@ async def websocket_endpoint(websocket: WebSocket): print("Client disconnected") except Exception as e: print(f"Error: {str(e)}") - await websocket.send_json({ - "type": "error", - "message": str(e) - }) + try: + await websocket.send_json({ + "type": "error", + "message": str(e) + }) + except: + pass manager.disconnect(websocket) if __name__ == "__main__": - uvicorn.run(app, host="localhost", port=8000) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file