From dd85e8a9abac0d890d2bfb6d1da126b97c4fad3e Mon Sep 17 00:00:00 2001 From: Surya Vemulapalli Date: Sun, 30 Mar 2025 01:22:25 -0400 Subject: [PATCH 01/30] Made a button for the calls --- React/src/app/page.tsx | 13 ++++++++++--- React/src/pages/api/databaseStorage.ts | 11 +++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 React/src/pages/api/databaseStorage.ts diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index bdd8179..18d5c53 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -1,9 +1,6 @@ "use client"; import { useState } from "react"; import { auth0 } from "../lib/auth0"; -import { NextApiRequest, NextApiResponse } from "next"; - - export default async function Home() { @@ -89,6 +86,16 @@ export default async function Home() { +
+ + + +
+

+ + + +

); } diff --git a/React/src/pages/api/databaseStorage.ts b/React/src/pages/api/databaseStorage.ts new file mode 100644 index 0000000..a751f86 --- /dev/null +++ b/React/src/pages/api/databaseStorage.ts @@ -0,0 +1,11 @@ +import { NextApiRequest, NextApiResponse } from "next"; +import { MongoClient } from "mongodb"; + +export default function handler(req: NextApiRequest, res: NextApiResponse){ + if(req.method === 'POST') + const { codeword, contacts } = req.body; + + try{ + + } +} \ No newline at end of file From 46be33b10ab249d813560fb967fe228ac4ccdfb7 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 01:28:07 -0400 Subject: [PATCH 02/30] Complete Refactor --- Backend/README.md | 201 +--- Backend/api/app.py | 22 + Backend/api/routes.py | 29 + Backend/api/socket_handlers.py | 32 + Backend/config.py | 13 + Backend/index.html | 492 -------- Backend/models.py | 203 ---- Backend/requirements.txt | 25 +- Backend/run_csm.py | 117 -- Backend/server.py | 929 +------------- Backend/setup.py | 13 - Backend/src/audio/processor.py | 28 + Backend/src/audio/streaming.py | 35 + Backend/{ => src/llm}/generator.py | 16 +- Backend/src/llm/tokenizer.py | 14 + Backend/src/models/audio_model.py | 28 + Backend/src/models/conversation.py | 23 + Backend/src/services/transcription_service.py | 25 + Backend/src/services/tts_service.py | 24 + Backend/src/utils/config.py | 23 + Backend/src/utils/logger.py | 14 + Backend/static/css/styles.css | 105 ++ Backend/static/index.html | 31 + Backend/static/js/client.js | 131 ++ Backend/templates/index.html | 31 + Backend/test.py | 50 - Backend/voice-chat.js | 1071 ----------------- Backend/watermarking.py | 79 -- 28 files changed, 723 insertions(+), 3081 deletions(-) create mode 100644 Backend/api/app.py create mode 100644 Backend/api/routes.py create mode 100644 Backend/api/socket_handlers.py create mode 100644 Backend/config.py delete mode 100644 Backend/index.html delete mode 100644 Backend/models.py delete mode 100644 Backend/run_csm.py delete mode 100644 Backend/setup.py create mode 100644 Backend/src/audio/processor.py create mode 100644 Backend/src/audio/streaming.py rename Backend/{ => src/llm}/generator.py (90%) create mode 100644 Backend/src/llm/tokenizer.py create mode 100644 Backend/src/models/audio_model.py create mode 100644 Backend/src/models/conversation.py create mode 100644 Backend/src/services/transcription_service.py create mode 100644 Backend/src/services/tts_service.py create mode 100644 Backend/src/utils/config.py create mode 100644 Backend/src/utils/logger.py create mode 100644 Backend/static/css/styles.css create mode 100644 Backend/static/index.html create mode 100644 Backend/static/js/client.js create mode 100644 Backend/templates/index.html delete mode 100644 Backend/test.py delete mode 100644 Backend/voice-chat.js delete mode 100644 Backend/watermarking.py diff --git a/Backend/README.md b/Backend/README.md index 44cab4d..8438073 100644 --- a/Backend/README.md +++ b/Backend/README.md @@ -1,154 +1,71 @@ -# CSM +# csm-conversation-bot -**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b). +## Overview +The CSM Conversation Bot is an application that utilizes advanced audio processing and language model technologies to facilitate real-time voice conversations with an AI assistant. The bot processes audio streams, converts spoken input into text, generates responses using the Llama 3.2 model, and converts the text back into audio for seamless interaction. ---- - -CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. - -A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). - -A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation. - -## Requirements - -* A CUDA-compatible GPU -* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions -* Similarly, Python 3.10 is recommended, but newer versions may be fine -* For some audio operations, `ffmpeg` may be required -* Access to the following Hugging Face models: - * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) - * [CSM-1B](https://huggingface.co/sesame/csm-1b) - -### Setup - -```bash -git clone git@github.com:SesameAILabs/csm.git -cd csm -python3.10 -m venv .venv -source .venv/bin/activate -pip install -r requirements.txt - -# Disable lazy compilation in Mimi -export NO_TORCH_COMPILE=1 - -# You will need access to CSM-1B and Llama-3.2-1B -huggingface-cli login +## Project Structure +``` +csm-conversation-bot +├── api +│ ├── app.py # Main entry point for the API +│ ├── routes.py # Defines API routes +│ └── socket_handlers.py # Manages Socket.IO events +├── src +│ ├── audio +│ │ ├── processor.py # Audio processing functions +│ │ └── streaming.py # Audio streaming management +│ ├── llm +│ │ ├── generator.py # Response generation using Llama 3.2 +│ │ └── tokenizer.py # Text tokenization functions +│ ├── models +│ │ ├── audio_model.py # Audio processing model +│ │ └── conversation.py # Conversation state management +│ ├── services +│ │ ├── transcription_service.py # Audio to text conversion +│ │ └── tts_service.py # Text to speech conversion +│ └── utils +│ ├── config.py # Configuration settings +│ └── logger.py # Logging utilities +├── static +│ ├── css +│ │ └── styles.css # CSS styles for the web interface +│ ├── js +│ │ └── client.js # Client-side JavaScript +│ └── index.html # Main HTML file for the web interface +├── templates +│ └── index.html # Template for rendering the main HTML page +├── config.py # Main configuration settings +├── requirements.txt # Python dependencies +├── server.py # Entry point for running the application +└── README.md # Documentation for the project ``` -### Windows Setup +## Installation +1. Clone the repository: + ``` + git clone https://github.com/yourusername/csm-conversation-bot.git + cd csm-conversation-bot + ``` -The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`. +2. Install the required dependencies: + ``` + pip install -r requirements.txt + ``` -## Quickstart - -This script will generate a conversation between 2 characters, using a prompt for each character. - -```bash -python run_csm.py -``` +3. Configure the application settings in `config.py` as needed. ## Usage +1. Start the server: + ``` + python server.py + ``` -If you want to write your own applications with CSM, the following examples show basic usage. +2. Open your web browser and navigate to `http://localhost:5000` to access the application. -#### Generate a sentence +3. Use the interface to start a conversation with the AI assistant. -This will use a random speaker identity, as no prompt or context is provided. +## Contributing +Contributions are welcome! Please submit a pull request or open an issue for any enhancements or bug fixes. -```python -from generator import load_csm_1b -import torchaudio -import torch - -if torch.backends.mps.is_available(): - device = "mps" -elif torch.cuda.is_available(): - device = "cuda" -else: - device = "cpu" - -generator = load_csm_1b(device=device) - -audio = generator.generate( - text="Hello from Sesame.", - speaker=0, - context=[], - max_audio_length_ms=10_000, -) - -torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) -``` - -#### Generate with context - -CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance. - -NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM. - -```python -from generator import Segment - -speakers = [0, 1, 0, 0] -transcripts = [ - "Hey how are you doing.", - "Pretty good, pretty good.", - "I'm great.", - "So happy to be speaking to you.", -] -audio_paths = [ - "utterance_0.wav", - "utterance_1.wav", - "utterance_2.wav", - "utterance_3.wav", -] - -def load_audio(audio_path): - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = torchaudio.functional.resample( - audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate - ) - return audio_tensor - -segments = [ - Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) - for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) -] -audio = generator.generate( - text="Me too, this is some cool stuff huh?", - speaker=1, - context=segments, - max_audio_length_ms=10_000, -) - -torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) -``` - -## FAQ - -**Does this model come with any voices?** - -The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice. - -**Can I converse with the model?** - -CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation. - -**Does it support other languages?** - -The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well. - -## Misuse and abuse ⚠️ - -This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following: - -- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent. -- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls. -- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes. - -By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology. - ---- - -## Authors -Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team. +## License +This project is licensed under the MIT License. See the LICENSE file for more details. \ No newline at end of file diff --git a/Backend/api/app.py b/Backend/api/app.py new file mode 100644 index 0000000..d0f2c05 --- /dev/null +++ b/Backend/api/app.py @@ -0,0 +1,22 @@ +from flask import Flask +from flask_socketio import SocketIO +from src.utils.config import Config +from src.utils.logger import setup_logger +from api.routes import setup_routes +from api.socket_handlers import setup_socket_handlers + +def create_app(): + app = Flask(__name__) + app.config.from_object(Config) + + setup_logger(app) + setup_routes(app) + setup_socket_handlers(app) + + return app + +app = create_app() +socketio = SocketIO(app) + +if __name__ == "__main__": + socketio.run(app, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/Backend/api/routes.py b/Backend/api/routes.py new file mode 100644 index 0000000..4ec8a7c --- /dev/null +++ b/Backend/api/routes.py @@ -0,0 +1,29 @@ +from flask import Blueprint, request, jsonify +from src.services.transcription_service import TranscriptionService +from src.services.tts_service import TextToSpeechService + +api = Blueprint('api', __name__) + +transcription_service = TranscriptionService() +tts_service = TextToSpeechService() + +@api.route('/transcribe', methods=['POST']) +def transcribe_audio(): + audio_data = request.files.get('audio') + if not audio_data: + return jsonify({'error': 'No audio file provided'}), 400 + + text = transcription_service.transcribe(audio_data) + return jsonify({'transcription': text}) + +@api.route('/generate-response', methods=['POST']) +def generate_response(): + data = request.json + user_input = data.get('input') + if not user_input: + return jsonify({'error': 'No input provided'}), 400 + + response_text = tts_service.generate_response(user_input) + audio_data = tts_service.text_to_speech(response_text) + + return jsonify({'response': response_text, 'audio': audio_data}) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py new file mode 100644 index 0000000..f80ba96 --- /dev/null +++ b/Backend/api/socket_handlers.py @@ -0,0 +1,32 @@ +from flask import request +from flask_socketio import SocketIO, emit +from src.audio.processor import process_audio +from src.services.transcription_service import TranscriptionService +from src.services.tts_service import TextToSpeechService +from src.llm.generator import load_csm_1b + +socketio = SocketIO() + +transcription_service = TranscriptionService() +tts_service = TextToSpeechService() +generator = load_csm_1b() + +@socketio.on('audio_stream') +def handle_audio_stream(data): + audio_data = data['audio'] + speaker_id = data['speaker'] + + # Process the incoming audio + processed_audio = process_audio(audio_data) + + # Transcribe the audio to text + transcription = transcription_service.transcribe(processed_audio) + + # Generate a response using the LLM + response_text = generator.generate(text=transcription, speaker=speaker_id) + + # Convert the response text back to audio + response_audio = tts_service.convert_text_to_speech(response_text) + + # Emit the response audio back to the client + emit('audio_response', {'audio': response_audio, 'speaker': speaker_id}) \ No newline at end of file diff --git a/Backend/config.py b/Backend/config.py new file mode 100644 index 0000000..f23a0b5 --- /dev/null +++ b/Backend/config.py @@ -0,0 +1,13 @@ +from pathlib import Path + +class Config: + def __init__(self): + self.MODEL_PATH = Path("path/to/your/model") + self.AUDIO_MODEL_PATH = Path("path/to/your/audio/model") + self.WATERMARK_KEY = "your_watermark_key" + self.SOCKETIO_CORS = "*" + self.API_KEY = "your_api_key" + self.DEBUG = True + self.LOGGING_LEVEL = "INFO" + self.TTS_SERVICE_URL = "http://localhost:5001/tts" + self.TRANSCRIPTION_SERVICE_URL = "http://localhost:5002/transcribe" \ No newline at end of file diff --git a/Backend/index.html b/Backend/index.html deleted file mode 100644 index 5ea925c..0000000 --- a/Backend/index.html +++ /dev/null @@ -1,492 +0,0 @@ - - - - - - Sesame AI Voice Chat - - - - - - -
-

Sesame AI Voice Chat

-

Speak naturally and have a conversation with AI

-
- -
-
-
-

Conversation

- -
-
-
- -
-
-

Audio Visualizer

-
- -
Speak to see audio visualization
-
-
- -
-
-
Voice Settings
- - -
-
- Silence Threshold - 0.01 -
- -
- -
-
-
-
- -
-
Conversation Controls
-
- -
-
-
- -
-
Settings
-
-
- - - Auto-play responses -
-
- - - Show visualizer -
-
-
- -
-
-
Not connected
-
-
-
- -
-

Powered by Sesame AI | WhisperX for speech recognition

-
- - - - - \ No newline at end of file diff --git a/Backend/models.py b/Backend/models.py deleted file mode 100644 index e9508e7..0000000 --- a/Backend/models.py +++ /dev/null @@ -1,203 +0,0 @@ -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torchtune -from huggingface_hub import PyTorchModelHubMixin -from torchtune.models import llama3_2 - - -def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: - return llama3_2.llama3_2( - vocab_size=128_256, - num_layers=16, - num_heads=32, - num_kv_heads=8, - embed_dim=2048, - max_seq_len=2048, - intermediate_dim=8192, - attn_dropout=0.0, - norm_eps=1e-5, - rope_base=500_000, - scale_factor=32, - ) - - -def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: - return llama3_2.llama3_2( - vocab_size=128_256, - num_layers=4, - num_heads=8, - num_kv_heads=2, - embed_dim=1024, - max_seq_len=2048, - intermediate_dim=8192, - attn_dropout=0.0, - norm_eps=1e-5, - rope_base=500_000, - scale_factor=32, - ) - - -FLAVORS = { - "llama-1B": llama3_2_1B, - "llama-100M": llama3_2_100M, -} - - -def _prepare_transformer(model): - embed_dim = model.tok_embeddings.embedding_dim - model.tok_embeddings = nn.Identity() - model.output = nn.Identity() - return model, embed_dim - - -def _create_causal_mask(seq_len: int, device: torch.device): - return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) - - -def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): - """ - Args: - mask: (max_seq_len, max_seq_len) - input_pos: (batch_size, seq_len) - - Returns: - (batch_size, seq_len, max_seq_len) - """ - r = mask[input_pos, :] - return r - - -def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs).exponential_(1) - return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) - - -def sample_topk(logits: torch.Tensor, topk: int, temperature: float): - logits = logits / temperature - - filter_value: float = -float("Inf") - indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] - scores_processed = logits.masked_fill(indices_to_remove, filter_value) - scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) - probs = torch.nn.functional.softmax(scores_processed, dim=-1) - - sample_token = _multinomial_sample_one_no_sync(probs) - return sample_token - - -@dataclass -class ModelArgs: - backbone_flavor: str - decoder_flavor: str - text_vocab_size: int - audio_vocab_size: int - audio_num_codebooks: int - - -class Model( - nn.Module, - PyTorchModelHubMixin, - repo_url="https://github.com/SesameAILabs/csm", - pipeline_tag="text-to-speech", - license="apache-2.0", -): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - - self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) - self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) - - self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) - self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) - - self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) - self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) - self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) - - def setup_caches(self, max_batch_size: int) -> torch.Tensor: - """Setup KV caches and return a causal mask.""" - dtype = next(self.parameters()).dtype - device = next(self.parameters()).device - - with device: - self.backbone.setup_caches(max_batch_size, dtype) - self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) - - self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) - self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) - - def generate_frame( - self, - tokens: torch.Tensor, - tokens_mask: torch.Tensor, - input_pos: torch.Tensor, - temperature: float, - topk: int, - ) -> torch.Tensor: - """ - Args: - tokens: (batch_size, seq_len, audio_num_codebooks+1) - tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) - input_pos: (batch_size, seq_len) positions for each token - mask: (batch_size, seq_len, max_seq_len - - Returns: - (batch_size, audio_num_codebooks) sampled tokens - """ - dtype = next(self.parameters()).dtype - b, s, _ = tokens.size() - - assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" - curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) - embeds = self._embed_tokens(tokens) - masked_embeds = embeds * tokens_mask.unsqueeze(-1) - h = masked_embeds.sum(dim=2) - h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) - - last_h = h[:, -1, :] - c0_logits = self.codebook0_head(last_h) - c0_sample = sample_topk(c0_logits, topk, temperature) - c0_embed = self._embed_audio(0, c0_sample) - - curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) - curr_sample = c0_sample.clone() - curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) - - # Decoder caches must be reset every frame. - self.decoder.reset_caches() - for i in range(1, self.config.audio_num_codebooks): - curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) - decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( - dtype=dtype - ) - ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) - ci_sample = sample_topk(ci_logits, topk, temperature) - ci_embed = self._embed_audio(i, ci_sample) - - curr_h = ci_embed - curr_sample = torch.cat([curr_sample, ci_sample], dim=1) - curr_pos = curr_pos[:, -1:] + 1 - - return curr_sample - - def reset_caches(self): - self.backbone.reset_caches() - self.decoder.reset_caches() - - def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: - return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) - - def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: - text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) - - audio_tokens = tokens[:, :, :-1] + ( - self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) - ) - audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( - tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 - ) - - return torch.cat([audio_embeds, text_embeds], dim=-2) diff --git a/Backend/requirements.txt b/Backend/requirements.txt index ba8a04f..ef7beab 100644 --- a/Backend/requirements.txt +++ b/Backend/requirements.txt @@ -1,9 +1,16 @@ -torch==2.4.0 -torchaudio==2.4.0 -tokenizers==0.21.0 -transformers==4.49.0 -huggingface_hub==0.28.1 -moshi==0.2.2 -torchtune==0.4.0 -torchao==0.9.0 -silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master \ No newline at end of file +Flask==2.2.2 +Flask-SocketIO==5.3.2 +torch>=2.0.0 +torchaudio>=2.0.0 +transformers>=4.30.0 +huggingface-hub>=0.14.0 +python-dotenv==0.19.2 +numpy>=1.21.6 +scipy>=1.7.3 +soundfile==0.10.3.post1 +requests==2.28.1 +pydub==0.25.1 +python-socketio==5.7.2 +eventlet==0.33.3 +whisper>=20230314 +ffmpeg-python>=0.2.0 \ No newline at end of file diff --git a/Backend/run_csm.py b/Backend/run_csm.py deleted file mode 100644 index 0062973..0000000 --- a/Backend/run_csm.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -# Disable Triton compilation -os.environ["NO_TORCH_COMPILE"] = "1" - -# Default prompts are available at https://hf.co/sesame/csm-1b -prompt_filepath_conversational_a = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_a.wav" -) -prompt_filepath_conversational_b = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_b.wav" -) - -SPEAKER_PROMPTS = { - "conversational_a": { - "text": ( - "like revising for an exam I'd have to try and like keep up the momentum because I'd " - "start really early I'd be like okay I'm gonna start revising now and then like " - "you're revising for ages and then I just like start losing steam I didn't do that " - "for the exam we had recently to be fair that was a more of a last minute scenario " - "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " - "sort of start the day with this not like a panic but like a" - ), - "audio": prompt_filepath_conversational_a - }, - "conversational_b": { - "text": ( - "like a super Mario level. Like it's very like high detail. And like, once you get " - "into the park, it just like, everything looks like a computer game and they have all " - "these, like, you know, if, if there's like a, you know, like in a Mario game, they " - "will have like a question block. And if you like, you know, punch it, a coin will " - "come out. So like everyone, when they come into the park, they get like this little " - "bracelet and then you can go punching question blocks around." - ), - "audio": prompt_filepath_conversational_b - } -} - -def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = audio_tensor.squeeze(0) - # Resample is lazy so we can always call it - audio_tensor = torchaudio.functional.resample( - audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate - ) - return audio_tensor - -def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: - audio_tensor = load_prompt_audio(audio_path, sample_rate) - return Segment(text=text, speaker=speaker, audio=audio_tensor) - -def main(): - # Select the best available device, skipping MPS due to float64 limitations - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - print(f"Using device: {device}") - - # Load model - generator = load_csm_1b(device) - - # Prepare prompts - prompt_a = prepare_prompt( - SPEAKER_PROMPTS["conversational_a"]["text"], - 0, - SPEAKER_PROMPTS["conversational_a"]["audio"], - generator.sample_rate - ) - - prompt_b = prepare_prompt( - SPEAKER_PROMPTS["conversational_b"]["text"], - 1, - SPEAKER_PROMPTS["conversational_b"]["audio"], - generator.sample_rate - ) - - # Generate conversation - conversation = [ - {"text": "Hey how are you doing?", "speaker_id": 0}, - {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, - {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, - {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} - ] - - # Generate each utterance - generated_segments = [] - prompt_segments = [prompt_a, prompt_b] - - for utterance in conversation: - print(f"Generating: {utterance['text']}") - audio_tensor = generator.generate( - text=utterance['text'], - speaker=utterance['speaker_id'], - context=prompt_segments + generated_segments, - max_audio_length_ms=10_000, - ) - generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) - - # Concatenate all generations - all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) - torchaudio.save( - "full_conversation.wav", - all_audio.unsqueeze(0).cpu(), - generator.sample_rate - ) - print("Successfully generated full_conversation.wav") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index 2cf721e..2069b29 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -1,904 +1,53 @@ import os -import base64 -import json -import time -import math -import gc import logging -import numpy as np import torch -import torchaudio +import eventlet +import base64 +import tempfile from io import BytesIO -from typing import List, Dict, Any, Optional -from flask import Flask, request, send_from_directory, Response -from flask_cors import CORS -from flask_socketio import SocketIO, emit, disconnect -from generator import load_csm_1b, Segment -from collections import deque -from threading import Lock -from transformers import pipeline -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from flask import Flask, render_template, request, jsonify +from flask_socketio import SocketIO, emit +import whisper +import torchaudio +from src.models.conversation import Segment +from src.services.tts_service import load_csm_1b +from src.llm.generator import generate_llm_response +from transformers import AutoTokenizer, AutoModelForCausalLM +from src.audio.streaming import AudioStreamer +from src.services.transcription_service import TranscriptionService +from src.services.tts_service import TextToSpeechService # Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger("sesame-server") +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -# Determine best compute device -if torch.backends.mps.is_available(): - device = "mps" -elif torch.cuda.is_available(): - try: - # Test CUDA functionality - torch.rand(10, device="cuda") - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = True - device = "cuda" - logger.info("CUDA is fully functional") - except Exception as e: - logger.warning(f"CUDA available but not working correctly: {e}") - device = "cpu" -else: - device = "cpu" - logger.info("Using CPU") +app = Flask(__name__, static_folder='static', template_folder='templates') +app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'your-secret-key') +socketio = SocketIO(app) -# Constants and Configuration -SILENCE_THRESHOLD = 0.01 -SILENCE_DURATION_SEC = 0.75 -MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing -CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses +# Initialize services +transcription_service = TranscriptionService() +tts_service = TextToSpeechService() +audio_streamer = AudioStreamer() -# Define the base directory and static files directory -base_dir = os.path.dirname(os.path.abspath(__file__)) -static_dir = os.path.join(base_dir, "static") -os.makedirs(static_dir, exist_ok=True) - -# Define a simple energy-based speech detector -class SpeechDetector: - def __init__(self): - self.min_speech_energy = 0.01 - self.speech_window = 0.2 # seconds +@socketio.on('audio_input') +def handle_audio_input(data): + audio_chunk = data['audio'] + speaker_id = data['speaker'] - def detect_speech(self, audio_tensor, sample_rate): - # Calculate frame size based on window size - frame_size = int(sample_rate * self.speech_window) - - # If audio is shorter than frame size, use the entire audio - if audio_tensor.shape[0] < frame_size: - frames = [audio_tensor] - else: - # Split audio into frames - frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)] - - # Calculate energy per frame - energies = [torch.mean(frame**2).item() for frame in frames] - - # Determine if there's speech based on energy threshold - has_speech = any(e > self.min_speech_energy for e in energies) - - return has_speech + # Process audio and convert to text + text = transcription_service.transcribe(audio_chunk) + logging.info(f"Transcribed text: {text}") -speech_detector = SpeechDetector() -logger.info("Initialized simple speech detector") + # Generate response using Llama 3.2 + response_text = tts_service.generate_response(text, speaker_id) + logging.info(f"Generated response: {response_text}") -# Model Loading Functions -def load_speech_models(): - """Load speech generation and recognition models""" - # Load CSM (existing code) - generator = load_csm_1b(device=device) + # Convert response text to audio + audio_response = tts_service.text_to_speech(response_text, speaker_id) - # Load Whisper model for speech recognition - try: - logger.info(f"Loading speech recognition model on {device}...") - - # Try with newer API first - try: - from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline - - model_id = "openai/whisper-small" - - # Load model and processor - model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, - torch_dtype=torch.float16 if device == "cuda" else torch.float32, - device_map=device, - ) - processor = AutoProcessor.from_pretrained(model_id) - - # Create pipeline with specific parameters - speech_recognizer = pipeline( - "automatic-speech-recognition", - model=model, - tokenizer=processor.tokenizer, - feature_extractor=processor.feature_extractor, - max_new_tokens=128, - chunk_length_s=30, - batch_size=16, - device=device, - ) - - except Exception as api_error: - logger.warning(f"Newer API loading failed: {api_error}, trying simpler approach") - - # Fallback to simpler API - speech_recognizer = pipeline( - "automatic-speech-recognition", - model="openai/whisper-small", - device=device - ) - - logger.info("Speech recognition model loaded successfully") - return generator, speech_recognizer - - except Exception as e: - logger.error(f"Error loading speech recognition model: {e}") - return generator, None + # Stream audio response back to client + socketio.emit('audio_response', {'audio': audio_response}) -# Unpack both models -generator, speech_recognizer = load_speech_models() - -# Initialize Llama 3.2 model for conversation responses -def load_llm_model(): - """Load Llama 3.2 model for generating text responses""" - try: - logger.info("Loading Llama 3.2 model for conversational responses...") - model_id = "meta-llama/Llama-3.2-1B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_id) - - # Determine compute device for LLM - llm_device = "cpu" # Default to CPU for LLM - - # Use CUDA if available and there's enough VRAM - if device == "cuda" and torch.cuda.is_available(): - try: - free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) - # If we have at least 2GB free, use CUDA for LLM - if free_mem > 2 * 1024 * 1024 * 1024: - llm_device = "cuda" - except: - pass - - logger.info(f"Using {llm_device} for Llama 3.2 model") - - # Load the model with lower precision for efficiency - model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32, - device_map=llm_device - ) - - # Create a pipeline for easier inference - llm = pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - max_length=512, - do_sample=True, - temperature=0.7, - top_p=0.9, - repetition_penalty=1.1 - ) - - logger.info("Llama 3.2 model loaded successfully") - return llm - except Exception as e: - logger.error(f"Error loading Llama 3.2 model: {e}") - return None - -# Load the LLM model -llm = load_llm_model() - -# Set up Flask and Socket.IO -app = Flask(__name__) -CORS(app) -socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') - -# Socket connection management -thread_lock = Lock() -active_clients = {} # Map client_id to client context - -# Audio Utility Functions -def decode_audio_data(audio_data: str) -> torch.Tensor: - """Decode base64 audio data to a torch tensor with improved error handling""" - try: - # Skip empty audio data - if not audio_data or len(audio_data) < 100: - logger.warning("Empty or too short audio data received") - return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence - - # Extract the actual base64 content - if ',' in audio_data: - audio_data = audio_data.split(',')[1] - - # Decode base64 audio data - try: - binary_data = base64.b64decode(audio_data) - logger.debug(f"Decoded base64 data: {len(binary_data)} bytes") - - # Check if we have enough data for a valid WAV - if len(binary_data) < 44: # WAV header is 44 bytes - logger.warning("Data too small to be a valid WAV file") - return torch.zeros(generator.sample_rate // 2) - except Exception as e: - logger.error(f"Base64 decoding error: {e}") - return torch.zeros(generator.sample_rate // 2) - - # Multiple approaches to handle audio data - audio_tensor = None - sample_rate = None - - # Approach 1: Direct loading with torchaudio - try: - with BytesIO(binary_data) as temp_file: - temp_file.seek(0) - audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") - logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz") - - # Validate tensor - if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any(): - raise ValueError("Invalid audio tensor") - except Exception as e: - logger.warning(f"Direct loading failed: {e}") - - # Approach 2: Using wave module and numpy - try: - temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") - with open(temp_path, 'wb') as f: - f.write(binary_data) - - import wave - with wave.open(temp_path, 'rb') as wf: - n_channels = wf.getnchannels() - sample_width = wf.getsampwidth() - sample_rate = wf.getframerate() - n_frames = wf.getnframes() - frames = wf.readframes(n_frames) - - # Convert to numpy array - if sample_width == 2: # 16-bit audio - data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0 - elif sample_width == 1: # 8-bit audio - data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0 - else: - raise ValueError(f"Unsupported sample width: {sample_width}") - - # Convert to mono if needed - if n_channels > 1: - data = data.reshape(-1, n_channels) - data = data.mean(axis=1) - - # Convert to torch tensor - audio_tensor = torch.from_numpy(data) - logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}") - - # Clean up temp file - if os.path.exists(temp_path): - os.remove(temp_path) - - except Exception as e2: - logger.error(f"All audio loading methods failed: {e2}") - return torch.zeros(generator.sample_rate // 2) - - # Format corrections - if audio_tensor is None: - return torch.zeros(generator.sample_rate // 2) - - # Ensure audio is mono - if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1: - audio_tensor = torch.mean(audio_tensor, dim=0) - - # Ensure 1D tensor - audio_tensor = audio_tensor.squeeze() - - # Resample if needed - if sample_rate != generator.sample_rate: - try: - logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz") - resampler = torchaudio.transforms.Resample( - orig_freq=sample_rate, - new_freq=generator.sample_rate - ) - audio_tensor = resampler(audio_tensor) - except Exception as e: - logger.warning(f"Resampling error: {e}") - - # Normalize audio to avoid issues - if torch.abs(audio_tensor).max() > 0: - audio_tensor = audio_tensor / torch.abs(audio_tensor).max() - - return audio_tensor - except Exception as e: - logger.error(f"Unhandled error in decode_audio_data: {e}") - return torch.zeros(generator.sample_rate // 2) - -def encode_audio_data(audio_tensor: torch.Tensor) -> str: - """Encode torch tensor audio to base64 string""" - try: - buf = BytesIO() - torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") - buf.seek(0) - audio_base64 = base64.b64encode(buf.read()).decode('utf-8') - return f"data:audio/wav;base64,{audio_base64}" - except Exception as e: - logger.error(f"Error encoding audio: {e}") - # Return a minimal silent audio file - silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0) - buf = BytesIO() - torchaudio.save(buf, silence, generator.sample_rate, format="wav") - buf.seek(0) - return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}" - -def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str: - """Process speech with speech recognition""" - if not speech_recognizer: - # Fallback to basic detection if model failed to load - return detect_speech_energy(audio_tensor) - - try: - # Save audio to temp file for Whisper - temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav") - torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate) - - # Perform speech recognition - handle the warning differently - # Just pass the path without any additional parameters - try: - # First try - use default parameters - result = speech_recognizer(temp_path) - transcription = result["text"] - except Exception as whisper_error: - logger.warning(f"First transcription attempt failed: {whisper_error}") - # Try with explicit parameters for older versions of transformers - import numpy as np - import soundfile as sf - - # Load audio as numpy array - audio_np, sr = sf.read(temp_path) - if sr != 16000: - # Whisper expects 16kHz audio - from scipy import signal - audio_np = signal.resample(audio_np, int(len(audio_np) * 16000 / sr)) - - # Try with numpy array directly - result = speech_recognizer(audio_np) - transcription = result["text"] - - # Clean up temp file - if os.path.exists(temp_path): - os.remove(temp_path) - - # Return empty string if no speech detected - if not transcription or transcription.isspace(): - return "I didn't detect any speech. Could you please try again?" - - logger.info(f"Transcription successful: '{transcription}'") - return transcription - - except Exception as e: - logger.error(f"Speech recognition error: {e}") - return "Sorry, I couldn't understand what you said. Could you try again?" - -def detect_speech_energy(audio_tensor: torch.Tensor) -> str: - """Basic speech detection based on audio energy levels""" - # Calculate audio energy - energy = torch.mean(torch.abs(audio_tensor)).item() - - logger.debug(f"Audio energy detected: {energy:.6f}") - - # Generate response based on energy level - if energy > 0.1: # Louder speech - return "I heard you speaking clearly. How can I help you today?" - elif energy > 0.05: # Moderate speech - return "I heard you say something. Could you please repeat that?" - elif energy > 0.02: # Soft speech - return "I detected some speech, but it was quite soft. Could you speak up a bit?" - else: # Very soft or no speech - return "I didn't detect any speech. Could you please try again?" - -def generate_response(text: str, conversation_history: List[Segment]) -> str: - """Generate a contextual response based on the transcribed text using Llama 3.2""" - # If LLM is not available, use simple responses - if llm is None: - return generate_simple_response(text) - - try: - # Create a conversational prompt based on history - # Format: recent conversation turns (up to 4) + current user input - history_str = "" - - # Add up to 4 recent conversation turns (excluding the current one) - recent_segments = [ - seg for seg in conversation_history[-8:] - if seg.text and not seg.text.isspace() - ] - - for i, segment in enumerate(recent_segments): - speaker_name = "User" if segment.speaker == 0 else "Assistant" - history_str += f"{speaker_name}: {segment.text}\n" - - # Construct the prompt for Llama 3.2 - prompt = f"""<|system|> -You are Sesame, a helpful, friendly and concise voice assistant. -Keep your responses conversational, natural, and to the point. -Respond to the user's latest message in the context of the conversation. -<|end|> - -{history_str} -User: {text} -Assistant:""" - - logger.debug(f"LLM Prompt: {prompt}") - - # Generate response with the LLM - result = llm( - prompt, - max_new_tokens=150, - do_sample=True, - temperature=0.7, - top_p=0.9, - repetition_penalty=1.1 - ) - - # Extract the generated text - response = result[0]["generated_text"] - - # Extract just the Assistant's response (after the prompt) - response = response.split("Assistant:")[-1].strip() - - # Clean up and ensure it's not too long for TTS - response = response.split("\n")[0].strip() - if len(response) > 200: - response = response[:197] + "..." - - logger.info(f"LLM response: {response}") - return response - - except Exception as e: - logger.error(f"Error generating LLM response: {e}") - # Fall back to simple responses - return generate_simple_response(text) - -def generate_simple_response(text: str) -> str: - """Generate a simple rule-based response as fallback""" - responses = { - "hello": "Hello there! How can I help you today?", - "hi": "Hi there! What can I do for you?", - "how are you": "I'm doing well, thanks for asking! How about you?", - "what is your name": "I'm Sesame, your voice assistant. How can I help you?", - "who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!", - "bye": "Goodbye! It was nice chatting with you.", - "thank you": "You're welcome! Is there anything else I can help with?", - "weather": "I don't have real-time weather data, but I hope it's nice where you are!", - "help": "I can chat with you using natural voice. Just speak normally and I'll respond.", - "what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.", - } - - text_lower = text.lower() - - # Check for matching keywords - for key, response in responses.items(): - if key in text_lower: - return response - - # Default responses based on text length - if not text: - return "I didn't catch that. Could you please repeat?" - elif len(text) < 10: - return "Thanks for your message. Could you elaborate a bit more?" - else: - return f"I heard you say something about that. Can you tell me more?" - -# Flask Routes -@app.route('/') -def index(): - return send_from_directory(base_dir, 'index.html') - -@app.route('/favicon.ico') -def favicon(): - if os.path.exists(os.path.join(static_dir, 'favicon.ico')): - return send_from_directory(static_dir, 'favicon.ico') - return Response(status=204) - -@app.route('/voice-chat.js') -def voice_chat_js(): - return send_from_directory(base_dir, 'voice-chat.js') - -@app.route('/static/') -def serve_static(path): - return send_from_directory(static_dir, path) - -# Socket.IO Event Handlers -@socketio.on('connect') -def handle_connect(): - client_id = request.sid - logger.info(f"Client connected: {client_id}") - - # Initialize client context - active_clients[client_id] = { - 'context_segments': [], - 'streaming_buffer': [], - 'is_streaming': False, - 'is_silence': False, - 'last_active_time': time.time(), - 'energy_window': deque(maxlen=10) - } - - emit('status', {'type': 'connected', 'message': 'Connected to server'}) - -@socketio.on('disconnect') -def handle_disconnect(): - client_id = request.sid - if client_id in active_clients: - del active_clients[client_id] - logger.info(f"Client disconnected: {client_id}") - -@socketio.on('generate') -def handle_generate(data): - client_id = request.sid - if client_id not in active_clients: - emit('error', {'message': 'Client not registered'}) - return - - try: - text = data.get('text', '') - speaker_id = data.get('speaker', 0) - - logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}") - - # Generate audio response - audio_tensor = generator.generate( - text=text, - speaker=speaker_id, - context=active_clients[client_id]['context_segments'], - max_audio_length_ms=10_000, - ) - - # Add to conversation context - active_clients[client_id]['context_segments'].append( - Segment(text=text, speaker=speaker_id, audio=audio_tensor) - ) - - # Convert audio to base64 and send back to client - audio_base64 = encode_audio_data(audio_tensor) - emit('audio_response', { - 'type': 'audio_response', - 'audio': audio_base64, - 'text': text - }) - - except Exception as e: - logger.error(f"Error generating audio: {e}") - emit('error', { - 'type': 'error', - 'message': f"Error generating audio: {str(e)}" - }) - -@socketio.on('add_to_context') -def handle_add_to_context(data): - client_id = request.sid - if client_id not in active_clients: - emit('error', {'message': 'Client not registered'}) - return - - try: - text = data.get('text', '') - speaker_id = data.get('speaker', 0) - audio_data = data.get('audio', '') - - # Convert received audio to tensor - audio_tensor = decode_audio_data(audio_data) - - # Add to conversation context - active_clients[client_id]['context_segments'].append( - Segment(text=text, speaker=speaker_id, audio=audio_tensor) - ) - - emit('context_updated', { - 'type': 'context_updated', - 'message': 'Audio added to context' - }) - - except Exception as e: - logger.error(f"Error adding to context: {e}") - emit('error', { - 'type': 'error', - 'message': f"Error processing audio: {str(e)}" - }) - -@socketio.on('clear_context') -def handle_clear_context(): - client_id = request.sid - if client_id in active_clients: - active_clients[client_id]['context_segments'] = [] - - emit('context_updated', { - 'type': 'context_updated', - 'message': 'Context cleared' - }) - -@socketio.on('stream_audio') -def handle_stream_audio(data): - client_id = request.sid - if client_id not in active_clients: - emit('error', {'message': 'Client not registered'}) - return - - client = active_clients[client_id] - - try: - speaker_id = data.get('speaker', 0) - audio_data = data.get('audio', '') - - # Skip if no audio data (might be just a connection test) - if not audio_data: - logger.debug("Empty audio data received, ignoring") - return - - # Convert received audio to tensor - audio_chunk = decode_audio_data(audio_data) - - # Start streaming mode if not already started - if not client['is_streaming']: - client['is_streaming'] = True - client['streaming_buffer'] = [] - client['energy_window'].clear() - client['is_silence'] = False - client['last_active_time'] = time.time() - logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}") - emit('streaming_status', { - 'type': 'streaming_status', - 'status': 'started' - }) - - # Calculate audio energy for silence detection - chunk_energy = torch.mean(torch.abs(audio_chunk)).item() - client['energy_window'].append(chunk_energy) - avg_energy = sum(client['energy_window']) / len(client['energy_window']) - - # Check if audio is silent - current_silence = avg_energy < SILENCE_THRESHOLD - - # Track silence transition - if not client['is_silence'] and current_silence: - # Transition to silence - client['is_silence'] = True - client['last_active_time'] = time.time() - elif client['is_silence'] and not current_silence: - # User started talking again - client['is_silence'] = False - - # Add chunk to buffer regardless of silence state - client['streaming_buffer'].append(audio_chunk) - - # Check if silence has persisted long enough to consider "stopped talking" - silence_elapsed = time.time() - client['last_active_time'] - - if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0: - # User has stopped talking - process the collected audio - logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence") - process_complete_utterance(client_id, client, speaker_id) - - # If buffer gets too large without silence, process it anyway - elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE: - logger.info(f"[{client_id[:8]}] Processing long audio segment without silence") - process_complete_utterance(client_id, client, speaker_id, is_incomplete=True) - - # Keep half of the buffer for context (sliding window approach) - half_point = len(client['streaming_buffer']) // 2 - client['streaming_buffer'] = client['streaming_buffer'][half_point:] - - except Exception as e: - import traceback - traceback.print_exc() - logger.error(f"Error processing streaming audio: {e}") - emit('error', { - 'type': 'error', - 'message': f"Error processing streaming audio: {str(e)}" - }) - -def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False): - """Process a complete utterance (after silence or buffer limit)""" - try: - # Combine audio chunks - full_audio = torch.cat(client['streaming_buffer'], dim=0) - - # Process audio to generate a response (using speech recognition) - generated_text = process_speech(full_audio, client_id) - - # Add suffix for incomplete utterances - if is_incomplete: - generated_text += " (processing continued speech...)" - - # Log the generated text - logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'") - - # Handle the result - if generated_text: - # Add user message to context - user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio) - client['context_segments'].append(user_segment) - - # Send the text to client - emit('transcription', { - 'type': 'transcription', - 'text': generated_text - }, room=client_id) - - # Only generate a response if this is a complete utterance - if not is_incomplete: - # Generate a contextual response - response_text = generate_response(generated_text, client['context_segments']) - logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'") - - # Let the client know we're processing - emit('processing_status', { - 'type': 'processing_status', - 'status': 'generating_audio', - 'message': 'Generating audio response...' - }, room=client_id) - - # Generate audio for the response - try: - # Use a different speaker than the user - ai_speaker_id = 1 if speaker_id == 0 else 0 - - # Generate the full response - audio_tensor = generator.generate( - text=response_text, - speaker=ai_speaker_id, - context=client['context_segments'], - max_audio_length_ms=10_000, - ) - - # Add response to context - ai_segment = Segment( - text=response_text, - speaker=ai_speaker_id, - audio=audio_tensor - ) - client['context_segments'].append(ai_segment) - - # CHANGE HERE: Use the streaming function instead of sending all at once - # Check if the audio is short enough to send at once or if it should be streamed - if audio_tensor.size(0) < generator.sample_rate * 2: # Less than 2 seconds - # For short responses, just send in one go for better responsiveness - audio_base64 = encode_audio_data(audio_tensor) - emit('audio_response', { - 'type': 'audio_response', - 'text': response_text, - 'audio': audio_base64 - }, room=client_id) - logger.info(f"[{client_id[:8]}] Short audio response sent in one piece") - else: - # For longer responses, use streaming - logger.info(f"[{client_id[:8]}] Using streaming for audio response") - # Start a new thread for streaming to avoid blocking the main thread - import threading - stream_thread = threading.Thread( - target=stream_audio_to_client, - args=(client_id, audio_tensor, response_text, ai_speaker_id) - ) - stream_thread.start() - - except Exception as e: - logger.error(f"Error generating audio response: {e}") - emit('error', { - 'type': 'error', - 'message': "Sorry, there was an error generating the audio response." - }, room=client_id) - else: - # If processing failed, send a notification - emit('error', { - 'type': 'error', - 'message': "Sorry, I couldn't understand what you said. Could you try again?" - }, room=client_id) - - # Only clear buffer for complete utterances - if not is_incomplete: - # Reset state - client['streaming_buffer'] = [] - client['energy_window'].clear() - client['is_silence'] = False - client['last_active_time'] = time.time() - - except Exception as e: - logger.error(f"Error processing utterance: {e}") - emit('error', { - 'type': 'error', - 'message': f"Error processing audio: {str(e)}" - }, room=client_id) - -@socketio.on('stop_streaming') -def handle_stop_streaming(data): - client_id = request.sid - if client_id not in active_clients: - return - - client = active_clients[client_id] - client['is_streaming'] = False - - if client['streaming_buffer'] and len(client['streaming_buffer']) > 5: - # Process any remaining audio in the buffer - logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop") - process_complete_utterance(client_id, client, data.get("speaker", 0)) - - client['streaming_buffer'] = [] - emit('streaming_status', { - 'type': 'streaming_status', - 'status': 'stopped' - }) - -def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS): - """Stream audio to client in chunks to simulate real-time generation""" - try: - if client_id not in active_clients: - logger.warning(f"Client {client_id} not found for streaming") - return - - # Calculate chunk size in samples - chunk_size = int(generator.sample_rate * chunk_size_ms / 1000) - total_chunks = math.ceil(audio_tensor.size(0) / chunk_size) - - logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each") - - # Send initial response with text but no audio yet - socketio.emit('audio_response_start', { - 'type': 'audio_response_start', - 'text': text, - 'total_chunks': total_chunks - }, room=client_id) - - # Stream each chunk - for i in range(total_chunks): - start_idx = i * chunk_size - end_idx = min(start_idx + chunk_size, audio_tensor.size(0)) - - # Extract chunk - chunk = audio_tensor[start_idx:end_idx] - - # Encode chunk - chunk_base64 = encode_audio_data(chunk) - - # Send chunk - socketio.emit('audio_response_chunk', { - 'type': 'audio_response_chunk', - 'chunk_index': i, - 'total_chunks': total_chunks, - 'audio': chunk_base64, - 'is_last': i == total_chunks - 1 - }, room=client_id) - - # Brief pause between chunks to simulate streaming - time.sleep(0.1) - - # Send completion message - socketio.emit('audio_response_complete', { - 'type': 'audio_response_complete', - 'text': text - }, room=client_id) - - logger.info(f"Audio streaming complete: {total_chunks} chunks sent") - - except Exception as e: - logger.error(f"Error streaming audio to client: {e}") - import traceback - traceback.print_exc() - -# Main server start -if __name__ == "__main__": - print(f"\n{'='*60}") - print(f"🔊 Sesame AI Voice Chat Server") - print(f"{'='*60}") - print(f"📡 Server Information:") - print(f" - Local URL: http://localhost:5000") - print(f" - Network URL: http://:5000") - print(f"{'='*60}") - print(f"🌐 Device: {device.upper()}") - print(f"🧠 Models: Sesame CSM (TTS only)") - print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}") - print(f"{'='*60}") - print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n") - - socketio.run(app, host="0.0.0.0", port=5000, debug=False) \ No newline at end of file +if __name__ == '__main__': + socketio.run(app, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/Backend/setup.py b/Backend/setup.py deleted file mode 100644 index 8eddb95..0000000 --- a/Backend/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import setup, find_packages -import os - -# Read requirements from requirements.txt -with open('requirements.txt') as f: - requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] - -setup( - name='csm', - version='0.1.0', - packages=find_packages(), - install_requires=requirements, -) diff --git a/Backend/src/audio/processor.py b/Backend/src/audio/processor.py new file mode 100644 index 0000000..40d636e --- /dev/null +++ b/Backend/src/audio/processor.py @@ -0,0 +1,28 @@ +from scipy.io import wavfile +import numpy as np +import torchaudio + +def load_audio(file_path): + sample_rate, audio_data = wavfile.read(file_path) + return sample_rate, audio_data + +def normalize_audio(audio_data): + audio_data = audio_data.astype(np.float32) + max_val = np.max(np.abs(audio_data)) + if max_val > 0: + audio_data /= max_val + return audio_data + +def reduce_noise(audio_data, noise_factor=0.1): + noise = np.random.randn(len(audio_data)) + noisy_audio = audio_data + noise_factor * noise + return noisy_audio + +def save_audio(file_path, sample_rate, audio_data): + torchaudio.save(file_path, torch.tensor(audio_data).unsqueeze(0), sample_rate) + +def process_audio(file_path, output_path): + sample_rate, audio_data = load_audio(file_path) + normalized_audio = normalize_audio(audio_data) + denoised_audio = reduce_noise(normalized_audio) + save_audio(output_path, sample_rate, denoised_audio) \ No newline at end of file diff --git a/Backend/src/audio/streaming.py b/Backend/src/audio/streaming.py new file mode 100644 index 0000000..19ee4cb --- /dev/null +++ b/Backend/src/audio/streaming.py @@ -0,0 +1,35 @@ +from flask import Blueprint, request +from flask_socketio import SocketIO, emit +from src.audio.processor import process_audio +from src.services.transcription_service import TranscriptionService +from src.services.tts_service import TextToSpeechService + +streaming_bp = Blueprint('streaming', __name__) +socketio = SocketIO() + +transcription_service = TranscriptionService() +tts_service = TextToSpeechService() + +@socketio.on('audio_stream') +def handle_audio_stream(data): + audio_chunk = data['audio'] + speaker_id = data['speaker'] + + # Process the audio chunk + processed_audio = process_audio(audio_chunk) + + # Transcribe the audio to text + transcription = transcription_service.transcribe(processed_audio) + + # Generate a response using the LLM + response_text = generate_response(transcription, speaker_id) + + # Convert the response text back to audio + response_audio = tts_service.convert_text_to_speech(response_text, speaker_id) + + # Emit the response audio back to the client + emit('audio_response', {'audio': response_audio}) + +def generate_response(transcription, speaker_id): + # Placeholder for the actual response generation logic + return f"Response to: {transcription}" \ No newline at end of file diff --git a/Backend/generator.py b/Backend/src/llm/generator.py similarity index 90% rename from Backend/generator.py rename to Backend/src/llm/generator.py index 7bc3634..ce4297c 100644 --- a/Backend/generator.py +++ b/Backend/src/llm/generator.py @@ -15,14 +15,10 @@ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark class Segment: speaker: int text: str - # (num_samples,), sample_rate = 24_000 audio: torch.Tensor def load_llama3_tokenizer(): - """ - https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 - """ tokenizer_name = "meta-llama/Llama-3.2-1B" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) bos = tokenizer.bos_token @@ -78,10 +74,8 @@ class Generator: frame_tokens = [] frame_masks = [] - # (K, T) audio = audio.to(self.device) audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] - # add EOS frame eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) @@ -96,10 +90,6 @@ class Generator: return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Returns: - (seq_len, 33), (seq_len, 33) - """ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) audio_tokens, audio_masks = self._tokenize_audio(segment.audio) @@ -146,7 +136,7 @@ class Generator: for _ in range(max_generation_len): sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) if torch.all(sample == 0): - break # eos + break samples.append(sample) @@ -158,10 +148,6 @@ class Generator: audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) - # This applies an imperceptible watermark to identify audio as AI-generated. - # Watermarking ensures transparency, dissuades misuse, and enables traceability. - # Please be a responsible AI citizen and keep the watermarking in place. - # If using CSM 1B in another application, use your own private key and keep it secret. audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) diff --git a/Backend/src/llm/tokenizer.py b/Backend/src/llm/tokenizer.py new file mode 100644 index 0000000..0a05bcd --- /dev/null +++ b/Backend/src/llm/tokenizer.py @@ -0,0 +1,14 @@ +from transformers import AutoTokenizer + +def load_llama3_tokenizer(): + tokenizer_name = "meta-llama/Llama-3.2-1B" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + return tokenizer + +def tokenize_text(text: str, tokenizer) -> list: + tokens = tokenizer.encode(text, return_tensors='pt') + return tokens + +def decode_tokens(tokens: list, tokenizer) -> str: + text = tokenizer.decode(tokens, skip_special_tokens=True) + return text \ No newline at end of file diff --git a/Backend/src/models/audio_model.py b/Backend/src/models/audio_model.py new file mode 100644 index 0000000..726bec4 --- /dev/null +++ b/Backend/src/models/audio_model.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +import torch + +@dataclass +class AudioModel: + model: torch.nn.Module + sample_rate: int + + def __post_init__(self): + self.model.eval() + + def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + processed_audio = self.model(audio_tensor) + return processed_audio + + def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor: + if self.sample_rate != target_sample_rate: + resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate) + return resampled_audio + return audio_tensor + + def save_model(self, path: str): + torch.save(self.model.state_dict(), path) + + def load_model(self, path: str): + self.model.load_state_dict(torch.load(path)) + self.model.eval() \ No newline at end of file diff --git a/Backend/src/models/conversation.py b/Backend/src/models/conversation.py new file mode 100644 index 0000000..4558958 --- /dev/null +++ b/Backend/src/models/conversation.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +@dataclass +class Conversation: + context: List[str] = field(default_factory=list) + current_speaker: Optional[int] = None + + def add_message(self, message: str, speaker: int): + self.context.append(f"Speaker {speaker}: {message}") + self.current_speaker = speaker + + def get_context(self) -> List[str]: + return self.context + + def clear_context(self): + self.context.clear() + self.current_speaker = None + + def get_last_message(self) -> Optional[str]: + if self.context: + return self.context[-1] + return None \ No newline at end of file diff --git a/Backend/src/services/transcription_service.py b/Backend/src/services/transcription_service.py new file mode 100644 index 0000000..06f8dd1 --- /dev/null +++ b/Backend/src/services/transcription_service.py @@ -0,0 +1,25 @@ +from typing import List +import torchaudio +import torch +from generator import load_csm_1b, Segment + +class TranscriptionService: + def __init__(self, model_device: str = "cpu"): + self.generator = load_csm_1b(device=model_device) + + def transcribe_audio(self, audio_path: str) -> str: + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = self._resample_audio(audio_tensor, sample_rate) + transcription = self.generator.generate_transcription(audio_tensor) + return transcription + + def _resample_audio(self, audio_tensor: torch.Tensor, orig_freq: int) -> torch.Tensor: + target_sample_rate = self.generator.sample_rate + if orig_freq != target_sample_rate: + audio_tensor = torchaudio.functional.resample(audio_tensor.squeeze(0), orig_freq=orig_freq, new_freq=target_sample_rate) + return audio_tensor + + def transcribe_audio_stream(self, audio_chunks: List[torch.Tensor]) -> str: + combined_audio = torch.cat(audio_chunks, dim=1) + transcription = self.generator.generate_transcription(combined_audio) + return transcription \ No newline at end of file diff --git a/Backend/src/services/tts_service.py b/Backend/src/services/tts_service.py new file mode 100644 index 0000000..64fab04 --- /dev/null +++ b/Backend/src/services/tts_service.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from src.llm.generator import load_csm_1b + +@dataclass +class TextToSpeechService: + generator: any + + def __init__(self, device: str = "cuda"): + self.generator = load_csm_1b(device=device) + + def text_to_speech(self, text: str, speaker: int = 0) -> torch.Tensor: + audio = self.generator.generate( + text=text, + speaker=speaker, + context=[], + max_audio_length_ms=10000, + ) + return audio + + def save_audio(self, audio: torch.Tensor, file_path: str): + torchaudio.save(file_path, audio.unsqueeze(0).cpu(), self.generator.sample_rate) \ No newline at end of file diff --git a/Backend/src/utils/config.py b/Backend/src/utils/config.py new file mode 100644 index 0000000..2206481 --- /dev/null +++ b/Backend/src/utils/config.py @@ -0,0 +1,23 @@ +# filepath: /csm-conversation-bot/csm-conversation-bot/src/utils/config.py + +import os + +class Config: + # General configuration + DEBUG = os.getenv('DEBUG', 'False') == 'True' + SECRET_KEY = os.getenv('SECRET_KEY', 'your_secret_key_here') + + # API configuration + API_URL = os.getenv('API_URL', 'http://localhost:5000') + + # Model configuration + LLM_MODEL_PATH = os.getenv('LLM_MODEL_PATH', 'path/to/llm/model') + AUDIO_MODEL_PATH = os.getenv('AUDIO_MODEL_PATH', 'path/to/audio/model') + + # Socket.IO configuration + SOCKETIO_MESSAGE_QUEUE = os.getenv('SOCKETIO_MESSAGE_QUEUE', 'redis://localhost:6379/0') + + # Logging configuration + LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') + + # Other configurations can be added as needed \ No newline at end of file diff --git a/Backend/src/utils/logger.py b/Backend/src/utils/logger.py new file mode 100644 index 0000000..93e8966 --- /dev/null +++ b/Backend/src/utils/logger.py @@ -0,0 +1,14 @@ +import logging + +def setup_logger(name, log_file, level=logging.INFO): + handler = logging.FileHandler(log_file) + handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(handler) + + return logger + +# Example usage: +# logger = setup_logger('my_logger', 'app.log') \ No newline at end of file diff --git a/Backend/static/css/styles.css b/Backend/static/css/styles.css new file mode 100644 index 0000000..4e2d752 --- /dev/null +++ b/Backend/static/css/styles.css @@ -0,0 +1,105 @@ +body { + font-family: 'Arial', sans-serif; + background-color: #f4f4f4; + color: #333; + margin: 0; + padding: 0; +} + +header { + background: #4c84ff; + color: #fff; + padding: 10px 0; + text-align: center; +} + +h1 { + margin: 0; + font-size: 2.5rem; +} + +.container { + width: 80%; + margin: auto; + overflow: hidden; +} + +.conversation { + background: #fff; + padding: 20px; + border-radius: 5px; + box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); + max-height: 400px; + overflow-y: auto; +} + +.message { + padding: 10px; + margin: 10px 0; + border-radius: 5px; +} + +.user { + background: #e3f2fd; + text-align: right; +} + +.ai { + background: #f1f1f1; + text-align: left; +} + +.controls { + display: flex; + justify-content: space-between; + margin-top: 20px; +} + +button { + padding: 10px 15px; + border: none; + border-radius: 5px; + cursor: pointer; + transition: background 0.3s; +} + +button:hover { + background: #3367d6; + color: #fff; +} + +.visualizer-container { + height: 150px; + background: #000; + border-radius: 5px; + margin-top: 20px; +} + +.visualizer-label { + color: rgba(255, 255, 255, 0.7); + text-align: center; + padding: 10px; +} + +.status-indicator { + display: flex; + align-items: center; + margin-top: 10px; +} + +.status-dot { + width: 12px; + height: 12px; + border-radius: 50%; + background-color: #ccc; + margin-right: 10px; +} + +.status-dot.active { + background-color: #4CAF50; +} + +.status-text { + font-size: 0.9em; + color: #666; +} \ No newline at end of file diff --git a/Backend/static/index.html b/Backend/static/index.html new file mode 100644 index 0000000..4922f17 --- /dev/null +++ b/Backend/static/index.html @@ -0,0 +1,31 @@ + + + + + + CSM Conversation Bot + + + + + +
+

CSM Conversation Bot

+

Talk to the AI and get responses in real-time!

+
+
+
+
+ + +
+
+
+
Disconnected
+
+
+
+

Powered by CSM and Llama 3.2

+
+ + \ No newline at end of file diff --git a/Backend/static/js/client.js b/Backend/static/js/client.js new file mode 100644 index 0000000..ec4037f --- /dev/null +++ b/Backend/static/js/client.js @@ -0,0 +1,131 @@ +// This file contains the client-side JavaScript code that handles audio streaming and communication with the server. + +const SERVER_URL = window.location.hostname === 'localhost' ? + 'http://localhost:5000' : window.location.origin; + +const elements = { + conversation: document.getElementById('conversation'), + streamButton: document.getElementById('streamButton'), + clearButton: document.getElementById('clearButton'), + speakerSelection: document.getElementById('speakerSelect'), + statusDot: document.getElementById('statusDot'), + statusText: document.getElementById('statusText'), +}; + +const state = { + socket: null, + isStreaming: false, + currentSpeaker: 0, +}; + +// Initialize the application +function initializeApp() { + setupSocketConnection(); + setupEventListeners(); +} + +// Setup Socket.IO connection +function setupSocketConnection() { + state.socket = io(SERVER_URL); + + state.socket.on('connect', () => { + updateConnectionStatus(true); + }); + + state.socket.on('disconnect', () => { + updateConnectionStatus(false); + }); + + state.socket.on('audio_response', handleAudioResponse); + state.socket.on('transcription', handleTranscription); +} + +// Setup event listeners +function setupEventListeners() { + elements.streamButton.addEventListener('click', toggleStreaming); + elements.clearButton.addEventListener('click', clearConversation); + elements.speakerSelection.addEventListener('change', (event) => { + state.currentSpeaker = event.target.value; + }); +} + +// Update connection status UI +function updateConnectionStatus(isConnected) { + elements.statusDot.classList.toggle('active', isConnected); + elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected'; +} + +// Toggle streaming state +function toggleStreaming() { + if (state.isStreaming) { + stopStreaming(); + } else { + startStreaming(); + } +} + +// Start streaming audio to the server +function startStreaming() { + if (state.isStreaming) return; + + navigator.mediaDevices.getUserMedia({ audio: true }) + .then(stream => { + const mediaRecorder = new MediaRecorder(stream); + mediaRecorder.start(); + + mediaRecorder.ondataavailable = (event) => { + if (event.data.size > 0) { + sendAudioChunk(event.data); + } + }; + + mediaRecorder.onstop = () => { + state.isStreaming = false; + elements.streamButton.innerHTML = 'Start Conversation'; + }; + + state.isStreaming = true; + elements.streamButton.innerHTML = 'Stop Conversation'; + }) + .catch(err => { + console.error('Error accessing microphone:', err); + }); +} + +// Stop streaming audio +function stopStreaming() { + if (!state.isStreaming) return; + + // Logic to stop the media recorder would go here +} + +// Send audio chunk to server +function sendAudioChunk(audioData) { + const reader = new FileReader(); + reader.onloadend = () => { + const arrayBuffer = reader.result; + state.socket.emit('audio_chunk', { audio: arrayBuffer, speaker: state.currentSpeaker }); + }; + reader.readAsArrayBuffer(audioData); +} + +// Handle audio response from server +function handleAudioResponse(data) { + const audioElement = new Audio(URL.createObjectURL(new Blob([data.audio]))); + audioElement.play(); +} + +// Handle transcription response from server +function handleTranscription(data) { + const messageElement = document.createElement('div'); + messageElement.textContent = `AI: ${data.transcription}`; + elements.conversation.appendChild(messageElement); +} + +// Clear conversation history +function clearConversation() { + elements.conversation.innerHTML = ''; +} + +// Initialize the application when DOM is fully loaded +document.addEventListener('DOMContentLoaded', initializeApp); \ No newline at end of file diff --git a/Backend/templates/index.html b/Backend/templates/index.html new file mode 100644 index 0000000..514b946 --- /dev/null +++ b/Backend/templates/index.html @@ -0,0 +1,31 @@ + + + + + + CSM Conversation Bot + + + + + +
+

CSM Conversation Bot

+

Talk to the AI and get responses in real-time!

+
+
+
+
+ + +
+
+
+
Not connected
+
+
+
+

Powered by CSM and Llama 3.2

+
+ + \ No newline at end of file diff --git a/Backend/test.py b/Backend/test.py deleted file mode 100644 index 34735b1..0000000 --- a/Backend/test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -if torch.backends.mps.is_available(): - device = "mps" -elif torch.cuda.is_available(): - device = "cuda" -else: - device = "cpu" - -generator = load_csm_1b(device=device) - -speakers = [0, 1, 0, 0] -transcripts = [ - "Hey how are you doing.", - "Pretty good, pretty good.", - "I'm great.", - "So happy to be speaking to you.", -] -audio_paths = [ - "utterance_0.wav", - "utterance_1.wav", - "utterance_2.wav", - "utterance_3.wav", -] - -def load_audio(audio_path): - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = torchaudio.functional.resample( - audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate - ) - return audio_tensor - -segments = [ - Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) - for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) -] - -audio = generator.generate( - text="Me too, this is some cool stuff huh?", - speaker=1, - context=segments, - max_audio_length_ms=10_000, -) - -torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) \ No newline at end of file diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js deleted file mode 100644 index 89ec71a..0000000 --- a/Backend/voice-chat.js +++ /dev/null @@ -1,1071 +0,0 @@ -/** - * Sesame AI Voice Chat Client - * - * A web client that connects to a Sesame AI voice chat server and enables - * real-time voice conversation with an AI assistant. - */ - -// Configuration constants -const SERVER_URL = window.location.hostname === 'localhost' ? - 'http://localhost:5000' : window.location.origin; -const ENERGY_WINDOW_SIZE = 15; -const CLIENT_SILENCE_DURATION_MS = 750; - -// DOM elements -const elements = { - conversation: null, - streamButton: null, - clearButton: null, - thresholdSlider: null, - thresholdValue: null, - visualizerCanvas: null, - visualizerLabel: null, - volumeLevel: null, - statusDot: null, - statusText: null, - speakerSelection: null, - autoPlayResponses: null, - showVisualizer: null -}; - -// Application state -const state = { - socket: null, - audioContext: null, - analyser: null, - microphone: null, - streamProcessor: null, - isStreaming: false, - isSpeaking: false, - silenceThreshold: 0.01, - energyWindow: [], - silenceTimer: null, - volumeUpdateInterval: null, - visualizerAnimationFrame: null, - currentSpeaker: 0 -}; - -// Visualizer variables -let canvasContext = null; -let visualizerBufferLength = 0; -let visualizerDataArray = null; - -// New state variables to track incremental audio streaming -const streamingAudio = { - messageElement: null, - audioElement: null, - chunks: [], - totalChunks: 0, - receivedChunks: 0, - text: '', - mediaSource: null, - sourceBuffer: null, - audioContext: null, - complete: false -}; - -// Initialize the application -function initializeApp() { - // Initialize the UI elements - initializeUIElements(); - - // Initialize socket.io connection - setupSocketConnection(); - - // Setup event listeners - setupEventListeners(); - - // Initialize visualizer - setupVisualizer(); - - // Show welcome message - addSystemMessage('Welcome to Sesame AI Voice Chat! Click "Start Conversation" to begin.'); -} - -// Initialize UI elements -function initializeUIElements() { - // Store references to UI elements - elements.conversation = document.getElementById('conversation'); - elements.streamButton = document.getElementById('streamButton'); - elements.clearButton = document.getElementById('clearButton'); - elements.thresholdSlider = document.getElementById('thresholdSlider'); - elements.thresholdValue = document.getElementById('thresholdValue'); - elements.visualizerCanvas = document.getElementById('audioVisualizer'); - elements.visualizerLabel = document.getElementById('visualizerLabel'); - elements.volumeLevel = document.getElementById('volumeLevel'); - elements.statusDot = document.getElementById('statusDot'); - elements.statusText = document.getElementById('statusText'); - elements.speakerSelection = document.getElementById('speakerSelect'); // Changed to match HTML - elements.autoPlayResponses = document.getElementById('autoPlayResponses'); - elements.showVisualizer = document.getElementById('showVisualizer'); -} - -// Setup Socket.IO connection -function setupSocketConnection() { - state.socket = io(SERVER_URL); - - // Connection events - state.socket.on('connect', () => { - console.log('Connected to server'); - updateConnectionStatus(true); - }); - - state.socket.on('disconnect', () => { - console.log('Disconnected from server'); - updateConnectionStatus(false); - - // Stop streaming if active - if (state.isStreaming) { - stopStreaming(false); - } - }); - - state.socket.on('error', (data) => { - console.error('Socket error:', data.message); - addSystemMessage(`Error: ${data.message}`); - }); - - // Register message handlers - state.socket.on('audio_response', handleAudioResponse); - state.socket.on('transcription', handleTranscription); - state.socket.on('context_updated', handleContextUpdate); - state.socket.on('streaming_status', handleStreamingStatus); - - // New event handlers for incremental audio streaming - state.socket.on('audio_response_start', handleAudioResponseStart); - state.socket.on('audio_response_chunk', handleAudioResponseChunk); - state.socket.on('audio_response_complete', handleAudioResponseComplete); - state.socket.on('processing_status', handleProcessingStatus); -} - -// Setup event listeners -function setupEventListeners() { - // Stream button - elements.streamButton.addEventListener('click', toggleStreaming); - - // Clear button - elements.clearButton.addEventListener('click', clearConversation); - - // Threshold slider - elements.thresholdSlider.addEventListener('input', updateThreshold); - - // Speaker selection - elements.speakerSelection.addEventListener('change', () => { - state.currentSpeaker = parseInt(elements.speakerSelection.value, 10); - }); - - // Visualizer toggle - elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); -} - -// Setup audio visualizer -function setupVisualizer() { - if (!elements.visualizerCanvas) return; - - canvasContext = elements.visualizerCanvas.getContext('2d'); - - // Set canvas dimensions - elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth; - elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight; - - // Initialize the visualizer - drawVisualizer(); -} - -// Update connection status UI -function updateConnectionStatus(isConnected) { - elements.statusDot.classList.toggle('active', isConnected); - elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected'; -} - -// Toggle streaming state -function toggleStreaming() { - if (state.isStreaming) { - stopStreaming(true); - } else { - startStreaming(); - } -} - -// Start streaming audio to the server -function startStreaming() { - if (state.isStreaming) return; - - // Request microphone access - navigator.mediaDevices.getUserMedia({ audio: true, video: false }) - .then(stream => { - // Show processing state while setting up - elements.streamButton.innerHTML = ' Initializing...'; - - // Create audio context - state.audioContext = new (window.AudioContext || window.webkitAudioContext)(); - - // Create microphone source - state.microphone = state.audioContext.createMediaStreamSource(stream); - - // Create analyser for visualizer - state.analyser = state.audioContext.createAnalyser(); - state.analyser.fftSize = 256; - visualizerBufferLength = state.analyser.frequencyBinCount; - visualizerDataArray = new Uint8Array(visualizerBufferLength); - - // Connect microphone to analyser - state.microphone.connect(state.analyser); - - // Create script processor for audio processing - const bufferSize = 4096; - state.streamProcessor = state.audioContext.createScriptProcessor(bufferSize, 1, 1); - - // Set up audio processing callback - state.streamProcessor.onaudioprocess = handleAudioProcess; - - // Connect the processors - state.analyser.connect(state.streamProcessor); - state.streamProcessor.connect(state.audioContext.destination); - - // Update UI - state.isStreaming = true; - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - - // Initialize energy window - state.energyWindow = []; - - // Start volume meter updates - state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100); - - // Start visualizer if enabled - if (elements.showVisualizer.checked && !state.visualizerAnimationFrame) { - drawVisualizer(); - } - - // Show starting message - addSystemMessage('Listening... Speak clearly into your microphone.'); - - // Notify the server that we're starting - state.socket.emit('stream_audio', { - audio: '', - speaker: state.currentSpeaker - }); - }) - .catch(err => { - console.error('Error accessing microphone:', err); - addSystemMessage(`Error: ${err.message}. Please make sure your microphone is connected and you've granted permission.`); - elements.streamButton.innerHTML = ' Start Conversation'; - }); -} - -// Stop streaming audio -function stopStreaming(notifyServer = true) { - if (!state.isStreaming) return; - - // Update UI first - elements.streamButton.innerHTML = ' Start Conversation'; - elements.streamButton.classList.remove('recording'); - elements.streamButton.classList.remove('processing'); - - // Stop volume meter updates - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; - } - - // Stop all audio processing - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - } - - if (state.microphone) { - state.microphone.disconnect(); - } - - // Close audio context - if (state.audioContext && state.audioContext.state !== 'closed') { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - } - - // Cleanup animation frames - if (state.visualizerAnimationFrame) { - cancelAnimationFrame(state.visualizerAnimationFrame); - state.visualizerAnimationFrame = null; - } - - // Reset state - state.isStreaming = false; - state.isSpeaking = false; - - // Notify the server - if (notifyServer && state.socket && state.socket.connected) { - state.socket.emit('stop_streaming', { - speaker: state.currentSpeaker - }); - } - - // Show message - addSystemMessage('Conversation paused. Click "Start Conversation" to resume.'); -} - -// Handle audio processing -function handleAudioProcess(event) { - const inputData = event.inputBuffer.getChannelData(0); - - // Calculate audio energy (volume level) - const energy = calculateAudioEnergy(inputData); - - // Update energy window for averaging - updateEnergyWindow(energy); - - // Calculate average energy - const avgEnergy = calculateAverageEnergy(); - - // Determine if audio is silent - const isSilent = avgEnergy < state.silenceThreshold; - - // Debug logging only if significant changes in audio patterns - if (Math.random() < 0.05) { // Log only 5% of frames to avoid console spam - console.log(`Audio: len=${inputData.length}, energy=${energy.toFixed(4)}, avg=${avgEnergy.toFixed(4)}, silent=${isSilent}`); - } - - // Handle speech state based on silence - handleSpeechState(isSilent); - - // Only send audio chunk if we detect speech - if (!isSilent) { - // Create a resampled version at 24kHz for the server - // Most WebRTC audio is 48kHz, but we want 24kHz for the model - const resampledData = downsampleBuffer(inputData, state.audioContext.sampleRate, 24000); - - // Send the audio chunk to the server - sendAudioChunk(resampledData, state.currentSpeaker); - } -} - -// Cleanup audio resources when done -function cleanupAudioResources() { - // Stop all audio processing - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - state.analyser = null; - } - - if (state.microphone) { - state.microphone.disconnect(); - state.microphone = null; - } - - // Close audio context - if (state.audioContext && state.audioContext.state !== 'closed') { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - } - - // Cancel all timers and animation frames - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; - } - - if (state.visualizerAnimationFrame) { - cancelAnimationFrame(state.visualizerAnimationFrame); - state.visualizerAnimationFrame = null; - } - - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } -} - -// Clear conversation history -function clearConversation() { - if (elements.conversation) { - elements.conversation.innerHTML = ''; - addSystemMessage('Conversation cleared.'); - - // Notify server to clear context - if (state.socket && state.socket.connected) { - state.socket.emit('clear_context'); - } - } -} - -// Calculate audio energy (volume) -function calculateAudioEnergy(buffer) { - let sum = 0; - for (let i = 0; i < buffer.length; i++) { - sum += buffer[i] * buffer[i]; - } - return Math.sqrt(sum / buffer.length); -} - -// Update energy window for averaging -function updateEnergyWindow(energy) { - state.energyWindow.push(energy); - if (state.energyWindow.length > ENERGY_WINDOW_SIZE) { - state.energyWindow.shift(); - } -} - -// Calculate average energy from window -function calculateAverageEnergy() { - if (state.energyWindow.length === 0) return 0; - - const sum = state.energyWindow.reduce((a, b) => a + b, 0); - return sum / state.energyWindow.length; -} - -// Update the threshold from the slider -function updateThreshold() { - state.silenceThreshold = parseFloat(elements.thresholdSlider.value); - elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); -} - -// Update the volume meter display -function updateVolumeMeter() { - if (!state.isStreaming || !state.energyWindow.length) return; - - const avgEnergy = calculateAverageEnergy(); - - // Scale energy to percentage (0-100) - // Typically, energy values will be very small (e.g., 0.001 to 0.1) - // So we multiply by a factor to make it more visible - const scaleFactor = 1000; - const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor)); - - // Update volume meter width - elements.volumeLevel.style.width = `${percentage}%`; - - // Change color based on level - if (percentage > 70) { - elements.volumeLevel.style.backgroundColor = '#ff5252'; - } else if (percentage > 30) { - elements.volumeLevel.style.backgroundColor = '#4CAF50'; - } else { - elements.volumeLevel.style.backgroundColor = '#4c84ff'; - } -} - -// Handle speech/silence state transitions -function handleSpeechState(isSilent) { - if (state.isSpeaking && isSilent) { - // Transition from speaking to silence - if (!state.silenceTimer) { - state.silenceTimer = setTimeout(() => { - // Only consider it a real silence after a certain duration - // This prevents detecting brief pauses as the end of speech - state.isSpeaking = false; - state.silenceTimer = null; - }, CLIENT_SILENCE_DURATION_MS); - } - } else if (state.silenceTimer && !isSilent) { - // User started speaking again, cancel the silence timer - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - - // Update speaking state for non-silent audio - if (!isSilent) { - state.isSpeaking = true; - } -} - -// Send audio chunk to server -function sendAudioChunk(audioData, speaker) { - if (!state.socket || !state.socket.connected) { - console.warn('Socket not connected'); - return; - } - - console.log(`Preparing audio chunk: length=${audioData.length}, speaker=${speaker}`); - - // Check for NaN or invalid values - let hasInvalidValues = false; - for (let i = 0; i < audioData.length; i++) { - if (isNaN(audioData[i]) || !isFinite(audioData[i])) { - hasInvalidValues = true; - console.warn(`Invalid audio value at index ${i}: ${audioData[i]}`); - break; - } - } - - if (hasInvalidValues) { - console.warn('Audio data contains invalid values. Creating silent audio.'); - audioData = new Float32Array(audioData.length).fill(0); - } - - try { - // Create WAV blob - const wavData = createWavBlob(audioData, 24000); - console.log(`WAV blob created: ${wavData.size} bytes`); - - const reader = new FileReader(); - - reader.onloadend = function() { - try { - // Get base64 data - const base64data = reader.result; - console.log(`Base64 data created: ${base64data.length} bytes`); - - // Send to server - state.socket.emit('stream_audio', { - audio: base64data, - speaker: speaker - }); - console.log('Audio chunk sent to server'); - } catch (err) { - console.error('Error preparing audio data:', err); - } - }; - - reader.onerror = function() { - console.error('Error reading audio data as base64'); - }; - - reader.readAsDataURL(wavData); - } catch (err) { - console.error('Error creating WAV data:', err); - } -} - -// Create WAV blob from audio data with improved error handling -function createWavBlob(audioData, sampleRate) { - // Validate input - if (!audioData || audioData.length === 0) { - console.warn('Empty audio data provided to createWavBlob'); - audioData = new Float32Array(1024).fill(0); // Create 1024 samples of silence - } - - // Function to convert Float32Array to Int16Array for WAV format - function floatTo16BitPCM(output, offset, input) { - for (let i = 0; i < input.length; i++, offset += 2) { - // Ensure values are in -1 to 1 range - const s = Math.max(-1, Math.min(1, input[i])); - // Convert to 16-bit PCM - output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true); - } - } - - // Create WAV header - function writeString(view, offset, string) { - for (let i = 0; i < string.length; i++) { - view.setUint8(offset + i, string.charCodeAt(i)); - } - } - - try { - // Create WAV file with header - careful with buffer sizes - const buffer = new ArrayBuffer(44 + audioData.length * 2); - const view = new DataView(buffer); - - // RIFF identifier - writeString(view, 0, 'RIFF'); - - // File length (will be filled later) - view.setUint32(4, 36 + audioData.length * 2, true); - - // WAVE identifier - writeString(view, 8, 'WAVE'); - - // fmt chunk identifier - writeString(view, 12, 'fmt '); - - // fmt chunk length - view.setUint32(16, 16, true); - - // Sample format (1 is PCM) - view.setUint16(20, 1, true); - - // Mono channel - view.setUint16(22, 1, true); - - // Sample rate - view.setUint32(24, sampleRate, true); - - // Byte rate (sample rate * block align) - view.setUint32(28, sampleRate * 2, true); - - // Block align (channels * bytes per sample) - view.setUint16(32, 2, true); - - // Bits per sample - view.setUint16(34, 16, true); - - // data chunk identifier - writeString(view, 36, 'data'); - - // data chunk length - view.setUint32(40, audioData.length * 2, true); - - // Write the PCM samples - floatTo16BitPCM(view, 44, audioData); - - // Create and return blob - return new Blob([view], { type: 'audio/wav' }); - } catch (err) { - console.error('Error in createWavBlob:', err); - - // Create a minimal valid WAV file with silence as fallback - const fallbackSamples = new Float32Array(1024).fill(0); - const fallbackBuffer = new ArrayBuffer(44 + fallbackSamples.length * 2); - const fallbackView = new DataView(fallbackBuffer); - - writeString(fallbackView, 0, 'RIFF'); - fallbackView.setUint32(4, 36 + fallbackSamples.length * 2, true); - writeString(fallbackView, 8, 'WAVE'); - writeString(fallbackView, 12, 'fmt '); - fallbackView.setUint32(16, 16, true); - fallbackView.setUint16(20, 1, true); - fallbackView.setUint16(22, 1, true); - fallbackView.setUint32(24, sampleRate, true); - fallbackView.setUint32(28, sampleRate * 2, true); - fallbackView.setUint16(32, 2, true); - fallbackView.setUint16(34, 16, true); - writeString(fallbackView, 36, 'data'); - fallbackView.setUint32(40, fallbackSamples.length * 2, true); - floatTo16BitPCM(fallbackView, 44, fallbackSamples); - - return new Blob([fallbackView], { type: 'audio/wav' }); - } -} - -// Draw audio visualizer -function drawVisualizer() { - if (!canvasContext) { - return; - } - - state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - - // Skip drawing if visualizer is hidden - if (!elements.showVisualizer.checked) { - if (elements.visualizerCanvas.style.opacity !== '0') { - elements.visualizerCanvas.style.opacity = '0'; - } - return; - } else if (elements.visualizerCanvas.style.opacity !== '1') { - elements.visualizerCanvas.style.opacity = '1'; - } - - // Get frequency data if available - if (state.isStreaming && state.analyser) { - try { - state.analyser.getByteFrequencyData(visualizerDataArray); - } catch (e) { - console.warn('Error getting frequency data:', e); - } - } else { - // Fade out when not streaming - for (let i = 0; i < visualizerDataArray.length; i++) { - visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5); - } - } - - // Clear canvas - canvasContext.fillStyle = 'rgb(0, 0, 0)'; - canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height); - - // Draw gradient bars - const width = elements.visualizerCanvas.width; - const height = elements.visualizerCanvas.height; - const barCount = Math.min(visualizerBufferLength, 64); - const barWidth = width / barCount - 1; - - for (let i = 0; i < barCount; i++) { - const index = Math.floor(i * visualizerBufferLength / barCount); - const value = visualizerDataArray[index]; - - // Use logarithmic scale for better audio visualization - // This makes low values more visible while still maintaining full range - const logFactor = 20; - const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor); - const barHeight = scaledValue * height; - - // Position bars - const x = i * (barWidth + 1); - const y = height - barHeight; - - // Create color gradient based on frequency and amplitude - const hue = i / barCount * 360; // Full color spectrum - const saturation = 80 + (value / 255 * 20); // Higher values more saturated - const lightness = 40 + (value / 255 * 20); // Dynamic brightness based on amplitude - - // Draw main bar - canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`; - canvasContext.fillRect(x, y, barWidth, barHeight); - - // Add reflection effect - if (barHeight > 5) { - const gradient = canvasContext.createLinearGradient( - x, y, - x, y + barHeight * 0.5 - ); - gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`); - gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`); - canvasContext.fillStyle = gradient; - canvasContext.fillRect(x, y, barWidth, barHeight * 0.5); - - // Add highlight on top of the bar for better 3D effect - canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`; - canvasContext.fillRect(x, y, barWidth, 2); - } - } - - // Show/hide the label - elements.visualizerLabel.style.opacity = (state.isStreaming) ? '0' : '0.7'; -} - -// Toggle visualizer visibility -function toggleVisualizerVisibility() { - const isVisible = elements.showVisualizer.checked; - elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; - - if (isVisible && state.isStreaming && !state.visualizerAnimationFrame) { - drawVisualizer(); - } -} - -// Handle audio response from server -function handleAudioResponse(data) { - console.log('Received audio response'); - - // Create message container - const messageElement = document.createElement('div'); - messageElement.className = 'message ai'; - - // Add text content if available - if (data.text) { - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - } - - // Create and configure audio element - const audioElement = document.createElement('audio'); - audioElement.controls = true; - audioElement.className = 'audio-player'; - - // Set audio source - const audioSource = document.createElement('source'); - audioSource.src = data.audio; - audioSource.type = 'audio/wav'; - - // Add fallback text - audioElement.textContent = 'Your browser does not support the audio element.'; - - // Assemble audio element - audioElement.appendChild(audioSource); - messageElement.appendChild(audioElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // Auto-play if enabled - if (elements.autoPlayResponses.checked) { - audioElement.play() - .catch(err => { - console.warn('Auto-play failed:', err); - addSystemMessage('Auto-play failed. Please click play to hear the response.'); - }); - } - - // Re-enable stream button after processing is complete - if (state.isStreaming) { - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - elements.streamButton.classList.remove('processing'); - } -} - -// Handle transcription response from server -function handleTranscription(data) { - console.log('Received transcription:', data.text); - - // Create message element - const messageElement = document.createElement('div'); - messageElement.className = 'message user'; - - // Add text content - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; -} - -// Handle context update from server -function handleContextUpdate(data) { - console.log('Context updated:', data.message); -} - -// Handle streaming status updates from server -function handleStreamingStatus(data) { - console.log('Streaming status:', data.status); - - if (data.status === 'stopped') { - // Reset UI if needed - if (state.isStreaming) { - stopStreaming(false); // Don't send to server since this came from server - } - } -} - -// Add a system message to the conversation -function addSystemMessage(message) { - const messageElement = document.createElement('div'); - messageElement.className = 'message system'; - messageElement.textContent = message; - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; -} - -// Downsample audio buffer to target sample rate -function downsampleBuffer(buffer, originalSampleRate, targetSampleRate) { - if (originalSampleRate === targetSampleRate) { - return buffer; - } - - const ratio = originalSampleRate / targetSampleRate; - const newLength = Math.round(buffer.length / ratio); - const result = new Float32Array(newLength); - - for (let i = 0; i < newLength; i++) { - const pos = Math.round(i * ratio); - result[i] = buffer[pos]; - } - - return result; -} - -// Handle processing status updates -function handleProcessingStatus(data) { - console.log('Processing status update:', data); - - // Show processing status in UI - if (data.status === 'generating_audio') { - elements.streamButton.innerHTML = ' Processing...'; - elements.streamButton.classList.add('processing'); - elements.streamButton.classList.remove('recording'); - - // Show message to user - addSystemMessage(data.message || 'Processing your request...'); - } -} - -// Handle the start of an audio streaming response -function handleAudioResponseStart(data) { - console.log('Audio response starting:', data); - - // Reset streaming audio state - streamingAudio.chunks = []; - streamingAudio.totalChunks = data.total_chunks; - streamingAudio.receivedChunks = 0; - streamingAudio.text = data.text; - streamingAudio.complete = false; - - // Create message container now, so we can update it as chunks arrive - const messageElement = document.createElement('div'); - messageElement.className = 'message ai processing'; - - // Add text content if available - if (data.text) { - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - } - - // Create audio element (will be populated as chunks arrive) - const audioElement = document.createElement('audio'); - audioElement.controls = true; - audioElement.className = 'audio-player'; - audioElement.textContent = 'Audio is being generated...'; - messageElement.appendChild(audioElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add loading indicator - const loadingElement = document.createElement('div'); - loadingElement.className = 'loading-indicator'; - loadingElement.innerHTML = '
Generating audio response...'; - messageElement.appendChild(loadingElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // Store elements for later updates - streamingAudio.messageElement = messageElement; - streamingAudio.audioElement = audioElement; -} - -// Handle an incoming audio chunk -function handleAudioResponseChunk(data) { - console.log(`Received audio chunk ${data.chunk_index + 1}/${data.total_chunks}`); - - // Store the chunk - streamingAudio.chunks[data.chunk_index] = data.audio; - streamingAudio.receivedChunks++; - - // Update progress in the UI - if (streamingAudio.messageElement) { - const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator span'); - if (loadingElement) { - loadingElement.textContent = `Generating audio response... ${Math.round((streamingAudio.receivedChunks / data.total_chunks) * 100)}%`; - } - } - - // If this is the first chunk, start playing it immediately for faster response - if (data.chunk_index === 0 && streamingAudio.audioElement && elements.autoPlayResponses && elements.autoPlayResponses.checked) { - try { - streamingAudio.audioElement.src = data.audio; - streamingAudio.audioElement.play().catch(err => console.warn('Auto-play failed:', err)); - } catch (e) { - console.error('Error playing first chunk:', e); - } - } - - // If this is the last chunk or we've received all chunks, finalize the audio - if (data.is_last || streamingAudio.receivedChunks >= data.total_chunks) { - finalizeStreamingAudio(); - } -} - -// Handle completion of audio streaming -function handleAudioResponseComplete(data) { - console.log('Audio response complete:', data); - streamingAudio.complete = true; - - // Make sure we finalize the audio even if some chunks were missed - finalizeStreamingAudio(); - - // Update UI to normal state - if (state.isStreaming) { - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - elements.streamButton.classList.remove('processing'); - } -} - -// Finalize streaming audio by combining chunks and updating the UI -function finalizeStreamingAudio() { - if (!streamingAudio.messageElement || streamingAudio.chunks.length === 0) { - return; - } - - try { - // For more sophisticated audio streaming, you would need to properly concatenate - // the WAV files, but for now we'll use the last chunk as the complete audio - // since it should contain the entire response due to how the server is implementing it - const lastChunkIndex = streamingAudio.chunks.length - 1; - const audioData = streamingAudio.chunks[lastChunkIndex] || streamingAudio.chunks[0]; - - // Update the audio element with the complete audio - if (streamingAudio.audioElement) { - streamingAudio.audioElement.src = audioData; - - // Auto-play if enabled and not already playing - if (elements.autoPlayResponses && elements.autoPlayResponses.checked && - streamingAudio.audioElement.paused) { - streamingAudio.audioElement.play() - .catch(err => { - console.warn('Auto-play failed:', err); - addSystemMessage('Auto-play failed. Please click play to hear the response.'); - }); - } - } - - // Remove loading indicator and processing class - if (streamingAudio.messageElement) { - const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator'); - if (loadingElement) { - streamingAudio.messageElement.removeChild(loadingElement); - } - streamingAudio.messageElement.classList.remove('processing'); - } - - console.log('Audio response finalized and ready for playback'); - } catch (e) { - console.error('Error finalizing streaming audio:', e); - } - - // Reset streaming audio state - streamingAudio.chunks = []; - streamingAudio.totalChunks = 0; - streamingAudio.receivedChunks = 0; - streamingAudio.messageElement = null; - streamingAudio.audioElement = null; -} - -// Add CSS styles for new UI elements -document.addEventListener('DOMContentLoaded', function() { - // Add styles for processing state - const style = document.createElement('style'); - style.textContent = ` - .message.processing { - opacity: 0.8; - } - - .loading-indicator { - display: flex; - align-items: center; - margin-top: 8px; - font-size: 0.9em; - color: #666; - } - - .loading-spinner { - width: 16px; - height: 16px; - border: 2px solid #ddd; - border-top: 2px solid var(--primary-color); - border-radius: 50%; - margin-right: 8px; - animation: spin 1s linear infinite; - } - - @keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } - } - `; - document.head.appendChild(style); -}); - -// Initialize the application when DOM is fully loaded -document.addEventListener('DOMContentLoaded', initializeApp); - diff --git a/Backend/watermarking.py b/Backend/watermarking.py deleted file mode 100644 index 093962f..0000000 --- a/Backend/watermarking.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse - -import silentcipher -import torch -import torchaudio - -# This watermark key is public, it is not secure. -# If using CSM 1B in another application, use a new private key and keep it secret. -CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201] - - -def cli_check_audio() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--audio_path", type=str, required=True) - args = parser.parse_args() - - check_audio_from_file(args.audio_path) - - -def load_watermarker(device: str = "cuda") -> silentcipher.server.Model: - model = silentcipher.get_model( - model_type="44.1k", - device=device, - ) - return model - - -@torch.inference_mode() -def watermark( - watermarker: silentcipher.server.Model, - audio_array: torch.Tensor, - sample_rate: int, - watermark_key: list[int], -) -> tuple[torch.Tensor, int]: - audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100) - encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36) - - output_sample_rate = min(44100, sample_rate) - encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate) - return encoded, output_sample_rate - - -@torch.inference_mode() -def verify( - watermarker: silentcipher.server.Model, - watermarked_audio: torch.Tensor, - sample_rate: int, - watermark_key: list[int], -) -> bool: - watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100) - result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True) - - is_watermarked = result["status"] - if is_watermarked: - is_csm_watermarked = result["messages"][0] == watermark_key - else: - is_csm_watermarked = False - - return is_watermarked and is_csm_watermarked - - -def check_audio_from_file(audio_path: str) -> None: - watermarker = load_watermarker(device="cuda") - - audio_array, sample_rate = load_audio(audio_path) - is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK) - - outcome = "Watermarked" if is_watermarked else "Not watermarked" - print(f"{outcome}: {audio_path}") - - -def load_audio(audio_path: str) -> tuple[torch.Tensor, int]: - audio_array, sample_rate = torchaudio.load(audio_path) - audio_array = audio_array.mean(dim=0) - return audio_array, int(sample_rate) - - -if __name__ == "__main__": - cli_check_audio() From 9ad22b4fe48463d7933df19d4ae6c577273540f8 Mon Sep 17 00:00:00 2001 From: Surya Vemulapalli Date: Sun, 30 Mar 2025 01:29:35 -0400 Subject: [PATCH 03/30] Finished styling the call button --- React/src/app/page.tsx | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index 0927025..1dcc93b 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -69,16 +69,6 @@ export default async function Home() { - -

- - - -

); } @@ -146,7 +136,11 @@ export default async function Home() { - +

From a0ee0685dc090fc322b5f0535db17c8ae8a74be6 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 01:30:14 -0400 Subject: [PATCH 04/30] Demo Update 11 --- Backend/src/models/conversation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/Backend/src/models/conversation.py b/Backend/src/models/conversation.py index 4558958..25d1a70 100644 --- a/Backend/src/models/conversation.py +++ b/Backend/src/models/conversation.py @@ -1,23 +1,51 @@ from dataclasses import dataclass, field from typing import List, Optional +import torch + +@dataclass +class Segment: + speaker: int + text: str + # (num_samples,), sample_rate = 24_000 + audio: Optional[torch.Tensor] = None + + def __post_init__(self): + # Ensure audio is a tensor if provided + if self.audio is not None and not isinstance(self.audio, torch.Tensor): + self.audio = torch.tensor(self.audio, dtype=torch.float32) @dataclass class Conversation: context: List[str] = field(default_factory=list) + segments: List[Segment] = field(default_factory=list) current_speaker: Optional[int] = None def add_message(self, message: str, speaker: int): self.context.append(f"Speaker {speaker}: {message}") self.current_speaker = speaker + def add_segment(self, segment: Segment): + self.segments.append(segment) + self.context.append(f"Speaker {segment.speaker}: {segment.text}") + self.current_speaker = segment.speaker + def get_context(self) -> List[str]: return self.context + + def get_segments(self) -> List[Segment]: + return self.segments def clear_context(self): self.context.clear() + self.segments.clear() self.current_speaker = None def get_last_message(self) -> Optional[str]: if self.context: return self.context[-1] + return None + + def get_last_segment(self) -> Optional[Segment]: + if self.segments: + return self.segments[-1] return None \ No newline at end of file From df1595cd10eff16de3bd04b62b4e1de3e9d6ee79 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 01:46:11 -0400 Subject: [PATCH 05/30] Complete Refactor 2 --- Backend/.gitignore | 46 - Backend/README.md | 71 -- Backend/api/app.py | 22 - Backend/api/routes.py | 29 - Backend/api/socket_handlers.py | 32 - Backend/config.py | 13 - Backend/{src/llm => }/generator.py | 16 +- Backend/index.html | 711 +++++++++++ Backend/models.py | 203 ++++ Backend/requirements.txt | 25 +- Backend/run_csm.py | 117 ++ Backend/server.py | 451 ++++++- Backend/setup.py | 13 + Backend/src/audio/processor.py | 28 - Backend/src/audio/streaming.py | 35 - Backend/src/llm/tokenizer.py | 14 - Backend/src/models/audio_model.py | 28 - Backend/src/models/conversation.py | 51 - Backend/src/services/transcription_service.py | 25 - Backend/src/services/tts_service.py | 24 - Backend/src/utils/config.py | 23 - Backend/src/utils/logger.py | 14 - Backend/static/css/styles.css | 105 -- Backend/static/index.html | 31 - Backend/static/js/client.js | 131 -- Backend/templates/index.html | 31 - Backend/voice-chat.js | 1071 +++++++++++++++++ Backend/watermarking.py | 79 ++ 28 files changed, 2630 insertions(+), 809 deletions(-) delete mode 100644 Backend/.gitignore delete mode 100644 Backend/README.md delete mode 100644 Backend/api/app.py delete mode 100644 Backend/api/routes.py delete mode 100644 Backend/api/socket_handlers.py delete mode 100644 Backend/config.py rename Backend/{src/llm => }/generator.py (90%) create mode 100644 Backend/index.html create mode 100644 Backend/models.py create mode 100644 Backend/run_csm.py create mode 100644 Backend/setup.py delete mode 100644 Backend/src/audio/processor.py delete mode 100644 Backend/src/audio/streaming.py delete mode 100644 Backend/src/llm/tokenizer.py delete mode 100644 Backend/src/models/audio_model.py delete mode 100644 Backend/src/models/conversation.py delete mode 100644 Backend/src/services/transcription_service.py delete mode 100644 Backend/src/services/tts_service.py delete mode 100644 Backend/src/utils/config.py delete mode 100644 Backend/src/utils/logger.py delete mode 100644 Backend/static/css/styles.css delete mode 100644 Backend/static/index.html delete mode 100644 Backend/static/js/client.js delete mode 100644 Backend/templates/index.html create mode 100644 Backend/voice-chat.js create mode 100644 Backend/watermarking.py diff --git a/Backend/.gitignore b/Backend/.gitignore deleted file mode 100644 index 4b7fc9d..0000000 --- a/Backend/.gitignore +++ /dev/null @@ -1,46 +0,0 @@ -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# Virtual Environment -.env -.venv -env/ -venv/ -ENV/ - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# Project specific -.python-version -*.wav -output_*/ -basic_audio.wav -full_conversation.wav -context_audio.wav - -# Model files -*.pt -*.ckpt \ No newline at end of file diff --git a/Backend/README.md b/Backend/README.md deleted file mode 100644 index 8438073..0000000 --- a/Backend/README.md +++ /dev/null @@ -1,71 +0,0 @@ -# csm-conversation-bot - -## Overview -The CSM Conversation Bot is an application that utilizes advanced audio processing and language model technologies to facilitate real-time voice conversations with an AI assistant. The bot processes audio streams, converts spoken input into text, generates responses using the Llama 3.2 model, and converts the text back into audio for seamless interaction. - -## Project Structure -``` -csm-conversation-bot -├── api -│ ├── app.py # Main entry point for the API -│ ├── routes.py # Defines API routes -│ └── socket_handlers.py # Manages Socket.IO events -├── src -│ ├── audio -│ │ ├── processor.py # Audio processing functions -│ │ └── streaming.py # Audio streaming management -│ ├── llm -│ │ ├── generator.py # Response generation using Llama 3.2 -│ │ └── tokenizer.py # Text tokenization functions -│ ├── models -│ │ ├── audio_model.py # Audio processing model -│ │ └── conversation.py # Conversation state management -│ ├── services -│ │ ├── transcription_service.py # Audio to text conversion -│ │ └── tts_service.py # Text to speech conversion -│ └── utils -│ ├── config.py # Configuration settings -│ └── logger.py # Logging utilities -├── static -│ ├── css -│ │ └── styles.css # CSS styles for the web interface -│ ├── js -│ │ └── client.js # Client-side JavaScript -│ └── index.html # Main HTML file for the web interface -├── templates -│ └── index.html # Template for rendering the main HTML page -├── config.py # Main configuration settings -├── requirements.txt # Python dependencies -├── server.py # Entry point for running the application -└── README.md # Documentation for the project -``` - -## Installation -1. Clone the repository: - ``` - git clone https://github.com/yourusername/csm-conversation-bot.git - cd csm-conversation-bot - ``` - -2. Install the required dependencies: - ``` - pip install -r requirements.txt - ``` - -3. Configure the application settings in `config.py` as needed. - -## Usage -1. Start the server: - ``` - python server.py - ``` - -2. Open your web browser and navigate to `http://localhost:5000` to access the application. - -3. Use the interface to start a conversation with the AI assistant. - -## Contributing -Contributions are welcome! Please submit a pull request or open an issue for any enhancements or bug fixes. - -## License -This project is licensed under the MIT License. See the LICENSE file for more details. \ No newline at end of file diff --git a/Backend/api/app.py b/Backend/api/app.py deleted file mode 100644 index d0f2c05..0000000 --- a/Backend/api/app.py +++ /dev/null @@ -1,22 +0,0 @@ -from flask import Flask -from flask_socketio import SocketIO -from src.utils.config import Config -from src.utils.logger import setup_logger -from api.routes import setup_routes -from api.socket_handlers import setup_socket_handlers - -def create_app(): - app = Flask(__name__) - app.config.from_object(Config) - - setup_logger(app) - setup_routes(app) - setup_socket_handlers(app) - - return app - -app = create_app() -socketio = SocketIO(app) - -if __name__ == "__main__": - socketio.run(app, host='0.0.0.0', port=5000) \ No newline at end of file diff --git a/Backend/api/routes.py b/Backend/api/routes.py deleted file mode 100644 index 4ec8a7c..0000000 --- a/Backend/api/routes.py +++ /dev/null @@ -1,29 +0,0 @@ -from flask import Blueprint, request, jsonify -from src.services.transcription_service import TranscriptionService -from src.services.tts_service import TextToSpeechService - -api = Blueprint('api', __name__) - -transcription_service = TranscriptionService() -tts_service = TextToSpeechService() - -@api.route('/transcribe', methods=['POST']) -def transcribe_audio(): - audio_data = request.files.get('audio') - if not audio_data: - return jsonify({'error': 'No audio file provided'}), 400 - - text = transcription_service.transcribe(audio_data) - return jsonify({'transcription': text}) - -@api.route('/generate-response', methods=['POST']) -def generate_response(): - data = request.json - user_input = data.get('input') - if not user_input: - return jsonify({'error': 'No input provided'}), 400 - - response_text = tts_service.generate_response(user_input) - audio_data = tts_service.text_to_speech(response_text) - - return jsonify({'response': response_text, 'audio': audio_data}) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py deleted file mode 100644 index f80ba96..0000000 --- a/Backend/api/socket_handlers.py +++ /dev/null @@ -1,32 +0,0 @@ -from flask import request -from flask_socketio import SocketIO, emit -from src.audio.processor import process_audio -from src.services.transcription_service import TranscriptionService -from src.services.tts_service import TextToSpeechService -from src.llm.generator import load_csm_1b - -socketio = SocketIO() - -transcription_service = TranscriptionService() -tts_service = TextToSpeechService() -generator = load_csm_1b() - -@socketio.on('audio_stream') -def handle_audio_stream(data): - audio_data = data['audio'] - speaker_id = data['speaker'] - - # Process the incoming audio - processed_audio = process_audio(audio_data) - - # Transcribe the audio to text - transcription = transcription_service.transcribe(processed_audio) - - # Generate a response using the LLM - response_text = generator.generate(text=transcription, speaker=speaker_id) - - # Convert the response text back to audio - response_audio = tts_service.convert_text_to_speech(response_text) - - # Emit the response audio back to the client - emit('audio_response', {'audio': response_audio, 'speaker': speaker_id}) \ No newline at end of file diff --git a/Backend/config.py b/Backend/config.py deleted file mode 100644 index f23a0b5..0000000 --- a/Backend/config.py +++ /dev/null @@ -1,13 +0,0 @@ -from pathlib import Path - -class Config: - def __init__(self): - self.MODEL_PATH = Path("path/to/your/model") - self.AUDIO_MODEL_PATH = Path("path/to/your/audio/model") - self.WATERMARK_KEY = "your_watermark_key" - self.SOCKETIO_CORS = "*" - self.API_KEY = "your_api_key" - self.DEBUG = True - self.LOGGING_LEVEL = "INFO" - self.TTS_SERVICE_URL = "http://localhost:5001/tts" - self.TRANSCRIPTION_SERVICE_URL = "http://localhost:5002/transcribe" \ No newline at end of file diff --git a/Backend/src/llm/generator.py b/Backend/generator.py similarity index 90% rename from Backend/src/llm/generator.py rename to Backend/generator.py index ce4297c..7bc3634 100644 --- a/Backend/src/llm/generator.py +++ b/Backend/generator.py @@ -15,10 +15,14 @@ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark class Segment: speaker: int text: str + # (num_samples,), sample_rate = 24_000 audio: torch.Tensor def load_llama3_tokenizer(): + """ + https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 + """ tokenizer_name = "meta-llama/Llama-3.2-1B" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) bos = tokenizer.bos_token @@ -74,8 +78,10 @@ class Generator: frame_tokens = [] frame_masks = [] + # (K, T) audio = audio.to(self.device) audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] + # add EOS frame eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) @@ -90,6 +96,10 @@ class Generator: return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + (seq_len, 33), (seq_len, 33) + """ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) audio_tokens, audio_masks = self._tokenize_audio(segment.audio) @@ -136,7 +146,7 @@ class Generator: for _ in range(max_generation_len): sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) if torch.all(sample == 0): - break + break # eos samples.append(sample) @@ -148,6 +158,10 @@ class Generator: audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) + # This applies an imperceptible watermark to identify audio as AI-generated. + # Watermarking ensures transparency, dissuades misuse, and enables traceability. + # Please be a responsible AI citizen and keep the watermarking in place. + # If using CSM 1B in another application, use your own private key and keep it secret. audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) diff --git a/Backend/index.html b/Backend/index.html new file mode 100644 index 0000000..6169390 --- /dev/null +++ b/Backend/index.html @@ -0,0 +1,711 @@ + + + + + + CSM Voice Chat + + + + + + +

+

CSM Voice Chat

+

Talk naturally with the AI using your voice

+
+ +
+
+
+

Conversation

+
+
+ Disconnected +
+
+
+
+ +
+
+

Controls

+

Click the button below to start and stop recording.

+
+ + +
+
+ +
+

Settings

+
+
+ + +
+
+ + +
+
+
+
+
+ +
+

Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition

+
+ + + + \ No newline at end of file diff --git a/Backend/models.py b/Backend/models.py new file mode 100644 index 0000000..e9508e7 --- /dev/null +++ b/Backend/models.py @@ -0,0 +1,203 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torchtune +from huggingface_hub import PyTorchModelHubMixin +from torchtune.models import llama3_2 + + +def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: + return llama3_2.llama3_2( + vocab_size=128_256, + num_layers=16, + num_heads=32, + num_kv_heads=8, + embed_dim=2048, + max_seq_len=2048, + intermediate_dim=8192, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500_000, + scale_factor=32, + ) + + +def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: + return llama3_2.llama3_2( + vocab_size=128_256, + num_layers=4, + num_heads=8, + num_kv_heads=2, + embed_dim=1024, + max_seq_len=2048, + intermediate_dim=8192, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500_000, + scale_factor=32, + ) + + +FLAVORS = { + "llama-1B": llama3_2_1B, + "llama-100M": llama3_2_100M, +} + + +def _prepare_transformer(model): + embed_dim = model.tok_embeddings.embedding_dim + model.tok_embeddings = nn.Identity() + model.output = nn.Identity() + return model, embed_dim + + +def _create_causal_mask(seq_len: int, device: torch.device): + return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) + + +def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): + """ + Args: + mask: (max_seq_len, max_seq_len) + input_pos: (batch_size, seq_len) + + Returns: + (batch_size, seq_len, max_seq_len) + """ + r = mask[input_pos, :] + return r + + +def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample_topk(logits: torch.Tensor, topk: int, temperature: float): + logits = logits / temperature + + filter_value: float = -float("Inf") + indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] + scores_processed = logits.masked_fill(indices_to_remove, filter_value) + scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) + probs = torch.nn.functional.softmax(scores_processed, dim=-1) + + sample_token = _multinomial_sample_one_no_sync(probs) + return sample_token + + +@dataclass +class ModelArgs: + backbone_flavor: str + decoder_flavor: str + text_vocab_size: int + audio_vocab_size: int + audio_num_codebooks: int + + +class Model( + nn.Module, + PyTorchModelHubMixin, + repo_url="https://github.com/SesameAILabs/csm", + pipeline_tag="text-to-speech", + license="apache-2.0", +): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + + self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) + self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) + + self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) + self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) + + self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) + self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) + self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) + + def setup_caches(self, max_batch_size: int) -> torch.Tensor: + """Setup KV caches and return a causal mask.""" + dtype = next(self.parameters()).dtype + device = next(self.parameters()).device + + with device: + self.backbone.setup_caches(max_batch_size, dtype) + self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) + + self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) + self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) + + def generate_frame( + self, + tokens: torch.Tensor, + tokens_mask: torch.Tensor, + input_pos: torch.Tensor, + temperature: float, + topk: int, + ) -> torch.Tensor: + """ + Args: + tokens: (batch_size, seq_len, audio_num_codebooks+1) + tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) + input_pos: (batch_size, seq_len) positions for each token + mask: (batch_size, seq_len, max_seq_len + + Returns: + (batch_size, audio_num_codebooks) sampled tokens + """ + dtype = next(self.parameters()).dtype + b, s, _ = tokens.size() + + assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" + curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) + embeds = self._embed_tokens(tokens) + masked_embeds = embeds * tokens_mask.unsqueeze(-1) + h = masked_embeds.sum(dim=2) + h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) + + last_h = h[:, -1, :] + c0_logits = self.codebook0_head(last_h) + c0_sample = sample_topk(c0_logits, topk, temperature) + c0_embed = self._embed_audio(0, c0_sample) + + curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) + curr_sample = c0_sample.clone() + curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) + + # Decoder caches must be reset every frame. + self.decoder.reset_caches() + for i in range(1, self.config.audio_num_codebooks): + curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) + decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( + dtype=dtype + ) + ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) + ci_sample = sample_topk(ci_logits, topk, temperature) + ci_embed = self._embed_audio(i, ci_sample) + + curr_h = ci_embed + curr_sample = torch.cat([curr_sample, ci_sample], dim=1) + curr_pos = curr_pos[:, -1:] + 1 + + return curr_sample + + def reset_caches(self): + self.backbone.reset_caches() + self.decoder.reset_caches() + + def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: + return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) + + def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) + + audio_tokens = tokens[:, :, :-1] + ( + self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) + ) + audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( + tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 + ) + + return torch.cat([audio_embeds, text_embeds], dim=-2) diff --git a/Backend/requirements.txt b/Backend/requirements.txt index ef7beab..ba8a04f 100644 --- a/Backend/requirements.txt +++ b/Backend/requirements.txt @@ -1,16 +1,9 @@ -Flask==2.2.2 -Flask-SocketIO==5.3.2 -torch>=2.0.0 -torchaudio>=2.0.0 -transformers>=4.30.0 -huggingface-hub>=0.14.0 -python-dotenv==0.19.2 -numpy>=1.21.6 -scipy>=1.7.3 -soundfile==0.10.3.post1 -requests==2.28.1 -pydub==0.25.1 -python-socketio==5.7.2 -eventlet==0.33.3 -whisper>=20230314 -ffmpeg-python>=0.2.0 \ No newline at end of file +torch==2.4.0 +torchaudio==2.4.0 +tokenizers==0.21.0 +transformers==4.49.0 +huggingface_hub==0.28.1 +moshi==0.2.2 +torchtune==0.4.0 +torchao==0.9.0 +silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master \ No newline at end of file diff --git a/Backend/run_csm.py b/Backend/run_csm.py new file mode 100644 index 0000000..0062973 --- /dev/null +++ b/Backend/run_csm.py @@ -0,0 +1,117 @@ +import os +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from generator import load_csm_1b, Segment +from dataclasses import dataclass + +# Disable Triton compilation +os.environ["NO_TORCH_COMPILE"] = "1" + +# Default prompts are available at https://hf.co/sesame/csm-1b +prompt_filepath_conversational_a = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_a.wav" +) +prompt_filepath_conversational_b = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_b.wav" +) + +SPEAKER_PROMPTS = { + "conversational_a": { + "text": ( + "like revising for an exam I'd have to try and like keep up the momentum because I'd " + "start really early I'd be like okay I'm gonna start revising now and then like " + "you're revising for ages and then I just like start losing steam I didn't do that " + "for the exam we had recently to be fair that was a more of a last minute scenario " + "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " + "sort of start the day with this not like a panic but like a" + ), + "audio": prompt_filepath_conversational_a + }, + "conversational_b": { + "text": ( + "like a super Mario level. Like it's very like high detail. And like, once you get " + "into the park, it just like, everything looks like a computer game and they have all " + "these, like, you know, if, if there's like a, you know, like in a Mario game, they " + "will have like a question block. And if you like, you know, punch it, a coin will " + "come out. So like everyone, when they come into the park, they get like this little " + "bracelet and then you can go punching question blocks around." + ), + "audio": prompt_filepath_conversational_b + } +} + +def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = audio_tensor.squeeze(0) + # Resample is lazy so we can always call it + audio_tensor = torchaudio.functional.resample( + audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate + ) + return audio_tensor + +def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: + audio_tensor = load_prompt_audio(audio_path, sample_rate) + return Segment(text=text, speaker=speaker, audio=audio_tensor) + +def main(): + # Select the best available device, skipping MPS due to float64 limitations + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + print(f"Using device: {device}") + + # Load model + generator = load_csm_1b(device) + + # Prepare prompts + prompt_a = prepare_prompt( + SPEAKER_PROMPTS["conversational_a"]["text"], + 0, + SPEAKER_PROMPTS["conversational_a"]["audio"], + generator.sample_rate + ) + + prompt_b = prepare_prompt( + SPEAKER_PROMPTS["conversational_b"]["text"], + 1, + SPEAKER_PROMPTS["conversational_b"]["audio"], + generator.sample_rate + ) + + # Generate conversation + conversation = [ + {"text": "Hey how are you doing?", "speaker_id": 0}, + {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, + {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, + {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} + ] + + # Generate each utterance + generated_segments = [] + prompt_segments = [prompt_a, prompt_b] + + for utterance in conversation: + print(f"Generating: {utterance['text']}") + audio_tensor = generator.generate( + text=utterance['text'], + speaker=utterance['speaker_id'], + context=prompt_segments + generated_segments, + max_audio_length_ms=10_000, + ) + generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) + + # Concatenate all generations + all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) + torchaudio.save( + "full_conversation.wav", + all_audio.unsqueeze(0).cpu(), + generator.sample_rate + ) + print("Successfully generated full_conversation.wav") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index 2069b29..ef9fbda 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -1,53 +1,426 @@ import os -import logging -import torch -import eventlet +import io import base64 +import time +import json +import uuid +import logging +import threading +import queue import tempfile -from io import BytesIO -from flask import Flask, render_template, request, jsonify -from flask_socketio import SocketIO, emit -import whisper +from typing import Dict, List, Optional, Tuple + +import torch import torchaudio -from src.models.conversation import Segment -from src.services.tts_service import load_csm_1b -from src.llm.generator import generate_llm_response -from transformers import AutoTokenizer, AutoModelForCausalLM -from src.audio.streaming import AudioStreamer -from src.services.transcription_service import TranscriptionService -from src.services.tts_service import TextToSpeechService +import numpy as np +from flask import Flask, request, jsonify, send_from_directory +from flask_socketio import SocketIO, emit +from flask_cors import CORS +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + +from generator import load_csm_1b, Segment +from dataclasses import dataclass # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -app = Flask(__name__, static_folder='static', template_folder='templates') -app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'your-secret-key') -socketio = SocketIO(app) +# Initialize Flask app +app = Flask(__name__, static_folder='.') +CORS(app) +socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) -# Initialize services -transcription_service = TranscriptionService() -tts_service = TextToSpeechService() -audio_streamer = AudioStreamer() +# Configure device +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" -@socketio.on('audio_input') -def handle_audio_input(data): - audio_chunk = data['audio'] - speaker_id = data['speaker'] +logger.info(f"Using device: {DEVICE}") + +# Global variables +active_conversations = {} +user_queues = {} +processing_threads = {} + +# Load models +@dataclass +class AppModels: + generator = None + tokenizer = None + llm = None + asr = None + +models = AppModels() + +def load_models(): + """Load all required models""" + global models - # Process audio and convert to text - text = transcription_service.transcribe(audio_chunk) - logging.info(f"Transcribed text: {text}") - - # Generate response using Llama 3.2 - response_text = tts_service.generate_response(text, speaker_id) - logging.info(f"Generated response: {response_text}") - - # Convert response text to audio - audio_response = tts_service.text_to_speech(response_text, speaker_id) + logger.info("Loading CSM 1B model...") + models.generator = load_csm_1b(device=DEVICE) - # Stream audio response back to client - socketio.emit('audio_response', {'audio': audio_response}) + logger.info("Loading ASR pipeline...") + models.asr = pipeline( + "automatic-speech-recognition", + model="openai/whisper-small", + device=DEVICE + ) + + logger.info("Loading Llama 3.2 model...") + models.llm = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + device_map=DEVICE, + torch_dtype=torch.bfloat16 + ) + models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") +# Load models in a background thread +threading.Thread(target=load_models, daemon=True).start() + +# Conversation data structure +class Conversation: + def __init__(self, session_id): + self.session_id = session_id + self.segments: List[Segment] = [] + self.current_speaker = 0 + self.last_activity = time.time() + self.is_processing = False + + def add_segment(self, text, speaker, audio): + segment = Segment(text=text, speaker=speaker, audio=audio) + self.segments.append(segment) + self.last_activity = time.time() + return segment + + def get_context(self, max_segments=10): + """Return the most recent segments for context""" + return self.segments[-max_segments:] if self.segments else [] + +# Routes +@app.route('/') +def index(): + return send_from_directory('.', 'index.html') + +@app.route('/api/health') +def health_check(): + return jsonify({ + "status": "ok", + "cuda_available": torch.cuda.is_available(), + "models_loaded": models.generator is not None and models.llm is not None + }) + +# Socket event handlers +@socketio.on('connect') +def handle_connect(): + session_id = request.sid + logger.info(f"Client connected: {session_id}") + + # Initialize conversation data + if session_id not in active_conversations: + active_conversations[session_id] = Conversation(session_id) + user_queues[session_id] = queue.Queue() + processing_threads[session_id] = threading.Thread( + target=process_audio_queue, + args=(session_id, user_queues[session_id]), + daemon=True + ) + processing_threads[session_id].start() + + emit('connection_status', {'status': 'connected'}) + +@socketio.on('disconnect') +def handle_disconnect(): + session_id = request.sid + logger.info(f"Client disconnected: {session_id}") + + # Cleanup + if session_id in active_conversations: + # Mark for deletion rather than immediately removing + # as the processing thread might still be accessing it + active_conversations[session_id].is_processing = False + user_queues[session_id].put(None) # Signal thread to terminate + +@socketio.on('start_stream') +def handle_start_stream(): + session_id = request.sid + logger.info(f"Starting stream for client: {session_id}") + emit('streaming_status', {'status': 'active'}) + +@socketio.on('stop_stream') +def handle_stop_stream(): + session_id = request.sid + logger.info(f"Stopping stream for client: {session_id}") + emit('streaming_status', {'status': 'inactive'}) + +@socketio.on('clear_context') +def handle_clear_context(): + session_id = request.sid + logger.info(f"Clearing context for client: {session_id}") + + if session_id in active_conversations: + active_conversations[session_id].segments = [] + emit('context_updated', {'status': 'cleared'}) + +@socketio.on('audio_chunk') +def handle_audio_chunk(data): + session_id = request.sid + audio_data = data.get('audio', '') + speaker_id = int(data.get('speaker', 0)) + + if not audio_data or not session_id in active_conversations: + return + + # Update the current speaker + active_conversations[session_id].current_speaker = speaker_id + + # Queue audio for processing + user_queues[session_id].put({ + 'audio': audio_data, + 'speaker': speaker_id + }) + + emit('processing_status', {'status': 'transcribing'}) + +def process_audio_queue(session_id, q): + """Background thread to process audio chunks for a session""" + logger.info(f"Started processing thread for session: {session_id}") + + try: + while session_id in active_conversations: + try: + # Get the next audio chunk with a timeout + data = q.get(timeout=120) + if data is None: # Termination signal + break + + # Process the audio and generate a response + process_audio_and_respond(session_id, data) + + except queue.Empty: + # Timeout, check if session is still valid + continue + except Exception as e: + logger.error(f"Error processing audio for {session_id}: {str(e)}") + socketio.emit('error', {'message': str(e)}, room=session_id) + finally: + logger.info(f"Ending processing thread for session: {session_id}") + # Clean up when thread is done + with app.app_context(): + if session_id in active_conversations: + del active_conversations[session_id] + if session_id in user_queues: + del user_queues[session_id] + +def process_audio_and_respond(session_id, data): + """Process audio data and generate a response""" + if models.generator is None or models.asr is None or models.llm is None: + socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) + return + + conversation = active_conversations[session_id] + + try: + # Set processing flag + conversation.is_processing = True + + # Process base64 audio data + audio_data = data['audio'] + speaker_id = data['speaker'] + + # Convert from base64 to WAV + audio_bytes = base64.b64decode(audio_data.split(',')[1]) + + # Save to temporary file for processing + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: + temp_file.write(audio_bytes) + temp_path = temp_file.name + + try: + # Load audio file + waveform, sample_rate = torchaudio.load(temp_path) + + # Normalize to mono if needed + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + # Resample to the CSM sample rate if needed + if sample_rate != models.generator.sample_rate: + waveform = torchaudio.functional.resample( + waveform, + orig_freq=sample_rate, + new_freq=models.generator.sample_rate + ) + + # Transcribe audio + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + + # Use the ASR pipeline to transcribe + transcription_result = models.asr( + {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate}, + return_timestamps=False + ) + user_text = transcription_result['text'].strip() + + # If no text was recognized, don't process further + if not user_text: + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + return + + # Add the user's message to conversation history + user_segment = conversation.add_segment( + text=user_text, + speaker=speaker_id, + audio=waveform.squeeze() + ) + + # Send transcription to client + socketio.emit('transcription', { + 'text': user_text, + 'speaker': speaker_id + }, room=session_id) + + # Generate AI response using Llama + socketio.emit('processing_status', {'status': 'generating'}, room=session_id) + + # Create prompt from conversation history + conversation_history = "" + for segment in conversation.segments[-5:]: # Last 5 segments for context + role = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{role}: {segment.text}\n" + + # Add final prompt + prompt = f"{conversation_history}Assistant: " + + # Generate response with Llama + input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) + + with torch.no_grad(): + generated_ids = models.llm.generate( + input_ids, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=models.tokenizer.eos_token_id + ) + + # Decode the response + response_text = models.tokenizer.decode( + generated_ids[0][input_ids.shape[1]:], + skip_special_tokens=True + ).strip() + + # Synthesize speech + socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) + + # Generate audio with CSM + ai_speaker_id = 1 # Use speaker 1 for AI responses + + # Start sending the audio response + socketio.emit('audio_response_start', { + 'text': response_text, + 'total_chunks': 1, + 'chunk_index': 0 + }, room=session_id) + + # Generate audio + audio_tensor = models.generator.generate( + text=response_text, + speaker=ai_speaker_id, + context=conversation.get_context(), + max_audio_length_ms=10_000, + temperature=0.9 + ) + + # Add AI response to conversation history + ai_segment = conversation.add_segment( + text=response_text, + speaker=ai_speaker_id, + audio=audio_tensor + ) + + # Convert audio to WAV format + with io.BytesIO() as wav_io: + torchaudio.save( + wav_io, + audio_tensor.unsqueeze(0).cpu(), + models.generator.sample_rate, + format="wav" + ) + wav_io.seek(0) + wav_data = wav_io.read() + + # Convert WAV data to base64 + audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" + + # Send audio chunk to client + socketio.emit('audio_response_chunk', { + 'chunk': audio_base64, + 'chunk_index': 0, + 'total_chunks': 1, + 'is_last': True + }, room=session_id) + + # Signal completion + socketio.emit('audio_response_complete', { + 'text': response_text + }, room=session_id) + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.unlink(temp_path) + + except Exception as e: + logger.error(f"Error processing audio: {str(e)}") + socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) + finally: + # Reset processing flag + conversation.is_processing = False + +# Error handler +@socketio.on_error() +def error_handler(e): + logger.error(f"SocketIO error: {str(e)}") + +# Periodic cleanup of inactive sessions +def cleanup_inactive_sessions(): + """Remove sessions that have been inactive for too long""" + current_time = time.time() + inactive_timeout = 3600 # 1 hour + + for session_id in list(active_conversations.keys()): + conversation = active_conversations[session_id] + if (current_time - conversation.last_activity > inactive_timeout and + not conversation.is_processing): + + logger.info(f"Cleaning up inactive session: {session_id}") + + # Signal processing thread to terminate + if session_id in user_queues: + user_queues[session_id].put(None) + + # Remove from active conversations + del active_conversations[session_id] + +# Start cleanup thread +def start_cleanup_thread(): + while True: + try: + cleanup_inactive_sessions() + except Exception as e: + logger.error(f"Error in cleanup thread: {str(e)}") + time.sleep(300) # Run every 5 minutes + +cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True) +cleanup_thread.start() + +# Start the server if __name__ == '__main__': - socketio.run(app, host='0.0.0.0', port=5000) \ No newline at end of file + port = int(os.environ.get('PORT', 5000)) + logger.info(f"Starting server on port {port}") + socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/setup.py b/Backend/setup.py new file mode 100644 index 0000000..8eddb95 --- /dev/null +++ b/Backend/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages +import os + +# Read requirements from requirements.txt +with open('requirements.txt') as f: + requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] + +setup( + name='csm', + version='0.1.0', + packages=find_packages(), + install_requires=requirements, +) diff --git a/Backend/src/audio/processor.py b/Backend/src/audio/processor.py deleted file mode 100644 index 40d636e..0000000 --- a/Backend/src/audio/processor.py +++ /dev/null @@ -1,28 +0,0 @@ -from scipy.io import wavfile -import numpy as np -import torchaudio - -def load_audio(file_path): - sample_rate, audio_data = wavfile.read(file_path) - return sample_rate, audio_data - -def normalize_audio(audio_data): - audio_data = audio_data.astype(np.float32) - max_val = np.max(np.abs(audio_data)) - if max_val > 0: - audio_data /= max_val - return audio_data - -def reduce_noise(audio_data, noise_factor=0.1): - noise = np.random.randn(len(audio_data)) - noisy_audio = audio_data + noise_factor * noise - return noisy_audio - -def save_audio(file_path, sample_rate, audio_data): - torchaudio.save(file_path, torch.tensor(audio_data).unsqueeze(0), sample_rate) - -def process_audio(file_path, output_path): - sample_rate, audio_data = load_audio(file_path) - normalized_audio = normalize_audio(audio_data) - denoised_audio = reduce_noise(normalized_audio) - save_audio(output_path, sample_rate, denoised_audio) \ No newline at end of file diff --git a/Backend/src/audio/streaming.py b/Backend/src/audio/streaming.py deleted file mode 100644 index 19ee4cb..0000000 --- a/Backend/src/audio/streaming.py +++ /dev/null @@ -1,35 +0,0 @@ -from flask import Blueprint, request -from flask_socketio import SocketIO, emit -from src.audio.processor import process_audio -from src.services.transcription_service import TranscriptionService -from src.services.tts_service import TextToSpeechService - -streaming_bp = Blueprint('streaming', __name__) -socketio = SocketIO() - -transcription_service = TranscriptionService() -tts_service = TextToSpeechService() - -@socketio.on('audio_stream') -def handle_audio_stream(data): - audio_chunk = data['audio'] - speaker_id = data['speaker'] - - # Process the audio chunk - processed_audio = process_audio(audio_chunk) - - # Transcribe the audio to text - transcription = transcription_service.transcribe(processed_audio) - - # Generate a response using the LLM - response_text = generate_response(transcription, speaker_id) - - # Convert the response text back to audio - response_audio = tts_service.convert_text_to_speech(response_text, speaker_id) - - # Emit the response audio back to the client - emit('audio_response', {'audio': response_audio}) - -def generate_response(transcription, speaker_id): - # Placeholder for the actual response generation logic - return f"Response to: {transcription}" \ No newline at end of file diff --git a/Backend/src/llm/tokenizer.py b/Backend/src/llm/tokenizer.py deleted file mode 100644 index 0a05bcd..0000000 --- a/Backend/src/llm/tokenizer.py +++ /dev/null @@ -1,14 +0,0 @@ -from transformers import AutoTokenizer - -def load_llama3_tokenizer(): - tokenizer_name = "meta-llama/Llama-3.2-1B" - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - return tokenizer - -def tokenize_text(text: str, tokenizer) -> list: - tokens = tokenizer.encode(text, return_tensors='pt') - return tokens - -def decode_tokens(tokens: list, tokenizer) -> str: - text = tokenizer.decode(tokens, skip_special_tokens=True) - return text \ No newline at end of file diff --git a/Backend/src/models/audio_model.py b/Backend/src/models/audio_model.py deleted file mode 100644 index 726bec4..0000000 --- a/Backend/src/models/audio_model.py +++ /dev/null @@ -1,28 +0,0 @@ -from dataclasses import dataclass -import torch - -@dataclass -class AudioModel: - model: torch.nn.Module - sample_rate: int - - def __post_init__(self): - self.model.eval() - - def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor: - with torch.no_grad(): - processed_audio = self.model(audio_tensor) - return processed_audio - - def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor: - if self.sample_rate != target_sample_rate: - resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate) - return resampled_audio - return audio_tensor - - def save_model(self, path: str): - torch.save(self.model.state_dict(), path) - - def load_model(self, path: str): - self.model.load_state_dict(torch.load(path)) - self.model.eval() \ No newline at end of file diff --git a/Backend/src/models/conversation.py b/Backend/src/models/conversation.py deleted file mode 100644 index 25d1a70..0000000 --- a/Backend/src/models/conversation.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass, field -from typing import List, Optional -import torch - -@dataclass -class Segment: - speaker: int - text: str - # (num_samples,), sample_rate = 24_000 - audio: Optional[torch.Tensor] = None - - def __post_init__(self): - # Ensure audio is a tensor if provided - if self.audio is not None and not isinstance(self.audio, torch.Tensor): - self.audio = torch.tensor(self.audio, dtype=torch.float32) - -@dataclass -class Conversation: - context: List[str] = field(default_factory=list) - segments: List[Segment] = field(default_factory=list) - current_speaker: Optional[int] = None - - def add_message(self, message: str, speaker: int): - self.context.append(f"Speaker {speaker}: {message}") - self.current_speaker = speaker - - def add_segment(self, segment: Segment): - self.segments.append(segment) - self.context.append(f"Speaker {segment.speaker}: {segment.text}") - self.current_speaker = segment.speaker - - def get_context(self) -> List[str]: - return self.context - - def get_segments(self) -> List[Segment]: - return self.segments - - def clear_context(self): - self.context.clear() - self.segments.clear() - self.current_speaker = None - - def get_last_message(self) -> Optional[str]: - if self.context: - return self.context[-1] - return None - - def get_last_segment(self) -> Optional[Segment]: - if self.segments: - return self.segments[-1] - return None \ No newline at end of file diff --git a/Backend/src/services/transcription_service.py b/Backend/src/services/transcription_service.py deleted file mode 100644 index 06f8dd1..0000000 --- a/Backend/src/services/transcription_service.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List -import torchaudio -import torch -from generator import load_csm_1b, Segment - -class TranscriptionService: - def __init__(self, model_device: str = "cpu"): - self.generator = load_csm_1b(device=model_device) - - def transcribe_audio(self, audio_path: str) -> str: - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = self._resample_audio(audio_tensor, sample_rate) - transcription = self.generator.generate_transcription(audio_tensor) - return transcription - - def _resample_audio(self, audio_tensor: torch.Tensor, orig_freq: int) -> torch.Tensor: - target_sample_rate = self.generator.sample_rate - if orig_freq != target_sample_rate: - audio_tensor = torchaudio.functional.resample(audio_tensor.squeeze(0), orig_freq=orig_freq, new_freq=target_sample_rate) - return audio_tensor - - def transcribe_audio_stream(self, audio_chunks: List[torch.Tensor]) -> str: - combined_audio = torch.cat(audio_chunks, dim=1) - transcription = self.generator.generate_transcription(combined_audio) - return transcription \ No newline at end of file diff --git a/Backend/src/services/tts_service.py b/Backend/src/services/tts_service.py deleted file mode 100644 index 64fab04..0000000 --- a/Backend/src/services/tts_service.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from src.llm.generator import load_csm_1b - -@dataclass -class TextToSpeechService: - generator: any - - def __init__(self, device: str = "cuda"): - self.generator = load_csm_1b(device=device) - - def text_to_speech(self, text: str, speaker: int = 0) -> torch.Tensor: - audio = self.generator.generate( - text=text, - speaker=speaker, - context=[], - max_audio_length_ms=10000, - ) - return audio - - def save_audio(self, audio: torch.Tensor, file_path: str): - torchaudio.save(file_path, audio.unsqueeze(0).cpu(), self.generator.sample_rate) \ No newline at end of file diff --git a/Backend/src/utils/config.py b/Backend/src/utils/config.py deleted file mode 100644 index 2206481..0000000 --- a/Backend/src/utils/config.py +++ /dev/null @@ -1,23 +0,0 @@ -# filepath: /csm-conversation-bot/csm-conversation-bot/src/utils/config.py - -import os - -class Config: - # General configuration - DEBUG = os.getenv('DEBUG', 'False') == 'True' - SECRET_KEY = os.getenv('SECRET_KEY', 'your_secret_key_here') - - # API configuration - API_URL = os.getenv('API_URL', 'http://localhost:5000') - - # Model configuration - LLM_MODEL_PATH = os.getenv('LLM_MODEL_PATH', 'path/to/llm/model') - AUDIO_MODEL_PATH = os.getenv('AUDIO_MODEL_PATH', 'path/to/audio/model') - - # Socket.IO configuration - SOCKETIO_MESSAGE_QUEUE = os.getenv('SOCKETIO_MESSAGE_QUEUE', 'redis://localhost:6379/0') - - # Logging configuration - LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') - - # Other configurations can be added as needed \ No newline at end of file diff --git a/Backend/src/utils/logger.py b/Backend/src/utils/logger.py deleted file mode 100644 index 93e8966..0000000 --- a/Backend/src/utils/logger.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -def setup_logger(name, log_file, level=logging.INFO): - handler = logging.FileHandler(log_file) - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - - logger = logging.getLogger(name) - logger.setLevel(level) - logger.addHandler(handler) - - return logger - -# Example usage: -# logger = setup_logger('my_logger', 'app.log') \ No newline at end of file diff --git a/Backend/static/css/styles.css b/Backend/static/css/styles.css deleted file mode 100644 index 4e2d752..0000000 --- a/Backend/static/css/styles.css +++ /dev/null @@ -1,105 +0,0 @@ -body { - font-family: 'Arial', sans-serif; - background-color: #f4f4f4; - color: #333; - margin: 0; - padding: 0; -} - -header { - background: #4c84ff; - color: #fff; - padding: 10px 0; - text-align: center; -} - -h1 { - margin: 0; - font-size: 2.5rem; -} - -.container { - width: 80%; - margin: auto; - overflow: hidden; -} - -.conversation { - background: #fff; - padding: 20px; - border-radius: 5px; - box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); - max-height: 400px; - overflow-y: auto; -} - -.message { - padding: 10px; - margin: 10px 0; - border-radius: 5px; -} - -.user { - background: #e3f2fd; - text-align: right; -} - -.ai { - background: #f1f1f1; - text-align: left; -} - -.controls { - display: flex; - justify-content: space-between; - margin-top: 20px; -} - -button { - padding: 10px 15px; - border: none; - border-radius: 5px; - cursor: pointer; - transition: background 0.3s; -} - -button:hover { - background: #3367d6; - color: #fff; -} - -.visualizer-container { - height: 150px; - background: #000; - border-radius: 5px; - margin-top: 20px; -} - -.visualizer-label { - color: rgba(255, 255, 255, 0.7); - text-align: center; - padding: 10px; -} - -.status-indicator { - display: flex; - align-items: center; - margin-top: 10px; -} - -.status-dot { - width: 12px; - height: 12px; - border-radius: 50%; - background-color: #ccc; - margin-right: 10px; -} - -.status-dot.active { - background-color: #4CAF50; -} - -.status-text { - font-size: 0.9em; - color: #666; -} \ No newline at end of file diff --git a/Backend/static/index.html b/Backend/static/index.html deleted file mode 100644 index 4922f17..0000000 --- a/Backend/static/index.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - CSM Conversation Bot - - - - - -
-

CSM Conversation Bot

-

Talk to the AI and get responses in real-time!

-
-
-
-
- - -
-
-
-
Disconnected
-
-
-
-

Powered by CSM and Llama 3.2

-
- - \ No newline at end of file diff --git a/Backend/static/js/client.js b/Backend/static/js/client.js deleted file mode 100644 index ec4037f..0000000 --- a/Backend/static/js/client.js +++ /dev/null @@ -1,131 +0,0 @@ -// This file contains the client-side JavaScript code that handles audio streaming and communication with the server. - -const SERVER_URL = window.location.hostname === 'localhost' ? - 'http://localhost:5000' : window.location.origin; - -const elements = { - conversation: document.getElementById('conversation'), - streamButton: document.getElementById('streamButton'), - clearButton: document.getElementById('clearButton'), - speakerSelection: document.getElementById('speakerSelect'), - statusDot: document.getElementById('statusDot'), - statusText: document.getElementById('statusText'), -}; - -const state = { - socket: null, - isStreaming: false, - currentSpeaker: 0, -}; - -// Initialize the application -function initializeApp() { - setupSocketConnection(); - setupEventListeners(); -} - -// Setup Socket.IO connection -function setupSocketConnection() { - state.socket = io(SERVER_URL); - - state.socket.on('connect', () => { - updateConnectionStatus(true); - }); - - state.socket.on('disconnect', () => { - updateConnectionStatus(false); - }); - - state.socket.on('audio_response', handleAudioResponse); - state.socket.on('transcription', handleTranscription); -} - -// Setup event listeners -function setupEventListeners() { - elements.streamButton.addEventListener('click', toggleStreaming); - elements.clearButton.addEventListener('click', clearConversation); - elements.speakerSelection.addEventListener('change', (event) => { - state.currentSpeaker = event.target.value; - }); -} - -// Update connection status UI -function updateConnectionStatus(isConnected) { - elements.statusDot.classList.toggle('active', isConnected); - elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected'; -} - -// Toggle streaming state -function toggleStreaming() { - if (state.isStreaming) { - stopStreaming(); - } else { - startStreaming(); - } -} - -// Start streaming audio to the server -function startStreaming() { - if (state.isStreaming) return; - - navigator.mediaDevices.getUserMedia({ audio: true }) - .then(stream => { - const mediaRecorder = new MediaRecorder(stream); - mediaRecorder.start(); - - mediaRecorder.ondataavailable = (event) => { - if (event.data.size > 0) { - sendAudioChunk(event.data); - } - }; - - mediaRecorder.onstop = () => { - state.isStreaming = false; - elements.streamButton.innerHTML = 'Start Conversation'; - }; - - state.isStreaming = true; - elements.streamButton.innerHTML = 'Stop Conversation'; - }) - .catch(err => { - console.error('Error accessing microphone:', err); - }); -} - -// Stop streaming audio -function stopStreaming() { - if (!state.isStreaming) return; - - // Logic to stop the media recorder would go here -} - -// Send audio chunk to server -function sendAudioChunk(audioData) { - const reader = new FileReader(); - reader.onloadend = () => { - const arrayBuffer = reader.result; - state.socket.emit('audio_chunk', { audio: arrayBuffer, speaker: state.currentSpeaker }); - }; - reader.readAsArrayBuffer(audioData); -} - -// Handle audio response from server -function handleAudioResponse(data) { - const audioElement = new Audio(URL.createObjectURL(new Blob([data.audio]))); - audioElement.play(); -} - -// Handle transcription response from server -function handleTranscription(data) { - const messageElement = document.createElement('div'); - messageElement.textContent = `AI: ${data.transcription}`; - elements.conversation.appendChild(messageElement); -} - -// Clear conversation history -function clearConversation() { - elements.conversation.innerHTML = ''; -} - -// Initialize the application when DOM is fully loaded -document.addEventListener('DOMContentLoaded', initializeApp); \ No newline at end of file diff --git a/Backend/templates/index.html b/Backend/templates/index.html deleted file mode 100644 index 514b946..0000000 --- a/Backend/templates/index.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - CSM Conversation Bot - - - - - -
-

CSM Conversation Bot

-

Talk to the AI and get responses in real-time!

-
-
-
-
- - -
-
-
-
Not connected
-
-
-
-

Powered by CSM and Llama 3.2

-
- - \ No newline at end of file diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js new file mode 100644 index 0000000..89ec71a --- /dev/null +++ b/Backend/voice-chat.js @@ -0,0 +1,1071 @@ +/** + * Sesame AI Voice Chat Client + * + * A web client that connects to a Sesame AI voice chat server and enables + * real-time voice conversation with an AI assistant. + */ + +// Configuration constants +const SERVER_URL = window.location.hostname === 'localhost' ? + 'http://localhost:5000' : window.location.origin; +const ENERGY_WINDOW_SIZE = 15; +const CLIENT_SILENCE_DURATION_MS = 750; + +// DOM elements +const elements = { + conversation: null, + streamButton: null, + clearButton: null, + thresholdSlider: null, + thresholdValue: null, + visualizerCanvas: null, + visualizerLabel: null, + volumeLevel: null, + statusDot: null, + statusText: null, + speakerSelection: null, + autoPlayResponses: null, + showVisualizer: null +}; + +// Application state +const state = { + socket: null, + audioContext: null, + analyser: null, + microphone: null, + streamProcessor: null, + isStreaming: false, + isSpeaking: false, + silenceThreshold: 0.01, + energyWindow: [], + silenceTimer: null, + volumeUpdateInterval: null, + visualizerAnimationFrame: null, + currentSpeaker: 0 +}; + +// Visualizer variables +let canvasContext = null; +let visualizerBufferLength = 0; +let visualizerDataArray = null; + +// New state variables to track incremental audio streaming +const streamingAudio = { + messageElement: null, + audioElement: null, + chunks: [], + totalChunks: 0, + receivedChunks: 0, + text: '', + mediaSource: null, + sourceBuffer: null, + audioContext: null, + complete: false +}; + +// Initialize the application +function initializeApp() { + // Initialize the UI elements + initializeUIElements(); + + // Initialize socket.io connection + setupSocketConnection(); + + // Setup event listeners + setupEventListeners(); + + // Initialize visualizer + setupVisualizer(); + + // Show welcome message + addSystemMessage('Welcome to Sesame AI Voice Chat! Click "Start Conversation" to begin.'); +} + +// Initialize UI elements +function initializeUIElements() { + // Store references to UI elements + elements.conversation = document.getElementById('conversation'); + elements.streamButton = document.getElementById('streamButton'); + elements.clearButton = document.getElementById('clearButton'); + elements.thresholdSlider = document.getElementById('thresholdSlider'); + elements.thresholdValue = document.getElementById('thresholdValue'); + elements.visualizerCanvas = document.getElementById('audioVisualizer'); + elements.visualizerLabel = document.getElementById('visualizerLabel'); + elements.volumeLevel = document.getElementById('volumeLevel'); + elements.statusDot = document.getElementById('statusDot'); + elements.statusText = document.getElementById('statusText'); + elements.speakerSelection = document.getElementById('speakerSelect'); // Changed to match HTML + elements.autoPlayResponses = document.getElementById('autoPlayResponses'); + elements.showVisualizer = document.getElementById('showVisualizer'); +} + +// Setup Socket.IO connection +function setupSocketConnection() { + state.socket = io(SERVER_URL); + + // Connection events + state.socket.on('connect', () => { + console.log('Connected to server'); + updateConnectionStatus(true); + }); + + state.socket.on('disconnect', () => { + console.log('Disconnected from server'); + updateConnectionStatus(false); + + // Stop streaming if active + if (state.isStreaming) { + stopStreaming(false); + } + }); + + state.socket.on('error', (data) => { + console.error('Socket error:', data.message); + addSystemMessage(`Error: ${data.message}`); + }); + + // Register message handlers + state.socket.on('audio_response', handleAudioResponse); + state.socket.on('transcription', handleTranscription); + state.socket.on('context_updated', handleContextUpdate); + state.socket.on('streaming_status', handleStreamingStatus); + + // New event handlers for incremental audio streaming + state.socket.on('audio_response_start', handleAudioResponseStart); + state.socket.on('audio_response_chunk', handleAudioResponseChunk); + state.socket.on('audio_response_complete', handleAudioResponseComplete); + state.socket.on('processing_status', handleProcessingStatus); +} + +// Setup event listeners +function setupEventListeners() { + // Stream button + elements.streamButton.addEventListener('click', toggleStreaming); + + // Clear button + elements.clearButton.addEventListener('click', clearConversation); + + // Threshold slider + elements.thresholdSlider.addEventListener('input', updateThreshold); + + // Speaker selection + elements.speakerSelection.addEventListener('change', () => { + state.currentSpeaker = parseInt(elements.speakerSelection.value, 10); + }); + + // Visualizer toggle + elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); +} + +// Setup audio visualizer +function setupVisualizer() { + if (!elements.visualizerCanvas) return; + + canvasContext = elements.visualizerCanvas.getContext('2d'); + + // Set canvas dimensions + elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth; + elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight; + + // Initialize the visualizer + drawVisualizer(); +} + +// Update connection status UI +function updateConnectionStatus(isConnected) { + elements.statusDot.classList.toggle('active', isConnected); + elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected'; +} + +// Toggle streaming state +function toggleStreaming() { + if (state.isStreaming) { + stopStreaming(true); + } else { + startStreaming(); + } +} + +// Start streaming audio to the server +function startStreaming() { + if (state.isStreaming) return; + + // Request microphone access + navigator.mediaDevices.getUserMedia({ audio: true, video: false }) + .then(stream => { + // Show processing state while setting up + elements.streamButton.innerHTML = ' Initializing...'; + + // Create audio context + state.audioContext = new (window.AudioContext || window.webkitAudioContext)(); + + // Create microphone source + state.microphone = state.audioContext.createMediaStreamSource(stream); + + // Create analyser for visualizer + state.analyser = state.audioContext.createAnalyser(); + state.analyser.fftSize = 256; + visualizerBufferLength = state.analyser.frequencyBinCount; + visualizerDataArray = new Uint8Array(visualizerBufferLength); + + // Connect microphone to analyser + state.microphone.connect(state.analyser); + + // Create script processor for audio processing + const bufferSize = 4096; + state.streamProcessor = state.audioContext.createScriptProcessor(bufferSize, 1, 1); + + // Set up audio processing callback + state.streamProcessor.onaudioprocess = handleAudioProcess; + + // Connect the processors + state.analyser.connect(state.streamProcessor); + state.streamProcessor.connect(state.audioContext.destination); + + // Update UI + state.isStreaming = true; + elements.streamButton.innerHTML = ' Listening...'; + elements.streamButton.classList.add('recording'); + + // Initialize energy window + state.energyWindow = []; + + // Start volume meter updates + state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100); + + // Start visualizer if enabled + if (elements.showVisualizer.checked && !state.visualizerAnimationFrame) { + drawVisualizer(); + } + + // Show starting message + addSystemMessage('Listening... Speak clearly into your microphone.'); + + // Notify the server that we're starting + state.socket.emit('stream_audio', { + audio: '', + speaker: state.currentSpeaker + }); + }) + .catch(err => { + console.error('Error accessing microphone:', err); + addSystemMessage(`Error: ${err.message}. Please make sure your microphone is connected and you've granted permission.`); + elements.streamButton.innerHTML = ' Start Conversation'; + }); +} + +// Stop streaming audio +function stopStreaming(notifyServer = true) { + if (!state.isStreaming) return; + + // Update UI first + elements.streamButton.innerHTML = ' Start Conversation'; + elements.streamButton.classList.remove('recording'); + elements.streamButton.classList.remove('processing'); + + // Stop volume meter updates + if (state.volumeUpdateInterval) { + clearInterval(state.volumeUpdateInterval); + state.volumeUpdateInterval = null; + } + + // Stop all audio processing + if (state.streamProcessor) { + state.streamProcessor.disconnect(); + state.streamProcessor = null; + } + + if (state.analyser) { + state.analyser.disconnect(); + } + + if (state.microphone) { + state.microphone.disconnect(); + } + + // Close audio context + if (state.audioContext && state.audioContext.state !== 'closed') { + state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); + } + + // Cleanup animation frames + if (state.visualizerAnimationFrame) { + cancelAnimationFrame(state.visualizerAnimationFrame); + state.visualizerAnimationFrame = null; + } + + // Reset state + state.isStreaming = false; + state.isSpeaking = false; + + // Notify the server + if (notifyServer && state.socket && state.socket.connected) { + state.socket.emit('stop_streaming', { + speaker: state.currentSpeaker + }); + } + + // Show message + addSystemMessage('Conversation paused. Click "Start Conversation" to resume.'); +} + +// Handle audio processing +function handleAudioProcess(event) { + const inputData = event.inputBuffer.getChannelData(0); + + // Calculate audio energy (volume level) + const energy = calculateAudioEnergy(inputData); + + // Update energy window for averaging + updateEnergyWindow(energy); + + // Calculate average energy + const avgEnergy = calculateAverageEnergy(); + + // Determine if audio is silent + const isSilent = avgEnergy < state.silenceThreshold; + + // Debug logging only if significant changes in audio patterns + if (Math.random() < 0.05) { // Log only 5% of frames to avoid console spam + console.log(`Audio: len=${inputData.length}, energy=${energy.toFixed(4)}, avg=${avgEnergy.toFixed(4)}, silent=${isSilent}`); + } + + // Handle speech state based on silence + handleSpeechState(isSilent); + + // Only send audio chunk if we detect speech + if (!isSilent) { + // Create a resampled version at 24kHz for the server + // Most WebRTC audio is 48kHz, but we want 24kHz for the model + const resampledData = downsampleBuffer(inputData, state.audioContext.sampleRate, 24000); + + // Send the audio chunk to the server + sendAudioChunk(resampledData, state.currentSpeaker); + } +} + +// Cleanup audio resources when done +function cleanupAudioResources() { + // Stop all audio processing + if (state.streamProcessor) { + state.streamProcessor.disconnect(); + state.streamProcessor = null; + } + + if (state.analyser) { + state.analyser.disconnect(); + state.analyser = null; + } + + if (state.microphone) { + state.microphone.disconnect(); + state.microphone = null; + } + + // Close audio context + if (state.audioContext && state.audioContext.state !== 'closed') { + state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); + } + + // Cancel all timers and animation frames + if (state.volumeUpdateInterval) { + clearInterval(state.volumeUpdateInterval); + state.volumeUpdateInterval = null; + } + + if (state.visualizerAnimationFrame) { + cancelAnimationFrame(state.visualizerAnimationFrame); + state.visualizerAnimationFrame = null; + } + + if (state.silenceTimer) { + clearTimeout(state.silenceTimer); + state.silenceTimer = null; + } +} + +// Clear conversation history +function clearConversation() { + if (elements.conversation) { + elements.conversation.innerHTML = ''; + addSystemMessage('Conversation cleared.'); + + // Notify server to clear context + if (state.socket && state.socket.connected) { + state.socket.emit('clear_context'); + } + } +} + +// Calculate audio energy (volume) +function calculateAudioEnergy(buffer) { + let sum = 0; + for (let i = 0; i < buffer.length; i++) { + sum += buffer[i] * buffer[i]; + } + return Math.sqrt(sum / buffer.length); +} + +// Update energy window for averaging +function updateEnergyWindow(energy) { + state.energyWindow.push(energy); + if (state.energyWindow.length > ENERGY_WINDOW_SIZE) { + state.energyWindow.shift(); + } +} + +// Calculate average energy from window +function calculateAverageEnergy() { + if (state.energyWindow.length === 0) return 0; + + const sum = state.energyWindow.reduce((a, b) => a + b, 0); + return sum / state.energyWindow.length; +} + +// Update the threshold from the slider +function updateThreshold() { + state.silenceThreshold = parseFloat(elements.thresholdSlider.value); + elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); +} + +// Update the volume meter display +function updateVolumeMeter() { + if (!state.isStreaming || !state.energyWindow.length) return; + + const avgEnergy = calculateAverageEnergy(); + + // Scale energy to percentage (0-100) + // Typically, energy values will be very small (e.g., 0.001 to 0.1) + // So we multiply by a factor to make it more visible + const scaleFactor = 1000; + const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor)); + + // Update volume meter width + elements.volumeLevel.style.width = `${percentage}%`; + + // Change color based on level + if (percentage > 70) { + elements.volumeLevel.style.backgroundColor = '#ff5252'; + } else if (percentage > 30) { + elements.volumeLevel.style.backgroundColor = '#4CAF50'; + } else { + elements.volumeLevel.style.backgroundColor = '#4c84ff'; + } +} + +// Handle speech/silence state transitions +function handleSpeechState(isSilent) { + if (state.isSpeaking && isSilent) { + // Transition from speaking to silence + if (!state.silenceTimer) { + state.silenceTimer = setTimeout(() => { + // Only consider it a real silence after a certain duration + // This prevents detecting brief pauses as the end of speech + state.isSpeaking = false; + state.silenceTimer = null; + }, CLIENT_SILENCE_DURATION_MS); + } + } else if (state.silenceTimer && !isSilent) { + // User started speaking again, cancel the silence timer + clearTimeout(state.silenceTimer); + state.silenceTimer = null; + } + + // Update speaking state for non-silent audio + if (!isSilent) { + state.isSpeaking = true; + } +} + +// Send audio chunk to server +function sendAudioChunk(audioData, speaker) { + if (!state.socket || !state.socket.connected) { + console.warn('Socket not connected'); + return; + } + + console.log(`Preparing audio chunk: length=${audioData.length}, speaker=${speaker}`); + + // Check for NaN or invalid values + let hasInvalidValues = false; + for (let i = 0; i < audioData.length; i++) { + if (isNaN(audioData[i]) || !isFinite(audioData[i])) { + hasInvalidValues = true; + console.warn(`Invalid audio value at index ${i}: ${audioData[i]}`); + break; + } + } + + if (hasInvalidValues) { + console.warn('Audio data contains invalid values. Creating silent audio.'); + audioData = new Float32Array(audioData.length).fill(0); + } + + try { + // Create WAV blob + const wavData = createWavBlob(audioData, 24000); + console.log(`WAV blob created: ${wavData.size} bytes`); + + const reader = new FileReader(); + + reader.onloadend = function() { + try { + // Get base64 data + const base64data = reader.result; + console.log(`Base64 data created: ${base64data.length} bytes`); + + // Send to server + state.socket.emit('stream_audio', { + audio: base64data, + speaker: speaker + }); + console.log('Audio chunk sent to server'); + } catch (err) { + console.error('Error preparing audio data:', err); + } + }; + + reader.onerror = function() { + console.error('Error reading audio data as base64'); + }; + + reader.readAsDataURL(wavData); + } catch (err) { + console.error('Error creating WAV data:', err); + } +} + +// Create WAV blob from audio data with improved error handling +function createWavBlob(audioData, sampleRate) { + // Validate input + if (!audioData || audioData.length === 0) { + console.warn('Empty audio data provided to createWavBlob'); + audioData = new Float32Array(1024).fill(0); // Create 1024 samples of silence + } + + // Function to convert Float32Array to Int16Array for WAV format + function floatTo16BitPCM(output, offset, input) { + for (let i = 0; i < input.length; i++, offset += 2) { + // Ensure values are in -1 to 1 range + const s = Math.max(-1, Math.min(1, input[i])); + // Convert to 16-bit PCM + output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true); + } + } + + // Create WAV header + function writeString(view, offset, string) { + for (let i = 0; i < string.length; i++) { + view.setUint8(offset + i, string.charCodeAt(i)); + } + } + + try { + // Create WAV file with header - careful with buffer sizes + const buffer = new ArrayBuffer(44 + audioData.length * 2); + const view = new DataView(buffer); + + // RIFF identifier + writeString(view, 0, 'RIFF'); + + // File length (will be filled later) + view.setUint32(4, 36 + audioData.length * 2, true); + + // WAVE identifier + writeString(view, 8, 'WAVE'); + + // fmt chunk identifier + writeString(view, 12, 'fmt '); + + // fmt chunk length + view.setUint32(16, 16, true); + + // Sample format (1 is PCM) + view.setUint16(20, 1, true); + + // Mono channel + view.setUint16(22, 1, true); + + // Sample rate + view.setUint32(24, sampleRate, true); + + // Byte rate (sample rate * block align) + view.setUint32(28, sampleRate * 2, true); + + // Block align (channels * bytes per sample) + view.setUint16(32, 2, true); + + // Bits per sample + view.setUint16(34, 16, true); + + // data chunk identifier + writeString(view, 36, 'data'); + + // data chunk length + view.setUint32(40, audioData.length * 2, true); + + // Write the PCM samples + floatTo16BitPCM(view, 44, audioData); + + // Create and return blob + return new Blob([view], { type: 'audio/wav' }); + } catch (err) { + console.error('Error in createWavBlob:', err); + + // Create a minimal valid WAV file with silence as fallback + const fallbackSamples = new Float32Array(1024).fill(0); + const fallbackBuffer = new ArrayBuffer(44 + fallbackSamples.length * 2); + const fallbackView = new DataView(fallbackBuffer); + + writeString(fallbackView, 0, 'RIFF'); + fallbackView.setUint32(4, 36 + fallbackSamples.length * 2, true); + writeString(fallbackView, 8, 'WAVE'); + writeString(fallbackView, 12, 'fmt '); + fallbackView.setUint32(16, 16, true); + fallbackView.setUint16(20, 1, true); + fallbackView.setUint16(22, 1, true); + fallbackView.setUint32(24, sampleRate, true); + fallbackView.setUint32(28, sampleRate * 2, true); + fallbackView.setUint16(32, 2, true); + fallbackView.setUint16(34, 16, true); + writeString(fallbackView, 36, 'data'); + fallbackView.setUint32(40, fallbackSamples.length * 2, true); + floatTo16BitPCM(fallbackView, 44, fallbackSamples); + + return new Blob([fallbackView], { type: 'audio/wav' }); + } +} + +// Draw audio visualizer +function drawVisualizer() { + if (!canvasContext) { + return; + } + + state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); + + // Skip drawing if visualizer is hidden + if (!elements.showVisualizer.checked) { + if (elements.visualizerCanvas.style.opacity !== '0') { + elements.visualizerCanvas.style.opacity = '0'; + } + return; + } else if (elements.visualizerCanvas.style.opacity !== '1') { + elements.visualizerCanvas.style.opacity = '1'; + } + + // Get frequency data if available + if (state.isStreaming && state.analyser) { + try { + state.analyser.getByteFrequencyData(visualizerDataArray); + } catch (e) { + console.warn('Error getting frequency data:', e); + } + } else { + // Fade out when not streaming + for (let i = 0; i < visualizerDataArray.length; i++) { + visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5); + } + } + + // Clear canvas + canvasContext.fillStyle = 'rgb(0, 0, 0)'; + canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height); + + // Draw gradient bars + const width = elements.visualizerCanvas.width; + const height = elements.visualizerCanvas.height; + const barCount = Math.min(visualizerBufferLength, 64); + const barWidth = width / barCount - 1; + + for (let i = 0; i < barCount; i++) { + const index = Math.floor(i * visualizerBufferLength / barCount); + const value = visualizerDataArray[index]; + + // Use logarithmic scale for better audio visualization + // This makes low values more visible while still maintaining full range + const logFactor = 20; + const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor); + const barHeight = scaledValue * height; + + // Position bars + const x = i * (barWidth + 1); + const y = height - barHeight; + + // Create color gradient based on frequency and amplitude + const hue = i / barCount * 360; // Full color spectrum + const saturation = 80 + (value / 255 * 20); // Higher values more saturated + const lightness = 40 + (value / 255 * 20); // Dynamic brightness based on amplitude + + // Draw main bar + canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`; + canvasContext.fillRect(x, y, barWidth, barHeight); + + // Add reflection effect + if (barHeight > 5) { + const gradient = canvasContext.createLinearGradient( + x, y, + x, y + barHeight * 0.5 + ); + gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`); + gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`); + canvasContext.fillStyle = gradient; + canvasContext.fillRect(x, y, barWidth, barHeight * 0.5); + + // Add highlight on top of the bar for better 3D effect + canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`; + canvasContext.fillRect(x, y, barWidth, 2); + } + } + + // Show/hide the label + elements.visualizerLabel.style.opacity = (state.isStreaming) ? '0' : '0.7'; +} + +// Toggle visualizer visibility +function toggleVisualizerVisibility() { + const isVisible = elements.showVisualizer.checked; + elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; + + if (isVisible && state.isStreaming && !state.visualizerAnimationFrame) { + drawVisualizer(); + } +} + +// Handle audio response from server +function handleAudioResponse(data) { + console.log('Received audio response'); + + // Create message container + const messageElement = document.createElement('div'); + messageElement.className = 'message ai'; + + // Add text content if available + if (data.text) { + const textElement = document.createElement('p'); + textElement.textContent = data.text; + messageElement.appendChild(textElement); + } + + // Create and configure audio element + const audioElement = document.createElement('audio'); + audioElement.controls = true; + audioElement.className = 'audio-player'; + + // Set audio source + const audioSource = document.createElement('source'); + audioSource.src = data.audio; + audioSource.type = 'audio/wav'; + + // Add fallback text + audioElement.textContent = 'Your browser does not support the audio element.'; + + // Assemble audio element + audioElement.appendChild(audioSource); + messageElement.appendChild(audioElement); + + // Add timestamp + const timeElement = document.createElement('span'); + timeElement.className = 'message-time'; + timeElement.textContent = new Date().toLocaleTimeString(); + messageElement.appendChild(timeElement); + + // Add to conversation + elements.conversation.appendChild(messageElement); + + // Auto-scroll to bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; + + // Auto-play if enabled + if (elements.autoPlayResponses.checked) { + audioElement.play() + .catch(err => { + console.warn('Auto-play failed:', err); + addSystemMessage('Auto-play failed. Please click play to hear the response.'); + }); + } + + // Re-enable stream button after processing is complete + if (state.isStreaming) { + elements.streamButton.innerHTML = ' Listening...'; + elements.streamButton.classList.add('recording'); + elements.streamButton.classList.remove('processing'); + } +} + +// Handle transcription response from server +function handleTranscription(data) { + console.log('Received transcription:', data.text); + + // Create message element + const messageElement = document.createElement('div'); + messageElement.className = 'message user'; + + // Add text content + const textElement = document.createElement('p'); + textElement.textContent = data.text; + messageElement.appendChild(textElement); + + // Add timestamp + const timeElement = document.createElement('span'); + timeElement.className = 'message-time'; + timeElement.textContent = new Date().toLocaleTimeString(); + messageElement.appendChild(timeElement); + + // Add to conversation + elements.conversation.appendChild(messageElement); + + // Auto-scroll to bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; +} + +// Handle context update from server +function handleContextUpdate(data) { + console.log('Context updated:', data.message); +} + +// Handle streaming status updates from server +function handleStreamingStatus(data) { + console.log('Streaming status:', data.status); + + if (data.status === 'stopped') { + // Reset UI if needed + if (state.isStreaming) { + stopStreaming(false); // Don't send to server since this came from server + } + } +} + +// Add a system message to the conversation +function addSystemMessage(message) { + const messageElement = document.createElement('div'); + messageElement.className = 'message system'; + messageElement.textContent = message; + elements.conversation.appendChild(messageElement); + + // Auto-scroll to bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; +} + +// Downsample audio buffer to target sample rate +function downsampleBuffer(buffer, originalSampleRate, targetSampleRate) { + if (originalSampleRate === targetSampleRate) { + return buffer; + } + + const ratio = originalSampleRate / targetSampleRate; + const newLength = Math.round(buffer.length / ratio); + const result = new Float32Array(newLength); + + for (let i = 0; i < newLength; i++) { + const pos = Math.round(i * ratio); + result[i] = buffer[pos]; + } + + return result; +} + +// Handle processing status updates +function handleProcessingStatus(data) { + console.log('Processing status update:', data); + + // Show processing status in UI + if (data.status === 'generating_audio') { + elements.streamButton.innerHTML = ' Processing...'; + elements.streamButton.classList.add('processing'); + elements.streamButton.classList.remove('recording'); + + // Show message to user + addSystemMessage(data.message || 'Processing your request...'); + } +} + +// Handle the start of an audio streaming response +function handleAudioResponseStart(data) { + console.log('Audio response starting:', data); + + // Reset streaming audio state + streamingAudio.chunks = []; + streamingAudio.totalChunks = data.total_chunks; + streamingAudio.receivedChunks = 0; + streamingAudio.text = data.text; + streamingAudio.complete = false; + + // Create message container now, so we can update it as chunks arrive + const messageElement = document.createElement('div'); + messageElement.className = 'message ai processing'; + + // Add text content if available + if (data.text) { + const textElement = document.createElement('p'); + textElement.textContent = data.text; + messageElement.appendChild(textElement); + } + + // Create audio element (will be populated as chunks arrive) + const audioElement = document.createElement('audio'); + audioElement.controls = true; + audioElement.className = 'audio-player'; + audioElement.textContent = 'Audio is being generated...'; + messageElement.appendChild(audioElement); + + // Add timestamp + const timeElement = document.createElement('span'); + timeElement.className = 'message-time'; + timeElement.textContent = new Date().toLocaleTimeString(); + messageElement.appendChild(timeElement); + + // Add loading indicator + const loadingElement = document.createElement('div'); + loadingElement.className = 'loading-indicator'; + loadingElement.innerHTML = '
Generating audio response...'; + messageElement.appendChild(loadingElement); + + // Add to conversation + elements.conversation.appendChild(messageElement); + + // Auto-scroll to bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; + + // Store elements for later updates + streamingAudio.messageElement = messageElement; + streamingAudio.audioElement = audioElement; +} + +// Handle an incoming audio chunk +function handleAudioResponseChunk(data) { + console.log(`Received audio chunk ${data.chunk_index + 1}/${data.total_chunks}`); + + // Store the chunk + streamingAudio.chunks[data.chunk_index] = data.audio; + streamingAudio.receivedChunks++; + + // Update progress in the UI + if (streamingAudio.messageElement) { + const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator span'); + if (loadingElement) { + loadingElement.textContent = `Generating audio response... ${Math.round((streamingAudio.receivedChunks / data.total_chunks) * 100)}%`; + } + } + + // If this is the first chunk, start playing it immediately for faster response + if (data.chunk_index === 0 && streamingAudio.audioElement && elements.autoPlayResponses && elements.autoPlayResponses.checked) { + try { + streamingAudio.audioElement.src = data.audio; + streamingAudio.audioElement.play().catch(err => console.warn('Auto-play failed:', err)); + } catch (e) { + console.error('Error playing first chunk:', e); + } + } + + // If this is the last chunk or we've received all chunks, finalize the audio + if (data.is_last || streamingAudio.receivedChunks >= data.total_chunks) { + finalizeStreamingAudio(); + } +} + +// Handle completion of audio streaming +function handleAudioResponseComplete(data) { + console.log('Audio response complete:', data); + streamingAudio.complete = true; + + // Make sure we finalize the audio even if some chunks were missed + finalizeStreamingAudio(); + + // Update UI to normal state + if (state.isStreaming) { + elements.streamButton.innerHTML = ' Listening...'; + elements.streamButton.classList.add('recording'); + elements.streamButton.classList.remove('processing'); + } +} + +// Finalize streaming audio by combining chunks and updating the UI +function finalizeStreamingAudio() { + if (!streamingAudio.messageElement || streamingAudio.chunks.length === 0) { + return; + } + + try { + // For more sophisticated audio streaming, you would need to properly concatenate + // the WAV files, but for now we'll use the last chunk as the complete audio + // since it should contain the entire response due to how the server is implementing it + const lastChunkIndex = streamingAudio.chunks.length - 1; + const audioData = streamingAudio.chunks[lastChunkIndex] || streamingAudio.chunks[0]; + + // Update the audio element with the complete audio + if (streamingAudio.audioElement) { + streamingAudio.audioElement.src = audioData; + + // Auto-play if enabled and not already playing + if (elements.autoPlayResponses && elements.autoPlayResponses.checked && + streamingAudio.audioElement.paused) { + streamingAudio.audioElement.play() + .catch(err => { + console.warn('Auto-play failed:', err); + addSystemMessage('Auto-play failed. Please click play to hear the response.'); + }); + } + } + + // Remove loading indicator and processing class + if (streamingAudio.messageElement) { + const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator'); + if (loadingElement) { + streamingAudio.messageElement.removeChild(loadingElement); + } + streamingAudio.messageElement.classList.remove('processing'); + } + + console.log('Audio response finalized and ready for playback'); + } catch (e) { + console.error('Error finalizing streaming audio:', e); + } + + // Reset streaming audio state + streamingAudio.chunks = []; + streamingAudio.totalChunks = 0; + streamingAudio.receivedChunks = 0; + streamingAudio.messageElement = null; + streamingAudio.audioElement = null; +} + +// Add CSS styles for new UI elements +document.addEventListener('DOMContentLoaded', function() { + // Add styles for processing state + const style = document.createElement('style'); + style.textContent = ` + .message.processing { + opacity: 0.8; + } + + .loading-indicator { + display: flex; + align-items: center; + margin-top: 8px; + font-size: 0.9em; + color: #666; + } + + .loading-spinner { + width: 16px; + height: 16px; + border: 2px solid #ddd; + border-top: 2px solid var(--primary-color); + border-radius: 50%; + margin-right: 8px; + animation: spin 1s linear infinite; + } + + @keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } + } + `; + document.head.appendChild(style); +}); + +// Initialize the application when DOM is fully loaded +document.addEventListener('DOMContentLoaded', initializeApp); + diff --git a/Backend/watermarking.py b/Backend/watermarking.py new file mode 100644 index 0000000..093962f --- /dev/null +++ b/Backend/watermarking.py @@ -0,0 +1,79 @@ +import argparse + +import silentcipher +import torch +import torchaudio + +# This watermark key is public, it is not secure. +# If using CSM 1B in another application, use a new private key and keep it secret. +CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201] + + +def cli_check_audio() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--audio_path", type=str, required=True) + args = parser.parse_args() + + check_audio_from_file(args.audio_path) + + +def load_watermarker(device: str = "cuda") -> silentcipher.server.Model: + model = silentcipher.get_model( + model_type="44.1k", + device=device, + ) + return model + + +@torch.inference_mode() +def watermark( + watermarker: silentcipher.server.Model, + audio_array: torch.Tensor, + sample_rate: int, + watermark_key: list[int], +) -> tuple[torch.Tensor, int]: + audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100) + encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36) + + output_sample_rate = min(44100, sample_rate) + encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate) + return encoded, output_sample_rate + + +@torch.inference_mode() +def verify( + watermarker: silentcipher.server.Model, + watermarked_audio: torch.Tensor, + sample_rate: int, + watermark_key: list[int], +) -> bool: + watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100) + result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True) + + is_watermarked = result["status"] + if is_watermarked: + is_csm_watermarked = result["messages"][0] == watermark_key + else: + is_csm_watermarked = False + + return is_watermarked and is_csm_watermarked + + +def check_audio_from_file(audio_path: str) -> None: + watermarker = load_watermarker(device="cuda") + + audio_array, sample_rate = load_audio(audio_path) + is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK) + + outcome = "Watermarked" if is_watermarked else "Not watermarked" + print(f"{outcome}: {audio_path}") + + +def load_audio(audio_path: str) -> tuple[torch.Tensor, int]: + audio_array, sample_rate = torchaudio.load(audio_path) + audio_array = audio_array.mean(dim=0) + return audio_array, int(sample_rate) + + +if __name__ == "__main__": + cli_check_audio() From afddf08d3ef6052b6dc5707f19a657ecc02b8f9d Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 01:46:31 -0400 Subject: [PATCH 06/30] db progress --- React/src/pages/api/databaseStorage.ts | 40 +++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/React/src/pages/api/databaseStorage.ts b/React/src/pages/api/databaseStorage.ts index a751f86..226625e 100644 --- a/React/src/pages/api/databaseStorage.ts +++ b/React/src/pages/api/databaseStorage.ts @@ -1,11 +1,37 @@ import { NextApiRequest, NextApiResponse } from "next"; -import { MongoClient } from "mongodb"; +import mongoose from "mongoose"; -export default function handler(req: NextApiRequest, res: NextApiResponse){ - if(req.method === 'POST') - const { codeword, contacts } = req.body; +const uri = process.env.MONGODB_URI || "mongodb://localhost:27017/mydatabase"; +const clientOptions = { serverApi: { version: "1" as const, strict: true, deprecationErrors: true } }; - try{ - - } +// Create a reusable connection function +async function connectToDatabase() { + if (mongoose.connection.readyState === 0) { + // Only connect if not already connected + await mongoose.connect(uri, clientOptions); + console.log("Connected to MongoDB!"); + } +} + +export default async function handler(req: NextApiRequest, res: NextApiResponse) { + try { + // Ensure the database is connected + await connectToDatabase(); + + if (req.method === 'POST') { + const { codeword, contacts } = req.body; + + // Perform database operations here + console.log("Codeword:", codeword); + console.log("Contacts:", contacts); + + res.status(200).json({ success: true, message: "Data saved successfully!" }); + } else { + res.setHeader('Allow', ['POST']); + res.status(405).end(`Method ${req.method} Not Allowed`); + } + } catch (error) { + console.error("Error:", error); + res.status(500).json({ success: false, error: "Internal Server Error" }); + } } \ No newline at end of file From 58920480d987548d49e65a91e5df3473080568c3 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:03:57 -0400 Subject: [PATCH 07/30] Demo Update 13 --- Backend/index.html | 529 ++++-------------------- Backend/voice-chat.js | 906 +++++++++++++++--------------------------- 2 files changed, 391 insertions(+), 1044 deletions(-) diff --git a/Backend/index.html b/Backend/index.html index 6169390..01bd5f7 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -168,6 +168,10 @@ animation: pulse 1.5s infinite; } + button.processing { + background-color: #ffa000; + } + @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.7; } @@ -193,6 +197,63 @@ background-color: var(--success-color); } + /* Audio visualizer styles */ + .visualizer-container { + margin-top: 15px; + position: relative; + width: 100%; + height: 100px; + background-color: #000; + border-radius: 8px; + overflow: hidden; + } + + #audioVisualizer { + width: 100%; + height: 100%; + transition: opacity 0.3s; + } + + #visualizerLabel { + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: rgba(255, 255, 255, 0.7); + font-size: 0.9em; + pointer-events: none; + transition: opacity 0.3s; + } + + .volume-meter { + height: 8px; + width: 100%; + background-color: #eee; + border-radius: 4px; + margin-top: 8px; + overflow: hidden; + } + + #volumeLevel { + height: 100%; + width: 0%; + background-color: var(--primary-color); + border-radius: 4px; + transition: width 0.1s ease, background-color 0.2s; + } + + .settings-toggles { + display: flex; + flex-direction: column; + gap: 12px; + } + + .toggle-switch { + display: flex; + align-items: center; + gap: 8px; + } + footer { margin-top: 30px; text-align: center; @@ -233,6 +294,15 @@ Clear + + +
+ +
Start speaking to see audio visualization
+
+
+
+
@@ -242,6 +312,10 @@
+
+ + +
+
+ + +
@@ -258,454 +336,7 @@

Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition

- + + \ No newline at end of file diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 89ec71a..93bd434 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -1,7 +1,7 @@ /** - * Sesame AI Voice Chat Client + * CSM AI Voice Chat Client * - * A web client that connects to a Sesame AI voice chat server and enables + * A web client that connects to a CSM AI voice chat server and enables * real-time voice conversation with an AI assistant. */ @@ -13,19 +13,19 @@ const CLIENT_SILENCE_DURATION_MS = 750; // DOM elements const elements = { - conversation: null, - streamButton: null, - clearButton: null, - thresholdSlider: null, - thresholdValue: null, - visualizerCanvas: null, - visualizerLabel: null, - volumeLevel: null, - statusDot: null, - statusText: null, - speakerSelection: null, - autoPlayResponses: null, - showVisualizer: null + conversation: document.getElementById('conversation'), + streamButton: document.getElementById('streamButton'), + clearButton: document.getElementById('clearButton'), + thresholdSlider: document.getElementById('thresholdSlider'), + thresholdValue: document.getElementById('thresholdValue'), + visualizerCanvas: document.getElementById('audioVisualizer'), + visualizerLabel: document.getElementById('visualizerLabel'), + volumeLevel: document.getElementById('volumeLevel'), + statusDot: document.getElementById('statusDot'), + statusText: document.getElementById('statusText'), + speakerSelection: document.getElementById('speakerSelect'), + autoPlayResponses: document.getElementById('autoPlayResponses'), + showVisualizer: document.getElementById('showVisualizer') }; // Application state @@ -50,7 +50,7 @@ let canvasContext = null; let visualizerBufferLength = 0; let visualizerDataArray = null; -// New state variables to track incremental audio streaming +// Audio streaming state const streamingAudio = { messageElement: null, audioElement: null, @@ -58,9 +58,6 @@ const streamingAudio = { totalChunks: 0, receivedChunks: 0, text: '', - mediaSource: null, - sourceBuffer: null, - audioContext: null, complete: false }; @@ -79,25 +76,15 @@ function initializeApp() { setupVisualizer(); // Show welcome message - addSystemMessage('Welcome to Sesame AI Voice Chat! Click "Start Conversation" to begin.'); + addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.'); } // Initialize UI elements function initializeUIElements() { - // Store references to UI elements - elements.conversation = document.getElementById('conversation'); - elements.streamButton = document.getElementById('streamButton'); - elements.clearButton = document.getElementById('clearButton'); - elements.thresholdSlider = document.getElementById('thresholdSlider'); - elements.thresholdValue = document.getElementById('thresholdValue'); - elements.visualizerCanvas = document.getElementById('audioVisualizer'); - elements.visualizerLabel = document.getElementById('visualizerLabel'); - elements.volumeLevel = document.getElementById('volumeLevel'); - elements.statusDot = document.getElementById('statusDot'); - elements.statusText = document.getElementById('statusText'); - elements.speakerSelection = document.getElementById('speakerSelect'); // Changed to match HTML - elements.autoPlayResponses = document.getElementById('autoPlayResponses'); - elements.showVisualizer = document.getElementById('showVisualizer'); + // Update threshold display + if (elements.thresholdValue) { + elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); + } } // Setup Socket.IO connection @@ -106,36 +93,31 @@ function setupSocketConnection() { // Connection events state.socket.on('connect', () => { - console.log('Connected to server'); updateConnectionStatus(true); + addSystemMessage('Connected to server.'); }); state.socket.on('disconnect', () => { - console.log('Disconnected from server'); updateConnectionStatus(false); - - // Stop streaming if active - if (state.isStreaming) { - stopStreaming(false); - } + addSystemMessage('Disconnected from server.'); + stopStreaming(false); }); state.socket.on('error', (data) => { - console.error('Socket error:', data.message); addSystemMessage(`Error: ${data.message}`); + console.error('Server error:', data.message); }); // Register message handlers - state.socket.on('audio_response', handleAudioResponse); state.socket.on('transcription', handleTranscription); state.socket.on('context_updated', handleContextUpdate); state.socket.on('streaming_status', handleStreamingStatus); + state.socket.on('processing_status', handleProcessingStatus); - // New event handlers for incremental audio streaming + // Handlers for incremental audio streaming state.socket.on('audio_response_start', handleAudioResponseStart); state.socket.on('audio_response_chunk', handleAudioResponseChunk); state.socket.on('audio_response_complete', handleAudioResponseComplete); - state.socket.on('processing_status', handleProcessingStatus); } // Setup event listeners @@ -147,15 +129,19 @@ function setupEventListeners() { elements.clearButton.addEventListener('click', clearConversation); // Threshold slider - elements.thresholdSlider.addEventListener('input', updateThreshold); + if (elements.thresholdSlider) { + elements.thresholdSlider.addEventListener('input', updateThreshold); + } // Speaker selection elements.speakerSelection.addEventListener('change', () => { - state.currentSpeaker = parseInt(elements.speakerSelection.value, 10); + state.currentSpeaker = parseInt(elements.speakerSelection.value); }); // Visualizer toggle - elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); + if (elements.showVisualizer) { + elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); + } } // Setup audio visualizer @@ -168,20 +154,28 @@ function setupVisualizer() { elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth; elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight; - // Initialize the visualizer + // Initialize visualization data array + visualizerDataArray = new Uint8Array(128); + + // Start the visualizer animation drawVisualizer(); } // Update connection status UI function updateConnectionStatus(isConnected) { - elements.statusDot.classList.toggle('active', isConnected); - elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected'; + if (isConnected) { + elements.statusDot.classList.add('active'); + elements.statusText.textContent = 'Connected'; + } else { + elements.statusDot.classList.remove('active'); + elements.statusText.textContent = 'Disconnected'; + } } // Toggle streaming state function toggleStreaming() { if (state.isStreaming) { - stopStreaming(true); + stopStreaming(); } else { startStreaming(); } @@ -189,213 +183,132 @@ function toggleStreaming() { // Start streaming audio to the server function startStreaming() { - if (state.isStreaming) return; + if (!state.socket || !state.socket.connected) { + addSystemMessage('Not connected to server. Please refresh the page.'); + return; + } // Request microphone access navigator.mediaDevices.getUserMedia({ audio: true, video: false }) .then(stream => { - // Show processing state while setting up - elements.streamButton.innerHTML = ' Initializing...'; + state.isStreaming = true; + elements.streamButton.classList.add('recording'); + elements.streamButton.innerHTML = ' Stop Recording'; - // Create audio context - state.audioContext = new (window.AudioContext || window.webkitAudioContext)(); - - // Create microphone source + // Initialize Web Audio API + state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); state.microphone = state.audioContext.createMediaStreamSource(stream); - - // Create analyser for visualizer state.analyser = state.audioContext.createAnalyser(); - state.analyser.fftSize = 256; + state.analyser.fftSize = 2048; + + // Setup analyzer for visualizer visualizerBufferLength = state.analyser.frequencyBinCount; visualizerDataArray = new Uint8Array(visualizerBufferLength); - // Connect microphone to analyser state.microphone.connect(state.analyser); - // Create script processor for audio processing - const bufferSize = 4096; - state.streamProcessor = state.audioContext.createScriptProcessor(bufferSize, 1, 1); + // Create processor node for audio data + const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1); + processorNode.onaudioprocess = handleAudioProcess; + state.analyser.connect(processorNode); + processorNode.connect(state.audioContext.destination); + state.streamProcessor = processorNode; - // Set up audio processing callback - state.streamProcessor.onaudioprocess = handleAudioProcess; - - // Connect the processors - state.analyser.connect(state.streamProcessor); - state.streamProcessor.connect(state.audioContext.destination); - - // Update UI - state.isStreaming = true; - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - - // Initialize energy window + state.silenceTimer = null; state.energyWindow = []; + state.isSpeaking = false; + + // Notify server + state.socket.emit('start_stream'); // Start volume meter updates state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100); - // Start visualizer if enabled - if (elements.showVisualizer.checked && !state.visualizerAnimationFrame) { - drawVisualizer(); + // Make sure visualizer is visible if enabled + if (elements.showVisualizer && elements.showVisualizer.checked) { + elements.visualizerLabel.style.opacity = '0'; } - // Show starting message - addSystemMessage('Listening... Speak clearly into your microphone.'); - - // Notify the server that we're starting - state.socket.emit('stream_audio', { - audio: '', - speaker: state.currentSpeaker - }); + addSystemMessage('Recording started. Speak now...'); }) - .catch(err => { - console.error('Error accessing microphone:', err); - addSystemMessage(`Error: ${err.message}. Please make sure your microphone is connected and you've granted permission.`); - elements.streamButton.innerHTML = ' Start Conversation'; + .catch(error => { + console.error('Error accessing microphone:', error); + addSystemMessage('Could not access microphone. Please check permissions.'); }); } // Stop streaming audio function stopStreaming(notifyServer = true) { - if (!state.isStreaming) return; - - // Update UI first - elements.streamButton.innerHTML = ' Start Conversation'; - elements.streamButton.classList.remove('recording'); - elements.streamButton.classList.remove('processing'); - - // Stop volume meter updates - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; + if (state.isStreaming) { + state.isStreaming = false; + elements.streamButton.classList.remove('recording'); + elements.streamButton.classList.remove('processing'); + elements.streamButton.innerHTML = ' Start Conversation'; + + // Clean up audio resources + if (state.streamProcessor) { + state.streamProcessor.disconnect(); + state.streamProcessor = null; + } + + if (state.analyser) { + state.analyser.disconnect(); + state.analyser = null; + } + + if (state.microphone) { + state.microphone.disconnect(); + state.microphone = null; + } + + if (state.audioContext) { + state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); + state.audioContext = null; + } + + // Clear any pending silence timer + if (state.silenceTimer) { + clearTimeout(state.silenceTimer); + state.silenceTimer = null; + } + + // Clear volume meter updates + if (state.volumeUpdateInterval) { + clearInterval(state.volumeUpdateInterval); + state.volumeUpdateInterval = null; + + // Reset volume meter + if (elements.volumeLevel) { + elements.volumeLevel.style.width = '0%'; + } + } + + // Show visualizer label + if (elements.visualizerLabel) { + elements.visualizerLabel.style.opacity = '0.7'; + } + + // Notify server if needed + if (notifyServer && state.socket && state.socket.connected) { + state.socket.emit('stop_stream'); + } + + addSystemMessage('Recording stopped.'); } - - // Stop all audio processing - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - } - - if (state.microphone) { - state.microphone.disconnect(); - } - - // Close audio context - if (state.audioContext && state.audioContext.state !== 'closed') { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - } - - // Cleanup animation frames - if (state.visualizerAnimationFrame) { - cancelAnimationFrame(state.visualizerAnimationFrame); - state.visualizerAnimationFrame = null; - } - - // Reset state - state.isStreaming = false; - state.isSpeaking = false; - - // Notify the server - if (notifyServer && state.socket && state.socket.connected) { - state.socket.emit('stop_streaming', { - speaker: state.currentSpeaker - }); - } - - // Show message - addSystemMessage('Conversation paused. Click "Start Conversation" to resume.'); } // Handle audio processing function handleAudioProcess(event) { + if (!state.isStreaming) return; + const inputData = event.inputBuffer.getChannelData(0); - - // Calculate audio energy (volume level) const energy = calculateAudioEnergy(inputData); - - // Update energy window for averaging updateEnergyWindow(energy); - // Calculate average energy - const avgEnergy = calculateAverageEnergy(); + const averageEnergy = calculateAverageEnergy(); + const isSilent = averageEnergy < state.silenceThreshold; - // Determine if audio is silent - const isSilent = avgEnergy < state.silenceThreshold; - - // Debug logging only if significant changes in audio patterns - if (Math.random() < 0.05) { // Log only 5% of frames to avoid console spam - console.log(`Audio: len=${inputData.length}, energy=${energy.toFixed(4)}, avg=${avgEnergy.toFixed(4)}, silent=${isSilent}`); - } - - // Handle speech state based on silence handleSpeechState(isSilent); - - // Only send audio chunk if we detect speech - if (!isSilent) { - // Create a resampled version at 24kHz for the server - // Most WebRTC audio is 48kHz, but we want 24kHz for the model - const resampledData = downsampleBuffer(inputData, state.audioContext.sampleRate, 24000); - - // Send the audio chunk to the server - sendAudioChunk(resampledData, state.currentSpeaker); - } -} - -// Cleanup audio resources when done -function cleanupAudioResources() { - // Stop all audio processing - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - state.analyser = null; - } - - if (state.microphone) { - state.microphone.disconnect(); - state.microphone = null; - } - - // Close audio context - if (state.audioContext && state.audioContext.state !== 'closed') { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - } - - // Cancel all timers and animation frames - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; - } - - if (state.visualizerAnimationFrame) { - cancelAnimationFrame(state.visualizerAnimationFrame); - state.visualizerAnimationFrame = null; - } - - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } -} - -// Clear conversation history -function clearConversation() { - if (elements.conversation) { - elements.conversation.innerHTML = ''; - addSystemMessage('Conversation cleared.'); - - // Notify server to clear context - if (state.socket && state.socket.connected) { - state.socket.emit('clear_context'); - } - } } // Calculate audio energy (volume) @@ -419,7 +332,7 @@ function updateEnergyWindow(energy) { function calculateAverageEnergy() { if (state.energyWindow.length === 0) return 0; - const sum = state.energyWindow.reduce((a, b) => a + b, 0); + const sum = state.energyWindow.reduce((acc, val) => acc + val, 0); return sum / state.energyWindow.length; } @@ -431,13 +344,12 @@ function updateThreshold() { // Update the volume meter display function updateVolumeMeter() { - if (!state.isStreaming || !state.energyWindow.length) return; + if (!state.isStreaming || !state.energyWindow.length || !elements.volumeLevel) return; const avgEnergy = calculateAverageEnergy(); // Scale energy to percentage (0-100) - // Typically, energy values will be very small (e.g., 0.001 to 0.1) - // So we multiply by a factor to make it more visible + // Energy values are typically very small (e.g., 0.001 to 0.1) const scaleFactor = 1000; const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor)); @@ -456,197 +368,134 @@ function updateVolumeMeter() { // Handle speech/silence state transitions function handleSpeechState(isSilent) { - if (state.isSpeaking && isSilent) { - // Transition from speaking to silence - if (!state.silenceTimer) { - state.silenceTimer = setTimeout(() => { - // Only consider it a real silence after a certain duration - // This prevents detecting brief pauses as the end of speech - state.isSpeaking = false; + if (state.isSpeaking) { + if (isSilent) { + // User was speaking but now is silent + if (!state.silenceTimer) { + state.silenceTimer = setTimeout(() => { + // Silence lasted long enough, consider speech done + if (state.isSpeaking) { + state.isSpeaking = false; + + // Get the current audio data and send it + const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max + state.analyser.getFloatTimeDomainData(audioBuffer); + + // Create WAV blob + const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); + + // Convert to base64 + const reader = new FileReader(); + reader.onloadend = function() { + sendAudioChunk(reader.result, state.currentSpeaker); + }; + reader.readAsDataURL(wavBlob); + + // Update button state + elements.streamButton.classList.add('processing'); + elements.streamButton.innerHTML = ' Processing...'; + + addSystemMessage('Processing your message...'); + } + }, CLIENT_SILENCE_DURATION_MS); + } + } else { + // User is still speaking, reset silence timer + if (state.silenceTimer) { + clearTimeout(state.silenceTimer); state.silenceTimer = null; - }, CLIENT_SILENCE_DURATION_MS); + } + } + } else { + if (!isSilent) { + // User started speaking + state.isSpeaking = true; + if (state.silenceTimer) { + clearTimeout(state.silenceTimer); + state.silenceTimer = null; + } } - } else if (state.silenceTimer && !isSilent) { - // User started speaking again, cancel the silence timer - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - - // Update speaking state for non-silent audio - if (!isSilent) { - state.isSpeaking = true; } } // Send audio chunk to server function sendAudioChunk(audioData, speaker) { - if (!state.socket || !state.socket.connected) { - console.warn('Socket not connected'); - return; - } - - console.log(`Preparing audio chunk: length=${audioData.length}, speaker=${speaker}`); - - // Check for NaN or invalid values - let hasInvalidValues = false; - for (let i = 0; i < audioData.length; i++) { - if (isNaN(audioData[i]) || !isFinite(audioData[i])) { - hasInvalidValues = true; - console.warn(`Invalid audio value at index ${i}: ${audioData[i]}`); - break; - } - } - - if (hasInvalidValues) { - console.warn('Audio data contains invalid values. Creating silent audio.'); - audioData = new Float32Array(audioData.length).fill(0); - } - - try { - // Create WAV blob - const wavData = createWavBlob(audioData, 24000); - console.log(`WAV blob created: ${wavData.size} bytes`); - - const reader = new FileReader(); - - reader.onloadend = function() { - try { - // Get base64 data - const base64data = reader.result; - console.log(`Base64 data created: ${base64data.length} bytes`); - - // Send to server - state.socket.emit('stream_audio', { - audio: base64data, - speaker: speaker - }); - console.log('Audio chunk sent to server'); - } catch (err) { - console.error('Error preparing audio data:', err); - } - }; - - reader.onerror = function() { - console.error('Error reading audio data as base64'); - }; - - reader.readAsDataURL(wavData); - } catch (err) { - console.error('Error creating WAV data:', err); + if (state.socket && state.socket.connected) { + state.socket.emit('audio_chunk', { + audio: audioData, + speaker: speaker + }); } } -// Create WAV blob from audio data with improved error handling +// Create WAV blob from audio data function createWavBlob(audioData, sampleRate) { - // Validate input - if (!audioData || audioData.length === 0) { - console.warn('Empty audio data provided to createWavBlob'); - audioData = new Float32Array(1024).fill(0); // Create 1024 samples of silence + const numChannels = 1; + const bitsPerSample = 16; + const bytesPerSample = bitsPerSample / 8; + + // Create buffer for WAV file + const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample); + const view = new DataView(buffer); + + // Write WAV header + // "RIFF" chunk descriptor + writeString(view, 0, 'RIFF'); + view.setUint32(4, 36 + audioData.length * bytesPerSample, true); + writeString(view, 8, 'WAVE'); + + // "fmt " sub-chunk + writeString(view, 12, 'fmt '); + view.setUint32(16, 16, true); // subchunk1size + view.setUint16(20, 1, true); // audio format (PCM) + view.setUint16(22, numChannels, true); + view.setUint32(24, sampleRate, true); + view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate + view.setUint16(32, numChannels * bytesPerSample, true); // block align + view.setUint16(34, bitsPerSample, true); + + // "data" sub-chunk + writeString(view, 36, 'data'); + view.setUint32(40, audioData.length * bytesPerSample, true); + + // Write audio data + const audioDataStart = 44; + for (let i = 0; i < audioData.length; i++) { + const sample = Math.max(-1, Math.min(1, audioData[i])); + const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF; + view.setInt16(audioDataStart + i * bytesPerSample, value, true); } - // Function to convert Float32Array to Int16Array for WAV format - function floatTo16BitPCM(output, offset, input) { - for (let i = 0; i < input.length; i++, offset += 2) { - // Ensure values are in -1 to 1 range - const s = Math.max(-1, Math.min(1, input[i])); - // Convert to 16-bit PCM - output.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true); - } + return new Blob([buffer], { type: 'audio/wav' }); +} + +// Helper function to write strings to DataView +function writeString(view, offset, string) { + for (let i = 0; i < string.length; i++) { + view.setUint8(offset + i, string.charCodeAt(i)); } - - // Create WAV header - function writeString(view, offset, string) { - for (let i = 0; i < string.length; i++) { - view.setUint8(offset + i, string.charCodeAt(i)); - } - } - - try { - // Create WAV file with header - careful with buffer sizes - const buffer = new ArrayBuffer(44 + audioData.length * 2); - const view = new DataView(buffer); - - // RIFF identifier - writeString(view, 0, 'RIFF'); - - // File length (will be filled later) - view.setUint32(4, 36 + audioData.length * 2, true); - - // WAVE identifier - writeString(view, 8, 'WAVE'); - - // fmt chunk identifier - writeString(view, 12, 'fmt '); - - // fmt chunk length - view.setUint32(16, 16, true); - - // Sample format (1 is PCM) - view.setUint16(20, 1, true); - - // Mono channel - view.setUint16(22, 1, true); - - // Sample rate - view.setUint32(24, sampleRate, true); - - // Byte rate (sample rate * block align) - view.setUint32(28, sampleRate * 2, true); - - // Block align (channels * bytes per sample) - view.setUint16(32, 2, true); - - // Bits per sample - view.setUint16(34, 16, true); - - // data chunk identifier - writeString(view, 36, 'data'); - - // data chunk length - view.setUint32(40, audioData.length * 2, true); - - // Write the PCM samples - floatTo16BitPCM(view, 44, audioData); - - // Create and return blob - return new Blob([view], { type: 'audio/wav' }); - } catch (err) { - console.error('Error in createWavBlob:', err); - - // Create a minimal valid WAV file with silence as fallback - const fallbackSamples = new Float32Array(1024).fill(0); - const fallbackBuffer = new ArrayBuffer(44 + fallbackSamples.length * 2); - const fallbackView = new DataView(fallbackBuffer); - - writeString(fallbackView, 0, 'RIFF'); - fallbackView.setUint32(4, 36 + fallbackSamples.length * 2, true); - writeString(fallbackView, 8, 'WAVE'); - writeString(fallbackView, 12, 'fmt '); - fallbackView.setUint32(16, 16, true); - fallbackView.setUint16(20, 1, true); - fallbackView.setUint16(22, 1, true); - fallbackView.setUint32(24, sampleRate, true); - fallbackView.setUint32(28, sampleRate * 2, true); - fallbackView.setUint16(32, 2, true); - fallbackView.setUint16(34, 16, true); - writeString(fallbackView, 36, 'data'); - fallbackView.setUint32(40, fallbackSamples.length * 2, true); - floatTo16BitPCM(fallbackView, 44, fallbackSamples); - - return new Blob([fallbackView], { type: 'audio/wav' }); +} + +// Clear conversation history +function clearConversation() { + elements.conversation.innerHTML = ''; + if (state.socket && state.socket.connected) { + state.socket.emit('clear_context'); } + addSystemMessage('Conversation cleared.'); } // Draw audio visualizer function drawVisualizer() { - if (!canvasContext) { + if (!canvasContext || !elements.visualizerCanvas) { + state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); return; } state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - // Skip drawing if visualizer is hidden - if (!elements.showVisualizer.checked) { + // Skip drawing if visualizer is hidden or not enabled + if (elements.showVisualizer && !elements.showVisualizer.checked) { if (elements.visualizerCanvas.style.opacity !== '0') { elements.visualizerCanvas.style.opacity = '0'; } @@ -684,7 +533,6 @@ function drawVisualizer() { const value = visualizerDataArray[index]; // Use logarithmic scale for better audio visualization - // This makes low values more visible while still maintaining full range const logFactor = 20; const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor); const barHeight = scaledValue * height; @@ -696,13 +544,13 @@ function drawVisualizer() { // Create color gradient based on frequency and amplitude const hue = i / barCount * 360; // Full color spectrum const saturation = 80 + (value / 255 * 20); // Higher values more saturated - const lightness = 40 + (value / 255 * 20); // Dynamic brightness based on amplitude + const lightness = 40 + (value / 255 * 20); // Dynamic brightness // Draw main bar canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`; canvasContext.fillRect(x, y, barWidth, barHeight); - // Add reflection effect + // Add highlight effect if (barHeight > 5) { const gradient = canvasContext.createLinearGradient( x, y, @@ -713,255 +561,123 @@ function drawVisualizer() { canvasContext.fillStyle = gradient; canvasContext.fillRect(x, y, barWidth, barHeight * 0.5); - // Add highlight on top of the bar for better 3D effect + // Add highlight on top of the bar canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`; canvasContext.fillRect(x, y, barWidth, 2); } } - - // Show/hide the label - elements.visualizerLabel.style.opacity = (state.isStreaming) ? '0' : '0.7'; } // Toggle visualizer visibility function toggleVisualizerVisibility() { const isVisible = elements.showVisualizer.checked; elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; - - if (isVisible && state.isStreaming && !state.visualizerAnimationFrame) { - drawVisualizer(); - } -} - -// Handle audio response from server -function handleAudioResponse(data) { - console.log('Received audio response'); - - // Create message container - const messageElement = document.createElement('div'); - messageElement.className = 'message ai'; - - // Add text content if available - if (data.text) { - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - } - - // Create and configure audio element - const audioElement = document.createElement('audio'); - audioElement.controls = true; - audioElement.className = 'audio-player'; - - // Set audio source - const audioSource = document.createElement('source'); - audioSource.src = data.audio; - audioSource.type = 'audio/wav'; - - // Add fallback text - audioElement.textContent = 'Your browser does not support the audio element.'; - - // Assemble audio element - audioElement.appendChild(audioSource); - messageElement.appendChild(audioElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // Auto-play if enabled - if (elements.autoPlayResponses.checked) { - audioElement.play() - .catch(err => { - console.warn('Auto-play failed:', err); - addSystemMessage('Auto-play failed. Please click play to hear the response.'); - }); - } - - // Re-enable stream button after processing is complete - if (state.isStreaming) { - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - elements.streamButton.classList.remove('processing'); - } } // Handle transcription response from server function handleTranscription(data) { - console.log('Received transcription:', data.text); - - // Create message element - const messageElement = document.createElement('div'); - messageElement.className = 'message user'; - - // Add text content - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; + const speaker = data.speaker === 0 ? 'user' : 'ai'; + addMessage(data.text, speaker); } // Handle context update from server function handleContextUpdate(data) { - console.log('Context updated:', data.message); + if (data.status === 'cleared') { + elements.conversation.innerHTML = ''; + addSystemMessage('Conversation context cleared.'); + } } // Handle streaming status updates from server function handleStreamingStatus(data) { - console.log('Streaming status:', data.status); - - if (data.status === 'stopped') { - // Reset UI if needed - if (state.isStreaming) { - stopStreaming(false); // Don't send to server since this came from server - } + if (data.status === 'active') { + console.log('Server acknowledged streaming is active'); + } else if (data.status === 'inactive') { + console.log('Server acknowledged streaming is inactive'); } } -// Add a system message to the conversation -function addSystemMessage(message) { - const messageElement = document.createElement('div'); - messageElement.className = 'message system'; - messageElement.textContent = message; - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; -} - -// Downsample audio buffer to target sample rate -function downsampleBuffer(buffer, originalSampleRate, targetSampleRate) { - if (originalSampleRate === targetSampleRate) { - return buffer; - } - - const ratio = originalSampleRate / targetSampleRate; - const newLength = Math.round(buffer.length / ratio); - const result = new Float32Array(newLength); - - for (let i = 0; i < newLength; i++) { - const pos = Math.round(i * ratio); - result[i] = buffer[pos]; - } - - return result; -} - // Handle processing status updates function handleProcessingStatus(data) { - console.log('Processing status update:', data); - - // Show processing status in UI - if (data.status === 'generating_audio') { - elements.streamButton.innerHTML = ' Processing...'; - elements.streamButton.classList.add('processing'); - elements.streamButton.classList.remove('recording'); - - // Show message to user - addSystemMessage(data.message || 'Processing your request...'); + switch (data.status) { + case 'transcribing': + addSystemMessage('Transcribing your message...'); + break; + case 'generating': + addSystemMessage('Generating response...'); + break; + case 'synthesizing': + addSystemMessage('Synthesizing voice...'); + break; } } // Handle the start of an audio streaming response function handleAudioResponseStart(data) { - console.log('Audio response starting:', data); + console.log(`Expecting ${data.total_chunks} audio chunks`); - // Reset streaming audio state + // Reset streaming state streamingAudio.chunks = []; streamingAudio.totalChunks = data.total_chunks; streamingAudio.receivedChunks = 0; streamingAudio.text = data.text; streamingAudio.complete = false; - - // Create message container now, so we can update it as chunks arrive - const messageElement = document.createElement('div'); - messageElement.className = 'message ai processing'; - - // Add text content if available - if (data.text) { - const textElement = document.createElement('p'); - textElement.textContent = data.text; - messageElement.appendChild(textElement); - } - - // Create audio element (will be populated as chunks arrive) - const audioElement = document.createElement('audio'); - audioElement.controls = true; - audioElement.className = 'audio-player'; - audioElement.textContent = 'Audio is being generated...'; - messageElement.appendChild(audioElement); - - // Add timestamp - const timeElement = document.createElement('span'); - timeElement.className = 'message-time'; - timeElement.textContent = new Date().toLocaleTimeString(); - messageElement.appendChild(timeElement); - - // Add loading indicator - const loadingElement = document.createElement('div'); - loadingElement.className = 'loading-indicator'; - loadingElement.innerHTML = '
Generating audio response...'; - messageElement.appendChild(loadingElement); - - // Add to conversation - elements.conversation.appendChild(messageElement); - - // Auto-scroll to bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // Store elements for later updates - streamingAudio.messageElement = messageElement; - streamingAudio.audioElement = audioElement; } // Handle an incoming audio chunk function handleAudioResponseChunk(data) { - console.log(`Received audio chunk ${data.chunk_index + 1}/${data.total_chunks}`); + // Create or update audio element for playback + const audioElement = document.createElement('audio'); + if (elements.autoPlayResponses.checked) { + audioElement.autoplay = true; + } + audioElement.controls = true; + audioElement.className = 'audio-player'; + audioElement.src = data.chunk; // Store the chunk - streamingAudio.chunks[data.chunk_index] = data.audio; + streamingAudio.chunks[data.chunk_index] = data.chunk; streamingAudio.receivedChunks++; - // Update progress in the UI - if (streamingAudio.messageElement) { - const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator span'); - if (loadingElement) { - loadingElement.textContent = `Generating audio response... ${Math.round((streamingAudio.receivedChunks / data.total_chunks) * 100)}%`; + // Add to the conversation + const messages = elements.conversation.querySelectorAll('.message.ai'); + if (messages.length > 0) { + const lastAiMessage = messages[messages.length - 1]; + + // Replace existing audio player if there is one + const existingPlayer = lastAiMessage.querySelector('.audio-player'); + if (existingPlayer) { + lastAiMessage.replaceChild(audioElement, existingPlayer); + } else { + lastAiMessage.appendChild(audioElement); } + } else { + // Create a new message for the AI response + const aiMessage = document.createElement('div'); + aiMessage.className = 'message ai'; + + if (streamingAudio.text) { + const textElement = document.createElement('p'); + textElement.textContent = streamingAudio.text; + aiMessage.appendChild(textElement); + } + + aiMessage.appendChild(audioElement); + elements.conversation.appendChild(aiMessage); } - // If this is the first chunk, start playing it immediately for faster response - if (data.chunk_index === 0 && streamingAudio.audioElement && elements.autoPlayResponses && elements.autoPlayResponses.checked) { - try { - streamingAudio.audioElement.src = data.audio; - streamingAudio.audioElement.play().catch(err => console.warn('Auto-play failed:', err)); - } catch (e) { - console.error('Error playing first chunk:', e); - } - } + // Auto-scroll + elements.conversation.scrollTop = elements.conversation.scrollHeight; - // If this is the last chunk or we've received all chunks, finalize the audio - if (data.is_last || streamingAudio.receivedChunks >= data.total_chunks) { - finalizeStreamingAudio(); + // If this is the last chunk or we've received all expected chunks + if (data.is_last || streamingAudio.receivedChunks >= streamingAudio.totalChunks) { + streamingAudio.complete = true; + + // Reset stream button if we're still streaming + if (state.isStreaming) { + elements.streamButton.classList.remove('processing'); + elements.streamButton.innerHTML = ' Listening...'; + } } } From d2f5df4e154885e0bed778865d61765e99dafe38 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:06:42 -0400 Subject: [PATCH 08/30] Demo Fixes 3 --- Backend/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Backend/server.py b/Backend/server.py index ef9fbda..cb135f6 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -105,6 +105,10 @@ class Conversation: def index(): return send_from_directory('.', 'index.html') +@app.route('/voice-chat.js') +def voice_chat_js(): + return send_from_directory('.', 'voice-chat.js') + @app.route('/api/health') def health_check(): return jsonify({ From 215b420579bff393cfe8f5418a5d1f6b8843380c Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:11:51 -0400 Subject: [PATCH 09/30] Demo Update 14 --- Backend/server.py | 6 +++--- Backend/voice-chat.js | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index cb135f6..78254e4 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -119,7 +119,7 @@ def health_check(): # Socket event handlers @socketio.on('connect') -def handle_connect(): +def handle_connect(auth=None): session_id = request.sid logger.info(f"Client connected: {session_id}") @@ -137,9 +137,9 @@ def handle_connect(): emit('connection_status', {'status': 'connected'}) @socketio.on('disconnect') -def handle_disconnect(): +def handle_disconnect(reason=None): session_id = request.sid - logger.info(f"Client disconnected: {session_id}") + logger.info(f"Client disconnected: {session_id}. Reason: {reason}") # Cleanup if session_id in active_conversations: diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 93bd434..2c9e949 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -574,6 +574,41 @@ function toggleVisualizerVisibility() { elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; } +// Add a message to the conversation +function addMessage(text, type) { + if (!elements.conversation) return; + + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${type}`; + + const textElement = document.createElement('p'); + textElement.textContent = text; + messageDiv.appendChild(textElement); + + elements.conversation.appendChild(messageDiv); + + // Auto-scroll to the bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; + + return messageDiv; +} + +// Add a system message to the conversation +function addSystemMessage(text) { + if (!elements.conversation) return; + + const messageDiv = document.createElement('div'); + messageDiv.className = 'message system'; + messageDiv.textContent = text; + + elements.conversation.appendChild(messageDiv); + + // Auto-scroll to the bottom + elements.conversation.scrollTop = elements.conversation.scrollHeight; + + return messageDiv; +} + // Handle transcription response from server function handleTranscription(data) { const speaker = data.speaker === 0 ? 'user' : 'ai'; From 9c38d30932405f9edcd76d3dbb37f1b190f59fed Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:12:44 -0400 Subject: [PATCH 10/30] changed message --- React/src/app/call/page.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/React/src/app/call/page.tsx b/React/src/app/call/page.tsx index 5a76e36..1597927 100644 --- a/React/src/app/call/page.tsx +++ b/React/src/app/call/page.tsx @@ -21,7 +21,7 @@ const CallPage = () => { "Content-Type": "application/json", }, body: JSON.stringify({ - message: `yo i need help`, + message: `John Smith needs help.`, }), }); From 7cc033e74b5f217f744e90bda8484898f32212ee Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:18:23 -0400 Subject: [PATCH 11/30] stuff --- React/src/pages/api/databaseStorage.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/React/src/pages/api/databaseStorage.ts b/React/src/pages/api/databaseStorage.ts index 226625e..5ccc7a8 100644 --- a/React/src/pages/api/databaseStorage.ts +++ b/React/src/pages/api/databaseStorage.ts @@ -18,10 +18,13 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) // Ensure the database is connected await connectToDatabase(); + if (req.method === 'POST') { - const { codeword, contacts } = req.body; + const { email, codeword, contacts } = req.body; // Perform database operations here + + console.log("Codeword:", codeword); console.log("Contacts:", contacts); From 60db42f98e428a1eb513c4fe5625f8fbd73ef95c Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:22:08 -0400 Subject: [PATCH 12/30] Demo Update 16 --- Backend/server.py | 75 +++++++++++++++++++++++++------------------ Backend/voice-chat.js | 8 ++++- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 78254e4..5fbe12a 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -87,6 +87,7 @@ class Conversation: self.session_id = session_id self.segments: List[Segment] = [] self.current_speaker = 0 + self.ai_speaker_id = 1 # Add this property self.last_activity = time.time() self.is_processing = False @@ -209,7 +210,9 @@ def process_audio_queue(session_id, q): continue except Exception as e: logger.error(f"Error processing audio for {session_id}: {str(e)}") - socketio.emit('error', {'message': str(e)}, room=session_id) + # Create an app context for the socket emit + with app.app_context(): + socketio.emit('error', {'message': str(e)}, room=session_id) finally: logger.info(f"Ending processing thread for session: {session_id}") # Clean up when thread is done @@ -222,7 +225,8 @@ def process_audio_queue(session_id, q): def process_audio_and_respond(session_id, data): """Process audio data and generate a response""" if models.generator is None or models.asr is None or models.llm is None: - socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) + with app.app_context(): + socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) return conversation = active_conversations[session_id] @@ -260,7 +264,8 @@ def process_audio_and_respond(session_id, data): ) # Transcribe audio - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + with app.app_context(): + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) # Use the ASR pipeline to transcribe transcription_result = models.asr( @@ -271,7 +276,8 @@ def process_audio_and_respond(session_id, data): # If no text was recognized, don't process further if not user_text: - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + with app.app_context(): + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) return # Add the user's message to conversation history @@ -282,13 +288,15 @@ def process_audio_and_respond(session_id, data): ) # Send transcription to client - socketio.emit('transcription', { - 'text': user_text, - 'speaker': speaker_id - }, room=session_id) + with app.app_context(): + socketio.emit('transcription', { + 'text': user_text, + 'speaker': speaker_id + }, room=session_id) # Generate AI response using Llama - socketio.emit('processing_status', {'status': 'generating'}, room=session_id) + with app.app_context(): + socketio.emit('processing_status', {'status': 'generating'}, room=session_id) # Create prompt from conversation history conversation_history = "" @@ -319,22 +327,23 @@ def process_audio_and_respond(session_id, data): ).strip() # Synthesize speech - socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) + with app.app_context(): + socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) + + # Start sending the audio response + socketio.emit('audio_response_start', { + 'text': response_text, + 'total_chunks': 1, + 'chunk_index': 0 + }, room=session_id) - # Generate audio with CSM + # Define AI speaker ID (use a consistent value for the AI's voice) ai_speaker_id = 1 # Use speaker 1 for AI responses - # Start sending the audio response - socketio.emit('audio_response_start', { - 'text': response_text, - 'total_chunks': 1, - 'chunk_index': 0 - }, room=session_id) - # Generate audio audio_tensor = models.generator.generate( text=response_text, - speaker=ai_speaker_id, + speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id context=conversation.get_context(), max_audio_length_ms=10_000, temperature=0.9 @@ -343,7 +352,7 @@ def process_audio_and_respond(session_id, data): # Add AI response to conversation history ai_segment = conversation.add_segment( text=response_text, - speaker=ai_speaker_id, + speaker=ai_speaker_id, # Also use the local variable here audio=audio_tensor ) @@ -362,17 +371,18 @@ def process_audio_and_respond(session_id, data): audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" # Send audio chunk to client - socketio.emit('audio_response_chunk', { - 'chunk': audio_base64, - 'chunk_index': 0, - 'total_chunks': 1, - 'is_last': True - }, room=session_id) - - # Signal completion - socketio.emit('audio_response_complete', { - 'text': response_text - }, room=session_id) + with app.app_context(): + socketio.emit('audio_response_chunk', { + 'chunk': audio_base64, + 'chunk_index': 0, + 'total_chunks': 1, + 'is_last': True + }, room=session_id) + + # Signal completion + socketio.emit('audio_response_complete', { + 'text': response_text + }, room=session_id) finally: # Clean up temp file @@ -381,7 +391,8 @@ def process_audio_and_respond(session_id, data): except Exception as e: logger.error(f"Error processing audio: {str(e)}") - socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) + with app.app_context(): + socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) finally: # Reset processing flag conversation.is_processing = False diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 2c9e949..109f426 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -42,7 +42,8 @@ const state = { silenceTimer: null, volumeUpdateInterval: null, visualizerAnimationFrame: null, - currentSpeaker: 0 + currentSpeaker: 0, + aiSpeakerId: 1 // Define the AI's speaker ID to match server.py }; // Visualizer variables @@ -674,10 +675,14 @@ function handleAudioResponseChunk(data) { streamingAudio.chunks[data.chunk_index] = data.chunk; streamingAudio.receivedChunks++; + // Store audio element reference for later use + streamingAudio.audioElement = audioElement; + // Add to the conversation const messages = elements.conversation.querySelectorAll('.message.ai'); if (messages.length > 0) { const lastAiMessage = messages[messages.length - 1]; + streamingAudio.messageElement = lastAiMessage; // Replace existing audio player if there is one const existingPlayer = lastAiMessage.querySelector('.audio-player'); @@ -690,6 +695,7 @@ function handleAudioResponseChunk(data) { // Create a new message for the AI response const aiMessage = document.createElement('div'); aiMessage.className = 'message ai'; + streamingAudio.messageElement = aiMessage; if (streamingAudio.text) { const textElement = document.createElement('p'); From 5431e1fa5ee78420f2e0546c6ae4e807b4d68f5b Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:23:30 -0400 Subject: [PATCH 13/30] api route --- React/src/pages/api/databaseStorage.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/React/src/pages/api/databaseStorage.ts b/React/src/pages/api/databaseStorage.ts index 5ccc7a8..061341a 100644 --- a/React/src/pages/api/databaseStorage.ts +++ b/React/src/pages/api/databaseStorage.ts @@ -23,7 +23,18 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) const { email, codeword, contacts } = req.body; // Perform database operations here - + // query database to see if document with email exists + const existingUser = await mongoose.model('User').findOne({ email }); + if (existingUser) { + // If user exists, update their codeword and contacts + await mongoose.model('User').updateOne({ email }, { codeword, contacts }); + } else { + // If user does not exist, create a new user + const User = mongoose.model('User'); + const newUser = new User({ email, codeword, contacts }); + await newUser.save(); + } + console.log("Codeword:", codeword); console.log("Contacts:", contacts); From fbb3ff40069d14bfb3f9ced6252907c03b4e43dc Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:29:11 -0400 Subject: [PATCH 14/30] Demo Update 17 --- Backend/server.py | 85 ++++++++++++++++++++++++++++++++++--------- Backend/voice-chat.js | 52 ++++++++++++++++---------- 2 files changed, 101 insertions(+), 36 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 5fbe12a..52a85e9 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -61,22 +61,42 @@ def load_models(): global models logger.info("Loading CSM 1B model...") - models.generator = load_csm_1b(device=DEVICE) + try: + models.generator = load_csm_1b(device=DEVICE) + logger.info("CSM 1B model loaded successfully") + socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading CSM 1B model: {str(e)}") + socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) logger.info("Loading ASR pipeline...") - models.asr = pipeline( - "automatic-speech-recognition", - model="openai/whisper-small", - device=DEVICE - ) + try: + models.asr = pipeline( + "automatic-speech-recognition", + model="openai/whisper-small", + device=DEVICE, + language="en", # Force English language + return_attention_mask=True # Add attention mask + ) + logger.info("ASR pipeline loaded successfully") + socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading ASR pipeline: {str(e)}") + socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) logger.info("Loading Llama 3.2 model...") - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + try: + models.llm = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + device_map=DEVICE, + torch_dtype=torch.bfloat16 + ) + models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + logger.info("Llama 3.2 model loaded successfully") + socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading Llama 3.2 model: {str(e)}") + socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) # Load models in a background thread threading.Thread(target=load_models, daemon=True).start() @@ -118,6 +138,20 @@ def health_check(): "models_loaded": models.generator is not None and models.llm is not None }) +# Add a system status endpoint +@app.route('/api/status') +def system_status(): + return jsonify({ + "status": "ok", + "cuda_available": torch.cuda.is_available(), + "device": DEVICE, + "models": { + "generator": models.generator is not None, + "asr": models.asr is not None, + "llm": models.llm is not None + } + }) + # Socket event handlers @socketio.on('connect') def handle_connect(auth=None): @@ -225,10 +259,12 @@ def process_audio_queue(session_id, q): def process_audio_and_respond(session_id, data): """Process audio data and generate a response""" if models.generator is None or models.asr is None or models.llm is None: + logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) return + logger.info(f"Processing audio for session {session_id}") conversation = active_conversations[session_id] try: @@ -238,9 +274,15 @@ def process_audio_and_respond(session_id, data): # Process base64 audio data audio_data = data['audio'] speaker_id = data['speaker'] + logger.info(f"Received audio from speaker {speaker_id}") # Convert from base64 to WAV - audio_bytes = base64.b64decode(audio_data.split(',')[1]) + try: + audio_bytes = base64.b64decode(audio_data.split(',')[1]) + logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") + except Exception as e: + logger.error(f"Error decoding base64 audio: {str(e)}") + raise # Save to temporary file for processing with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: @@ -308,11 +350,19 @@ def process_audio_and_respond(session_id, data): prompt = f"{conversation_history}Assistant: " # Generate response with Llama - input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE) - + input_tokens = models.tokenizer( + prompt, + return_tensors="pt", + padding=True, + return_attention_mask=True + ) + input_ids = input_tokens.input_ids.to(DEVICE) + attention_mask = input_tokens.attention_mask.to(DEVICE) + with torch.no_grad(): generated_ids = models.llm.generate( input_ids, + attention_mask=attention_mask, # Add the attention mask max_new_tokens=100, temperature=0.7, top_p=0.9, @@ -437,5 +487,6 @@ cleanup_thread.start() # Start the server if __name__ == '__main__': port = int(os.environ.get('PORT', 5000)) - logger.info(f"Starting server on port {port}") - socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file + debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' + logger.info(f"Starting server on port {port} (debug={debug_mode})") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 109f426..5c3f247 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -378,25 +378,39 @@ function handleSpeechState(isSilent) { if (state.isSpeaking) { state.isSpeaking = false; - // Get the current audio data and send it - const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max - state.analyser.getFloatTimeDomainData(audioBuffer); - - // Create WAV blob - const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); - - // Convert to base64 - const reader = new FileReader(); - reader.onloadend = function() { - sendAudioChunk(reader.result, state.currentSpeaker); - }; - reader.readAsDataURL(wavBlob); - - // Update button state - elements.streamButton.classList.add('processing'); - elements.streamButton.innerHTML = ' Processing...'; - - addSystemMessage('Processing your message...'); + try { + // Get the current audio data and send it + const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max + state.analyser.getFloatTimeDomainData(audioBuffer); + + // Check if audio has content + const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01); + + if (!hasAudioContent) { + console.warn('Audio buffer appears to be empty or very quiet'); + addSystemMessage('No speech detected. Please try again and speak clearly.'); + return; + } + + // Create WAV blob + const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); + + // Convert to base64 + const reader = new FileReader(); + reader.onloadend = function() { + sendAudioChunk(reader.result, state.currentSpeaker); + }; + reader.readAsDataURL(wavBlob); + + // Update button state + elements.streamButton.classList.add('processing'); + elements.streamButton.innerHTML = ' Processing...'; + + addSystemMessage('Processing your message...'); + } catch (e) { + console.error('Error recording audio:', e); + addSystemMessage('Error recording audio. Please try again.'); + } } }, CLIENT_SILENCE_DURATION_MS); } From 6622d0c605e47b777ff9be03bb0fde0b0824f969 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:31:58 -0400 Subject: [PATCH 15/30] Demo Fixes 4 --- Backend/server.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 52a85e9..992e674 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -71,13 +71,15 @@ def load_models(): logger.info("Loading ASR pipeline...") try: + # Initialize the pipeline without the language parameter in the constructor models.asr = pipeline( "automatic-speech-recognition", model="openai/whisper-small", - device=DEVICE, - language="en", # Force English language - return_attention_mask=True # Add attention mask + device=DEVICE ) + + # Configure the model with the appropriate options + # Note that for whisper, language should be set during inference, not initialization logger.info("ASR pipeline loaded successfully") socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) except Exception as e: @@ -312,7 +314,8 @@ def process_audio_and_respond(session_id, data): # Use the ASR pipeline to transcribe transcription_result = models.asr( {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate}, - return_timestamps=False + return_timestamps=False, + generate_kwargs={"language": "en"} # Set language during inference ) user_text = transcription_result['text'].strip() From e9c90d4d680943335c4805d94888186b36bf92f8 Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:36:12 -0400 Subject: [PATCH 16/30] close to making db work --- React/src/app/page.tsx | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index 1dcc93b..e19cf15 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -134,7 +134,39 @@ export default async function Home() { className="bg-emerald-500 text-fuchsia-300" type="button">Add - +
From 10902f1d713b13295d762bb76b55709f531c34ba Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:40:08 -0400 Subject: [PATCH 17/30] Demo Update 18 --- Backend/server.py | 122 +++++++++++++++++++++++++++--------------- Backend/voice-chat.js | 94 +++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 45 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 992e674..4cc4f91 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,6 +8,7 @@ import logging import threading import queue import tempfile +import gc from typing import Dict, List, Optional, Tuple import torch @@ -18,6 +19,9 @@ from flask_socketio import SocketIO, emit from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +# Import WhisperX for better transcription +import whisperx + from generator import load_csm_1b, Segment from dataclasses import dataclass @@ -52,7 +56,10 @@ class AppModels: generator = None tokenizer = None llm = None - asr = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + diarize_model = None models = AppModels() @@ -69,22 +76,16 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - logger.info("Loading ASR pipeline...") + logger.info("Loading WhisperX model...") try: - # Initialize the pipeline without the language parameter in the constructor - models.asr = pipeline( - "automatic-speech-recognition", - model="openai/whisper-small", - device=DEVICE - ) - - # Configure the model with the appropriate options - # Note that for whisper, language should be set during inference, not initialization - logger.info("ASR pipeline loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + # Use WhisperX instead of the regular Whisper + compute_type = "float16" if DEVICE == "cuda" else "float32" + models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type) + logger.info("WhisperX model loaded successfully") + socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'}) except Exception as e: - logger.error(f"Error loading ASR pipeline: {str(e)}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) + logger.error(f"Error loading WhisperX model: {str(e)}") + socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)}) logger.info("Loading Llama 3.2 model...") try: @@ -149,7 +150,7 @@ def system_status(): "device": DEVICE, "models": { "generator": models.generator is not None, - "asr": models.asr is not None, + "whisperx": models.whisperx_model is not None, "llm": models.llm is not None } }) @@ -259,8 +260,8 @@ def process_audio_queue(session_id, q): del user_queues[session_id] def process_audio_and_respond(session_id, data): - """Process audio data and generate a response""" - if models.generator is None or models.asr is None or models.llm is None: + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) @@ -292,7 +293,57 @@ def process_audio_and_respond(session_id, data): temp_path = temp_file.name try: - # Load audio file + # Load audio using WhisperX + with app.app_context(): + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + + # Load audio with WhisperX instead of torchaudio + audio = whisperx.load_audio(temp_path) + + # Transcribe using WhisperX + batch_size = 16 # Adjust based on available memory + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Load alignment model if not already loaded + if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None): + # Clear previous models to save memory + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + gc.collect() + torch.cuda.empty_cache() if DEVICE == "cuda" else None + + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE + ) + models.last_language = language_code + + # Align the transcript + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Combine all segments into a single transcript + user_text = ' '.join([segment['text'] for segment in result['segments']]) + + # If no text was recognized, don't process further + if not user_text or len(user_text.strip()) == 0: + with app.app_context(): + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + return + + logger.info(f"Transcription: {user_text}") + + # Load audio for CSM input waveform, sample_rate = torchaudio.load(temp_path) # Normalize to mono if needed @@ -307,24 +358,6 @@ def process_audio_and_respond(session_id, data): new_freq=models.generator.sample_rate ) - # Transcribe audio - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Use the ASR pipeline to transcribe - transcription_result = models.asr( - {"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate}, - return_timestamps=False, - generate_kwargs={"language": "en"} # Set language during inference - ) - user_text = transcription_result['text'].strip() - - # If no text was recognized, don't process further - if not user_text: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - # Add the user's message to conversation history user_segment = conversation.add_segment( text=user_text, @@ -336,7 +369,8 @@ def process_audio_and_respond(session_id, data): with app.app_context(): socketio.emit('transcription', { 'text': user_text, - 'speaker': speaker_id + 'speaker': speaker_id, + 'segments': result['segments'] # Send detailed segments info }, room=session_id) # Generate AI response using Llama @@ -365,7 +399,7 @@ def process_audio_and_respond(session_id, data): with torch.no_grad(): generated_ids = models.llm.generate( input_ids, - attention_mask=attention_mask, # Add the attention mask + attention_mask=attention_mask, max_new_tokens=100, temperature=0.7, top_p=0.9, @@ -390,13 +424,13 @@ def process_audio_and_respond(session_id, data): 'chunk_index': 0 }, room=session_id) - # Define AI speaker ID (use a consistent value for the AI's voice) - ai_speaker_id = 1 # Use speaker 1 for AI responses + # Define AI speaker ID + ai_speaker_id = conversation.ai_speaker_id # Generate audio audio_tensor = models.generator.generate( text=response_text, - speaker=ai_speaker_id, # Use the local variable instead of conversation.ai_speaker_id + speaker=ai_speaker_id, context=conversation.get_context(), max_audio_length_ms=10_000, temperature=0.9 @@ -405,7 +439,7 @@ def process_audio_and_respond(session_id, data): # Add AI response to conversation history ai_segment = conversation.add_segment( text=response_text, - speaker=ai_speaker_id, # Also use the local variable here + speaker=ai_speaker_id, audio=audio_tensor ) @@ -444,6 +478,8 @@ def process_audio_and_respond(session_id, data): except Exception as e: logger.error(f"Error processing audio: {str(e)}") + import traceback + logger.error(traceback.format_exc()) with app.app_context(): socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) finally: diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 5c3f247..705d5ab 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -627,7 +627,59 @@ function addSystemMessage(text) { // Handle transcription response from server function handleTranscription(data) { const speaker = data.speaker === 0 ? 'user' : 'ai'; - addMessage(data.text, speaker); + + // Create the message div + const messageDiv = addMessage(data.text, speaker); + + // If we have detailed segments from WhisperX, add timestamps + if (data.segments && data.segments.length > 0) { + // Add a timestamps container + const timestampsContainer = document.createElement('div'); + timestampsContainer.className = 'timestamps-container'; + timestampsContainer.style.display = 'none'; // Hidden by default + + // Add a toggle button + const toggleButton = document.createElement('button'); + toggleButton.className = 'timestamp-toggle'; + toggleButton.textContent = 'Show Timestamps'; + toggleButton.onclick = function() { + const isHidden = timestampsContainer.style.display === 'none'; + timestampsContainer.style.display = isHidden ? 'block' : 'none'; + toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps'; + }; + + // Add timestamps for each segment + data.segments.forEach(segment => { + const timestampDiv = document.createElement('div'); + timestampDiv.className = 'timestamp'; + + // Format start and end times + const startTime = formatTime(segment.start); + const endTime = formatTime(segment.end); + + timestampDiv.innerHTML = ` + [${startTime} - ${endTime}] + ${segment.text} + `; + + timestampsContainer.appendChild(timestampDiv); + }); + + // Add the timestamp elements to the message + messageDiv.appendChild(toggleButton); + messageDiv.appendChild(timestampsContainer); + } + + return messageDiv; +} + +// Helper function to format time in seconds to MM:SS.ms format +function formatTime(seconds) { + const mins = Math.floor(seconds / 60); + const secs = Math.floor(seconds % 60); + const ms = Math.floor((seconds % 1) * 1000); + + return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`; } // Handle context update from server @@ -804,7 +856,7 @@ function finalizeStreamingAudio() { // Add CSS styles for new UI elements document.addEventListener('DOMContentLoaded', function() { - // Add styles for processing state + // Add styles for processing state and timestamps const style = document.createElement('style'); style.textContent = ` .message.processing { @@ -833,6 +885,44 @@ document.addEventListener('DOMContentLoaded', function() { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } + + /* Timestamp styles */ + .timestamp-toggle { + font-size: 0.75em; + padding: 4px 8px; + margin-top: 8px; + background-color: #f0f0f0; + border: 1px solid #ddd; + border-radius: 4px; + cursor: pointer; + } + + .timestamp-toggle:hover { + background-color: #e0e0e0; + } + + .timestamps-container { + margin-top: 8px; + padding: 8px; + background-color: #f9f9f9; + border-radius: 4px; + font-size: 0.85em; + } + + .timestamp { + margin-bottom: 4px; + padding: 2px 0; + } + + .timestamp .time { + color: #666; + font-family: monospace; + margin-right: 8px; + } + + .timestamp .text { + color: #333; + } `; document.head.appendChild(style); }); From 4fb2c9bc52aeeed95eba576eddb93fb9bc8cca5f Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:53:30 -0400 Subject: [PATCH 18/30] Demo Update 20 --- Backend/index.html | 48 ++++++++++++++++++ Backend/requirements.txt | 4 ++ Backend/server.py | 96 ++++++++++++++++-------------------- Backend/voice-chat.js | 103 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 196 insertions(+), 55 deletions(-) diff --git a/Backend/index.html b/Backend/index.html index 01bd5f7..e69ec9a 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -260,6 +260,47 @@ font-size: 0.8em; color: #777; } + + /* Model status indicators */ + .model-status { + display: flex; + gap: 8px; + } + + .model-indicator { + padding: 3px 6px; + border-radius: 4px; + font-size: 0.7em; + font-weight: bold; + } + + .model-indicator.loading { + background-color: #ffd54f; + color: #000; + } + + .model-indicator.loaded { + background-color: #4CAF50; + color: white; + } + + .model-indicator.error { + background-color: #f44336; + color: white; + } + + .message-timestamp { + font-size: 0.7em; + color: #888; + margin-top: 4px; + text-align: right; + } + + .simple-timestamp { + font-size: 0.8em; + color: #888; + margin-top: 5px; + } @@ -276,6 +317,13 @@
Disconnected
+ + +
+
CSM
+
ASR
+
LLM
+
diff --git a/Backend/requirements.txt b/Backend/requirements.txt index ba8a04f..1e05eb3 100644 --- a/Backend/requirements.txt +++ b/Backend/requirements.txt @@ -1,7 +1,11 @@ +flask==2.2.5 +flask-socketio==5.3.6 +flask-cors==4.0.0 torch==2.4.0 torchaudio==2.4.0 tokenizers==0.21.0 transformers==4.49.0 +librosa==0.10.1 huggingface_hub==0.28.1 moshi==0.2.2 torchtune==0.4.0 diff --git a/Backend/server.py b/Backend/server.py index 4cc4f91..ab56e77 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -56,12 +56,8 @@ class AppModels: generator = None tokenizer = None llm = None - whisperx_model = None - whisperx_align_model = None - whisperx_align_metadata = None - diarize_model = None - -models = AppModels() + asr_model = None + asr_processor = None def load_models(): """Load all required models""" @@ -76,16 +72,22 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - logger.info("Loading WhisperX model...") + logger.info("Loading Whisper ASR model...") try: - # Use WhisperX instead of the regular Whisper - compute_type = "float16" if DEVICE == "cuda" else "float32" - models.whisperx_model = whisperx.load_model("large-v2", DEVICE, compute_type=compute_type) - logger.info("WhisperX model loaded successfully") - socketio.emit('model_status', {'model': 'whisperx', 'status': 'loaded'}) + # Use regular Whisper instead of WhisperX to avoid compatibility issues + from transformers import WhisperProcessor, WhisperForConditionalGeneration + + # Use a smaller model for faster processing + model_id = "openai/whisper-small" + + models.asr_processor = WhisperProcessor.from_pretrained(model_id) + models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE) + + logger.info("Whisper ASR model loaded successfully") + socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) except Exception as e: - logger.error(f"Error loading WhisperX model: {str(e)}") - socketio.emit('model_status', {'model': 'whisperx', 'status': 'error', 'message': str(e)}) + logger.error(f"Error loading ASR model: {str(e)}") + socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) logger.info("Loading Llama 3.2 model...") try: @@ -141,7 +143,8 @@ def health_check(): "models_loaded": models.generator is not None and models.llm is not None }) -# Add a system status endpoint +# Fix the system_status function: + @app.route('/api/status') def system_status(): return jsonify({ @@ -150,7 +153,7 @@ def system_status(): "device": DEVICE, "models": { "generator": models.generator is not None, - "whisperx": models.whisperx_model is not None, + "asr": models.asr_model is not None, # Use the correct model name "llm": models.llm is not None } }) @@ -260,8 +263,8 @@ def process_audio_queue(session_id, q): del user_queues[session_id] def process_audio_and_respond(session_id, data): - """Process audio data and generate a response using WhisperX""" - if models.generator is None or models.whisperx_model is None or models.llm is None: + """Process audio data and generate a response using standard Whisper""" + if models.generator is None or models.asr_model is None or models.llm is None: logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) @@ -293,47 +296,33 @@ def process_audio_and_respond(session_id, data): temp_path = temp_file.name try: - # Load audio using WhisperX + # Notify client that transcription is starting with app.app_context(): socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - # Load audio with WhisperX instead of torchaudio - audio = whisperx.load_audio(temp_path) + # Load audio for ASR processing + import librosa + speech_array, sampling_rate = librosa.load(temp_path, sr=16000) - # Transcribe using WhisperX - batch_size = 16 # Adjust based on available memory - result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + # Convert to required format + input_features = models.asr_processor( + speech_array, + sampling_rate=sampling_rate, + return_tensors="pt" + ).input_features.to(DEVICE) - # Get the detected language - language_code = result["language"] - logger.info(f"Detected language: {language_code}") - - # Load alignment model if not already loaded - if models.whisperx_align_model is None or language_code != getattr(models, 'last_language', None): - # Clear previous models to save memory - if models.whisperx_align_model is not None: - del models.whisperx_align_model - del models.whisperx_align_metadata - gc.collect() - torch.cuda.empty_cache() if DEVICE == "cuda" else None - - models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( - language_code=language_code, device=DEVICE - ) - models.last_language = language_code - - # Align the transcript - result = whisperx.align( - result["segments"], - models.whisperx_align_model, - models.whisperx_align_metadata, - audio, - DEVICE, - return_char_alignments=False + # Generate token ids + predicted_ids = models.asr_model.generate( + input_features, + language="en", + task="transcribe" ) - # Combine all segments into a single transcript - user_text = ' '.join([segment['text'] for segment in result['segments']]) + # Decode the predicted ids to text + user_text = models.asr_processor.batch_decode( + predicted_ids, + skip_special_tokens=True + )[0] # If no text was recognized, don't process further if not user_text or len(user_text.strip()) == 0: @@ -369,8 +358,7 @@ def process_audio_and_respond(session_id, data): with app.app_context(): socketio.emit('transcription', { 'text': user_text, - 'speaker': speaker_id, - 'segments': result['segments'] # Send detailed segments info + 'speaker': speaker_id }, room=session_id) # Generate AI response using Llama diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 705d5ab..e4f1272 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -105,8 +105,25 @@ function setupSocketConnection() { }); state.socket.on('error', (data) => { - addSystemMessage(`Error: ${data.message}`); console.error('Server error:', data.message); + + // Make the error more user-friendly + let userMessage = data.message; + + // Check for common errors and provide more helpful messages + if (data.message.includes('Models still loading')) { + userMessage = 'The AI models are still loading. Please wait a moment and try again.'; + } else if (data.message.includes('No speech detected')) { + userMessage = 'No speech was detected. Please speak clearly and try again.'; + } + + addSystemMessage(`Error: ${userMessage}`); + + // Reset button state if it was processing + if (elements.streamButton.classList.contains('processing')) { + elements.streamButton.classList.remove('processing'); + elements.streamButton.innerHTML = ' Start Conversation'; + } }); // Register message handlers @@ -115,6 +132,9 @@ function setupSocketConnection() { state.socket.on('streaming_status', handleStreamingStatus); state.socket.on('processing_status', handleProcessingStatus); + // Add model status handlers + state.socket.on('model_status', handleModelStatusUpdate); + // Handlers for incremental audio streaming state.socket.on('audio_response_start', handleAudioResponseStart); state.socket.on('audio_response_chunk', handleAudioResponseChunk); @@ -189,6 +209,27 @@ function startStreaming() { return; } + // Check if models are loaded via the API + fetch('/api/status') + .then(response => response.json()) + .then(data => { + if (!data.models.generator || !data.models.asr || !data.models.llm) { + addSystemMessage('Still loading AI models. Please wait...'); + return; + } + + // Continue with recording if models are loaded + initializeRecording(); + }) + .catch(error => { + console.error('Error checking model status:', error); + // Try anyway, the server will respond with an error if models aren't ready + initializeRecording(); + }); +} + +// Extracted the recording initialization to a separate function +function initializeRecording() { // Request microphone access navigator.mediaDevices.getUserMedia({ audio: true, video: false }) .then(stream => { @@ -600,6 +641,13 @@ function addMessage(text, type) { textElement.textContent = text; messageDiv.appendChild(textElement); + // Add timestamp to every message + const timestamp = new Date().toLocaleTimeString(); + const timeLabel = document.createElement('div'); + timeLabel.className = 'message-timestamp'; + timeLabel.textContent = timestamp; + messageDiv.appendChild(timeLabel); + elements.conversation.appendChild(messageDiv); // Auto-scroll to the bottom @@ -668,6 +716,13 @@ function handleTranscription(data) { // Add the timestamp elements to the message messageDiv.appendChild(toggleButton); messageDiv.appendChild(timestampsContainer); + } else { + // No timestamp data available - add a simple timestamp for the entire message + const timestamp = new Date().toLocaleTimeString(); + const timeLabel = document.createElement('div'); + timeLabel.className = 'simple-timestamp'; + timeLabel.textContent = timestamp; + messageDiv.appendChild(timeLabel); } return messageDiv; @@ -854,6 +909,52 @@ function finalizeStreamingAudio() { streamingAudio.audioElement = null; } +// Handle model status updates +function handleModelStatusUpdate(data) { + const { model, status, message } = data; + + if (status === 'loaded') { + console.log(`Model ${model} loaded successfully`); + addSystemMessage(`${model.toUpperCase()} model loaded successfully`); + + // Update UI to show model is ready + const modelStatusElement = document.getElementById(`${model}Status`); + if (modelStatusElement) { + modelStatusElement.classList.remove('loading'); + modelStatusElement.classList.add('loaded'); + modelStatusElement.title = 'Model loaded successfully'; + } + + // Check if the required models are loaded to enable conversation + checkAllModelsLoaded(); + } else if (status === 'error') { + console.error(`Error loading ${model} model: ${message}`); + addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`); + + // Update UI to show model loading failed + const modelStatusElement = document.getElementById(`${model}Status`); + if (modelStatusElement) { + modelStatusElement.classList.remove('loading'); + modelStatusElement.classList.add('error'); + modelStatusElement.title = `Error: ${message}`; + } + } +} + +// Check if all required models are loaded and enable UI accordingly +function checkAllModelsLoaded() { + // When all models are loaded, enable the stream button if it was disabled + const allLoaded = + document.getElementById('csmStatus')?.classList.contains('loaded') && + document.getElementById('asrStatus')?.classList.contains('loaded') && + document.getElementById('llmStatus')?.classList.contains('loaded'); + + if (allLoaded) { + elements.streamButton.disabled = false; + addSystemMessage('All models loaded. Ready for conversation!'); + } +} + // Add CSS styles for new UI elements document.addEventListener('DOMContentLoaded', function() { // Add styles for processing state and timestamps From 13441150139cf6ca2e1924cdb0232381487a6fcd Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:53:45 -0400 Subject: [PATCH 19/30] storage api route --- React/src/pages/api/databaseStorage.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/React/src/pages/api/databaseStorage.ts b/React/src/pages/api/databaseStorage.ts index 061341a..aa01d37 100644 --- a/React/src/pages/api/databaseStorage.ts +++ b/React/src/pages/api/databaseStorage.ts @@ -10,6 +10,11 @@ async function connectToDatabase() { // Only connect if not already connected await mongoose.connect(uri, clientOptions); console.log("Connected to MongoDB!"); + mongoose.model("User", new mongoose.Schema({ + email: { type: String, required: true, unique: true }, + codeword: { type: String, required: true }, + contacts: [{ type: String }], + })); } } From bfaffef68447643652a8680ee1400cca894689ff Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 02:57:57 -0400 Subject: [PATCH 20/30] Demo Fixes 5 --- Backend/index.html | 29 +++++++++++++++++++++++++++++ Backend/server.py | 19 ++++++++++++++++++- Backend/voice-chat.js | 15 +++++++++++++-- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/Backend/index.html b/Backend/index.html index e69ec9a..9950a00 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -301,6 +301,30 @@ color: #888; margin-top: 5px; } + + /* Add this to your existing styles */ + .loading-progress { + width: 100%; + max-width: 150px; + margin-right: 10px; + } + + progress { + width: 100%; + height: 8px; + border-radius: 4px; + overflow: hidden; + } + + progress::-webkit-progress-bar { + background-color: #eee; + border-radius: 4px; + } + + progress::-webkit-progress-value { + background-color: var(--primary-color); + border-radius: 4px; + } @@ -318,6 +342,11 @@ Disconnected + +
+ 0% +
+
CSM
diff --git a/Backend/server.py b/Backend/server.py index ab56e77..c4d38e0 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -59,17 +59,28 @@ class AppModels: asr_model = None asr_processor = None +# Initialize the models object +models = AppModels() + def load_models(): """Load all required models""" global models + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) + logger.info("Loading CSM 1B model...") try: models.generator = load_csm_1b(device=DEVICE) logger.info("CSM 1B model loaded successfully") socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) + progress = 33 + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress}) + if DEVICE == "cuda": + torch.cuda.empty_cache() except Exception as e: - logger.error(f"Error loading CSM 1B model: {str(e)}") + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) logger.info("Loading Whisper ASR model...") @@ -85,6 +96,10 @@ def load_models(): logger.info("Whisper ASR model loaded successfully") socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + progress = 66 + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress}) + if DEVICE == "cuda": + torch.cuda.empty_cache() except Exception as e: logger.error(f"Error loading ASR model: {str(e)}") socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) @@ -99,6 +114,8 @@ def load_models(): models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") logger.info("Llama 3.2 model loaded successfully") socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) + progress = 100 + socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress}) except Exception as e: logger.error(f"Error loading Llama 3.2 model: {str(e)}") socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index e4f1272..2efd76d 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -909,9 +909,20 @@ function finalizeStreamingAudio() { streamingAudio.audioElement = null; } -// Handle model status updates +// Enhance the handleModelStatusUpdate function: + function handleModelStatusUpdate(data) { - const { model, status, message } = data; + const { model, status, message, progress } = data; + + if (model === 'overall' && status === 'loading') { + // Update overall loading progress + const progressBar = document.getElementById('modelLoadingProgress'); + if (progressBar) { + progressBar.value = progress; + progressBar.textContent = `${progress}%`; + } + return; + } if (status === 'loaded') { console.log(`Model ${model} loaded successfully`); From fdb92ff0613a65ff8d6a4d180ad317a51a1f81f9 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:03:14 -0400 Subject: [PATCH 21/30] Demo Fixes 6 --- Backend/server.py | 67 ++++++++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index c4d38e0..563534c 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -112,6 +112,15 @@ def load_models(): torch_dtype=torch.bfloat16 ) models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + + # Configure all special tokens + models.tokenizer.pad_token = models.tokenizer.eos_token + models.tokenizer.padding_side = "left" # For causal language modeling + + # Inform the model about the pad token + if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: + models.llm.config.pad_token_id = models.tokenizer.pad_token_id + logger.info("Llama 3.2 model loaded successfully") socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) progress = 100 @@ -392,31 +401,41 @@ def process_audio_and_respond(session_id, data): prompt = f"{conversation_history}Assistant: " # Generate response with Llama - input_tokens = models.tokenizer( - prompt, - return_tensors="pt", - padding=True, - return_attention_mask=True - ) - input_ids = input_tokens.input_ids.to(DEVICE) - attention_mask = input_tokens.attention_mask.to(DEVICE) - - with torch.no_grad(): - generated_ids = models.llm.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=100, - temperature=0.7, - top_p=0.9, - do_sample=True, - pad_token_id=models.tokenizer.eos_token_id + try: + # Ensure pad token is set + if models.tokenizer.pad_token is None: + models.tokenizer.pad_token = models.tokenizer.eos_token + + input_tokens = models.tokenizer( + prompt, + return_tensors="pt", + padding=True, + return_attention_mask=True ) - - # Decode the response - response_text = models.tokenizer.decode( - generated_ids[0][input_ids.shape[1]:], - skip_special_tokens=True - ).strip() + input_ids = input_tokens.input_ids.to(DEVICE) + attention_mask = input_tokens.attention_mask.to(DEVICE) + + with torch.no_grad(): + generated_ids = models.llm.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=models.tokenizer.eos_token_id + ) + + # Decode the response + response_text = models.tokenizer.decode( + generated_ids[0][input_ids.shape[1]:], + skip_special_tokens=True + ).strip() + except Exception as e: + logger.error(f"Error generating response: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + response_text = "I'm sorry, I encountered an error while processing your request." # Synthesize speech with app.app_context(): From 284dd509727b082396404e1d24443d8581195420 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:09:57 -0400 Subject: [PATCH 22/30] Demo Fixes 7 --- Backend/server.py | 89 ++++++++++++++++++++++++++++++++++--------- Backend/voice-chat.js | 14 ++++++- 2 files changed, 83 insertions(+), 20 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 563534c..8145ab0 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -25,6 +25,10 @@ import whisperx from generator import load_csm_1b, Segment from dataclasses import dataclass +# Add these imports at the top +import psutil +import gc + # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -68,13 +72,13 @@ def load_models(): socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) - logger.info("Loading CSM 1B model...") + # CSM 1B loading try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) models.generator = load_csm_1b(device=DEVICE) logger.info("CSM 1B model loaded successfully") socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) - progress = 33 - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) if DEVICE == "cuda": torch.cuda.empty_cache() except Exception as e: @@ -83,8 +87,9 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - logger.info("Loading Whisper ASR model...") + # Whisper loading try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) # Use regular Whisper instead of WhisperX to avoid compatibility issues from transformers import WhisperProcessor, WhisperForConditionalGeneration @@ -96,16 +101,16 @@ def load_models(): logger.info("Whisper ASR model loaded successfully") socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) - progress = 66 - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': progress}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) if DEVICE == "cuda": torch.cuda.empty_cache() except Exception as e: logger.error(f"Error loading ASR model: {str(e)}") socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) - logger.info("Loading Llama 3.2 model...") + # Llama loading try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) models.llm = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-1B", device_map=DEVICE, @@ -123,8 +128,8 @@ def load_models(): logger.info("Llama 3.2 model loaded successfully") socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) - progress = 100 - socketio.emit('model_status', {'model': 'overall', 'status': 'loaded', 'progress': progress}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) except Exception as e: logger.error(f"Error loading Llama 3.2 model: {str(e)}") socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) @@ -184,6 +189,39 @@ def system_status(): } }) +# Add a new endpoint to check system resources +@app.route('/api/system_resources') +def system_resources(): + # Get CPU usage + cpu_percent = psutil.cpu_percent(interval=0.1) + + # Get memory usage + memory = psutil.virtual_memory() + memory_used_gb = memory.used / (1024 ** 3) + memory_total_gb = memory.total / (1024 ** 3) + memory_percent = memory.percent + + # Get GPU memory if available + gpu_memory = {} + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_memory[f"gpu_{i}"] = { + "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), + "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), + "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) + } + + return jsonify({ + "cpu_percent": cpu_percent, + "memory": { + "used_gb": memory_used_gb, + "total_gb": memory_total_gb, + "percent": memory_percent + }, + "gpu_memory": gpu_memory, + "active_sessions": len(active_conversations) + }) + # Socket event handlers @socketio.on('connect') def handle_connect(auth=None): @@ -331,18 +369,33 @@ def process_audio_and_respond(session_id, data): speech_array, sampling_rate = librosa.load(temp_path, sr=16000) # Convert to required format - input_features = models.asr_processor( + processor_output = models.asr_processor( speech_array, sampling_rate=sampling_rate, - return_tensors="pt" - ).input_features.to(DEVICE) - - # Generate token ids - predicted_ids = models.asr_model.generate( - input_features, - language="en", - task="transcribe" + return_tensors="pt", + padding=True, # Add padding + return_attention_mask=True # Request attention mask ) + input_features = processor_output.input_features.to(DEVICE) + attention_mask = processor_output.get('attention_mask', None) + + if attention_mask is not None: + attention_mask = attention_mask.to(DEVICE) + + # Generate token ids with attention mask + predicted_ids = models.asr_model.generate( + input_features, + attention_mask=attention_mask, + language="en", + task="transcribe" + ) + else: + # Fallback if attention mask is not available + predicted_ids = models.asr_model.generate( + input_features, + language="en", + task="transcribe" + ) # Decode the predicted ids to text user_text = models.asr_processor.batch_decode( diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js index 2efd76d..dc2db04 100644 --- a/Backend/voice-chat.js +++ b/Backend/voice-chat.js @@ -43,7 +43,9 @@ const state = { volumeUpdateInterval: null, visualizerAnimationFrame: null, currentSpeaker: 0, - aiSpeakerId: 1 // Define the AI's speaker ID to match server.py + aiSpeakerId: 1, // Define the AI's speaker ID to match server.py + transcriptionRetries: 0, + maxTranscriptionRetries: 3 }; // Visualizer variables @@ -429,7 +431,15 @@ function handleSpeechState(isSilent) { if (!hasAudioContent) { console.warn('Audio buffer appears to be empty or very quiet'); - addSystemMessage('No speech detected. Please try again and speak clearly.'); + + if (state.transcriptionRetries < state.maxTranscriptionRetries) { + state.transcriptionRetries++; + const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`; + addSystemMessage(retryMessage); + } else { + state.transcriptionRetries = 0; + addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.'); + } return; } From c8551f90b361c8abe3e73392670a9a8259268b71 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:19:23 -0400 Subject: [PATCH 23/30] Demo Fixes 8 --- Backend/server.py | 131 ++++++++++++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 50 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 8145ab0..e912a9d 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -60,8 +60,11 @@ class AppModels: generator = None tokenizer = None llm = None - asr_model = None - asr_processor = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + diarize_model = None + last_language = None # Initialize the models object models = AppModels() @@ -87,25 +90,27 @@ def load_models(): logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - # Whisper loading + # WhisperX loading try: socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use regular Whisper instead of WhisperX to avoid compatibility issues - from transformers import WhisperProcessor, WhisperForConditionalGeneration + # Use WhisperX for better transcription with timestamps + import whisperx - # Use a smaller model for faster processing - model_id = "openai/whisper-small" + # Use compute_type based on device + compute_type = "float16" if DEVICE == "cuda" else "float32" - models.asr_processor = WhisperProcessor.from_pretrained(model_id) - models.asr_model = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE) + # Load the WhisperX model (smaller model for faster processing) + models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - logger.info("Whisper ASR model loaded successfully") + logger.info("WhisperX model loaded successfully") socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) if DEVICE == "cuda": torch.cuda.empty_cache() except Exception as e: - logger.error(f"Error loading ASR model: {str(e)}") + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) # Llama loading @@ -184,7 +189,7 @@ def system_status(): "device": DEVICE, "models": { "generator": models.generator is not None, - "asr": models.asr_model is not None, # Use the correct model name + "asr": models.whisperx_model is not None, # Use the correct model name "llm": models.llm is not None } }) @@ -327,8 +332,8 @@ def process_audio_queue(session_id, q): del user_queues[session_id] def process_audio_and_respond(session_id, data): - """Process audio data and generate a response using standard Whisper""" - if models.generator is None or models.asr_model is None or models.llm is None: + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: logger.warning("Models not yet loaded!") with app.app_context(): socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) @@ -364,44 +369,69 @@ def process_audio_and_respond(session_id, data): with app.app_context(): socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - # Load audio for ASR processing - import librosa - speech_array, sampling_rate = librosa.load(temp_path, sr=16000) + # Load audio using WhisperX + import whisperx + audio = whisperx.load_audio(temp_path) - # Convert to required format - processor_output = models.asr_processor( - speech_array, - sampling_rate=sampling_rate, - return_tensors="pt", - padding=True, # Add padding - return_attention_mask=True # Request attention mask - ) - input_features = processor_output.input_features.to(DEVICE) - attention_mask = processor_output.get('attention_mask', None) - - if attention_mask is not None: - attention_mask = attention_mask.to(DEVICE) + # Check audio length and add a warning for short clips + audio_length = len(audio) / 16000 # assuming 16kHz sample rate + if audio_length < 1.0: + logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") + + # Transcribe using WhisperX + batch_size = 16 # adjust based on your GPU memory + logger.info("Running WhisperX transcription...") + + # Handle the warning about audio being shorter than 30s by suppressing it + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="audio is shorter than 30s") + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Check if alignment model needs to be loaded or updated + if models.whisperx_align_model is None or language_code != models.last_language: + # Clean up old models if they exist + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + if DEVICE == "cuda": + gc.collect() + torch.cuda.empty_cache() - # Generate token ids with attention mask - predicted_ids = models.asr_model.generate( - input_features, - attention_mask=attention_mask, - language="en", - task="transcribe" - ) - else: - # Fallback if attention mask is not available - predicted_ids = models.asr_model.generate( - input_features, - language="en", - task="transcribe" + # Load new alignment model for the detected language + logger.info(f"Loading alignment model for language: {language_code}") + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE ) + models.last_language = language_code - # Decode the predicted ids to text - user_text = models.asr_processor.batch_decode( - predicted_ids, - skip_special_tokens=True - )[0] + # Align the transcript to get word-level timestamps + if result["segments"] and len(result["segments"]) > 0: + logger.info("Aligning transcript...") + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Process the segments for better output + for segment in result["segments"]: + # Round timestamps for better display + segment["start"] = round(segment["start"], 2) + segment["end"] = round(segment["end"], 2) + # Add a confidence score if not present + if "confidence" not in segment: + segment["confidence"] = 1.0 # Default confidence + + # Extract the full text from all segments + user_text = ' '.join([segment['text'] for segment in result['segments']]) # If no text was recognized, don't process further if not user_text or len(user_text.strip()) == 0: @@ -433,11 +463,12 @@ def process_audio_and_respond(session_id, data): audio=waveform.squeeze() ) - # Send transcription to client + # Send transcription to client with detailed segments with app.app_context(): socketio.emit('transcription', { 'text': user_text, - 'speaker': speaker_id + 'speaker': speaker_id, + 'segments': result['segments'] # Include the detailed segments with timestamps }, room=session_id) # Generate AI response using Llama From accc5c49e312deb03545a3c9d6aeb18ed8d53ee0 Mon Sep 17 00:00:00 2001 From: BGV <26331505+bgv2@users.noreply.github.com> Date: Sun, 30 Mar 2025 03:31:30 -0400 Subject: [PATCH 24/30] onclick still not firing --- React/src/app/page.tsx | 69 ++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index e19cf15..c91ed2b 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -1,10 +1,8 @@ -"use client"; import { useState } from "react"; import { auth0 } from "../lib/auth0"; export default async function Home() { - const [contacts, setContacts] = useState([]); const [codeword, setCodeword] = useState(""); @@ -13,6 +11,38 @@ export default async function Home() { console.log("Session:", session?.user); + function saveToDB() { + //e.preventDefault(); + alert("Saving contacts..."); + // const contactInputs = document.querySelectorAll(".text-input") as NodeListOf; + // const contactValues = Array.from(contactInputs).map(input => input.value); + // console.log("Contact values:", contactValues); + // // save codeword and contacts to database + // fetch("/api/databaseStorage", { + // method: "POST", + // headers: { + // "Content-Type": "application/json", + // }, + // body: JSON.stringify({ + // email: session?.user?.email || "", + // codeword: (document.getElementById("codeword") as HTMLInputElement)?.value, + // contacts: contactValues, + // }), + // }) + // .then((response) => { + // if (response.ok) { + // alert("Contacts saved successfully!"); + // } else { + // alert("Error saving contacts."); + // } + // }) + // .catch((error) => { + // console.error("Error:", error); + // alert("Error saving contacts."); + // }); + + } + // If no session, show sign-up and login buttons if (!session) { @@ -134,38 +164,9 @@ export default async function Home() { className="bg-emerald-500 text-fuchsia-300" type="button">Add -
@@ -181,4 +182,6 @@ export default async function Home() {
); + + } From d4a7cf0e2fcdd02ae37fd0df4e598d670eb9294e Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:43:08 -0400 Subject: [PATCH 25/30] Frontend Fixed --- Backend/api/app.py | 136 ++++ Backend/{ => api}/generator.py | 0 Backend/{ => api}/models.py | 0 Backend/api/routes.py | 74 ++ Backend/api/socket_handlers.py | 392 ++++++++++ Backend/{ => api}/watermarking.py | 0 Backend/index.html | 419 ----------- Backend/requirements.txt | 13 - Backend/run_csm.py | 117 --- Backend/server.py | 646 +--------------- Backend/setup.py | 13 - Backend/voice-chat.js | 1054 --------------------------- React/src/app/auth/session/route.ts | 12 + React/src/app/page.tsx | 236 +++--- 14 files changed, 777 insertions(+), 2335 deletions(-) create mode 100644 Backend/api/app.py rename Backend/{ => api}/generator.py (100%) rename Backend/{ => api}/models.py (100%) create mode 100644 Backend/api/routes.py create mode 100644 Backend/api/socket_handlers.py rename Backend/{ => api}/watermarking.py (100%) delete mode 100644 Backend/index.html delete mode 100644 Backend/requirements.txt delete mode 100644 Backend/run_csm.py delete mode 100644 Backend/setup.py delete mode 100644 Backend/voice-chat.js create mode 100644 React/src/app/auth/session/route.ts diff --git a/Backend/api/app.py b/Backend/api/app.py new file mode 100644 index 0000000..018061f --- /dev/null +++ b/Backend/api/app.py @@ -0,0 +1,136 @@ +import os +import logging +import threading +from dataclasses import dataclass +from flask import Flask +from flask_socketio import SocketIO +from flask_cors import CORS + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Configure device +import torch +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" + +logger.info(f"Using device: {DEVICE}") + +# Initialize Flask app +app = Flask(__name__, static_folder='../', static_url_path='') +CORS(app) +socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) + +# Global variables for conversation state +active_conversations = {} +user_queues = {} +processing_threads = {} + +# Model storage +@dataclass +class AppModels: + generator = None + tokenizer = None + llm = None + whisperx_model = None + whisperx_align_model = None + whisperx_align_metadata = None + last_language = None + +models = AppModels() + +def load_models(): + """Load all required models""" + from generator import load_csm_1b + import whisperx + import gc + from transformers import AutoModelForCausalLM, AutoTokenizer + global models + + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) + + # CSM 1B loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) + models.generator = load_csm_1b(device=DEVICE) + logger.info("CSM 1B model loaded successfully") + socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) + if DEVICE == "cuda": + torch.cuda.empty_cache() + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") + socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) + + # WhisperX loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) + # Use WhisperX for better transcription with timestamps + # Use compute_type based on device + compute_type = "float16" if DEVICE == "cuda" else "float32" + + # Load the WhisperX model (smaller model for faster processing) + models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) + + logger.info("WhisperX model loaded successfully") + socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) + if DEVICE == "cuda": + torch.cuda.empty_cache() + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") + socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) + + # Llama loading + try: + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) + models.llm = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + device_map=DEVICE, + torch_dtype=torch.bfloat16 + ) + models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + + # Configure all special tokens + models.tokenizer.pad_token = models.tokenizer.eos_token + models.tokenizer.padding_side = "left" # For causal language modeling + + # Inform the model about the pad token + if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: + models.llm.config.pad_token_id = models.tokenizer.pad_token_id + + logger.info("Llama 3.2 model loaded successfully") + socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) + socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) + except Exception as e: + logger.error(f"Error loading Llama 3.2 model: {str(e)}") + socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) + +# Load models in a background thread +threading.Thread(target=load_models, daemon=True).start() + +# Import routes and socket handlers +from api.routes import register_routes +from api.socket_handlers import register_handlers + +# Register routes and socket handlers +register_routes(app) +register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE) + +# Run server if executed directly +if __name__ == '__main__': + port = int(os.environ.get('PORT', 5000)) + debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' + logger.info(f"Starting server on port {port} (debug={debug_mode})") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/generator.py b/Backend/api/generator.py similarity index 100% rename from Backend/generator.py rename to Backend/api/generator.py diff --git a/Backend/models.py b/Backend/api/models.py similarity index 100% rename from Backend/models.py rename to Backend/api/models.py diff --git a/Backend/api/routes.py b/Backend/api/routes.py new file mode 100644 index 0000000..af1bfce --- /dev/null +++ b/Backend/api/routes.py @@ -0,0 +1,74 @@ +import os +import torch +import psutil +from flask import send_from_directory, jsonify, request + +def register_routes(app): + """Register HTTP routes for the application""" + + @app.route('/') + def index(): + """Serve the main application page""" + return send_from_directory(app.static_folder, 'index.html') + + @app.route('/voice-chat.js') + def serve_js(): + """Serve the JavaScript file""" + return send_from_directory(app.static_folder, 'voice-chat.js') + + @app.route('/api/status') + def system_status(): + """Return the system status""" + # Import here to avoid circular imports + from api.app import models, DEVICE + + return jsonify({ + "status": "ok", + "cuda_available": torch.cuda.is_available(), + "device": DEVICE, + "models": { + "generator": models.generator is not None, + "asr": models.whisperx_model is not None, + "llm": models.llm is not None + }, + "versions": { + "transformers": "4.49.0", # Replace with actual version + "torch": torch.__version__ + } + }) + + @app.route('/api/system_resources') + def system_resources(): + """Return system resource usage""" + # Import here to avoid circular imports + from api.app import active_conversations, DEVICE + + # Get CPU usage + cpu_percent = psutil.cpu_percent(interval=0.1) + + # Get memory usage + memory = psutil.virtual_memory() + memory_used_gb = memory.used / (1024 ** 3) + memory_total_gb = memory.total / (1024 ** 3) + memory_percent = memory.percent + + # Get GPU memory if available + gpu_memory = {} + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_memory[f"gpu_{i}"] = { + "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), + "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), + "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) + } + + return jsonify({ + "cpu_percent": cpu_percent, + "memory": { + "used_gb": memory_used_gb, + "total_gb": memory_total_gb, + "percent": memory_percent + }, + "gpu_memory": gpu_memory, + "active_sessions": len(active_conversations) + }) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py new file mode 100644 index 0000000..20513e9 --- /dev/null +++ b/Backend/api/socket_handlers.py @@ -0,0 +1,392 @@ +import os +import io +import base64 +import time +import threading +import queue +import tempfile +import gc +import logging +import traceback +from typing import Dict, List, Optional + +import torch +import torchaudio +import numpy as np +from flask import request +from flask_socketio import emit + +# Import conversation model +from generator import Segment + +logger = logging.getLogger(__name__) + +# Conversation data structure +class Conversation: + def __init__(self, session_id): + self.session_id = session_id + self.segments: List[Segment] = [] + self.current_speaker = 0 + self.ai_speaker_id = 1 # Default AI speaker ID + self.last_activity = time.time() + self.is_processing = False + + def add_segment(self, text, speaker, audio): + segment = Segment(text=text, speaker=speaker, audio=audio) + self.segments.append(segment) + self.last_activity = time.time() + return segment + + def get_context(self, max_segments=10): + """Return the most recent segments for context""" + return self.segments[-max_segments:] if self.segments else [] + +def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE): + """Register Socket.IO event handlers""" + + @socketio.on('connect') + def handle_connect(auth=None): + """Handle client connection""" + session_id = request.sid + logger.info(f"Client connected: {session_id}") + + # Initialize conversation data + if session_id not in active_conversations: + active_conversations[session_id] = Conversation(session_id) + user_queues[session_id] = queue.Queue() + processing_threads[session_id] = threading.Thread( + target=process_audio_queue, + args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE), + daemon=True + ) + processing_threads[session_id].start() + + emit('connection_status', {'status': 'connected'}) + + @socketio.on('disconnect') + def handle_disconnect(reason=None): + """Handle client disconnection""" + session_id = request.sid + logger.info(f"Client disconnected: {session_id}. Reason: {reason}") + + # Cleanup + if session_id in active_conversations: + # Mark for deletion rather than immediately removing + # as the processing thread might still be accessing it + active_conversations[session_id].is_processing = False + user_queues[session_id].put(None) # Signal thread to terminate + + @socketio.on('audio_data') + def handle_audio_data(data): + """Handle incoming audio data""" + session_id = request.sid + logger.info(f"Received audio data from {session_id}") + + # Check if the models are loaded + if models.generator is None or models.whisperx_model is None or models.llm is None: + emit('error', {'message': 'Models still loading, please wait'}) + return + + # Check if we're already processing for this session + if session_id in active_conversations and active_conversations[session_id].is_processing: + emit('error', {'message': 'Still processing previous audio, please wait'}) + return + + # Add to processing queue + if session_id in user_queues: + user_queues[session_id].put(data) + else: + emit('error', {'message': 'Session not initialized, please refresh the page'}) + +def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE): + """Background thread to process audio chunks for a session""" + logger.info(f"Started processing thread for session: {session_id}") + + try: + while session_id in active_conversations: + try: + # Get the next audio chunk with a timeout + data = q.get(timeout=120) + if data is None: # Termination signal + break + + # Process the audio and generate a response + process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE) + + except queue.Empty: + # Timeout, check if session is still valid + continue + except Exception as e: + logger.error(f"Error processing audio for {session_id}: {str(e)}") + # Create an app context for the socket emit + with app.app_context(): + socketio.emit('error', {'message': str(e)}, room=session_id) + finally: + logger.info(f"Ending processing thread for session: {session_id}") + # Clean up when thread is done + with app.app_context(): + if session_id in active_conversations: + del active_conversations[session_id] + if session_id in user_queues: + del user_queues[session_id] + +def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE): + """Process audio data and generate a response using WhisperX""" + if models.generator is None or models.whisperx_model is None or models.llm is None: + logger.warning("Models not yet loaded!") + with app.app_context(): + socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) + return + + logger.info(f"Processing audio for session {session_id}") + conversation = active_conversations[session_id] + + try: + # Set processing flag + conversation.is_processing = True + + # Process base64 audio data + audio_data = data['audio'] + speaker_id = data['speaker'] + logger.info(f"Received audio from speaker {speaker_id}") + + # Convert from base64 to WAV + try: + audio_bytes = base64.b64decode(audio_data.split(',')[1]) + logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") + except Exception as e: + logger.error(f"Error decoding base64 audio: {str(e)}") + raise + + # Save to temporary file for processing + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: + temp_file.write(audio_bytes) + temp_path = temp_file.name + + try: + # Notify client that transcription is starting + with app.app_context(): + socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) + + # Load audio using WhisperX + import whisperx + audio = whisperx.load_audio(temp_path) + + # Check audio length and add a warning for short clips + audio_length = len(audio) / 16000 # assuming 16kHz sample rate + if audio_length < 1.0: + logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") + + # Transcribe using WhisperX + batch_size = 16 # adjust based on your GPU memory + logger.info("Running WhisperX transcription...") + + # Handle the warning about audio being shorter than 30s by suppressing it + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="audio is shorter than 30s") + result = models.whisperx_model.transcribe(audio, batch_size=batch_size) + + # Get the detected language + language_code = result["language"] + logger.info(f"Detected language: {language_code}") + + # Check if alignment model needs to be loaded or updated + if models.whisperx_align_model is None or language_code != models.last_language: + # Clean up old models if they exist + if models.whisperx_align_model is not None: + del models.whisperx_align_model + del models.whisperx_align_metadata + if DEVICE == "cuda": + gc.collect() + torch.cuda.empty_cache() + + # Load new alignment model for the detected language + logger.info(f"Loading alignment model for language: {language_code}") + models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( + language_code=language_code, device=DEVICE + ) + models.last_language = language_code + + # Align the transcript to get word-level timestamps + if result["segments"] and len(result["segments"]) > 0: + logger.info("Aligning transcript...") + result = whisperx.align( + result["segments"], + models.whisperx_align_model, + models.whisperx_align_metadata, + audio, + DEVICE, + return_char_alignments=False + ) + + # Process the segments for better output + for segment in result["segments"]: + # Round timestamps for better display + segment["start"] = round(segment["start"], 2) + segment["end"] = round(segment["end"], 2) + # Add a confidence score if not present + if "confidence" not in segment: + segment["confidence"] = 1.0 # Default confidence + + # Extract the full text from all segments + user_text = ' '.join([segment['text'] for segment in result['segments']]) + + # If no text was recognized, don't process further + if not user_text or len(user_text.strip()) == 0: + with app.app_context(): + socketio.emit('error', {'message': 'No speech detected'}, room=session_id) + return + + logger.info(f"Transcription: {user_text}") + + # Load audio for CSM input + waveform, sample_rate = torchaudio.load(temp_path) + + # Normalize to mono if needed + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + # Resample to the CSM sample rate if needed + if sample_rate != models.generator.sample_rate: + waveform = torchaudio.functional.resample( + waveform, + orig_freq=sample_rate, + new_freq=models.generator.sample_rate + ) + + # Add the user's message to conversation history + user_segment = conversation.add_segment( + text=user_text, + speaker=speaker_id, + audio=waveform.squeeze() + ) + + # Send transcription to client with detailed segments + with app.app_context(): + socketio.emit('transcription', { + 'text': user_text, + 'speaker': speaker_id, + 'segments': result['segments'] # Include the detailed segments with timestamps + }, room=session_id) + + # Generate AI response using Llama + with app.app_context(): + socketio.emit('processing_status', {'status': 'generating'}, room=session_id) + + # Create prompt from conversation history + conversation_history = "" + for segment in conversation.segments[-5:]: # Last 5 segments for context + role = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{role}: {segment.text}\n" + + # Add final prompt + prompt = f"{conversation_history}Assistant: " + + # Generate response with Llama + try: + # Ensure pad token is set + if models.tokenizer.pad_token is None: + models.tokenizer.pad_token = models.tokenizer.eos_token + + input_tokens = models.tokenizer( + prompt, + return_tensors="pt", + padding=True, + return_attention_mask=True + ) + input_ids = input_tokens.input_ids.to(DEVICE) + attention_mask = input_tokens.attention_mask.to(DEVICE) + + with torch.no_grad(): + generated_ids = models.llm.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + do_sample=True, + pad_token_id=models.tokenizer.eos_token_id + ) + + # Decode the response + response_text = models.tokenizer.decode( + generated_ids[0][input_ids.shape[1]:], + skip_special_tokens=True + ).strip() + except Exception as e: + logger.error(f"Error generating response: {str(e)}") + logger.error(traceback.format_exc()) + response_text = "I'm sorry, I encountered an error while processing your request." + + # Synthesize speech + with app.app_context(): + socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) + + # Start sending the audio response + socketio.emit('audio_response_start', { + 'text': response_text, + 'total_chunks': 1, + 'chunk_index': 0 + }, room=session_id) + + # Define AI speaker ID + ai_speaker_id = conversation.ai_speaker_id + + # Generate audio + audio_tensor = models.generator.generate( + text=response_text, + speaker=ai_speaker_id, + context=conversation.get_context(), + max_audio_length_ms=10_000, + temperature=0.9 + ) + + # Add AI response to conversation history + ai_segment = conversation.add_segment( + text=response_text, + speaker=ai_speaker_id, + audio=audio_tensor + ) + + # Convert audio to WAV format + with io.BytesIO() as wav_io: + torchaudio.save( + wav_io, + audio_tensor.unsqueeze(0).cpu(), + models.generator.sample_rate, + format="wav" + ) + wav_io.seek(0) + wav_data = wav_io.read() + + # Convert WAV data to base64 + audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" + + # Send audio chunk to client + with app.app_context(): + socketio.emit('audio_response_chunk', { + 'chunk': audio_base64, + 'chunk_index': 0, + 'total_chunks': 1, + 'is_last': True + }, room=session_id) + + # Signal completion + socketio.emit('audio_response_complete', { + 'text': response_text + }, room=session_id) + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.unlink(temp_path) + + except Exception as e: + logger.error(f"Error processing audio: {str(e)}") + logger.error(traceback.format_exc()) + with app.app_context(): + socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) + finally: + # Reset processing flag + conversation.is_processing = False \ No newline at end of file diff --git a/Backend/watermarking.py b/Backend/api/watermarking.py similarity index 100% rename from Backend/watermarking.py rename to Backend/api/watermarking.py diff --git a/Backend/index.html b/Backend/index.html deleted file mode 100644 index 9950a00..0000000 --- a/Backend/index.html +++ /dev/null @@ -1,419 +0,0 @@ - - - - - - CSM Voice Chat - - - - - - -
-

CSM Voice Chat

-

Talk naturally with the AI using your voice

-
- -
-
-
-

Conversation

-
-
- Disconnected -
- - -
- 0% -
- - -
-
CSM
-
ASR
-
LLM
-
-
-
-
- -
-
-

Controls

-

Click the button below to start and stop recording.

-
- - -
- - -
- -
Start speaking to see audio visualization
-
-
-
-
-
- -
-

Settings

-
-
- - -
-
- - -
-
- - -
-
- - -
-
-
-
-
- -
-

Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition

-
- - - - - \ No newline at end of file diff --git a/Backend/requirements.txt b/Backend/requirements.txt deleted file mode 100644 index 1e05eb3..0000000 --- a/Backend/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -flask==2.2.5 -flask-socketio==5.3.6 -flask-cors==4.0.0 -torch==2.4.0 -torchaudio==2.4.0 -tokenizers==0.21.0 -transformers==4.49.0 -librosa==0.10.1 -huggingface_hub==0.28.1 -moshi==0.2.2 -torchtune==0.4.0 -torchao==0.9.0 -silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master \ No newline at end of file diff --git a/Backend/run_csm.py b/Backend/run_csm.py deleted file mode 100644 index 0062973..0000000 --- a/Backend/run_csm.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import torch -import torchaudio -from huggingface_hub import hf_hub_download -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -# Disable Triton compilation -os.environ["NO_TORCH_COMPILE"] = "1" - -# Default prompts are available at https://hf.co/sesame/csm-1b -prompt_filepath_conversational_a = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_a.wav" -) -prompt_filepath_conversational_b = hf_hub_download( - repo_id="sesame/csm-1b", - filename="prompts/conversational_b.wav" -) - -SPEAKER_PROMPTS = { - "conversational_a": { - "text": ( - "like revising for an exam I'd have to try and like keep up the momentum because I'd " - "start really early I'd be like okay I'm gonna start revising now and then like " - "you're revising for ages and then I just like start losing steam I didn't do that " - "for the exam we had recently to be fair that was a more of a last minute scenario " - "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " - "sort of start the day with this not like a panic but like a" - ), - "audio": prompt_filepath_conversational_a - }, - "conversational_b": { - "text": ( - "like a super Mario level. Like it's very like high detail. And like, once you get " - "into the park, it just like, everything looks like a computer game and they have all " - "these, like, you know, if, if there's like a, you know, like in a Mario game, they " - "will have like a question block. And if you like, you know, punch it, a coin will " - "come out. So like everyone, when they come into the park, they get like this little " - "bracelet and then you can go punching question blocks around." - ), - "audio": prompt_filepath_conversational_b - } -} - -def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: - audio_tensor, sample_rate = torchaudio.load(audio_path) - audio_tensor = audio_tensor.squeeze(0) - # Resample is lazy so we can always call it - audio_tensor = torchaudio.functional.resample( - audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate - ) - return audio_tensor - -def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: - audio_tensor = load_prompt_audio(audio_path, sample_rate) - return Segment(text=text, speaker=speaker, audio=audio_tensor) - -def main(): - # Select the best available device, skipping MPS due to float64 limitations - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - print(f"Using device: {device}") - - # Load model - generator = load_csm_1b(device) - - # Prepare prompts - prompt_a = prepare_prompt( - SPEAKER_PROMPTS["conversational_a"]["text"], - 0, - SPEAKER_PROMPTS["conversational_a"]["audio"], - generator.sample_rate - ) - - prompt_b = prepare_prompt( - SPEAKER_PROMPTS["conversational_b"]["text"], - 1, - SPEAKER_PROMPTS["conversational_b"]["audio"], - generator.sample_rate - ) - - # Generate conversation - conversation = [ - {"text": "Hey how are you doing?", "speaker_id": 0}, - {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, - {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, - {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} - ] - - # Generate each utterance - generated_segments = [] - prompt_segments = [prompt_a, prompt_b] - - for utterance in conversation: - print(f"Generating: {utterance['text']}") - audio_tensor = generator.generate( - text=utterance['text'], - speaker=utterance['speaker_id'], - context=prompt_segments + generated_segments, - max_audio_length_ms=10_000, - ) - generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) - - # Concatenate all generations - all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) - torchaudio.save( - "full_conversation.wav", - all_audio.unsqueeze(0).cpu(), - generator.sample_rate - ) - print("Successfully generated full_conversation.wav") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py index e912a9d..b8af6b7 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -1,639 +1,19 @@ -import os -import io -import base64 -import time -import json -import uuid -import logging -import threading -import queue -import tempfile -import gc -from typing import Dict, List, Optional, Tuple +""" +CSM Voice Chat Server +A voice chat application that uses CSM 1B for voice synthesis, +WhisperX for speech recognition, and Llama 3.2 for language generation. +""" -import torch -import torchaudio -import numpy as np -from flask import Flask, request, jsonify, send_from_directory -from flask_socketio import SocketIO, emit -from flask_cors import CORS -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +# Start the Flask application +from api.app import app, socketio -# Import WhisperX for better transcription -import whisperx - -from generator import load_csm_1b, Segment -from dataclasses import dataclass - -# Add these imports at the top -import psutil -import gc - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Initialize Flask app -app = Flask(__name__, static_folder='.') -CORS(app) -socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) - -# Configure device -if torch.cuda.is_available(): - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" - -logger.info(f"Using device: {DEVICE}") - -# Global variables -active_conversations = {} -user_queues = {} -processing_threads = {} - -# Load models -@dataclass -class AppModels: - generator = None - tokenizer = None - llm = None - whisperx_model = None - whisperx_align_model = None - whisperx_align_metadata = None - diarize_model = None - last_language = None - -# Initialize the models object -models = AppModels() - -def load_models(): - """Load all required models""" - global models - - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) - - # CSM 1B loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) - models.generator = load_csm_1b(device=DEVICE) - logger.info("CSM 1B model loaded successfully") - socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - - # WhisperX loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use WhisperX for better transcription with timestamps - import whisperx - - # Use compute_type based on device - compute_type = "float16" if DEVICE == "cuda" else "float32" - - # Load the WhisperX model (smaller model for faster processing) - models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - - logger.info("WhisperX model loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) - - # Llama loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") - - # Configure all special tokens - models.tokenizer.pad_token = models.tokenizer.eos_token - models.tokenizer.padding_side = "left" # For causal language modeling - - # Inform the model about the pad token - if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: - models.llm.config.pad_token_id = models.tokenizer.pad_token_id - - logger.info("Llama 3.2 model loaded successfully") - socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) - except Exception as e: - logger.error(f"Error loading Llama 3.2 model: {str(e)}") - socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) - -# Load models in a background thread -threading.Thread(target=load_models, daemon=True).start() - -# Conversation data structure -class Conversation: - def __init__(self, session_id): - self.session_id = session_id - self.segments: List[Segment] = [] - self.current_speaker = 0 - self.ai_speaker_id = 1 # Add this property - self.last_activity = time.time() - self.is_processing = False - - def add_segment(self, text, speaker, audio): - segment = Segment(text=text, speaker=speaker, audio=audio) - self.segments.append(segment) - self.last_activity = time.time() - return segment - - def get_context(self, max_segments=10): - """Return the most recent segments for context""" - return self.segments[-max_segments:] if self.segments else [] - -# Routes -@app.route('/') -def index(): - return send_from_directory('.', 'index.html') - -@app.route('/voice-chat.js') -def voice_chat_js(): - return send_from_directory('.', 'voice-chat.js') - -@app.route('/api/health') -def health_check(): - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "models_loaded": models.generator is not None and models.llm is not None - }) - -# Fix the system_status function: - -@app.route('/api/status') -def system_status(): - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "device": DEVICE, - "models": { - "generator": models.generator is not None, - "asr": models.whisperx_model is not None, # Use the correct model name - "llm": models.llm is not None - } - }) - -# Add a new endpoint to check system resources -@app.route('/api/system_resources') -def system_resources(): - # Get CPU usage - cpu_percent = psutil.cpu_percent(interval=0.1) - - # Get memory usage - memory = psutil.virtual_memory() - memory_used_gb = memory.used / (1024 ** 3) - memory_total_gb = memory.total / (1024 ** 3) - memory_percent = memory.percent - - # Get GPU memory if available - gpu_memory = {} - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - gpu_memory[f"gpu_{i}"] = { - "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), - "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), - "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) - } - - return jsonify({ - "cpu_percent": cpu_percent, - "memory": { - "used_gb": memory_used_gb, - "total_gb": memory_total_gb, - "percent": memory_percent - }, - "gpu_memory": gpu_memory, - "active_sessions": len(active_conversations) - }) - -# Socket event handlers -@socketio.on('connect') -def handle_connect(auth=None): - session_id = request.sid - logger.info(f"Client connected: {session_id}") - - # Initialize conversation data - if session_id not in active_conversations: - active_conversations[session_id] = Conversation(session_id) - user_queues[session_id] = queue.Queue() - processing_threads[session_id] = threading.Thread( - target=process_audio_queue, - args=(session_id, user_queues[session_id]), - daemon=True - ) - processing_threads[session_id].start() - - emit('connection_status', {'status': 'connected'}) - -@socketio.on('disconnect') -def handle_disconnect(reason=None): - session_id = request.sid - logger.info(f"Client disconnected: {session_id}. Reason: {reason}") - - # Cleanup - if session_id in active_conversations: - # Mark for deletion rather than immediately removing - # as the processing thread might still be accessing it - active_conversations[session_id].is_processing = False - user_queues[session_id].put(None) # Signal thread to terminate - -@socketio.on('start_stream') -def handle_start_stream(): - session_id = request.sid - logger.info(f"Starting stream for client: {session_id}") - emit('streaming_status', {'status': 'active'}) - -@socketio.on('stop_stream') -def handle_stop_stream(): - session_id = request.sid - logger.info(f"Stopping stream for client: {session_id}") - emit('streaming_status', {'status': 'inactive'}) - -@socketio.on('clear_context') -def handle_clear_context(): - session_id = request.sid - logger.info(f"Clearing context for client: {session_id}") - - if session_id in active_conversations: - active_conversations[session_id].segments = [] - emit('context_updated', {'status': 'cleared'}) - -@socketio.on('audio_chunk') -def handle_audio_chunk(data): - session_id = request.sid - audio_data = data.get('audio', '') - speaker_id = int(data.get('speaker', 0)) - - if not audio_data or not session_id in active_conversations: - return - - # Update the current speaker - active_conversations[session_id].current_speaker = speaker_id - - # Queue audio for processing - user_queues[session_id].put({ - 'audio': audio_data, - 'speaker': speaker_id - }) - - emit('processing_status', {'status': 'transcribing'}) - -def process_audio_queue(session_id, q): - """Background thread to process audio chunks for a session""" - logger.info(f"Started processing thread for session: {session_id}") - - try: - while session_id in active_conversations: - try: - # Get the next audio chunk with a timeout - data = q.get(timeout=120) - if data is None: # Termination signal - break - - # Process the audio and generate a response - process_audio_and_respond(session_id, data) - - except queue.Empty: - # Timeout, check if session is still valid - continue - except Exception as e: - logger.error(f"Error processing audio for {session_id}: {str(e)}") - # Create an app context for the socket emit - with app.app_context(): - socketio.emit('error', {'message': str(e)}, room=session_id) - finally: - logger.info(f"Ending processing thread for session: {session_id}") - # Clean up when thread is done - with app.app_context(): - if session_id in active_conversations: - del active_conversations[session_id] - if session_id in user_queues: - del user_queues[session_id] - -def process_audio_and_respond(session_id, data): - """Process audio data and generate a response using WhisperX""" - if models.generator is None or models.whisperx_model is None or models.llm is None: - logger.warning("Models not yet loaded!") - with app.app_context(): - socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) - return - - logger.info(f"Processing audio for session {session_id}") - conversation = active_conversations[session_id] - - try: - # Set processing flag - conversation.is_processing = True - - # Process base64 audio data - audio_data = data['audio'] - speaker_id = data['speaker'] - logger.info(f"Received audio from speaker {speaker_id}") - - # Convert from base64 to WAV - try: - audio_bytes = base64.b64decode(audio_data.split(',')[1]) - logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") - except Exception as e: - logger.error(f"Error decoding base64 audio: {str(e)}") - raise - - # Save to temporary file for processing - with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: - temp_file.write(audio_bytes) - temp_path = temp_file.name - - try: - # Notify client that transcription is starting - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Load audio using WhisperX - import whisperx - audio = whisperx.load_audio(temp_path) - - # Check audio length and add a warning for short clips - audio_length = len(audio) / 16000 # assuming 16kHz sample rate - if audio_length < 1.0: - logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") - - # Transcribe using WhisperX - batch_size = 16 # adjust based on your GPU memory - logger.info("Running WhisperX transcription...") - - # Handle the warning about audio being shorter than 30s by suppressing it - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="audio is shorter than 30s") - result = models.whisperx_model.transcribe(audio, batch_size=batch_size) - - # Get the detected language - language_code = result["language"] - logger.info(f"Detected language: {language_code}") - - # Check if alignment model needs to be loaded or updated - if models.whisperx_align_model is None or language_code != models.last_language: - # Clean up old models if they exist - if models.whisperx_align_model is not None: - del models.whisperx_align_model - del models.whisperx_align_metadata - if DEVICE == "cuda": - gc.collect() - torch.cuda.empty_cache() - - # Load new alignment model for the detected language - logger.info(f"Loading alignment model for language: {language_code}") - models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( - language_code=language_code, device=DEVICE - ) - models.last_language = language_code - - # Align the transcript to get word-level timestamps - if result["segments"] and len(result["segments"]) > 0: - logger.info("Aligning transcript...") - result = whisperx.align( - result["segments"], - models.whisperx_align_model, - models.whisperx_align_metadata, - audio, - DEVICE, - return_char_alignments=False - ) - - # Process the segments for better output - for segment in result["segments"]: - # Round timestamps for better display - segment["start"] = round(segment["start"], 2) - segment["end"] = round(segment["end"], 2) - # Add a confidence score if not present - if "confidence" not in segment: - segment["confidence"] = 1.0 # Default confidence - - # Extract the full text from all segments - user_text = ' '.join([segment['text'] for segment in result['segments']]) - - # If no text was recognized, don't process further - if not user_text or len(user_text.strip()) == 0: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - - logger.info(f"Transcription: {user_text}") - - # Load audio for CSM input - waveform, sample_rate = torchaudio.load(temp_path) - - # Normalize to mono if needed - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - # Resample to the CSM sample rate if needed - if sample_rate != models.generator.sample_rate: - waveform = torchaudio.functional.resample( - waveform, - orig_freq=sample_rate, - new_freq=models.generator.sample_rate - ) - - # Add the user's message to conversation history - user_segment = conversation.add_segment( - text=user_text, - speaker=speaker_id, - audio=waveform.squeeze() - ) - - # Send transcription to client with detailed segments - with app.app_context(): - socketio.emit('transcription', { - 'text': user_text, - 'speaker': speaker_id, - 'segments': result['segments'] # Include the detailed segments with timestamps - }, room=session_id) - - # Generate AI response using Llama - with app.app_context(): - socketio.emit('processing_status', {'status': 'generating'}, room=session_id) - - # Create prompt from conversation history - conversation_history = "" - for segment in conversation.segments[-5:]: # Last 5 segments for context - role = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{role}: {segment.text}\n" - - # Add final prompt - prompt = f"{conversation_history}Assistant: " - - # Generate response with Llama - try: - # Ensure pad token is set - if models.tokenizer.pad_token is None: - models.tokenizer.pad_token = models.tokenizer.eos_token - - input_tokens = models.tokenizer( - prompt, - return_tensors="pt", - padding=True, - return_attention_mask=True - ) - input_ids = input_tokens.input_ids.to(DEVICE) - attention_mask = input_tokens.attention_mask.to(DEVICE) - - with torch.no_grad(): - generated_ids = models.llm.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=100, - temperature=0.7, - top_p=0.9, - do_sample=True, - pad_token_id=models.tokenizer.eos_token_id - ) - - # Decode the response - response_text = models.tokenizer.decode( - generated_ids[0][input_ids.shape[1]:], - skip_special_tokens=True - ).strip() - except Exception as e: - logger.error(f"Error generating response: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - response_text = "I'm sorry, I encountered an error while processing your request." - - # Synthesize speech - with app.app_context(): - socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) - - # Start sending the audio response - socketio.emit('audio_response_start', { - 'text': response_text, - 'total_chunks': 1, - 'chunk_index': 0 - }, room=session_id) - - # Define AI speaker ID - ai_speaker_id = conversation.ai_speaker_id - - # Generate audio - audio_tensor = models.generator.generate( - text=response_text, - speaker=ai_speaker_id, - context=conversation.get_context(), - max_audio_length_ms=10_000, - temperature=0.9 - ) - - # Add AI response to conversation history - ai_segment = conversation.add_segment( - text=response_text, - speaker=ai_speaker_id, - audio=audio_tensor - ) - - # Convert audio to WAV format - with io.BytesIO() as wav_io: - torchaudio.save( - wav_io, - audio_tensor.unsqueeze(0).cpu(), - models.generator.sample_rate, - format="wav" - ) - wav_io.seek(0) - wav_data = wav_io.read() - - # Convert WAV data to base64 - audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" - - # Send audio chunk to client - with app.app_context(): - socketio.emit('audio_response_chunk', { - 'chunk': audio_base64, - 'chunk_index': 0, - 'total_chunks': 1, - 'is_last': True - }, room=session_id) - - # Signal completion - socketio.emit('audio_response_complete', { - 'text': response_text - }, room=session_id) - - finally: - # Clean up temp file - if os.path.exists(temp_path): - os.unlink(temp_path) - - except Exception as e: - logger.error(f"Error processing audio: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - with app.app_context(): - socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) - finally: - # Reset processing flag - conversation.is_processing = False - -# Error handler -@socketio.on_error() -def error_handler(e): - logger.error(f"SocketIO error: {str(e)}") - -# Periodic cleanup of inactive sessions -def cleanup_inactive_sessions(): - """Remove sessions that have been inactive for too long""" - current_time = time.time() - inactive_timeout = 3600 # 1 hour - - for session_id in list(active_conversations.keys()): - conversation = active_conversations[session_id] - if (current_time - conversation.last_activity > inactive_timeout and - not conversation.is_processing): - - logger.info(f"Cleaning up inactive session: {session_id}") - - # Signal processing thread to terminate - if session_id in user_queues: - user_queues[session_id].put(None) - - # Remove from active conversations - del active_conversations[session_id] - -# Start cleanup thread -def start_cleanup_thread(): - while True: - try: - cleanup_inactive_sessions() - except Exception as e: - logger.error(f"Error in cleanup thread: {str(e)}") - time.sleep(300) # Run every 5 minutes - -cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True) -cleanup_thread.start() - -# Start the server if __name__ == '__main__': + import os + port = int(os.environ.get('PORT', 5000)) debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - logger.info(f"Starting server on port {port} (debug={debug_mode})") + + print(f"Starting server on port {port} (debug={debug_mode})") + print("Visit http://localhost:5000 to access the application") + socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/setup.py b/Backend/setup.py deleted file mode 100644 index 8eddb95..0000000 --- a/Backend/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import setup, find_packages -import os - -# Read requirements from requirements.txt -with open('requirements.txt') as f: - requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] - -setup( - name='csm', - version='0.1.0', - packages=find_packages(), - install_requires=requirements, -) diff --git a/Backend/voice-chat.js b/Backend/voice-chat.js deleted file mode 100644 index dc2db04..0000000 --- a/Backend/voice-chat.js +++ /dev/null @@ -1,1054 +0,0 @@ -/** - * CSM AI Voice Chat Client - * - * A web client that connects to a CSM AI voice chat server and enables - * real-time voice conversation with an AI assistant. - */ - -// Configuration constants -const SERVER_URL = window.location.hostname === 'localhost' ? - 'http://localhost:5000' : window.location.origin; -const ENERGY_WINDOW_SIZE = 15; -const CLIENT_SILENCE_DURATION_MS = 750; - -// DOM elements -const elements = { - conversation: document.getElementById('conversation'), - streamButton: document.getElementById('streamButton'), - clearButton: document.getElementById('clearButton'), - thresholdSlider: document.getElementById('thresholdSlider'), - thresholdValue: document.getElementById('thresholdValue'), - visualizerCanvas: document.getElementById('audioVisualizer'), - visualizerLabel: document.getElementById('visualizerLabel'), - volumeLevel: document.getElementById('volumeLevel'), - statusDot: document.getElementById('statusDot'), - statusText: document.getElementById('statusText'), - speakerSelection: document.getElementById('speakerSelect'), - autoPlayResponses: document.getElementById('autoPlayResponses'), - showVisualizer: document.getElementById('showVisualizer') -}; - -// Application state -const state = { - socket: null, - audioContext: null, - analyser: null, - microphone: null, - streamProcessor: null, - isStreaming: false, - isSpeaking: false, - silenceThreshold: 0.01, - energyWindow: [], - silenceTimer: null, - volumeUpdateInterval: null, - visualizerAnimationFrame: null, - currentSpeaker: 0, - aiSpeakerId: 1, // Define the AI's speaker ID to match server.py - transcriptionRetries: 0, - maxTranscriptionRetries: 3 -}; - -// Visualizer variables -let canvasContext = null; -let visualizerBufferLength = 0; -let visualizerDataArray = null; - -// Audio streaming state -const streamingAudio = { - messageElement: null, - audioElement: null, - chunks: [], - totalChunks: 0, - receivedChunks: 0, - text: '', - complete: false -}; - -// Initialize the application -function initializeApp() { - // Initialize the UI elements - initializeUIElements(); - - // Initialize socket.io connection - setupSocketConnection(); - - // Setup event listeners - setupEventListeners(); - - // Initialize visualizer - setupVisualizer(); - - // Show welcome message - addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.'); -} - -// Initialize UI elements -function initializeUIElements() { - // Update threshold display - if (elements.thresholdValue) { - elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); - } -} - -// Setup Socket.IO connection -function setupSocketConnection() { - state.socket = io(SERVER_URL); - - // Connection events - state.socket.on('connect', () => { - updateConnectionStatus(true); - addSystemMessage('Connected to server.'); - }); - - state.socket.on('disconnect', () => { - updateConnectionStatus(false); - addSystemMessage('Disconnected from server.'); - stopStreaming(false); - }); - - state.socket.on('error', (data) => { - console.error('Server error:', data.message); - - // Make the error more user-friendly - let userMessage = data.message; - - // Check for common errors and provide more helpful messages - if (data.message.includes('Models still loading')) { - userMessage = 'The AI models are still loading. Please wait a moment and try again.'; - } else if (data.message.includes('No speech detected')) { - userMessage = 'No speech was detected. Please speak clearly and try again.'; - } - - addSystemMessage(`Error: ${userMessage}`); - - // Reset button state if it was processing - if (elements.streamButton.classList.contains('processing')) { - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Start Conversation'; - } - }); - - // Register message handlers - state.socket.on('transcription', handleTranscription); - state.socket.on('context_updated', handleContextUpdate); - state.socket.on('streaming_status', handleStreamingStatus); - state.socket.on('processing_status', handleProcessingStatus); - - // Add model status handlers - state.socket.on('model_status', handleModelStatusUpdate); - - // Handlers for incremental audio streaming - state.socket.on('audio_response_start', handleAudioResponseStart); - state.socket.on('audio_response_chunk', handleAudioResponseChunk); - state.socket.on('audio_response_complete', handleAudioResponseComplete); -} - -// Setup event listeners -function setupEventListeners() { - // Stream button - elements.streamButton.addEventListener('click', toggleStreaming); - - // Clear button - elements.clearButton.addEventListener('click', clearConversation); - - // Threshold slider - if (elements.thresholdSlider) { - elements.thresholdSlider.addEventListener('input', updateThreshold); - } - - // Speaker selection - elements.speakerSelection.addEventListener('change', () => { - state.currentSpeaker = parseInt(elements.speakerSelection.value); - }); - - // Visualizer toggle - if (elements.showVisualizer) { - elements.showVisualizer.addEventListener('change', toggleVisualizerVisibility); - } -} - -// Setup audio visualizer -function setupVisualizer() { - if (!elements.visualizerCanvas) return; - - canvasContext = elements.visualizerCanvas.getContext('2d'); - - // Set canvas dimensions - elements.visualizerCanvas.width = elements.visualizerCanvas.offsetWidth; - elements.visualizerCanvas.height = elements.visualizerCanvas.offsetHeight; - - // Initialize visualization data array - visualizerDataArray = new Uint8Array(128); - - // Start the visualizer animation - drawVisualizer(); -} - -// Update connection status UI -function updateConnectionStatus(isConnected) { - if (isConnected) { - elements.statusDot.classList.add('active'); - elements.statusText.textContent = 'Connected'; - } else { - elements.statusDot.classList.remove('active'); - elements.statusText.textContent = 'Disconnected'; - } -} - -// Toggle streaming state -function toggleStreaming() { - if (state.isStreaming) { - stopStreaming(); - } else { - startStreaming(); - } -} - -// Start streaming audio to the server -function startStreaming() { - if (!state.socket || !state.socket.connected) { - addSystemMessage('Not connected to server. Please refresh the page.'); - return; - } - - // Check if models are loaded via the API - fetch('/api/status') - .then(response => response.json()) - .then(data => { - if (!data.models.generator || !data.models.asr || !data.models.llm) { - addSystemMessage('Still loading AI models. Please wait...'); - return; - } - - // Continue with recording if models are loaded - initializeRecording(); - }) - .catch(error => { - console.error('Error checking model status:', error); - // Try anyway, the server will respond with an error if models aren't ready - initializeRecording(); - }); -} - -// Extracted the recording initialization to a separate function -function initializeRecording() { - // Request microphone access - navigator.mediaDevices.getUserMedia({ audio: true, video: false }) - .then(stream => { - state.isStreaming = true; - elements.streamButton.classList.add('recording'); - elements.streamButton.innerHTML = ' Stop Recording'; - - // Initialize Web Audio API - state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); - state.microphone = state.audioContext.createMediaStreamSource(stream); - state.analyser = state.audioContext.createAnalyser(); - state.analyser.fftSize = 2048; - - // Setup analyzer for visualizer - visualizerBufferLength = state.analyser.frequencyBinCount; - visualizerDataArray = new Uint8Array(visualizerBufferLength); - - state.microphone.connect(state.analyser); - - // Create processor node for audio data - const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1); - processorNode.onaudioprocess = handleAudioProcess; - state.analyser.connect(processorNode); - processorNode.connect(state.audioContext.destination); - state.streamProcessor = processorNode; - - state.silenceTimer = null; - state.energyWindow = []; - state.isSpeaking = false; - - // Notify server - state.socket.emit('start_stream'); - - // Start volume meter updates - state.volumeUpdateInterval = setInterval(updateVolumeMeter, 100); - - // Make sure visualizer is visible if enabled - if (elements.showVisualizer && elements.showVisualizer.checked) { - elements.visualizerLabel.style.opacity = '0'; - } - - addSystemMessage('Recording started. Speak now...'); - }) - .catch(error => { - console.error('Error accessing microphone:', error); - addSystemMessage('Could not access microphone. Please check permissions.'); - }); -} - -// Stop streaming audio -function stopStreaming(notifyServer = true) { - if (state.isStreaming) { - state.isStreaming = false; - elements.streamButton.classList.remove('recording'); - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Start Conversation'; - - // Clean up audio resources - if (state.streamProcessor) { - state.streamProcessor.disconnect(); - state.streamProcessor = null; - } - - if (state.analyser) { - state.analyser.disconnect(); - state.analyser = null; - } - - if (state.microphone) { - state.microphone.disconnect(); - state.microphone = null; - } - - if (state.audioContext) { - state.audioContext.close().catch(err => console.warn('Error closing audio context:', err)); - state.audioContext = null; - } - - // Clear any pending silence timer - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - - // Clear volume meter updates - if (state.volumeUpdateInterval) { - clearInterval(state.volumeUpdateInterval); - state.volumeUpdateInterval = null; - - // Reset volume meter - if (elements.volumeLevel) { - elements.volumeLevel.style.width = '0%'; - } - } - - // Show visualizer label - if (elements.visualizerLabel) { - elements.visualizerLabel.style.opacity = '0.7'; - } - - // Notify server if needed - if (notifyServer && state.socket && state.socket.connected) { - state.socket.emit('stop_stream'); - } - - addSystemMessage('Recording stopped.'); - } -} - -// Handle audio processing -function handleAudioProcess(event) { - if (!state.isStreaming) return; - - const inputData = event.inputBuffer.getChannelData(0); - const energy = calculateAudioEnergy(inputData); - updateEnergyWindow(energy); - - const averageEnergy = calculateAverageEnergy(); - const isSilent = averageEnergy < state.silenceThreshold; - - handleSpeechState(isSilent); -} - -// Calculate audio energy (volume) -function calculateAudioEnergy(buffer) { - let sum = 0; - for (let i = 0; i < buffer.length; i++) { - sum += buffer[i] * buffer[i]; - } - return Math.sqrt(sum / buffer.length); -} - -// Update energy window for averaging -function updateEnergyWindow(energy) { - state.energyWindow.push(energy); - if (state.energyWindow.length > ENERGY_WINDOW_SIZE) { - state.energyWindow.shift(); - } -} - -// Calculate average energy from window -function calculateAverageEnergy() { - if (state.energyWindow.length === 0) return 0; - - const sum = state.energyWindow.reduce((acc, val) => acc + val, 0); - return sum / state.energyWindow.length; -} - -// Update the threshold from the slider -function updateThreshold() { - state.silenceThreshold = parseFloat(elements.thresholdSlider.value); - elements.thresholdValue.textContent = state.silenceThreshold.toFixed(3); -} - -// Update the volume meter display -function updateVolumeMeter() { - if (!state.isStreaming || !state.energyWindow.length || !elements.volumeLevel) return; - - const avgEnergy = calculateAverageEnergy(); - - // Scale energy to percentage (0-100) - // Energy values are typically very small (e.g., 0.001 to 0.1) - const scaleFactor = 1000; - const percentage = Math.min(100, Math.max(0, avgEnergy * scaleFactor)); - - // Update volume meter width - elements.volumeLevel.style.width = `${percentage}%`; - - // Change color based on level - if (percentage > 70) { - elements.volumeLevel.style.backgroundColor = '#ff5252'; - } else if (percentage > 30) { - elements.volumeLevel.style.backgroundColor = '#4CAF50'; - } else { - elements.volumeLevel.style.backgroundColor = '#4c84ff'; - } -} - -// Handle speech/silence state transitions -function handleSpeechState(isSilent) { - if (state.isSpeaking) { - if (isSilent) { - // User was speaking but now is silent - if (!state.silenceTimer) { - state.silenceTimer = setTimeout(() => { - // Silence lasted long enough, consider speech done - if (state.isSpeaking) { - state.isSpeaking = false; - - try { - // Get the current audio data and send it - const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max - state.analyser.getFloatTimeDomainData(audioBuffer); - - // Check if audio has content - const hasAudioContent = audioBuffer.some(sample => Math.abs(sample) > 0.01); - - if (!hasAudioContent) { - console.warn('Audio buffer appears to be empty or very quiet'); - - if (state.transcriptionRetries < state.maxTranscriptionRetries) { - state.transcriptionRetries++; - const retryMessage = `No speech detected (attempt ${state.transcriptionRetries}/${state.maxTranscriptionRetries}). Please speak louder and try again.`; - addSystemMessage(retryMessage); - } else { - state.transcriptionRetries = 0; - addSystemMessage('Multiple attempts failed to detect speech. Please check your microphone and try again.'); - } - return; - } - - // Create WAV blob - const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate); - - // Convert to base64 - const reader = new FileReader(); - reader.onloadend = function() { - sendAudioChunk(reader.result, state.currentSpeaker); - }; - reader.readAsDataURL(wavBlob); - - // Update button state - elements.streamButton.classList.add('processing'); - elements.streamButton.innerHTML = ' Processing...'; - - addSystemMessage('Processing your message...'); - } catch (e) { - console.error('Error recording audio:', e); - addSystemMessage('Error recording audio. Please try again.'); - } - } - }, CLIENT_SILENCE_DURATION_MS); - } - } else { - // User is still speaking, reset silence timer - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - } - } else { - if (!isSilent) { - // User started speaking - state.isSpeaking = true; - if (state.silenceTimer) { - clearTimeout(state.silenceTimer); - state.silenceTimer = null; - } - } - } -} - -// Send audio chunk to server -function sendAudioChunk(audioData, speaker) { - if (state.socket && state.socket.connected) { - state.socket.emit('audio_chunk', { - audio: audioData, - speaker: speaker - }); - } -} - -// Create WAV blob from audio data -function createWavBlob(audioData, sampleRate) { - const numChannels = 1; - const bitsPerSample = 16; - const bytesPerSample = bitsPerSample / 8; - - // Create buffer for WAV file - const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample); - const view = new DataView(buffer); - - // Write WAV header - // "RIFF" chunk descriptor - writeString(view, 0, 'RIFF'); - view.setUint32(4, 36 + audioData.length * bytesPerSample, true); - writeString(view, 8, 'WAVE'); - - // "fmt " sub-chunk - writeString(view, 12, 'fmt '); - view.setUint32(16, 16, true); // subchunk1size - view.setUint16(20, 1, true); // audio format (PCM) - view.setUint16(22, numChannels, true); - view.setUint32(24, sampleRate, true); - view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate - view.setUint16(32, numChannels * bytesPerSample, true); // block align - view.setUint16(34, bitsPerSample, true); - - // "data" sub-chunk - writeString(view, 36, 'data'); - view.setUint32(40, audioData.length * bytesPerSample, true); - - // Write audio data - const audioDataStart = 44; - for (let i = 0; i < audioData.length; i++) { - const sample = Math.max(-1, Math.min(1, audioData[i])); - const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF; - view.setInt16(audioDataStart + i * bytesPerSample, value, true); - } - - return new Blob([buffer], { type: 'audio/wav' }); -} - -// Helper function to write strings to DataView -function writeString(view, offset, string) { - for (let i = 0; i < string.length; i++) { - view.setUint8(offset + i, string.charCodeAt(i)); - } -} - -// Clear conversation history -function clearConversation() { - elements.conversation.innerHTML = ''; - if (state.socket && state.socket.connected) { - state.socket.emit('clear_context'); - } - addSystemMessage('Conversation cleared.'); -} - -// Draw audio visualizer -function drawVisualizer() { - if (!canvasContext || !elements.visualizerCanvas) { - state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - return; - } - - state.visualizerAnimationFrame = requestAnimationFrame(drawVisualizer); - - // Skip drawing if visualizer is hidden or not enabled - if (elements.showVisualizer && !elements.showVisualizer.checked) { - if (elements.visualizerCanvas.style.opacity !== '0') { - elements.visualizerCanvas.style.opacity = '0'; - } - return; - } else if (elements.visualizerCanvas.style.opacity !== '1') { - elements.visualizerCanvas.style.opacity = '1'; - } - - // Get frequency data if available - if (state.isStreaming && state.analyser) { - try { - state.analyser.getByteFrequencyData(visualizerDataArray); - } catch (e) { - console.warn('Error getting frequency data:', e); - } - } else { - // Fade out when not streaming - for (let i = 0; i < visualizerDataArray.length; i++) { - visualizerDataArray[i] = Math.max(0, visualizerDataArray[i] - 5); - } - } - - // Clear canvas - canvasContext.fillStyle = 'rgb(0, 0, 0)'; - canvasContext.fillRect(0, 0, elements.visualizerCanvas.width, elements.visualizerCanvas.height); - - // Draw gradient bars - const width = elements.visualizerCanvas.width; - const height = elements.visualizerCanvas.height; - const barCount = Math.min(visualizerBufferLength, 64); - const barWidth = width / barCount - 1; - - for (let i = 0; i < barCount; i++) { - const index = Math.floor(i * visualizerBufferLength / barCount); - const value = visualizerDataArray[index]; - - // Use logarithmic scale for better audio visualization - const logFactor = 20; - const scaledValue = Math.log(1 + (value / 255) * logFactor) / Math.log(1 + logFactor); - const barHeight = scaledValue * height; - - // Position bars - const x = i * (barWidth + 1); - const y = height - barHeight; - - // Create color gradient based on frequency and amplitude - const hue = i / barCount * 360; // Full color spectrum - const saturation = 80 + (value / 255 * 20); // Higher values more saturated - const lightness = 40 + (value / 255 * 20); // Dynamic brightness - - // Draw main bar - canvasContext.fillStyle = `hsl(${hue}, ${saturation}%, ${lightness}%)`; - canvasContext.fillRect(x, y, barWidth, barHeight); - - // Add highlight effect - if (barHeight > 5) { - const gradient = canvasContext.createLinearGradient( - x, y, - x, y + barHeight * 0.5 - ); - gradient.addColorStop(0, `hsla(${hue}, ${saturation}%, ${lightness + 20}%, 0.4)`); - gradient.addColorStop(1, `hsla(${hue}, ${saturation}%, ${lightness}%, 0)`); - canvasContext.fillStyle = gradient; - canvasContext.fillRect(x, y, barWidth, barHeight * 0.5); - - // Add highlight on top of the bar - canvasContext.fillStyle = `hsla(${hue}, ${saturation - 20}%, ${lightness + 30}%, 0.7)`; - canvasContext.fillRect(x, y, barWidth, 2); - } - } -} - -// Toggle visualizer visibility -function toggleVisualizerVisibility() { - const isVisible = elements.showVisualizer.checked; - elements.visualizerCanvas.style.opacity = isVisible ? '1' : '0'; -} - -// Add a message to the conversation -function addMessage(text, type) { - if (!elements.conversation) return; - - const messageDiv = document.createElement('div'); - messageDiv.className = `message ${type}`; - - const textElement = document.createElement('p'); - textElement.textContent = text; - messageDiv.appendChild(textElement); - - // Add timestamp to every message - const timestamp = new Date().toLocaleTimeString(); - const timeLabel = document.createElement('div'); - timeLabel.className = 'message-timestamp'; - timeLabel.textContent = timestamp; - messageDiv.appendChild(timeLabel); - - elements.conversation.appendChild(messageDiv); - - // Auto-scroll to the bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - return messageDiv; -} - -// Add a system message to the conversation -function addSystemMessage(text) { - if (!elements.conversation) return; - - const messageDiv = document.createElement('div'); - messageDiv.className = 'message system'; - messageDiv.textContent = text; - - elements.conversation.appendChild(messageDiv); - - // Auto-scroll to the bottom - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - return messageDiv; -} - -// Handle transcription response from server -function handleTranscription(data) { - const speaker = data.speaker === 0 ? 'user' : 'ai'; - - // Create the message div - const messageDiv = addMessage(data.text, speaker); - - // If we have detailed segments from WhisperX, add timestamps - if (data.segments && data.segments.length > 0) { - // Add a timestamps container - const timestampsContainer = document.createElement('div'); - timestampsContainer.className = 'timestamps-container'; - timestampsContainer.style.display = 'none'; // Hidden by default - - // Add a toggle button - const toggleButton = document.createElement('button'); - toggleButton.className = 'timestamp-toggle'; - toggleButton.textContent = 'Show Timestamps'; - toggleButton.onclick = function() { - const isHidden = timestampsContainer.style.display === 'none'; - timestampsContainer.style.display = isHidden ? 'block' : 'none'; - toggleButton.textContent = isHidden ? 'Hide Timestamps' : 'Show Timestamps'; - }; - - // Add timestamps for each segment - data.segments.forEach(segment => { - const timestampDiv = document.createElement('div'); - timestampDiv.className = 'timestamp'; - - // Format start and end times - const startTime = formatTime(segment.start); - const endTime = formatTime(segment.end); - - timestampDiv.innerHTML = ` - [${startTime} - ${endTime}] - ${segment.text} - `; - - timestampsContainer.appendChild(timestampDiv); - }); - - // Add the timestamp elements to the message - messageDiv.appendChild(toggleButton); - messageDiv.appendChild(timestampsContainer); - } else { - // No timestamp data available - add a simple timestamp for the entire message - const timestamp = new Date().toLocaleTimeString(); - const timeLabel = document.createElement('div'); - timeLabel.className = 'simple-timestamp'; - timeLabel.textContent = timestamp; - messageDiv.appendChild(timeLabel); - } - - return messageDiv; -} - -// Helper function to format time in seconds to MM:SS.ms format -function formatTime(seconds) { - const mins = Math.floor(seconds / 60); - const secs = Math.floor(seconds % 60); - const ms = Math.floor((seconds % 1) * 1000); - - return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}.${ms.toString().padStart(3, '0')}`; -} - -// Handle context update from server -function handleContextUpdate(data) { - if (data.status === 'cleared') { - elements.conversation.innerHTML = ''; - addSystemMessage('Conversation context cleared.'); - } -} - -// Handle streaming status updates from server -function handleStreamingStatus(data) { - if (data.status === 'active') { - console.log('Server acknowledged streaming is active'); - } else if (data.status === 'inactive') { - console.log('Server acknowledged streaming is inactive'); - } -} - -// Handle processing status updates -function handleProcessingStatus(data) { - switch (data.status) { - case 'transcribing': - addSystemMessage('Transcribing your message...'); - break; - case 'generating': - addSystemMessage('Generating response...'); - break; - case 'synthesizing': - addSystemMessage('Synthesizing voice...'); - break; - } -} - -// Handle the start of an audio streaming response -function handleAudioResponseStart(data) { - console.log(`Expecting ${data.total_chunks} audio chunks`); - - // Reset streaming state - streamingAudio.chunks = []; - streamingAudio.totalChunks = data.total_chunks; - streamingAudio.receivedChunks = 0; - streamingAudio.text = data.text; - streamingAudio.complete = false; -} - -// Handle an incoming audio chunk -function handleAudioResponseChunk(data) { - // Create or update audio element for playback - const audioElement = document.createElement('audio'); - if (elements.autoPlayResponses.checked) { - audioElement.autoplay = true; - } - audioElement.controls = true; - audioElement.className = 'audio-player'; - audioElement.src = data.chunk; - - // Store the chunk - streamingAudio.chunks[data.chunk_index] = data.chunk; - streamingAudio.receivedChunks++; - - // Store audio element reference for later use - streamingAudio.audioElement = audioElement; - - // Add to the conversation - const messages = elements.conversation.querySelectorAll('.message.ai'); - if (messages.length > 0) { - const lastAiMessage = messages[messages.length - 1]; - streamingAudio.messageElement = lastAiMessage; - - // Replace existing audio player if there is one - const existingPlayer = lastAiMessage.querySelector('.audio-player'); - if (existingPlayer) { - lastAiMessage.replaceChild(audioElement, existingPlayer); - } else { - lastAiMessage.appendChild(audioElement); - } - } else { - // Create a new message for the AI response - const aiMessage = document.createElement('div'); - aiMessage.className = 'message ai'; - streamingAudio.messageElement = aiMessage; - - if (streamingAudio.text) { - const textElement = document.createElement('p'); - textElement.textContent = streamingAudio.text; - aiMessage.appendChild(textElement); - } - - aiMessage.appendChild(audioElement); - elements.conversation.appendChild(aiMessage); - } - - // Auto-scroll - elements.conversation.scrollTop = elements.conversation.scrollHeight; - - // If this is the last chunk or we've received all expected chunks - if (data.is_last || streamingAudio.receivedChunks >= streamingAudio.totalChunks) { - streamingAudio.complete = true; - - // Reset stream button if we're still streaming - if (state.isStreaming) { - elements.streamButton.classList.remove('processing'); - elements.streamButton.innerHTML = ' Listening...'; - } - } -} - -// Handle completion of audio streaming -function handleAudioResponseComplete(data) { - console.log('Audio response complete:', data); - streamingAudio.complete = true; - - // Make sure we finalize the audio even if some chunks were missed - finalizeStreamingAudio(); - - // Update UI to normal state - if (state.isStreaming) { - elements.streamButton.innerHTML = ' Listening...'; - elements.streamButton.classList.add('recording'); - elements.streamButton.classList.remove('processing'); - } -} - -// Finalize streaming audio by combining chunks and updating the UI -function finalizeStreamingAudio() { - if (!streamingAudio.messageElement || streamingAudio.chunks.length === 0) { - return; - } - - try { - // For more sophisticated audio streaming, you would need to properly concatenate - // the WAV files, but for now we'll use the last chunk as the complete audio - // since it should contain the entire response due to how the server is implementing it - const lastChunkIndex = streamingAudio.chunks.length - 1; - const audioData = streamingAudio.chunks[lastChunkIndex] || streamingAudio.chunks[0]; - - // Update the audio element with the complete audio - if (streamingAudio.audioElement) { - streamingAudio.audioElement.src = audioData; - - // Auto-play if enabled and not already playing - if (elements.autoPlayResponses && elements.autoPlayResponses.checked && - streamingAudio.audioElement.paused) { - streamingAudio.audioElement.play() - .catch(err => { - console.warn('Auto-play failed:', err); - addSystemMessage('Auto-play failed. Please click play to hear the response.'); - }); - } - } - - // Remove loading indicator and processing class - if (streamingAudio.messageElement) { - const loadingElement = streamingAudio.messageElement.querySelector('.loading-indicator'); - if (loadingElement) { - streamingAudio.messageElement.removeChild(loadingElement); - } - streamingAudio.messageElement.classList.remove('processing'); - } - - console.log('Audio response finalized and ready for playback'); - } catch (e) { - console.error('Error finalizing streaming audio:', e); - } - - // Reset streaming audio state - streamingAudio.chunks = []; - streamingAudio.totalChunks = 0; - streamingAudio.receivedChunks = 0; - streamingAudio.messageElement = null; - streamingAudio.audioElement = null; -} - -// Enhance the handleModelStatusUpdate function: - -function handleModelStatusUpdate(data) { - const { model, status, message, progress } = data; - - if (model === 'overall' && status === 'loading') { - // Update overall loading progress - const progressBar = document.getElementById('modelLoadingProgress'); - if (progressBar) { - progressBar.value = progress; - progressBar.textContent = `${progress}%`; - } - return; - } - - if (status === 'loaded') { - console.log(`Model ${model} loaded successfully`); - addSystemMessage(`${model.toUpperCase()} model loaded successfully`); - - // Update UI to show model is ready - const modelStatusElement = document.getElementById(`${model}Status`); - if (modelStatusElement) { - modelStatusElement.classList.remove('loading'); - modelStatusElement.classList.add('loaded'); - modelStatusElement.title = 'Model loaded successfully'; - } - - // Check if the required models are loaded to enable conversation - checkAllModelsLoaded(); - } else if (status === 'error') { - console.error(`Error loading ${model} model: ${message}`); - addSystemMessage(`Error loading ${model.toUpperCase()} model: ${message}`); - - // Update UI to show model loading failed - const modelStatusElement = document.getElementById(`${model}Status`); - if (modelStatusElement) { - modelStatusElement.classList.remove('loading'); - modelStatusElement.classList.add('error'); - modelStatusElement.title = `Error: ${message}`; - } - } -} - -// Check if all required models are loaded and enable UI accordingly -function checkAllModelsLoaded() { - // When all models are loaded, enable the stream button if it was disabled - const allLoaded = - document.getElementById('csmStatus')?.classList.contains('loaded') && - document.getElementById('asrStatus')?.classList.contains('loaded') && - document.getElementById('llmStatus')?.classList.contains('loaded'); - - if (allLoaded) { - elements.streamButton.disabled = false; - addSystemMessage('All models loaded. Ready for conversation!'); - } -} - -// Add CSS styles for new UI elements -document.addEventListener('DOMContentLoaded', function() { - // Add styles for processing state and timestamps - const style = document.createElement('style'); - style.textContent = ` - .message.processing { - opacity: 0.8; - } - - .loading-indicator { - display: flex; - align-items: center; - margin-top: 8px; - font-size: 0.9em; - color: #666; - } - - .loading-spinner { - width: 16px; - height: 16px; - border: 2px solid #ddd; - border-top: 2px solid var(--primary-color); - border-radius: 50%; - margin-right: 8px; - animation: spin 1s linear infinite; - } - - @keyframes spin { - 0% { transform: rotate(0deg); } - 100% { transform: rotate(360deg); } - } - - /* Timestamp styles */ - .timestamp-toggle { - font-size: 0.75em; - padding: 4px 8px; - margin-top: 8px; - background-color: #f0f0f0; - border: 1px solid #ddd; - border-radius: 4px; - cursor: pointer; - } - - .timestamp-toggle:hover { - background-color: #e0e0e0; - } - - .timestamps-container { - margin-top: 8px; - padding: 8px; - background-color: #f9f9f9; - border-radius: 4px; - font-size: 0.85em; - } - - .timestamp { - margin-bottom: 4px; - padding: 2px 0; - } - - .timestamp .time { - color: #666; - font-family: monospace; - margin-right: 8px; - } - - .timestamp .text { - color: #333; - } - `; - document.head.appendChild(style); -}); - -// Initialize the application when DOM is fully loaded -document.addEventListener('DOMContentLoaded', initializeApp); - diff --git a/React/src/app/auth/session/route.ts b/React/src/app/auth/session/route.ts new file mode 100644 index 0000000..9299d4a --- /dev/null +++ b/React/src/app/auth/session/route.ts @@ -0,0 +1,12 @@ +import { NextResponse } from "next/server"; +import { auth0 } from "../../../lib/auth0"; + +export async function GET() { + try { + const session = await auth0.getSession(); + return NextResponse.json({ session }); + } catch (error) { + console.error("Error getting session:", error); + return NextResponse.json({ session: null }, { status: 500 }); + } +} diff --git a/React/src/app/page.tsx b/React/src/app/page.tsx index c91ed2b..21e0862 100644 --- a/React/src/app/page.tsx +++ b/React/src/app/page.tsx @@ -1,67 +1,94 @@ -import { useState } from "react"; -import { auth0 } from "../lib/auth0"; +"use client"; +import { useState, useEffect } from "react"; +import { useRouter } from "next/navigation"; - -export default async function Home() { - +export default function Home() { const [contacts, setContacts] = useState([]); const [codeword, setCodeword] = useState(""); + const [session, setSession] = useState(null); + const [loading, setLoading] = useState(true); + const router = useRouter(); - const session = await auth0.getSession(); - - console.log("Session:", session?.user); + useEffect(() => { + // Fetch session data from an API route + fetch("/auth/session") + .then((response) => response.json()) + .then((data) => { + setSession(data.session); + setLoading(false); + }) + .catch((error) => { + console.error("Failed to fetch session:", error); + setLoading(false); + }); + }, []); function saveToDB() { - //e.preventDefault(); alert("Saving contacts..."); - // const contactInputs = document.querySelectorAll(".text-input") as NodeListOf; - // const contactValues = Array.from(contactInputs).map(input => input.value); - // console.log("Contact values:", contactValues); - // // save codeword and contacts to database - // fetch("/api/databaseStorage", { - // method: "POST", - // headers: { - // "Content-Type": "application/json", - // }, - // body: JSON.stringify({ - // email: session?.user?.email || "", - // codeword: (document.getElementById("codeword") as HTMLInputElement)?.value, - // contacts: contactValues, - // }), - // }) - // .then((response) => { - // if (response.ok) { - // alert("Contacts saved successfully!"); - // } else { - // alert("Error saving contacts."); - // } - // }) - // .catch((error) => { - // console.error("Error:", error); - // alert("Error saving contacts."); - // }); - + const contactInputs = document.querySelectorAll( + ".text-input" + ) as NodeListOf; + const contactValues = Array.from(contactInputs).map((input) => input.value); + + fetch("/api/databaseStorage", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + email: session?.user?.email || "", + codeword: codeword, + contacts: contactValues, + }), + }) + .then((response) => { + if (response.ok) { + alert("Contacts saved successfully!"); + } else { + alert("Error saving contacts."); + } + }) + .catch((error) => { + console.error("Error:", error); + alert("Error saving contacts."); + }); } + if (loading) { + return
Loading...
; + } // If no session, show sign-up and login buttons - if (!session) { - + if (!session) { return (
- + - +
-

Fauxcall

-

Set emergency contacts

-

If you stop speaking or say the codeword, these contacts will be notified

+

+ Fauxcall +

+

+ Set emergency contacts +

+

+ If you stop speaking or say the codeword, these contacts will be + notified +

{/* form for setting codeword */} -
e.preventDefault()}> + e.preventDefault()} + > + className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2" + type="submit" + > + Set codeword +
{/* form for adding contacts */} -
e.preventDefault()}> + e.preventDefault()} + > - +
); @@ -107,25 +145,42 @@ export default async function Home() {

Welcome, {session.user.name}!

- -

Fauxcall

-

Set emergency contacts

-

If you stop speaking or say the codeword, these contacts will be notified

- {/* form for setting codeword */} -
e.preventDefault()}> - setCodeword(e.target.value)} - placeholder="Codeword" - className="border border-gray-300 rounded-md p-2" - /> - -
- {/* form for adding contacts */} -
e.preventDefault()}> + type="submit" + > + Set codeword + +
+ {/* form for adding contacts */} +
e.preventDefault()} + > - - - + + -
- + className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2" + > + Save + + +

@@ -182,6 +248,4 @@ export default async function Home() {

); - - } From 4d49c2a9872ca6030e26076fb3548010e59abfd1 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 03:46:20 -0400 Subject: [PATCH 26/30] Demo Fixes 9 --- Backend/api/socket_handlers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py index 20513e9..b09a2cf 100644 --- a/Backend/api/socket_handlers.py +++ b/Backend/api/socket_handlers.py @@ -43,6 +43,7 @@ class Conversation: def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE): """Register Socket.IO event handlers""" + # No need for global references, just use the parameters directly @socketio.on('connect') def handle_connect(auth=None): @@ -127,7 +128,7 @@ def process_audio_queue(session_id, q, app, socketio, models, active_conversatio with app.app_context(): if session_id in active_conversations: del active_conversations[session_id] - if session_id in user_queues: + if session_id in user_queues: # Use the passed-in reference del user_queues[session_id] def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE): From e1c66f1f59c53a4fa783f1346a13e266be72a0b1 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 06:56:28 -0400 Subject: [PATCH 27/30] backend restart --- Backend/api/app.py | 136 ----------- Backend/api/routes.py | 74 ------ Backend/api/socket_handlers.py | 393 ------------------------------ Backend/app.py | 229 +++++++++++++++++ Backend/{api => }/generator.py | 0 Backend/index.html | 212 ++++++++++++++++ Backend/{api => }/models.py | 0 Backend/run_csm.py | 117 +++++++++ Backend/server.py | 19 -- Backend/{api => }/watermarking.py | 0 10 files changed, 558 insertions(+), 622 deletions(-) delete mode 100644 Backend/api/app.py delete mode 100644 Backend/api/routes.py delete mode 100644 Backend/api/socket_handlers.py create mode 100644 Backend/app.py rename Backend/{api => }/generator.py (100%) create mode 100644 Backend/index.html rename Backend/{api => }/models.py (100%) create mode 100644 Backend/run_csm.py delete mode 100644 Backend/server.py rename Backend/{api => }/watermarking.py (100%) diff --git a/Backend/api/app.py b/Backend/api/app.py deleted file mode 100644 index 018061f..0000000 --- a/Backend/api/app.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import logging -import threading -from dataclasses import dataclass -from flask import Flask -from flask_socketio import SocketIO -from flask_cors import CORS - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -# Configure device -import torch -if torch.cuda.is_available(): - DEVICE = "cuda" -elif torch.backends.mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" - -logger.info(f"Using device: {DEVICE}") - -# Initialize Flask app -app = Flask(__name__, static_folder='../', static_url_path='') -CORS(app) -socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120) - -# Global variables for conversation state -active_conversations = {} -user_queues = {} -processing_threads = {} - -# Model storage -@dataclass -class AppModels: - generator = None - tokenizer = None - llm = None - whisperx_model = None - whisperx_align_model = None - whisperx_align_metadata = None - last_language = None - -models = AppModels() - -def load_models(): - """Load all required models""" - from generator import load_csm_1b - import whisperx - import gc - from transformers import AutoModelForCausalLM, AutoTokenizer - global models - - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0}) - - # CSM 1B loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'}) - models.generator = load_csm_1b(device=DEVICE) - logger.info("CSM 1B model loaded successfully") - socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)}) - - # WhisperX loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'}) - # Use WhisperX for better transcription with timestamps - # Use compute_type based on device - compute_type = "float16" if DEVICE == "cuda" else "float32" - - # Load the WhisperX model (smaller model for faster processing) - models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type) - - logger.info("WhisperX model loaded successfully") - socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66}) - if DEVICE == "cuda": - torch.cuda.empty_cache() - except Exception as e: - import traceback - error_details = traceback.format_exc() - logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}") - socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)}) - - # Llama loading - try: - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'}) - models.llm = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B", - device_map=DEVICE, - torch_dtype=torch.bfloat16 - ) - models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") - - # Configure all special tokens - models.tokenizer.pad_token = models.tokenizer.eos_token - models.tokenizer.padding_side = "left" # For causal language modeling - - # Inform the model about the pad token - if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None: - models.llm.config.pad_token_id = models.tokenizer.pad_token_id - - logger.info("Llama 3.2 model loaded successfully") - socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'}) - socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'}) - except Exception as e: - logger.error(f"Error loading Llama 3.2 model: {str(e)}") - socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)}) - -# Load models in a background thread -threading.Thread(target=load_models, daemon=True).start() - -# Import routes and socket handlers -from api.routes import register_routes -from api.socket_handlers import register_handlers - -# Register routes and socket handlers -register_routes(app) -register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE) - -# Run server if executed directly -if __name__ == '__main__': - port = int(os.environ.get('PORT', 5000)) - debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - logger.info(f"Starting server on port {port} (debug={debug_mode})") - socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/api/routes.py b/Backend/api/routes.py deleted file mode 100644 index af1bfce..0000000 --- a/Backend/api/routes.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -import psutil -from flask import send_from_directory, jsonify, request - -def register_routes(app): - """Register HTTP routes for the application""" - - @app.route('/') - def index(): - """Serve the main application page""" - return send_from_directory(app.static_folder, 'index.html') - - @app.route('/voice-chat.js') - def serve_js(): - """Serve the JavaScript file""" - return send_from_directory(app.static_folder, 'voice-chat.js') - - @app.route('/api/status') - def system_status(): - """Return the system status""" - # Import here to avoid circular imports - from api.app import models, DEVICE - - return jsonify({ - "status": "ok", - "cuda_available": torch.cuda.is_available(), - "device": DEVICE, - "models": { - "generator": models.generator is not None, - "asr": models.whisperx_model is not None, - "llm": models.llm is not None - }, - "versions": { - "transformers": "4.49.0", # Replace with actual version - "torch": torch.__version__ - } - }) - - @app.route('/api/system_resources') - def system_resources(): - """Return system resource usage""" - # Import here to avoid circular imports - from api.app import active_conversations, DEVICE - - # Get CPU usage - cpu_percent = psutil.cpu_percent(interval=0.1) - - # Get memory usage - memory = psutil.virtual_memory() - memory_used_gb = memory.used / (1024 ** 3) - memory_total_gb = memory.total / (1024 ** 3) - memory_percent = memory.percent - - # Get GPU memory if available - gpu_memory = {} - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - gpu_memory[f"gpu_{i}"] = { - "allocated": torch.cuda.memory_allocated(i) / (1024 ** 3), - "reserved": torch.cuda.memory_reserved(i) / (1024 ** 3), - "max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3) - } - - return jsonify({ - "cpu_percent": cpu_percent, - "memory": { - "used_gb": memory_used_gb, - "total_gb": memory_total_gb, - "percent": memory_percent - }, - "gpu_memory": gpu_memory, - "active_sessions": len(active_conversations) - }) \ No newline at end of file diff --git a/Backend/api/socket_handlers.py b/Backend/api/socket_handlers.py deleted file mode 100644 index b09a2cf..0000000 --- a/Backend/api/socket_handlers.py +++ /dev/null @@ -1,393 +0,0 @@ -import os -import io -import base64 -import time -import threading -import queue -import tempfile -import gc -import logging -import traceback -from typing import Dict, List, Optional - -import torch -import torchaudio -import numpy as np -from flask import request -from flask_socketio import emit - -# Import conversation model -from generator import Segment - -logger = logging.getLogger(__name__) - -# Conversation data structure -class Conversation: - def __init__(self, session_id): - self.session_id = session_id - self.segments: List[Segment] = [] - self.current_speaker = 0 - self.ai_speaker_id = 1 # Default AI speaker ID - self.last_activity = time.time() - self.is_processing = False - - def add_segment(self, text, speaker, audio): - segment = Segment(text=text, speaker=speaker, audio=audio) - self.segments.append(segment) - self.last_activity = time.time() - return segment - - def get_context(self, max_segments=10): - """Return the most recent segments for context""" - return self.segments[-max_segments:] if self.segments else [] - -def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE): - """Register Socket.IO event handlers""" - # No need for global references, just use the parameters directly - - @socketio.on('connect') - def handle_connect(auth=None): - """Handle client connection""" - session_id = request.sid - logger.info(f"Client connected: {session_id}") - - # Initialize conversation data - if session_id not in active_conversations: - active_conversations[session_id] = Conversation(session_id) - user_queues[session_id] = queue.Queue() - processing_threads[session_id] = threading.Thread( - target=process_audio_queue, - args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE), - daemon=True - ) - processing_threads[session_id].start() - - emit('connection_status', {'status': 'connected'}) - - @socketio.on('disconnect') - def handle_disconnect(reason=None): - """Handle client disconnection""" - session_id = request.sid - logger.info(f"Client disconnected: {session_id}. Reason: {reason}") - - # Cleanup - if session_id in active_conversations: - # Mark for deletion rather than immediately removing - # as the processing thread might still be accessing it - active_conversations[session_id].is_processing = False - user_queues[session_id].put(None) # Signal thread to terminate - - @socketio.on('audio_data') - def handle_audio_data(data): - """Handle incoming audio data""" - session_id = request.sid - logger.info(f"Received audio data from {session_id}") - - # Check if the models are loaded - if models.generator is None or models.whisperx_model is None or models.llm is None: - emit('error', {'message': 'Models still loading, please wait'}) - return - - # Check if we're already processing for this session - if session_id in active_conversations and active_conversations[session_id].is_processing: - emit('error', {'message': 'Still processing previous audio, please wait'}) - return - - # Add to processing queue - if session_id in user_queues: - user_queues[session_id].put(data) - else: - emit('error', {'message': 'Session not initialized, please refresh the page'}) - -def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE): - """Background thread to process audio chunks for a session""" - logger.info(f"Started processing thread for session: {session_id}") - - try: - while session_id in active_conversations: - try: - # Get the next audio chunk with a timeout - data = q.get(timeout=120) - if data is None: # Termination signal - break - - # Process the audio and generate a response - process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE) - - except queue.Empty: - # Timeout, check if session is still valid - continue - except Exception as e: - logger.error(f"Error processing audio for {session_id}: {str(e)}") - # Create an app context for the socket emit - with app.app_context(): - socketio.emit('error', {'message': str(e)}, room=session_id) - finally: - logger.info(f"Ending processing thread for session: {session_id}") - # Clean up when thread is done - with app.app_context(): - if session_id in active_conversations: - del active_conversations[session_id] - if session_id in user_queues: # Use the passed-in reference - del user_queues[session_id] - -def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE): - """Process audio data and generate a response using WhisperX""" - if models.generator is None or models.whisperx_model is None or models.llm is None: - logger.warning("Models not yet loaded!") - with app.app_context(): - socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id) - return - - logger.info(f"Processing audio for session {session_id}") - conversation = active_conversations[session_id] - - try: - # Set processing flag - conversation.is_processing = True - - # Process base64 audio data - audio_data = data['audio'] - speaker_id = data['speaker'] - logger.info(f"Received audio from speaker {speaker_id}") - - # Convert from base64 to WAV - try: - audio_bytes = base64.b64decode(audio_data.split(',')[1]) - logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes") - except Exception as e: - logger.error(f"Error decoding base64 audio: {str(e)}") - raise - - # Save to temporary file for processing - with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: - temp_file.write(audio_bytes) - temp_path = temp_file.name - - try: - # Notify client that transcription is starting - with app.app_context(): - socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id) - - # Load audio using WhisperX - import whisperx - audio = whisperx.load_audio(temp_path) - - # Check audio length and add a warning for short clips - audio_length = len(audio) / 16000 # assuming 16kHz sample rate - if audio_length < 1.0: - logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality") - - # Transcribe using WhisperX - batch_size = 16 # adjust based on your GPU memory - logger.info("Running WhisperX transcription...") - - # Handle the warning about audio being shorter than 30s by suppressing it - import warnings - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="audio is shorter than 30s") - result = models.whisperx_model.transcribe(audio, batch_size=batch_size) - - # Get the detected language - language_code = result["language"] - logger.info(f"Detected language: {language_code}") - - # Check if alignment model needs to be loaded or updated - if models.whisperx_align_model is None or language_code != models.last_language: - # Clean up old models if they exist - if models.whisperx_align_model is not None: - del models.whisperx_align_model - del models.whisperx_align_metadata - if DEVICE == "cuda": - gc.collect() - torch.cuda.empty_cache() - - # Load new alignment model for the detected language - logger.info(f"Loading alignment model for language: {language_code}") - models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model( - language_code=language_code, device=DEVICE - ) - models.last_language = language_code - - # Align the transcript to get word-level timestamps - if result["segments"] and len(result["segments"]) > 0: - logger.info("Aligning transcript...") - result = whisperx.align( - result["segments"], - models.whisperx_align_model, - models.whisperx_align_metadata, - audio, - DEVICE, - return_char_alignments=False - ) - - # Process the segments for better output - for segment in result["segments"]: - # Round timestamps for better display - segment["start"] = round(segment["start"], 2) - segment["end"] = round(segment["end"], 2) - # Add a confidence score if not present - if "confidence" not in segment: - segment["confidence"] = 1.0 # Default confidence - - # Extract the full text from all segments - user_text = ' '.join([segment['text'] for segment in result['segments']]) - - # If no text was recognized, don't process further - if not user_text or len(user_text.strip()) == 0: - with app.app_context(): - socketio.emit('error', {'message': 'No speech detected'}, room=session_id) - return - - logger.info(f"Transcription: {user_text}") - - # Load audio for CSM input - waveform, sample_rate = torchaudio.load(temp_path) - - # Normalize to mono if needed - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - - # Resample to the CSM sample rate if needed - if sample_rate != models.generator.sample_rate: - waveform = torchaudio.functional.resample( - waveform, - orig_freq=sample_rate, - new_freq=models.generator.sample_rate - ) - - # Add the user's message to conversation history - user_segment = conversation.add_segment( - text=user_text, - speaker=speaker_id, - audio=waveform.squeeze() - ) - - # Send transcription to client with detailed segments - with app.app_context(): - socketio.emit('transcription', { - 'text': user_text, - 'speaker': speaker_id, - 'segments': result['segments'] # Include the detailed segments with timestamps - }, room=session_id) - - # Generate AI response using Llama - with app.app_context(): - socketio.emit('processing_status', {'status': 'generating'}, room=session_id) - - # Create prompt from conversation history - conversation_history = "" - for segment in conversation.segments[-5:]: # Last 5 segments for context - role = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{role}: {segment.text}\n" - - # Add final prompt - prompt = f"{conversation_history}Assistant: " - - # Generate response with Llama - try: - # Ensure pad token is set - if models.tokenizer.pad_token is None: - models.tokenizer.pad_token = models.tokenizer.eos_token - - input_tokens = models.tokenizer( - prompt, - return_tensors="pt", - padding=True, - return_attention_mask=True - ) - input_ids = input_tokens.input_ids.to(DEVICE) - attention_mask = input_tokens.attention_mask.to(DEVICE) - - with torch.no_grad(): - generated_ids = models.llm.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=100, - temperature=0.7, - top_p=0.9, - do_sample=True, - pad_token_id=models.tokenizer.eos_token_id - ) - - # Decode the response - response_text = models.tokenizer.decode( - generated_ids[0][input_ids.shape[1]:], - skip_special_tokens=True - ).strip() - except Exception as e: - logger.error(f"Error generating response: {str(e)}") - logger.error(traceback.format_exc()) - response_text = "I'm sorry, I encountered an error while processing your request." - - # Synthesize speech - with app.app_context(): - socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id) - - # Start sending the audio response - socketio.emit('audio_response_start', { - 'text': response_text, - 'total_chunks': 1, - 'chunk_index': 0 - }, room=session_id) - - # Define AI speaker ID - ai_speaker_id = conversation.ai_speaker_id - - # Generate audio - audio_tensor = models.generator.generate( - text=response_text, - speaker=ai_speaker_id, - context=conversation.get_context(), - max_audio_length_ms=10_000, - temperature=0.9 - ) - - # Add AI response to conversation history - ai_segment = conversation.add_segment( - text=response_text, - speaker=ai_speaker_id, - audio=audio_tensor - ) - - # Convert audio to WAV format - with io.BytesIO() as wav_io: - torchaudio.save( - wav_io, - audio_tensor.unsqueeze(0).cpu(), - models.generator.sample_rate, - format="wav" - ) - wav_io.seek(0) - wav_data = wav_io.read() - - # Convert WAV data to base64 - audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}" - - # Send audio chunk to client - with app.app_context(): - socketio.emit('audio_response_chunk', { - 'chunk': audio_base64, - 'chunk_index': 0, - 'total_chunks': 1, - 'is_last': True - }, room=session_id) - - # Signal completion - socketio.emit('audio_response_complete', { - 'text': response_text - }, room=session_id) - - finally: - # Clean up temp file - if os.path.exists(temp_path): - os.unlink(temp_path) - - except Exception as e: - logger.error(f"Error processing audio: {str(e)}") - logger.error(traceback.format_exc()) - with app.app_context(): - socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id) - finally: - # Reset processing flag - conversation.is_processing = False \ No newline at end of file diff --git a/Backend/app.py b/Backend/app.py new file mode 100644 index 0000000..091de8e --- /dev/null +++ b/Backend/app.py @@ -0,0 +1,229 @@ +import os +import io +import base64 +import time +import torch +import torchaudio +import numpy as np +from flask import Flask, render_template, request +from flask_socketio import SocketIO, emit +from transformers import AutoModelForCausalLM, AutoTokenizer +import speech_recognition as sr +from generator import load_csm_1b, Segment +from collections import deque + +app = Flask(__name__) +app.config['SECRET_KEY'] = 'your-secret-key' +socketio = SocketIO(app, cors_allowed_origins="*") + +# Select the best available device +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" +print(f"Using device: {device}") + +# Initialize CSM model for audio generation +print("Loading CSM model...") +csm_generator = load_csm_1b(device=device) + +# Initialize Llama 3.2 model for response generation +print("Loading Llama 3.2 model...") +llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources +llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id) +llm_model = AutoModelForCausalLM.from_pretrained( + llm_model_id, + torch_dtype=torch.bfloat16, + device_map=device +) + +# Initialize speech recognition +recognizer = sr.Recognizer() + +# Store conversation context +conversation_context = {} # session_id -> context + +@app.route('/') +def index(): + return render_template('index.html') + +@socketio.on('connect') +def handle_connect(): + print(f"Client connected: {request.sid}") + conversation_context[request.sid] = { + 'segments': [], + 'speakers': [0, 1], # 0 = user, 1 = bot + 'audio_buffer': deque(maxlen=10), # Store recent audio chunks + 'is_speaking': False, + 'silence_start': None + } + emit('ready', {'message': 'Connection established'}) + +@socketio.on('disconnect') +def handle_disconnect(): + print(f"Client disconnected: {request.sid}") + if request.sid in conversation_context: + del conversation_context[request.sid] + +@socketio.on('start_speaking') +def handle_start_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = True + conversation_context[request.sid]['audio_buffer'].clear() + print(f"User {request.sid} started speaking") + +@socketio.on('audio_chunk') +def handle_audio_chunk(data): + if request.sid not in conversation_context: + return + + context = conversation_context[request.sid] + + # Decode audio data + audio_data = base64.b64decode(data['audio']) + audio_numpy = np.frombuffer(audio_data, dtype=np.float32) + audio_tensor = torch.tensor(audio_numpy) + + # Add to buffer + context['audio_buffer'].append(audio_tensor) + + # Check for silence to detect end of speech + if context['is_speaking'] and is_silence(audio_tensor): + if context['silence_start'] is None: + context['silence_start'] = time.time() + elif time.time() - context['silence_start'] > 1.0: # 1 second of silence + # Process the complete utterance + process_user_utterance(request.sid) + else: + context['silence_start'] = None + +@socketio.on('stop_speaking') +def handle_stop_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = False + process_user_utterance(request.sid) + print(f"User {request.sid} stopped speaking") + +def is_silence(audio_tensor, threshold=0.02): + """Check if an audio chunk is silence based on amplitude threshold""" + return torch.mean(torch.abs(audio_tensor)) < threshold + +def process_user_utterance(session_id): + """Process completed user utterance, generate response and send audio back""" + context = conversation_context[session_id] + + if not context['audio_buffer']: + return + + # Combine audio chunks + full_audio = torch.cat(list(context['audio_buffer']), dim=0) + context['audio_buffer'].clear() + context['is_speaking'] = False + context['silence_start'] = None + + # Convert audio to 16kHz for speech recognition + audio_16k = torchaudio.functional.resample( + full_audio, + orig_freq=44100, # Assuming 44.1kHz from client + new_freq=16000 + ) + + # Transcribe speech + try: + # Convert to wav format for speech_recognition + audio_data = io.BytesIO() + torchaudio.save(audio_data, audio_16k.unsqueeze(0), 16000, format="wav") + audio_data.seek(0) + + with sr.AudioFile(audio_data) as source: + audio = recognizer.record(source) + user_text = recognizer.recognize_google(audio) + print(f"Transcribed: {user_text}") + + # Add to conversation segments + user_segment = Segment( + text=user_text, + speaker=0, # User is speaker 0 + audio=full_audio + ) + context['segments'].append(user_segment) + + # Generate bot response + bot_response = generate_llm_response(user_text, context['segments']) + print(f"Bot response: {bot_response}") + + # Convert to audio using CSM + bot_audio = generate_audio_response(bot_response, context['segments']) + + # Convert audio to base64 for sending over websocket + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Add bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=bot_audio + ) + context['segments'].append(bot_segment) + + # Send transcribed text to client + emit('transcription', {'text': user_text}, room=session_id) + + # Send audio response to client + emit('audio_response', { + 'audio': audio_b64, + 'text': bot_response + }, room=session_id) + + except Exception as e: + print(f"Error processing speech: {e}") + emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) + +def generate_llm_response(user_text, conversation_segments): + """Generate text response using Llama 3.2""" + # Format conversation history for the LLM + conversation_history = "" + for segment in conversation_segments[-5:]: # Use last 5 utterances for context + speaker_name = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{speaker_name}: {segment.text}\n" + + # Add the current user query + conversation_history += f"User: {user_text}\nAssistant:" + + # Generate response + inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + output = llm_model.generate( + inputs.input_ids, + max_new_tokens=150, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + + response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return response.strip() + +def generate_audio_response(text, conversation_segments): + """Generate audio response using CSM""" + # Use the last few conversation segments as context + context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + + # Generate audio for bot response + audio = csm_generator.generate( + text=text, + speaker=1, # Bot is speaker 1 + context=context_segments, + max_audio_length_ms=10000, # 10 seconds max + temperature=0.9, + topk=50 + ) + + return audio + +if __name__ == '__main__': + socketio.run(app, host='0.0.0.0', port=5000, debug=True) \ No newline at end of file diff --git a/Backend/api/generator.py b/Backend/generator.py similarity index 100% rename from Backend/api/generator.py rename to Backend/generator.py diff --git a/Backend/index.html b/Backend/index.html new file mode 100644 index 0000000..e1f5f94 --- /dev/null +++ b/Backend/index.html @@ -0,0 +1,212 @@ + + + + + + + Audio Conversation Bot + + + + +

Audio Conversation Bot

+
+
+ +
+
Not connected
+ + + + \ No newline at end of file diff --git a/Backend/api/models.py b/Backend/models.py similarity index 100% rename from Backend/api/models.py rename to Backend/models.py diff --git a/Backend/run_csm.py b/Backend/run_csm.py new file mode 100644 index 0000000..0062973 --- /dev/null +++ b/Backend/run_csm.py @@ -0,0 +1,117 @@ +import os +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from generator import load_csm_1b, Segment +from dataclasses import dataclass + +# Disable Triton compilation +os.environ["NO_TORCH_COMPILE"] = "1" + +# Default prompts are available at https://hf.co/sesame/csm-1b +prompt_filepath_conversational_a = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_a.wav" +) +prompt_filepath_conversational_b = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_b.wav" +) + +SPEAKER_PROMPTS = { + "conversational_a": { + "text": ( + "like revising for an exam I'd have to try and like keep up the momentum because I'd " + "start really early I'd be like okay I'm gonna start revising now and then like " + "you're revising for ages and then I just like start losing steam I didn't do that " + "for the exam we had recently to be fair that was a more of a last minute scenario " + "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " + "sort of start the day with this not like a panic but like a" + ), + "audio": prompt_filepath_conversational_a + }, + "conversational_b": { + "text": ( + "like a super Mario level. Like it's very like high detail. And like, once you get " + "into the park, it just like, everything looks like a computer game and they have all " + "these, like, you know, if, if there's like a, you know, like in a Mario game, they " + "will have like a question block. And if you like, you know, punch it, a coin will " + "come out. So like everyone, when they come into the park, they get like this little " + "bracelet and then you can go punching question blocks around." + ), + "audio": prompt_filepath_conversational_b + } +} + +def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = audio_tensor.squeeze(0) + # Resample is lazy so we can always call it + audio_tensor = torchaudio.functional.resample( + audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate + ) + return audio_tensor + +def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: + audio_tensor = load_prompt_audio(audio_path, sample_rate) + return Segment(text=text, speaker=speaker, audio=audio_tensor) + +def main(): + # Select the best available device, skipping MPS due to float64 limitations + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + print(f"Using device: {device}") + + # Load model + generator = load_csm_1b(device) + + # Prepare prompts + prompt_a = prepare_prompt( + SPEAKER_PROMPTS["conversational_a"]["text"], + 0, + SPEAKER_PROMPTS["conversational_a"]["audio"], + generator.sample_rate + ) + + prompt_b = prepare_prompt( + SPEAKER_PROMPTS["conversational_b"]["text"], + 1, + SPEAKER_PROMPTS["conversational_b"]["audio"], + generator.sample_rate + ) + + # Generate conversation + conversation = [ + {"text": "Hey how are you doing?", "speaker_id": 0}, + {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, + {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, + {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} + ] + + # Generate each utterance + generated_segments = [] + prompt_segments = [prompt_a, prompt_b] + + for utterance in conversation: + print(f"Generating: {utterance['text']}") + audio_tensor = generator.generate( + text=utterance['text'], + speaker=utterance['speaker_id'], + context=prompt_segments + generated_segments, + max_audio_length_ms=10_000, + ) + generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) + + # Concatenate all generations + all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) + torchaudio.save( + "full_conversation.wav", + all_audio.unsqueeze(0).cpu(), + generator.sample_rate + ) + print("Successfully generated full_conversation.wav") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py deleted file mode 100644 index b8af6b7..0000000 --- a/Backend/server.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -CSM Voice Chat Server -A voice chat application that uses CSM 1B for voice synthesis, -WhisperX for speech recognition, and Llama 3.2 for language generation. -""" - -# Start the Flask application -from api.app import app, socketio - -if __name__ == '__main__': - import os - - port = int(os.environ.get('PORT', 5000)) - debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true' - - print(f"Starting server on port {port} (debug={debug_mode})") - print("Visit http://localhost:5000 to access the application") - - socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/Backend/api/watermarking.py b/Backend/watermarking.py similarity index 100% rename from Backend/api/watermarking.py rename to Backend/watermarking.py From a4f282fbcca0332ce5afcc1048ef3b91bc8d3467 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 07:18:14 -0400 Subject: [PATCH 28/30] Demo Update 21 --- Backend/index.html | 346 ++++++++++++++++++++++++++++++++++++++------- Backend/req.txt | 1 + Backend/server.py | 256 +++++++++++++++++++++++++++++++++ 3 files changed, 552 insertions(+), 51 deletions(-) create mode 100644 Backend/req.txt create mode 100644 Backend/server.py diff --git a/Backend/index.html b/Backend/index.html index e1f5f94..359ed41 100644 --- a/Backend/index.html +++ b/Backend/index.html @@ -1,86 +1,225 @@ - - Audio Conversation Bot + Voice Assistant - CSM & Whisper -

Audio Conversation Bot

+

Voice Assistant with CSM & Whisper

+
-
Not connected
+ + + +
Connecting to server...
\ No newline at end of file diff --git a/Backend/req.txt b/Backend/req.txt new file mode 100644 index 0000000..a3edbdc --- /dev/null +++ b/Backend/req.txt @@ -0,0 +1 @@ +pip install faster-whisper \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py new file mode 100644 index 0000000..978b87c --- /dev/null +++ b/Backend/server.py @@ -0,0 +1,256 @@ +import os +import io +import base64 +import time +import torch +import torchaudio +import numpy as np +from flask import Flask, render_template, request +from flask_socketio import SocketIO, emit +from transformers import AutoModelForCausalLM, AutoTokenizer +from faster_whisper import WhisperModel +from generator import load_csm_1b, Segment +from collections import deque + +app = Flask(__name__) +app.config['SECRET_KEY'] = 'your-secret-key' +socketio = SocketIO(app, cors_allowed_origins="*") + +# Select the best available device +if torch.cuda.is_available(): + device = "cuda" + whisper_compute_type = "float16" +elif torch.backends.mps.is_available(): + device = "mps" + whisper_compute_type = "float32" +else: + device = "cpu" + whisper_compute_type = "int8" + +print(f"Using device: {device}") + +# Initialize Faster-Whisper for transcription +print("Loading Whisper model...") +whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type) + +# Initialize CSM model for audio generation +print("Loading CSM model...") +csm_generator = load_csm_1b(device=device) + +# Initialize Llama 3.2 model for response generation +print("Loading Llama 3.2 model...") +llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources +llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id) +llm_model = AutoModelForCausalLM.from_pretrained( + llm_model_id, + torch_dtype=torch.bfloat16, + device_map=device +) + +# Store conversation context +conversation_context = {} # session_id -> context + +@app.route('/') +def index(): + return render_template('index.html') + +@socketio.on('connect') +def handle_connect(): + print(f"Client connected: {request.sid}") + conversation_context[request.sid] = { + 'segments': [], + 'speakers': [0, 1], # 0 = user, 1 = bot + 'audio_buffer': deque(maxlen=10), # Store recent audio chunks + 'is_speaking': False, + 'silence_start': None + } + emit('ready', {'message': 'Connection established'}) + +@socketio.on('disconnect') +def handle_disconnect(): + print(f"Client disconnected: {request.sid}") + if request.sid in conversation_context: + del conversation_context[request.sid] + +@socketio.on('start_speaking') +def handle_start_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = True + conversation_context[request.sid]['audio_buffer'].clear() + print(f"User {request.sid} started speaking") + +@socketio.on('audio_chunk') +def handle_audio_chunk(data): + if request.sid not in conversation_context: + return + + context = conversation_context[request.sid] + + # Decode audio data + audio_data = base64.b64decode(data['audio']) + audio_numpy = np.frombuffer(audio_data, dtype=np.float32) + audio_tensor = torch.tensor(audio_numpy) + + # Add to buffer + context['audio_buffer'].append(audio_tensor) + + # Check for silence to detect end of speech + if context['is_speaking'] and is_silence(audio_tensor): + if context['silence_start'] is None: + context['silence_start'] = time.time() + elif time.time() - context['silence_start'] > 1.0: # 1 second of silence + # Process the complete utterance + process_user_utterance(request.sid) + else: + context['silence_start'] = None + +@socketio.on('stop_speaking') +def handle_stop_speaking(): + if request.sid in conversation_context: + conversation_context[request.sid]['is_speaking'] = False + process_user_utterance(request.sid) + print(f"User {request.sid} stopped speaking") + +def is_silence(audio_tensor, threshold=0.02): + """Check if an audio chunk is silence based on amplitude threshold""" + return torch.mean(torch.abs(audio_tensor)) < threshold + +def process_user_utterance(session_id): + """Process completed user utterance, generate response and send audio back""" + context = conversation_context[session_id] + + if not context['audio_buffer']: + return + + # Combine audio chunks + full_audio = torch.cat(list(context['audio_buffer']), dim=0) + context['audio_buffer'].clear() + context['is_speaking'] = False + context['silence_start'] = None + + # Save audio to temporary WAV file for Whisper transcription + temp_audio_path = f"temp_audio_{session_id}.wav" + torchaudio.save( + temp_audio_path, + full_audio.unsqueeze(0), + 44100 # Assuming 44.1kHz from client + ) + + # Transcribe speech using Faster-Whisper + try: + segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5) + + # Collect all text from segments + user_text = "" + for segment in segments: + segment_text = segment.text.strip() + print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") + user_text += segment_text + " " + + user_text = user_text.strip() + + # Cleanup temp file + if os.path.exists(temp_audio_path): + os.remove(temp_audio_path) + + if not user_text: + print("No speech detected.") + return + + print(f"Transcribed: {user_text}") + + # Add to conversation segments + user_segment = Segment( + text=user_text, + speaker=0, # User is speaker 0 + audio=full_audio + ) + context['segments'].append(user_segment) + + # Generate bot response + bot_response = generate_llm_response(user_text, context['segments']) + print(f"Bot response: {bot_response}") + + # Convert to audio using CSM + bot_audio = generate_audio_response(bot_response, context['segments']) + + # Convert audio to base64 for sending over websocket + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Add bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=bot_audio + ) + context['segments'].append(bot_segment) + + # Send transcribed text to client + emit('transcription', {'text': user_text}, room=session_id) + + # Send audio response to client + emit('audio_response', { + 'audio': audio_b64, + 'text': bot_response + }, room=session_id) + + except Exception as e: + print(f"Error processing speech: {e}") + emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) + # Cleanup temp file in case of error + if os.path.exists(temp_audio_path): + os.remove(temp_audio_path) + +def generate_llm_response(user_text, conversation_segments): + """Generate text response using Llama 3.2""" + # Format conversation history for the LLM + conversation_history = "" + for segment in conversation_segments[-5:]: # Use last 5 utterances for context + speaker_name = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{speaker_name}: {segment.text}\n" + + # Add the current user query + conversation_history += f"User: {user_text}\nAssistant:" + + # Generate response + inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + output = llm_model.generate( + inputs.input_ids, + max_new_tokens=150, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + + response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return response.strip() + +def generate_audio_response(text, conversation_segments): + """Generate audio response using CSM""" + # Use the last few conversation segments as context + context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + + # Generate audio for bot response + audio = csm_generator.generate( + text=text, + speaker=1, # Bot is speaker 1 + context=context_segments, + max_audio_length_ms=10000, # 10 seconds max + temperature=0.9, + topk=50 + ) + + return audio + +if __name__ == '__main__': + # Ensure the existing index.html file is in the correct location + if not os.path.exists('templates'): + os.makedirs('templates') + + if os.path.exists('index.html') and not os.path.exists('templates/index.html'): + os.rename('index.html', 'templates/index.html') + + socketio.run(app, host='0.0.0.0', port=5000, debug=False) \ No newline at end of file From 263127ed18e541bb8815ff3cc1ed53eaefb535e2 Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 07:23:39 -0400 Subject: [PATCH 29/30] Demo Fixes 10 --- Backend/server.py | 293 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 203 insertions(+), 90 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 978b87c..352f5cd 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -8,9 +8,17 @@ import numpy as np from flask import Flask, render_template, request from flask_socketio import SocketIO, emit from transformers import AutoModelForCausalLM, AutoTokenizer -from faster_whisper import WhisperModel -from generator import load_csm_1b, Segment from collections import deque +import requests +import huggingface_hub +from generator import load_csm_1b, Segment + +# Configure environment with longer timeouts +os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "600" # 10 minutes timeout for downloads +requests.adapters.DEFAULT_TIMEOUT = 60 # Increase default requests timeout + +# Create a models directory for caching +os.makedirs("models", exist_ok=True) app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key' @@ -29,23 +37,50 @@ else: print(f"Using device: {device}") -# Initialize Faster-Whisper for transcription -print("Loading Whisper model...") -whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type) +# Initialize models with proper error handling +whisper_model = None +csm_generator = None +llm_model = None +llm_tokenizer = None -# Initialize CSM model for audio generation -print("Loading CSM model...") -csm_generator = load_csm_1b(device=device) - -# Initialize Llama 3.2 model for response generation -print("Loading Llama 3.2 model...") -llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources -llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id) -llm_model = AutoModelForCausalLM.from_pretrained( - llm_model_id, - torch_dtype=torch.bfloat16, - device_map=device -) +def load_models(): + global whisper_model, csm_generator, llm_model, llm_tokenizer + + # Initialize Faster-Whisper for transcription + try: + print("Loading Whisper model...") + # Import here to avoid immediate import errors if package is missing + from faster_whisper import WhisperModel + whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper") + print("Whisper model loaded successfully") + except Exception as e: + print(f"Error loading Whisper model: {e}") + print("Will use backup speech recognition method if available") + + # Initialize CSM model for audio generation + try: + print("Loading CSM model...") + csm_generator = load_csm_1b(device=device) + print("CSM model loaded successfully") + except Exception as e: + print(f"Error loading CSM model: {e}") + print("Audio generation will not be available") + + # Initialize Llama 3.2 model for response generation + try: + print("Loading Llama 3.2 model...") + llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources + llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama") + llm_model = AutoModelForCausalLM.from_pretrained( + llm_model_id, + torch_dtype=torch.bfloat16, + device_map=device, + cache_dir="./models/llama" + ) + print("Llama 3.2 model loaded successfully") + except Exception as e: + print(f"Error loading Llama 3.2 model: {e}") + print("Will use a fallback response generation method") # Store conversation context conversation_context = {} # session_id -> context @@ -128,7 +163,7 @@ def process_user_utterance(session_id): context['is_speaking'] = False context['silence_start'] = None - # Save audio to temporary WAV file for Whisper transcription + # Save audio to temporary WAV file for transcription temp_audio_path = f"temp_audio_{session_id}.wav" torchaudio.save( temp_audio_path, @@ -136,25 +171,17 @@ def process_user_utterance(session_id): 44100 # Assuming 44.1kHz from client ) - # Transcribe speech using Faster-Whisper try: - segments, info = whisper_model.transcribe(temp_audio_path, beam_size=5) + # Try using Whisper first if available + if whisper_model is not None: + user_text = transcribe_with_whisper(temp_audio_path) + else: + # Fallback to Google's speech recognition + user_text = transcribe_with_google(temp_audio_path) - # Collect all text from segments - user_text = "" - for segment in segments: - segment_text = segment.text.strip() - print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") - user_text += segment_text + " " - - user_text = user_text.strip() - - # Cleanup temp file - if os.path.exists(temp_audio_path): - os.remove(temp_audio_path) - if not user_text: print("No speech detected.") + emit('error', {'message': 'No speech detected. Please try again.'}, room=session_id) return print(f"Transcribed: {user_text}") @@ -171,79 +198,158 @@ def process_user_utterance(session_id): bot_response = generate_llm_response(user_text, context['segments']) print(f"Bot response: {bot_response}") - # Convert to audio using CSM - bot_audio = generate_audio_response(bot_response, context['segments']) - - # Convert audio to base64 for sending over websocket - audio_bytes = io.BytesIO() - torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") - audio_bytes.seek(0) - audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') - - # Add bot response to conversation history - bot_segment = Segment( - text=bot_response, - speaker=1, # Bot is speaker 1 - audio=bot_audio - ) - context['segments'].append(bot_segment) - # Send transcribed text to client emit('transcription', {'text': user_text}, room=session_id) - # Send audio response to client - emit('audio_response', { - 'audio': audio_b64, - 'text': bot_response - }, room=session_id) + # Generate and send audio response if CSM is available + if csm_generator is not None: + # Convert to audio using CSM + bot_audio = generate_audio_response(bot_response, context['segments']) + + # Convert audio to base64 for sending over websocket + audio_bytes = io.BytesIO() + torchaudio.save(audio_bytes, bot_audio.unsqueeze(0).cpu(), csm_generator.sample_rate, format="wav") + audio_bytes.seek(0) + audio_b64 = base64.b64encode(audio_bytes.read()).decode('utf-8') + + # Add bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=bot_audio + ) + context['segments'].append(bot_segment) + + # Send audio response to client + emit('audio_response', { + 'audio': audio_b64, + 'text': bot_response + }, room=session_id) + else: + # Send text-only response if audio generation isn't available + emit('text_response', {'text': bot_response}, room=session_id) + + # Add text-only bot response to conversation history + bot_segment = Segment( + text=bot_response, + speaker=1, # Bot is speaker 1 + audio=torch.zeros(1) # Placeholder empty audio + ) + context['segments'].append(bot_segment) except Exception as e: print(f"Error processing speech: {e}") emit('error', {'message': f'Error processing speech: {str(e)}'}, room=session_id) - # Cleanup temp file in case of error + finally: + # Cleanup temp file if os.path.exists(temp_audio_path): os.remove(temp_audio_path) +def transcribe_with_whisper(audio_path): + """Transcribe audio using Faster-Whisper""" + segments, info = whisper_model.transcribe(audio_path, beam_size=5) + + # Collect all text from segments + user_text = "" + for segment in segments: + segment_text = segment.text.strip() + print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment_text}") + user_text += segment_text + " " + + return user_text.strip() + +def transcribe_with_google(audio_path): + """Fallback transcription using Google's speech recognition""" + import speech_recognition as sr + recognizer = sr.Recognizer() + + with sr.AudioFile(audio_path) as source: + audio = recognizer.record(source) + try: + text = recognizer.recognize_google(audio) + return text + except sr.UnknownValueError: + return "" + except sr.RequestError: + # If Google API fails, try a basic energy-based VAD approach + # This is a very basic fallback and won't give good results + return "[Speech detected but transcription failed]" + def generate_llm_response(user_text, conversation_segments): - """Generate text response using Llama 3.2""" - # Format conversation history for the LLM - conversation_history = "" - for segment in conversation_segments[-5:]: # Use last 5 utterances for context - speaker_name = "User" if segment.speaker == 0 else "Assistant" - conversation_history += f"{speaker_name}: {segment.text}\n" + """Generate text response using available model""" + if llm_model is not None and llm_tokenizer is not None: + # Format conversation history for the LLM + conversation_history = "" + for segment in conversation_segments[-5:]: # Use last 5 utterances for context + speaker_name = "User" if segment.speaker == 0 else "Assistant" + conversation_history += f"{speaker_name}: {segment.text}\n" + + # Add the current user query + conversation_history += f"User: {user_text}\nAssistant:" + + try: + # Generate response + inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) + output = llm_model.generate( + inputs.input_ids, + max_new_tokens=150, + temperature=0.7, + top_p=0.9, + do_sample=True + ) + + response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + return response.strip() + except Exception as e: + print(f"Error generating response with LLM: {e}") + return fallback_response(user_text) + else: + return fallback_response(user_text) + +def fallback_response(user_text): + """Generate a simple fallback response when LLM is not available""" + # Simple rule-based responses + user_text_lower = user_text.lower() - # Add the current user query - conversation_history += f"User: {user_text}\nAssistant:" + if "hello" in user_text_lower or "hi" in user_text_lower: + return "Hello! I'm a simple fallback assistant. The main language model couldn't be loaded, so I have limited capabilities." - # Generate response - inputs = llm_tokenizer(conversation_history, return_tensors="pt").to(device) - output = llm_model.generate( - inputs.input_ids, - max_new_tokens=150, - temperature=0.7, - top_p=0.9, - do_sample=True - ) + elif "how are you" in user_text_lower: + return "I'm functioning within my limited capabilities. How can I assist you today?" - response = llm_tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) - return response.strip() + elif "thank" in user_text_lower: + return "You're welcome! Let me know if there's anything else I can help with." + + elif "bye" in user_text_lower or "goodbye" in user_text_lower: + return "Goodbye! Have a great day!" + + elif any(q in user_text_lower for q in ["what", "who", "where", "when", "why", "how"]): + return "I'm running in fallback mode and can't answer complex questions. Please try again when the main language model is available." + + else: + return "I understand you said something about that. Unfortunately, I'm running in fallback mode with limited capabilities. Please try again later when the main model is available." def generate_audio_response(text, conversation_segments): """Generate audio response using CSM""" - # Use the last few conversation segments as context - context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments - - # Generate audio for bot response - audio = csm_generator.generate( - text=text, - speaker=1, # Bot is speaker 1 - context=context_segments, - max_audio_length_ms=10000, # 10 seconds max - temperature=0.9, - topk=50 - ) - - return audio + try: + # Use the last few conversation segments as context + context_segments = conversation_segments[-4:] if len(conversation_segments) > 4 else conversation_segments + + # Generate audio for bot response + audio = csm_generator.generate( + text=text, + speaker=1, # Bot is speaker 1 + context=context_segments, + max_audio_length_ms=10000, # 10 seconds max + temperature=0.9, + topk=50 + ) + + return audio + except Exception as e: + print(f"Error generating audio: {e}") + # Return silence as fallback + return torch.zeros(csm_generator.sample_rate * 3) # 3 seconds of silence if __name__ == '__main__': # Ensure the existing index.html file is in the correct location @@ -253,4 +359,11 @@ if __name__ == '__main__': if os.path.exists('index.html') and not os.path.exists('templates/index.html'): os.rename('index.html', 'templates/index.html') + # Load models asynchronously before starting the server + print("Starting model loading...") + # In a production environment, you could load models in a separate thread + load_models() + + # Start the server + print("Starting Flask SocketIO server...") socketio.run(app, host='0.0.0.0', port=5000, debug=False) \ No newline at end of file From 30388d816f382e6808e2dd2830217d13f13a57be Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sun, 30 Mar 2025 07:30:23 -0400 Subject: [PATCH 30/30] Demo Fixes 11 --- Backend/server.py | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 352f5cd..9e98d60 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -24,14 +24,26 @@ app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key' socketio = SocketIO(app, cors_allowed_origins="*") -# Select the best available device -if torch.cuda.is_available(): - device = "cuda" - whisper_compute_type = "float16" -elif torch.backends.mps.is_available(): - device = "mps" - whisper_compute_type = "float32" -else: +# Check for CUDA availability and handle potential CUDA/cuDNN issues +try: + cuda_available = torch.cuda.is_available() + # Try to initialize CUDA to check if libraries are properly loaded + if cuda_available: + _ = torch.zeros(1).cuda() + device = "cuda" + whisper_compute_type = "float16" + print("CUDA is available and initialized successfully") + elif torch.backends.mps.is_available(): + device = "mps" + whisper_compute_type = "float32" + print("MPS is available (Apple Silicon)") + else: + device = "cpu" + whisper_compute_type = "int8" + print("Using CPU (CUDA/MPS not available)") +except Exception as e: + print(f"Error initializing CUDA: {e}") + print("Falling back to CPU") device = "cpu" whisper_compute_type = "int8" @@ -51,7 +63,9 @@ def load_models(): print("Loading Whisper model...") # Import here to avoid immediate import errors if package is missing from faster_whisper import WhisperModel - whisper_model = WhisperModel("base", device=device, compute_type=whisper_compute_type, download_root="./models/whisper") + # Force CPU for Whisper if we had CUDA issues + whisper_device = device if device != "cpu" else "cpu" + whisper_model = WhisperModel("base", device=whisper_device, compute_type=whisper_compute_type, download_root="./models/whisper") print("Whisper model loaded successfully") except Exception as e: print(f"Error loading Whisper model: {e}") @@ -60,7 +74,9 @@ def load_models(): # Initialize CSM model for audio generation try: print("Loading CSM model...") - csm_generator = load_csm_1b(device=device) + # Force CPU for CSM if we had CUDA issues + csm_device = device if device != "cpu" else "cpu" + csm_generator = load_csm_1b(device=csm_device) print("CSM model loaded successfully") except Exception as e: print(f"Error loading CSM model: {e}") @@ -71,11 +87,14 @@ def load_models(): print("Loading Llama 3.2 model...") llm_model_id = "meta-llama/Llama-3.2-1B" # Choose appropriate size based on resources llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id, cache_dir="./models/llama") + # Force CPU for LLM if we had CUDA issues + llm_device = device if device != "cpu" else "cpu" llm_model = AutoModelForCausalLM.from_pretrained( llm_model_id, - torch_dtype=torch.bfloat16, - device_map=device, - cache_dir="./models/llama" + torch_dtype=torch.bfloat16 if llm_device != "cpu" else torch.float32, + device_map=llm_device, + cache_dir="./models/llama", + low_cpu_mem_usage=True ) print("Llama 3.2 model loaded successfully") except Exception as e: