diff --git a/Backend/api/app.py b/Backend/api/app.py deleted file mode 100644 index 018061f..0000000 --- a/Backend/api/app.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import logging -import threading -from dataclasses import dataclass -from flask import Flask -from flask_socketio import SocketIO -from flask_cors import CORS - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Configure device -import torch -if torch.cuda.is_available(): - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" - -logger.info(f"Using device: {DEVICE}") - -# Initialize Flask app -app = Flask(__name__, static_folder='../', static_url_path='') -CORS(app) -socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) - -# Global variables for conversation state -active_conversations = {} -user_queues = {} -processing_threads = {} - -# Model storage -@dataclass -class AppModels: - generator = None - tokenizer = None - llm = None - whisperx_model = None - whisperx_align_model = None - whisperx_align_metadata = None - last_language = None - -models = AppModels() - -def load_models(): - """Load all required models""" - from generator import load_csm_1b - import whisperx - import gc - from transformers import AutoModelForCausalLM, AutoTokenizer - global models - - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) - - # CSM 1B loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) - models.generator = load_csm_1b(device=DEVICE) - logger.info("CSM 1B model loaded successfully") - socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - - # WhisperX loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use WhisperX for better transcription with timestamps - # Use compute_type based on device - compute_type = "float16" if DEVICE == "cuda" else "float32" - - # Load the WhisperX model (smaller model for faster processing) - models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - - logger.info("WhisperX model loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) - - # Llama loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") - - # Configure all special tokens - models.tokenizer.pad_token = models.tokenizer.eos_token - models.tokenizer.padding_side = "left" # For causal language modeling - - # Inform the model about the pad token - if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: - models.llm.config.pad_token_id = models.tokenizer.pad_token_id - - logger.info("Llama 3.2 model loaded successfully") - socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) - except Exception as e: - logger.error(f"Error loading Llama 3.2 model: {str(e)}") - socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) - -# Load models in a background thread -threading.Thread(target=load_models, daemon=True).start() - -# Import routes and socket handlers -from api.routes import register_routes -from api.socket_handlers import register_handlers - -# Register routes and socket handlers -register_routes(app) -register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE) - -# Run server if executed directly -if __name__ == '__main__': - port = int(os.environ.get('PORT', 5000)) - debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - logger.info(f"Starting server on port {port} (debug={debug_mode})") - socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/api/routes.py b/Backend/api/routes.py deleted file mode 100644 index af1bfce..0000000 --- a/Backend/api/routes.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -import psutil -from flask import send_from_directory, jsonify, request - -def register_routes(app): - """Register HTTP routes for the application""" - - @app.route('/') - def index(): - """Serve the main application page""" - return send_from_directory(app.static_folder, 'index.html') - - @app.route('/voice-chat.js') - def serve_js(): - """Serve the JavaScript file""" - return send_from_directory(app.static_folder, 'voice-chat.js') - - @app.route('/api/status') - def system_status(): - """Return the system status""" - # Import here to avoid circular imports - from api.app import models, DEVICE - - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "device": DEVICE, - "models": { - "generator": models.generator is not None, - "asr": models.whisperx_model is not None, - "llm": models.llm is not None - }, - "versions": { - "transformers": "4.49.0", # Replace with actual version - "torch": torch.__version__ - } - }) - - @app.route('/api/system_resources') - def system_resources(): - """Return system resource usage""" - # Import here to avoid circular imports - from api.app import active_conversations, DEVICE - - # Get CPU usage - cpu_percent = psutil.cpu_percent(interval=0.1) - - # Get memory usage - memory = psutil.virtual_memory() - memory_used_gb = memory.used / (1024 ** 3) - memory_total_gb = memory.total / (1024 ** 3) - memory_percent = memory.percent - - # Get GPU memory if available - gpu_memory = {} - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - gpu_memory[f"gpu_{i}"] = { - "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), - "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), - "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) - } - - return jsonify({ - "cpu_percent": cpu_percent, - "memory": { - "used_gb": memory_used_gb, - "total_gb": memory_total_gb, - "percent": memory_percent - }, - "gpu_memory": gpu_memory, - "active_sessions": len(active_conversations) - }) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py deleted file mode 100644 index b09a2cf..0000000 --- a/Backend/api/socket_handlers.py +++ /dev/null @@ -1,393 +0,0 @@ -import os -import io -import base64 -import time -import threading -import queue -import tempfile -import gc -import logging -import traceback -from typing import Dict, List, Optional - -import torch -import torchaudio -import numpy as np -from flask import request -from flask_socketio import emit - -# Import conversation model -from generator import Segment - -logger = logging.getLogger(__name__) - -# Conversation data structure -class Conversation: - def __init__(self, session_id): - self.session_id = session_id - self.segments: List[Segment] = [] - self.current_speaker = 0 - self.ai_speaker_id = 1 # Default AI speaker ID - self.last_activity = time.time() - self.is_processing = False - - def add_segment(self, text, speaker, audio): - segment = Segment(text=text, speaker=speaker, audio=audio) - self.segments.append(segment) - self.last_activity = time.time() - return segment - - def get_context(self, max_segments=10): - """Return the most recent segments for context""" - return self.segments[-max_segments:] if self.segments else [] - -def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE): - """Register Socket.IO event handlers""" - # No need for global references, just use the parameters directly - - @socketio.on('connect') - def handle_connect(auth=None): - """Handle client connection""" - session_id = request.sid - logger.info(f"Client connected: {session_id}") - - # Initialize conversation data - if session_id not in active_conversations: - active_conversations[session_id] = Conversation(session_id) - user_queues[session_id] = queue.Queue() - processing_threads[session_id] = threading.Thread( - target=process_audio_queue, - args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE), - daemon=True - ) - processing_threads[session_id].start() - - emit('connection_status', {'status': 'connected'}) - - @socketio.on('disconnect') - def handle_disconnect(reason=None): - """Handle client disconnection""" - session_id = request.sid - logger.info(f"Client disconnected: {session_id}. Reason: {reason}") - - # Cleanup - if session_id in active_conversations: - # Mark for deletion rather than immediately removing - # as the processing thread might still be accessing it - active_conversations[session_id].is_processing = False - user_queues[session_id].put(None) # Signal thread to terminate - - @socketio.on('audio_data') - def handle_audio_data(data): - """Handle incoming audio data""" - session_id = request.sid - logger.info(f"Received audio data from {session_id}") - - # Check if the models are loaded - if models.generator is None or models.whisperx_model is None or models.llm is None: - emit('error', {'message': 'Models still loading, please wait'}) - return - - # Check if we're already processing for this session - if session_id in active_conversations and active_conversations[session_id].is_processing: - emit('error', {'message': 'Still processing previous audio, please wait'}) - return - - # Add to processing queue - if session_id in user_queues: - user_queues[session_id].put(data) - else: - emit('error', {'message': 'Session not initialized, please refresh the page'}) - -def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE): - """Background thread to process audio chunks for a session""" - logger.info(f"Started processing thread for session: {session_id}") - - try: - while session_id in active_conversations: - try: - # Get the next audio chunk with a timeout - data = q.get(timeout=120) - if data is None: # Termination signal - break - - # Process the audio and generate a response - process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE) - - except queue.Empty: - # Timeout, check if session is still valid - continue - except Exception as e: - logger.error(f"Error processing audio for {session_id}: {str(e)}") - # Create an app context for the socket emit - with app.app_context(): - socketio.emit('error', {'message': str(e)}, room=session_id) - finally: - logger.info(f"Ending processing thread for session: {session_id}") - # Clean up when thread is done - with app.app_context(): - if session_id in active_conversations: - del active_conversations[session_id] - if session_id in user_queues: # Use the passed-in reference - del user_queues[session_id] - -def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE): - """Process audio data and generate a response using WhisperX""" - if models.generator is None or models.whisperx_model is None or models.llm is None: - logger.warning("Models not yet loaded!") - with app.app_context(): - socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) - return - - logger.info(f"Processing audio for session {session_id}") - conversation = active_conversations[session_id] - - try: - # Set processing flag - conversation.is_processing = True - - # Process base64 audio data - audio_data = data['audio'] - speaker_id = data['speaker'] - logger.info(f"Received audio from speaker {speaker_id}") - - # Convert from base64 to WAV - try: - audio_bytes = base64.b64decode(audio_data.split(',')[1]) - logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") - except Exception as e: - logger.error(f"Error decoding base64 audio: {str(e)}") - raise - - # Save to temporary file for processing - with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: - temp_file.write(audio_bytes) - temp_path = temp_file.name - - try: - # Notify client that transcription is starting - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Load audio using WhisperX - import whisperx - audio = whisperx.load_audio(temp_path) - - # Check audio length and add a warning for short clips - audio_length = len(audio) / 16000 # assuming 16kHz sample rate - if audio_length < 1.0: - logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") - - # Transcribe using WhisperX - batch_size = 16 # adjust based on your GPU memory - logger.info("Running WhisperX transcription...") - - # Handle the warning about audio being shorter than 30s by suppressing it - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="audio is shorter than 30s") - result = models.whisperx_model.transcribe(audio, batch_size=batch_size) - - # Get the detected language - language_code = result["language"] - logger.info(f"Detected language: {language_code}") - - # Check if alignment model needs to be loaded or updated - if models.whisperx_align_model is None or language_code != models.last_language: - # Clean up old models if they exist - if models.whisperx_align_model is not None: - del models.whisperx_align_model - del models.whisperx_align_metadata - if DEVICE == "cuda": - gc.collect() - torch.cuda.empty_cache() - - # Load new alignment model for the detected language - logger.info(f"Loading alignment model for language: {language_code}") - models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( - language_code=language_code, device=DEVICE - ) - models.last_language = language_code - - # Align the transcript to get word-level timestamps - if result["segments"] and len(result["segments"]) > 0: - logger.info("Aligning transcript...") - result = whisperx.align( - result["segments"], - models.whisperx_align_model, - models.whisperx_align_metadata, - audio, - DEVICE, - return_char_alignments=False - ) - - # Process the segments for better output - for segment in result["segments"]: - # Round timestamps for better display - segment["start"] = round(segment["start"], 2) - segment["end"] = round(segment["end"], 2) - # Add a confidence score if not present - if "confidence" not in segment: - segment["confidence"] = 1.0 # Default confidence - - # Extract the full text from all segments - user_text = ' '.join([segment['text'] for segment in result['segments']]) - - # If no text was recognized, don't process further - if not user_text or len(user_text.strip()) == 0: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - - logger.info(f"Transcription: {user_text}") - - # Load audio for CSM input - waveform, sample_rate = torchaudio.load(temp_path) - - # Normalize to mono if needed - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - # Resample to the CSM sample rate if needed - if sample_rate != models.generator.sample_rate: - waveform = torchaudio.functional.resample( - waveform, - orig_freq=sample_rate, - new_freq=models.generator.sample_rate - ) - - # Add the user's message to conversation history - user_segment = conversation.add_segment( - text=user_text, - speaker=speaker_id, - audio=waveform.squeeze() - ) - - # Send transcription to client with detailed segments - with app.app_context(): - socketio.emit('transcription', { - 'text': user_text, - 'speaker': speaker_id, - 'segments': result['segments'] # Include the detailed segments with timestamps - }, room=session_id) - - # Generate AI response using Llama - with app.app_context(): - socketio.emit('processing_status', {'status': 'generating'}, room=session_id) - - # Create prompt from conversation history - conversation_history = "" - for segment in conversation.segments[-5:]: # Last 5 segments for context - role = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{role}: {segment.text}\n" - - # Add final prompt - prompt = f"{conversation_history}Assistant: " - - # Generate response with Llama - try: - # Ensure pad token is set - if models.tokenizer.pad_token is None: - models.tokenizer.pad_token = models.tokenizer.eos_token - - input_tokens = models.tokenizer( - prompt, - return_tensors="pt", - padding=True, - return_attention_mask=True - ) - input_ids = input_tokens.input_ids.to(DEVICE) - attention_mask = input_tokens.attention_mask.to(DEVICE) - - with torch.no_grad(): - generated_ids = models.llm.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=100, - temperature=0.7, - top_p=0.9, - do_sample=True, - pad_token_id=models.tokenizer.eos_token_id - ) - - # Decode the response - response_text = models.tokenizer.decode( - generated_ids[0][input_ids.shape[1]:], - skip_special_tokens=True - ).strip() - except Exception as e: - logger.error(f"Error generating response: {str(e)}") - logger.error(traceback.format_exc()) - response_text = "I'm sorry, I encountered an error while processing your request." - - # Synthesize speech - with app.app_context(): - socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) - - # Start sending the audio response - socketio.emit('audio_response_start', { - 'text': response_text, - 'total_chunks': 1, - 'chunk_index': 0 - }, room=session_id) - - # Define AI speaker ID - ai_speaker_id = conversation.ai_speaker_id - - # Generate audio - audio_tensor = models.generator.generate( - text=response_text, - speaker=ai_speaker_id, - context=conversation.get_context(), - max_audio_length_ms=10_000, - temperature=0.9 - ) - - # Add AI response to conversation history - ai_segment = conversation.add_segment( - text=response_text, - speaker=ai_speaker_id, - audio=audio_tensor - ) - - # Convert audio to WAV format - with io.BytesIO() as wav_io: - torchaudio.save( - wav_io, - audio_tensor.unsqueeze(0).cpu(), - models.generator.sample_rate, - format="wav" - ) - wav_io.seek(0) - wav_data = wav_io.read() - - # Convert WAV data to base64 - audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" - - # Send audio chunk to client - with app.app_context(): - socketio.emit('audio_response_chunk', { - 'chunk': audio_base64, - 'chunk_index': 0, - 'total_chunks': 1, - 'is_last': True - }, room=session_id) - - # Signal completion - socketio.emit('audio_response_complete', { - 'text': response_text - }, room=session_id) - - finally: - # Clean up temp file - if os.path.exists(temp_path): - os.unlink(temp_path) - - except Exception as e: - logger.error(f"Error processing audio: {str(e)}") - logger.error(traceback.format_exc()) - with app.app_context(): - socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) - finally: - # Reset processing flag - conversation.is_processing = False \ No newline at end of file diff --git a/Backend/app.py b/Backend/app.py new file mode 100644 index 0000000..091de8e --- /dev/null +++ b/Backend/app.py @@ -0,0 +1,229 @@ +import os +import io +import base64 +import time +import torch +import torchaudio +import numpy as np +from flask import Flask, render_template, request +from flask_socketio import SocketIO, emit +from transformers import AutoModelForCausalLM, AutoTokenizer +import speech_recognition as sr +from generator import load_csm_1b, Segment +from collections import deque + +app = Flask(__name__) +app.config['SECRET_KEY'] = 'your-secret-key' +socketio = SocketIO(app, cors_allowed_origins="*") + +# Select the best available device +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" +print(f"Using device: {device}") + +# Initialize CSM model for audio generation +print("Loading CSM model...") +csm_generator = load_csm_1b(device=device) + +# Initialize Llama 3.2 model for response generation +print("Loading Llama 3.2 model...") +llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources +llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id) +llm_model = AutoModelForCausalLM.from_pretrained( + llm_model_id, + torch_dtype=torch.bfloat16, + device_map=device +) + +# Initialize speech recognition +recognizer = sr.Recognizer() + +# Store conversation context +conversation_context = {} # session_id -> context + +@app.route('/') +def index(): + return render_template('index.html') + +@socketio.on('connect') +def handle_connect(): + print(f"Client connected: {request.sid}") + conversation_context[request.sid] = { + 'segments': [], + 'speakers': [0, 1], # 0 = user, 1 = bot + 'audio_buffer': deque(maxlen=10), # Store recent audio chunks + 'is_speaking': False, + 'silence_start': None + } + emit('ready', {'message': 'Connection established'}) + +@socketio.on('disconnect') +def handle_disconnect(): + print(f"Client disconnected: {request.sid}") + if request.sid in conversation_context: + del conversation_context[request.sid] + +@socketio.on('start_speaking') +def handle_start_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = True + conversation_context[request.sid]['audio_buffer'].clear() + print(f"User {request.sid} started speaking") + +@socketio.on('audio_chunk') +def handle_audio_chunk(data): + if request.sid not in conversation_context: + return + + context = conversation_context[request.sid] + + # Decode audio data + audio_data = base64.b64decode(data['audio']) + audio_numpy = np.frombuffer(audio_data, dtype=np.float32) + audio_tensor = torch.tensor(audio_numpy) + + # Add to buffer + context['audio_buffer'].append(audio_tensor) + + # Check for silence to detect end of speech + if context['is_speaking'] and is_silence(audio_tensor): + if context['silence_start'] is None: + context['silence_start'] = time.time() + elif time.time() - context['silence_start'] > 1.0: # 1 second of silence + # Process the complete utterance + process_user_utterance(request.sid) + else: + context['silence_start'] = None + +@socketio.on('stop_speaking') +def handle_stop_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = False + process_user_utterance(request.sid) + print(f"User {request.sid} stopped speaking") + +def is_silence(audio_tensor, threshold=0.02): + """Check if an audio chunk is silence based on amplitude threshold""" + return torch.mean(torch.abs(audio_tensor)) < threshold + +def process_user_utterance(session_id): + """Process completed user utterance, generate response and send audio back""" + context = conversation_context[session_id] + + if not context['audio_buffer']: + return + + # Combine audio chunks + full_audio = torch.cat(list(context['audio_buffer']), dim=0) + context['audio_buffer'].clear() + context['is_speaking'] = False + context['silence_start'] = None + + # Convert audio to 16kHz for speech recognition + audio_16k = torchaudio.functional.resample( + full_audio, + orig_freq=44100, # Assuming 44.1kHz from client + new_freq=16000 + ) + + # Transcribe speech + try: + # Convert to wav format for speech_recognition + audio_data = io.BytesIO() + torchaudio.save(audio_data, audio_16k.unsqueeze(0), 16000, format="wav") + audio_data.seek(0) + + with sr.AudioFile(audio_data) as source: + audio = recognizer.record(source) + user_text = recognizer.recognize_google(audio) + print(f"Transcribed: {user_text}") + + # Add to conversation segments + user_segment = Segment( + text=user_text, + speaker=0, # User is speaker 0 + audio=full_audio + ) + context['segments'].append(user_segment) + + # Generate bot response + bot_response = generate_llm_response(user_text, context['segments']) + print(f"Bot response: {bot_response}") + + # Convert to audio using CSM + bot_audio = generate_audio_response(bot_response, context['segments']) + + # Convert audio to base64 for sending over websocket + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Add bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=bot_audio + ) + context['segments'].append(bot_segment) + + # Send transcribed text to client + emit('transcription', {'text': user_text}, room=session_id) + + # Send audio response to client + emit('audio_response', { + 'audio': audio_b64, + 'text': bot_response + }, room=session_id) + + except Exception as e: + print(f"Error processing speech: {e}") + emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) + +def generate_llm_response(user_text, conversation_segments): + """Generate text response using Llama 3.2""" + # Format conversation history for the LLM + conversation_history = "" + for segment in conversation_segments[-5:]: # Use last 5 utterances for context + speaker_name = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{speaker_name}: {segment.text}\n" + + # Add the current user query + conversation_history += f"User: {user_text}\nAssistant:" + + # Generate response + inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + output = llm_model.generate( + inputs.input_ids, + max_new_tokens=150, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + + response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return response.strip() + +def generate_audio_response(text, conversation_segments): + """Generate audio response using CSM""" + # Use the last few conversation segments as context + context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + + # Generate audio for bot response + audio = csm_generator.generate( + text=text, + speaker=1, # Bot is speaker 1 + context=context_segments, + max_audio_length_ms=10000, # 10 seconds max + temperature=0.9, + topk=50 + ) + + return audio + +if __name__ == '__main__': + socketio.run(app, host='0.0.0.0', port=5000, debug=True) \ No newline at end of file diff --git a/Backend/api/generator.py b/Backend/generator.py similarity index 100% rename from Backend/api/generator.py rename to Backend/generator.py diff --git a/Backend/index.html b/Backend/index.html new file mode 100644 index 0000000..e1f5f94 --- /dev/null +++ b/Backend/index.html @@ -0,0 +1,212 @@ + + + + + + + Audio Conversation Bot + + + + +

Audio Conversation Bot

+
+
+ +
+
Not connected
+ + + + \ No newline at end of file diff --git a/Backend/api/models.py b/Backend/models.py similarity index 100% rename from Backend/api/models.py rename to Backend/models.py diff --git a/Backend/run_csm.py b/Backend/run_csm.py new file mode 100644 index 0000000..0062973 --- /dev/null +++ b/Backend/run_csm.py @@ -0,0 +1,117 @@ +import os +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from generator import load_csm_1b, Segment +from dataclasses import dataclass + +# Disable Triton compilation +os.environ["NO_TORCH_COMPILE"] = "1" + +# Default prompts are available at https://hf.co/sesame/csm-1b +prompt_filepath_conversational_a = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_a.wav" +) +prompt_filepath_conversational_b = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_b.wav" +) + +SPEAKER_PROMPTS = { + "conversational_a": { + "text": ( + "like revising for an exam I'd have to try and like keep up the momentum because I'd " + "start really early I'd be like okay I'm gonna start revising now and then like " + "you're revising for ages and then I just like start losing steam I didn't do that " + "for the exam we had recently to be fair that was a more of a last minute scenario " + "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " + "sort of start the day with this not like a panic but like a" + ), + "audio": prompt_filepath_conversational_a + }, + "conversational_b": { + "text": ( + "like a super Mario level. Like it's very like high detail. And like, once you get " + "into the park, it just like, everything looks like a computer game and they have all " + "these, like, you know, if, if there's like a, you know, like in a Mario game, they " + "will have like a question block. And if you like, you know, punch it, a coin will " + "come out. So like everyone, when they come into the park, they get like this little " + "bracelet and then you can go punching question blocks around." + ), + "audio": prompt_filepath_conversational_b + } +} + +def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = audio_tensor.squeeze(0) + # Resample is lazy so we can always call it + audio_tensor = torchaudio.functional.resample( + audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate + ) + return audio_tensor + +def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: + audio_tensor = load_prompt_audio(audio_path, sample_rate) + return Segment(text=text, speaker=speaker, audio=audio_tensor) + +def main(): + # Select the best available device, skipping MPS due to float64 limitations + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + print(f"Using device: {device}") + + # Load model + generator = load_csm_1b(device) + + # Prepare prompts + prompt_a = prepare_prompt( + SPEAKER_PROMPTS["conversational_a"]["text"], + 0, + SPEAKER_PROMPTS["conversational_a"]["audio"], + generator.sample_rate + ) + + prompt_b = prepare_prompt( + SPEAKER_PROMPTS["conversational_b"]["text"], + 1, + SPEAKER_PROMPTS["conversational_b"]["audio"], + generator.sample_rate + ) + + # Generate conversation + conversation = [ + {"text": "Hey how are you doing?", "speaker_id": 0}, + {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, + {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, + {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} + ] + + # Generate each utterance + generated_segments = [] + prompt_segments = [prompt_a, prompt_b] + + for utterance in conversation: + print(f"Generating: {utterance['text']}") + audio_tensor = generator.generate( + text=utterance['text'], + speaker=utterance['speaker_id'], + context=prompt_segments + generated_segments, + max_audio_length_ms=10_000, + ) + generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) + + # Concatenate all generations + all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) + torchaudio.save( + "full_conversation.wav", + all_audio.unsqueeze(0).cpu(), + generator.sample_rate + ) + print("Successfully generated full_conversation.wav") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py deleted file mode 100644 index b8af6b7..0000000 --- a/Backend/server.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -CSM Voice Chat Server -A voice chat application that uses CSM 1B for voice synthesis, -WhisperX for speech recognition, and Llama 3.2 for language generation. -""" - -# Start the Flask application -from api.app import app, socketio - -if __name__ == '__main__': - import os - - port = int(os.environ.get('PORT', 5000)) - debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - - print(f"Starting server on port {port} (debug={debug_mode})") - print("Visit http://localhost:5000 to access the application") - - socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/api/watermarking.py b/Backend/watermarking.py similarity index 100% rename from Backend/api/watermarking.py rename to Backend/watermarking.py