This commit is contained in:
Surya Vemulapalli
2025-03-30 01:29:42 -04:00
28 changed files with 723 additions and 3081 deletions

View File

@@ -1,154 +1,71 @@
# CSM # csm-conversation-bot
**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b). ## Overview
The CSM Conversation Bot is an application that utilizes advanced audio processing and language model technologies to facilitate real-time voice conversations with an AI assistant. The bot processes audio streams, converts spoken input into text, generates responses using the Llama 3.2 model, and converts the text back into audio for seamless interaction.
--- ## Project Structure
```
CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. csm-conversation-bot
├── api
A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). │ ├── app.py # Main entry point for the API
│ ├── routes.py # Defines API routes
A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation. │ └── socket_handlers.py # Manages Socket.IO events
├── src
## Requirements │ ├── audio
│ │ ├── processor.py # Audio processing functions
* A CUDA-compatible GPU │ │ └── streaming.py # Audio streaming management
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions │ ├── llm
* Similarly, Python 3.10 is recommended, but newer versions may be fine │ │ ├── generator.py # Response generation using Llama 3.2
* For some audio operations, `ffmpeg` may be required │ │ └── tokenizer.py # Text tokenization functions
* Access to the following Hugging Face models: │ ├── models
* [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) │ ├── audio_model.py # Audio processing model
* [CSM-1B](https://huggingface.co/sesame/csm-1b) │ └── conversation.py # Conversation state management
│ ├── services
### Setup │ │ ├── transcription_service.py # Audio to text conversion
│ │ └── tts_service.py # Text to speech conversion
```bash │ └── utils
git clone git@github.com:SesameAILabs/csm.git │ ├── config.py # Configuration settings
cd csm │ └── logger.py # Logging utilities
python3.10 -m venv .venv ├── static
source .venv/bin/activate │ ├── css
pip install -r requirements.txt │ │ └── styles.css # CSS styles for the web interface
│ ├── js
# Disable lazy compilation in Mimi │ │ └── client.js # Client-side JavaScript
export NO_TORCH_COMPILE=1 │ └── index.html # Main HTML file for the web interface
├── templates
# You will need access to CSM-1B and Llama-3.2-1B │ └── index.html # Template for rendering the main HTML page
huggingface-cli login ├── config.py # Main configuration settings
├── requirements.txt # Python dependencies
├── server.py # Entry point for running the application
└── README.md # Documentation for the project
``` ```
### Windows Setup ## Installation
1. Clone the repository:
```
git clone https://github.com/yourusername/csm-conversation-bot.git
cd csm-conversation-bot
```
The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`. 2. Install the required dependencies:
```
pip install -r requirements.txt
```
## Quickstart 3. Configure the application settings in `config.py` as needed.
This script will generate a conversation between 2 characters, using a prompt for each character.
```bash
python run_csm.py
```
## Usage ## Usage
1. Start the server:
```
python server.py
```
If you want to write your own applications with CSM, the following examples show basic usage. 2. Open your web browser and navigate to `http://localhost:5000` to access the application.
#### Generate a sentence 3. Use the interface to start a conversation with the AI assistant.
This will use a random speaker identity, as no prompt or context is provided. ## Contributing
Contributions are welcome! Please submit a pull request or open an issue for any enhancements or bug fixes.
```python ## License
from generator import load_csm_1b This project is licensed under the MIT License. See the LICENSE file for more details.
import torchaudio
import torch
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
generator = load_csm_1b(device=device)
audio = generator.generate(
text="Hello from Sesame.",
speaker=0,
context=[],
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
#### Generate with context
CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance.
NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM.
```python
from generator import Segment
speakers = [0, 1, 0, 0]
transcripts = [
"Hey how are you doing.",
"Pretty good, pretty good.",
"I'm great.",
"So happy to be speaking to you.",
]
audio_paths = [
"utterance_0.wav",
"utterance_1.wav",
"utterance_2.wav",
"utterance_3.wav",
]
def load_audio(audio_path):
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
)
return audio_tensor
segments = [
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
]
audio = generator.generate(
text="Me too, this is some cool stuff huh?",
speaker=1,
context=segments,
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
## FAQ
**Does this model come with any voices?**
The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
**Can I converse with the model?**
CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
**Does it support other languages?**
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
## Misuse and abuse ⚠️
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
---
## Authors
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.

22
Backend/api/app.py Normal file
View File

@@ -0,0 +1,22 @@
from flask import Flask
from flask_socketio import SocketIO
from src.utils.config import Config
from src.utils.logger import setup_logger
from api.routes import setup_routes
from api.socket_handlers import setup_socket_handlers
def create_app():
app = Flask(__name__)
app.config.from_object(Config)
setup_logger(app)
setup_routes(app)
setup_socket_handlers(app)
return app
app = create_app()
socketio = SocketIO(app)
if __name__ == "__main__":
socketio.run(app, host='0.0.0.0', port=5000)

29
Backend/api/routes.py Normal file
View File

@@ -0,0 +1,29 @@
from flask import Blueprint, request, jsonify
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
api = Blueprint('api', __name__)
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
@api.route('/transcribe', methods=['POST'])
def transcribe_audio():
audio_data = request.files.get('audio')
if not audio_data:
return jsonify({'error': 'No audio file provided'}), 400
text = transcription_service.transcribe(audio_data)
return jsonify({'transcription': text})
@api.route('/generate-response', methods=['POST'])
def generate_response():
data = request.json
user_input = data.get('input')
if not user_input:
return jsonify({'error': 'No input provided'}), 400
response_text = tts_service.generate_response(user_input)
audio_data = tts_service.text_to_speech(response_text)
return jsonify({'response': response_text, 'audio': audio_data})

View File

@@ -0,0 +1,32 @@
from flask import request
from flask_socketio import SocketIO, emit
from src.audio.processor import process_audio
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
from src.llm.generator import load_csm_1b
socketio = SocketIO()
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
generator = load_csm_1b()
@socketio.on('audio_stream')
def handle_audio_stream(data):
audio_data = data['audio']
speaker_id = data['speaker']
# Process the incoming audio
processed_audio = process_audio(audio_data)
# Transcribe the audio to text
transcription = transcription_service.transcribe(processed_audio)
# Generate a response using the LLM
response_text = generator.generate(text=transcription, speaker=speaker_id)
# Convert the response text back to audio
response_audio = tts_service.convert_text_to_speech(response_text)
# Emit the response audio back to the client
emit('audio_response', {'audio': response_audio, 'speaker': speaker_id})

13
Backend/config.py Normal file
View File

@@ -0,0 +1,13 @@
from pathlib import Path
class Config:
def __init__(self):
self.MODEL_PATH = Path("path/to/your/model")
self.AUDIO_MODEL_PATH = Path("path/to/your/audio/model")
self.WATERMARK_KEY = "your_watermark_key"
self.SOCKETIO_CORS = "*"
self.API_KEY = "your_api_key"
self.DEBUG = True
self.LOGGING_LEVEL = "INFO"
self.TTS_SERVICE_URL = "http://localhost:5001/tts"
self.TRANSCRIPTION_SERVICE_URL = "http://localhost:5002/transcribe"

View File

@@ -1,492 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Sesame AI Voice Chat</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
<!-- Socket.IO client library -->
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<style>
:root {
--primary-color: #4c84ff;
--secondary-color: #3367d6;
--text-color: #333;
--background-color: #f9f9f9;
--card-background: #ffffff;
--accent-color: #ff5252;
--success-color: #4CAF50;
--border-color: #e0e0e0;
--shadow-color: rgba(0, 0, 0, 0.1);
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: var(--background-color);
color: var(--text-color);
line-height: 1.6;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
transition: all 0.3s ease;
}
header {
text-align: center;
margin-bottom: 30px;
}
h1 {
color: var(--primary-color);
font-size: 2.5rem;
margin-bottom: 10px;
}
.subtitle {
color: #666;
font-weight: 300;
}
.app-container {
display: grid;
grid-template-columns: 1fr;
gap: 20px;
}
@media (min-width: 768px) {
.app-container {
grid-template-columns: 1fr 1fr;
}
}
.chat-container, .control-panel {
background-color: var(--card-background);
border-radius: 12px;
box-shadow: 0 4px 12px var(--shadow-color);
padding: 20px;
}
.control-panel {
display: flex;
flex-direction: column;
gap: 20px;
}
.chat-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 1px solid var(--border-color);
}
.conversation {
height: 400px;
overflow-y: auto;
padding: 10px;
border-radius: 8px;
background-color: #f7f9fc;
margin-bottom: 20px;
scroll-behavior: smooth;
}
.message {
margin-bottom: 15px;
padding: 12px 15px;
border-radius: 12px;
max-width: 85%;
position: relative;
animation: fade-in 0.3s ease-out forwards;
}
@keyframes fade-in {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
.user {
background-color: #e3f2fd;
color: #0d47a1;
margin-left: auto;
border-bottom-right-radius: 4px;
}
.ai {
background-color: #f1f1f1;
color: #37474f;
margin-right: auto;
border-bottom-left-radius: 4px;
}
.system {
background-color: #f8f9fa;
font-style: italic;
color: #666;
text-align: center;
max-width: 90%;
margin: 10px auto;
font-size: 0.9em;
padding: 8px 12px;
border-radius: 8px;
}
.message-time {
font-size: 0.7em;
color: #999;
position: absolute;
bottom: 5px;
right: 10px;
}
.audio-player {
width: 100%;
margin-top: 8px;
border-radius: 8px;
}
.visualizer-section {
margin-bottom: 20px;
}
.visualizer-container {
height: 150px;
background-color: #000;
border-radius: 12px;
overflow: hidden;
position: relative;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
.visualizer-label {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
color: rgba(255, 255, 255, 0.7);
font-size: 1rem;
text-align: center;
pointer-events: none;
transition: opacity 0.3s ease;
z-index: 1;
}
#audioVisualizer {
width: 100%;
height: 100%;
display: block;
}
.controls {
display: flex;
gap: 15px;
flex-wrap: wrap;
}
.control-group {
flex: 1;
min-width: 200px;
}
.control-label {
font-weight: 600;
margin-bottom: 10px;
color: #555;
}
.button-row {
display: flex;
gap: 10px;
margin-top: 15px;
}
button {
padding: 12px 20px;
border-radius: 8px;
border: none;
background-color: var(--primary-color);
color: white;
font-weight: 600;
cursor: pointer;
transition: all 0.2s ease;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
flex: 1;
}
button:hover {
background-color: var(--secondary-color);
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
button:active {
transform: translateY(0);
}
button.recording {
background-color: var(--accent-color);
animation: pulse 1.5s infinite;
}
button.processing {
background-color: #ff9800;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.8; }
100% { opacity: 1; }
}
select, .slider-container {
width: 100%;
padding: 10px;
border-radius: 8px;
border: 1px solid var(--border-color);
background-color: white;
margin-bottom: 15px;
}
.slider-container {
display: flex;
flex-direction: column;
gap: 5px;
}
.slider-label {
display: flex;
justify-content: space-between;
}
input[type="range"] {
width: 100%;
cursor: pointer;
}
.volume-indicator {
height: 30px;
background: linear-gradient(to right, #4CAF50, #FFEB3B, #F44336);
border-radius: 4px;
margin-top: 5px;
position: relative;
overflow: hidden;
}
.volume-level {
height: 100%;
width: 0%;
background-color: rgba(0, 0, 0, 0.5);
position: absolute;
right: 0;
top: 0;
transition: width 0.1s ease;
}
.status-indicator {
display: flex;
align-items: center;
gap: 8px;
padding: 10px;
border-radius: 8px;
background-color: #f5f5f5;
margin-top: 20px;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
background-color: #ccc;
transition: background-color 0.3s ease;
}
.status-dot.active {
background-color: var(--success-color);
}
.status-text {
font-size: 0.9em;
color: #666;
}
/* Custom Scrollbar */
.conversation::-webkit-scrollbar {
width: 8px;
}
.conversation::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 10px;
}
.conversation::-webkit-scrollbar-thumb {
background: #ccc;
border-radius: 10px;
}
.conversation::-webkit-scrollbar-thumb:hover {
background: #aaa;
}
/* Settings Panel */
.settings-panel {
margin-top: 20px;
}
.settings-toggles {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(150px, 1fr));
gap: 10px;
margin-top: 10px;
}
.toggle-switch {
display: flex;
align-items: center;
}
.toggle-switch input {
opacity: 0;
width: 0;
height: 0;
}
.toggle-switch label {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
background-color: #ccc;
border-radius: 34px;
transition: .4s;
margin-right: 10px;
cursor: pointer;
}
.toggle-switch label:before {
position: absolute;
content: "";
height: 16px;
width: 16px;
left: 4px;
bottom: 4px;
background-color: white;
transition: .4s;
border-radius: 50%;
}
.toggle-switch input:checked + label {
background-color: var(--primary-color);
}
.toggle-switch input:checked + label:before {
transform: translateX(26px);
}
footer {
text-align: center;
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid var(--border-color);
color: #888;
font-size: 0.9em;
}
</style>
</head>
<body>
<header>
<h1>Sesame AI Voice Chat</h1>
<p class="subtitle">Speak naturally and have a conversation with AI</p>
</header>
<div class="app-container">
<div class="chat-container">
<div class="chat-header">
<h2>Conversation</h2>
<button id="clearButton" class="small-button">
<i class="fas fa-trash"></i> Clear Chat
</button>
</div>
<div class="conversation" id="conversation"></div>
</div>
<div class="control-panel">
<div class="visualizer-section">
<h3>Audio Visualizer</h3>
<div class="visualizer-container">
<canvas id="audioVisualizer"></canvas>
<div id="visualizerLabel" class="visualizer-label">Speak to see audio visualization</div>
</div>
</div>
<div class="controls">
<div class="control-group">
<div class="control-label">Voice Settings</div>
<select id="speakerSelect">
<option value="0">Speaker 0 (You)</option>
<option value="1">Speaker 1 (AI)</option>
</select>
<div class="slider-container">
<div class="slider-label">
<span>Silence Threshold</span>
<span id="thresholdValue">0.01</span>
</div>
<input type="range" id="thresholdSlider" min="0.001" max="0.1" step="0.001" value="0.01">
</div>
<div class="volume-indicator">
<div id="volumeLevel" class="volume-level"></div>
</div>
</div>
<div class="control-group">
<div class="control-label">Conversation Controls</div>
<div class="button-row">
<button id="streamButton" class="main-button">
<i class="fas fa-microphone"></i> Start Conversation
</button>
</div>
</div>
</div>
<div class="settings-panel">
<div class="control-label">Settings</div>
<div class="settings-toggles">
<div class="toggle-switch">
<input type="checkbox" id="autoPlayResponses" checked>
<label for="autoPlayResponses"></label>
<span>Auto-play responses</span>
</div>
<div class="toggle-switch">
<input type="checkbox" id="showVisualizer" checked>
<label for="showVisualizer"></label>
<span>Show visualizer</span>
</div>
</div>
</div>
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<div class="status-text" id="statusText">Not connected</div>
</div>
</div>
</div>
<footer>
<p>Powered by Sesame AI | WhisperX for speech recognition</p>
</footer>
<!-- Load our JavaScript file -->
<script src="./voice-chat.js"></script>
</body>
</html>

View File

@@ -1,203 +0,0 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
import torchtune
from huggingface_hub import PyTorchModelHubMixin
from torchtune.models import llama3_2
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=4,
num_heads=8,
num_kv_heads=2,
embed_dim=1024,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
FLAVORS = {
"llama-1B": llama3_2_1B,
"llama-100M": llama3_2_100M,
}
def _prepare_transformer(model):
embed_dim = model.tok_embeddings.embedding_dim
model.tok_embeddings = nn.Identity()
model.output = nn.Identity()
return model, embed_dim
def _create_causal_mask(seq_len: int, device: torch.device):
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
"""
Args:
mask: (max_seq_len, max_seq_len)
input_pos: (batch_size, seq_len)
Returns:
(batch_size, seq_len, max_seq_len)
"""
r = mask[input_pos, :]
return r
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
logits = logits / temperature
filter_value: float = -float("Inf")
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
@dataclass
class ModelArgs:
backbone_flavor: str
decoder_flavor: str
text_vocab_size: int
audio_vocab_size: int
audio_num_codebooks: int
class Model(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/SesameAILabs/csm",
pipeline_tag="text-to-speech",
license="apache-2.0",
):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
def generate_frame(
self,
tokens: torch.Tensor,
tokens_mask: torch.Tensor,
input_pos: torch.Tensor,
temperature: float,
topk: int,
) -> torch.Tensor:
"""
Args:
tokens: (batch_size, seq_len, audio_num_codebooks+1)
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
input_pos: (batch_size, seq_len) positions for each token
mask: (batch_size, seq_len, max_seq_len
Returns:
(batch_size, audio_num_codebooks) sampled tokens
"""
dtype = next(self.parameters()).dtype
b, s, _ = tokens.size()
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
c0_sample = sample_topk(c0_logits, topk, temperature)
c0_embed = self._embed_audio(0, c0_sample)
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
dtype=dtype
)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, topk, temperature)
ci_embed = self._embed_audio(i, ci_sample)
curr_h = ci_embed
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
curr_pos = curr_pos[:, -1:] + 1
return curr_sample
def reset_caches(self):
self.backbone.reset_caches()
self.decoder.reset_caches()
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
)
return torch.cat([audio_embeds, text_embeds], dim=-2)

View File

@@ -1,9 +1,16 @@
torch==2.4.0 Flask==2.2.2
torchaudio==2.4.0 Flask-SocketIO==5.3.2
tokenizers==0.21.0 torch>=2.0.0
transformers==4.49.0 torchaudio>=2.0.0
huggingface_hub==0.28.1 transformers>=4.30.0
moshi==0.2.2 huggingface-hub>=0.14.0
torchtune==0.4.0 python-dotenv==0.19.2
torchao==0.9.0 numpy>=1.21.6
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master scipy>=1.7.3
soundfile==0.10.3.post1
requests==2.28.1
pydub==0.25.1
python-socketio==5.7.2
eventlet==0.33.3
whisper>=20230314
ffmpeg-python>=0.2.0

View File

@@ -1,117 +0,0 @@
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Disable Triton compilation
os.environ["NO_TORCH_COMPILE"] = "1"
# Default prompts are available at https://hf.co/sesame/csm-1b
prompt_filepath_conversational_a = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_a.wav"
)
prompt_filepath_conversational_b = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_b.wav"
)
SPEAKER_PROMPTS = {
"conversational_a": {
"text": (
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
"start really early I'd be like okay I'm gonna start revising now and then like "
"you're revising for ages and then I just like start losing steam I didn't do that "
"for the exam we had recently to be fair that was a more of a last minute scenario "
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
"sort of start the day with this not like a panic but like a"
),
"audio": prompt_filepath_conversational_a
},
"conversational_b": {
"text": (
"like a super Mario level. Like it's very like high detail. And like, once you get "
"into the park, it just like, everything looks like a computer game and they have all "
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
"will have like a question block. And if you like, you know, punch it, a coin will "
"come out. So like everyone, when they come into the park, they get like this little "
"bracelet and then you can go punching question blocks around."
),
"audio": prompt_filepath_conversational_b
}
}
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.squeeze(0)
# Resample is lazy so we can always call it
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
)
return audio_tensor
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
audio_tensor = load_prompt_audio(audio_path, sample_rate)
return Segment(text=text, speaker=speaker, audio=audio_tensor)
def main():
# Select the best available device, skipping MPS due to float64 limitations
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Load model
generator = load_csm_1b(device)
# Prepare prompts
prompt_a = prepare_prompt(
SPEAKER_PROMPTS["conversational_a"]["text"],
0,
SPEAKER_PROMPTS["conversational_a"]["audio"],
generator.sample_rate
)
prompt_b = prepare_prompt(
SPEAKER_PROMPTS["conversational_b"]["text"],
1,
SPEAKER_PROMPTS["conversational_b"]["audio"],
generator.sample_rate
)
# Generate conversation
conversation = [
{"text": "Hey how are you doing?", "speaker_id": 0},
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
]
# Generate each utterance
generated_segments = []
prompt_segments = [prompt_a, prompt_b]
for utterance in conversation:
print(f"Generating: {utterance['text']}")
audio_tensor = generator.generate(
text=utterance['text'],
speaker=utterance['speaker_id'],
context=prompt_segments + generated_segments,
max_audio_length_ms=10_000,
)
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
# Concatenate all generations
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
torchaudio.save(
"full_conversation.wav",
all_audio.unsqueeze(0).cpu(),
generator.sample_rate
)
print("Successfully generated full_conversation.wav")
if __name__ == "__main__":
main()

View File

@@ -1,904 +1,53 @@
import os import os
import base64
import json
import time
import math
import gc
import logging import logging
import numpy as np
import torch import torch
import torchaudio import eventlet
import base64
import tempfile
from io import BytesIO from io import BytesIO
from typing import List, Dict, Any, Optional from flask import Flask, render_template, request, jsonify
from flask import Flask, request, send_from_directory, Response from flask_socketio import SocketIO, emit
from flask_cors import CORS import whisper
from flask_socketio import SocketIO, emit, disconnect import torchaudio
from generator import load_csm_1b, Segment from src.models.conversation import Segment
from collections import deque from src.services.tts_service import load_csm_1b
from threading import Lock from src.llm.generator import generate_llm_response
from transformers import pipeline from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from src.audio.streaming import AudioStreamer
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
level=logging.INFO, logger = logging.getLogger(__name__)
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("sesame-server")
# Determine best compute device app = Flask(__name__, static_folder='static', template_folder='templates')
if torch.backends.mps.is_available(): app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'your-secret-key')
device = "mps" socketio = SocketIO(app)
elif torch.cuda.is_available():
try:
# Test CUDA functionality
torch.rand(10, device="cuda")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = "cuda"
logger.info("CUDA is fully functional")
except Exception as e:
logger.warning(f"CUDA available but not working correctly: {e}")
device = "cpu"
else:
device = "cpu"
logger.info("Using CPU")
# Constants and Configuration # Initialize services
SILENCE_THRESHOLD = 0.01 transcription_service = TranscriptionService()
SILENCE_DURATION_SEC = 0.75 tts_service = TextToSpeechService()
MAX_BUFFER_SIZE = 30 # Maximum chunks to buffer before processing audio_streamer = AudioStreamer()
CHUNK_SIZE_MS = 500 # Size of audio chunks when streaming responses
# Define the base directory and static files directory @socketio.on('audio_input')
base_dir = os.path.dirname(os.path.abspath(__file__)) def handle_audio_input(data):
static_dir = os.path.join(base_dir, "static") audio_chunk = data['audio']
os.makedirs(static_dir, exist_ok=True) speaker_id = data['speaker']
# Define a simple energy-based speech detector # Process audio and convert to text
class SpeechDetector: text = transcription_service.transcribe(audio_chunk)
def __init__(self): logging.info(f"Transcribed text: {text}")
self.min_speech_energy = 0.01
self.speech_window = 0.2 # seconds
def detect_speech(self, audio_tensor, sample_rate): # Generate response using Llama 3.2
# Calculate frame size based on window size response_text = tts_service.generate_response(text, speaker_id)
frame_size = int(sample_rate * self.speech_window) logging.info(f"Generated response: {response_text}")
# If audio is shorter than frame size, use the entire audio # Convert response text to audio
if audio_tensor.shape[0] < frame_size: audio_response = tts_service.text_to_speech(response_text, speaker_id)
frames = [audio_tensor]
else:
# Split audio into frames
frames = [audio_tensor[i:i+frame_size] for i in range(0, len(audio_tensor), frame_size)]
# Calculate energy per frame # Stream audio response back to client
energies = [torch.mean(frame**2).item() for frame in frames] socketio.emit('audio_response', {'audio': audio_response})
# Determine if there's speech based on energy threshold if __name__ == '__main__':
has_speech = any(e > self.min_speech_energy for e in energies) socketio.run(app, host='0.0.0.0', port=5000)
return has_speech
speech_detector = SpeechDetector()
logger.info("Initialized simple speech detector")
# Model Loading Functions
def load_speech_models():
"""Load speech generation and recognition models"""
# Load CSM (existing code)
generator = load_csm_1b(device=device)
# Load Whisper model for speech recognition
try:
logger.info(f"Loading speech recognition model on {device}...")
# Try with newer API first
try:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
model_id = "openai/whisper-small"
# Load model and processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=device,
)
processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline with specific parameters
speech_recognizer = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device=device,
)
except Exception as api_error:
logger.warning(f"Newer API loading failed: {api_error}, trying simpler approach")
# Fallback to simpler API
speech_recognizer = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=device
)
logger.info("Speech recognition model loaded successfully")
return generator, speech_recognizer
except Exception as e:
logger.error(f"Error loading speech recognition model: {e}")
return generator, None
# Unpack both models
generator, speech_recognizer = load_speech_models()
# Initialize Llama 3.2 model for conversation responses
def load_llm_model():
"""Load Llama 3.2 model for generating text responses"""
try:
logger.info("Loading Llama 3.2 model for conversational responses...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Determine compute device for LLM
llm_device = "cpu" # Default to CPU for LLM
# Use CUDA if available and there's enough VRAM
if device == "cuda" and torch.cuda.is_available():
try:
free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
# If we have at least 2GB free, use CUDA for LLM
if free_mem > 2 * 1024 * 1024 * 1024:
llm_device = "cuda"
except:
pass
logger.info(f"Using {llm_device} for Llama 3.2 model")
# Load the model with lower precision for efficiency
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if llm_device == "cuda" else torch.float32,
device_map=llm_device
)
# Create a pipeline for easier inference
llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
logger.info("Llama 3.2 model loaded successfully")
return llm
except Exception as e:
logger.error(f"Error loading Llama 3.2 model: {e}")
return None
# Load the LLM model
llm = load_llm_model()
# Set up Flask and Socket.IO
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
# Socket connection management
thread_lock = Lock()
active_clients = {} # Map client_id to client context
# Audio Utility Functions
def decode_audio_data(audio_data: str) -> torch.Tensor:
"""Decode base64 audio data to a torch tensor with improved error handling"""
try:
# Skip empty audio data
if not audio_data or len(audio_data) < 100:
logger.warning("Empty or too short audio data received")
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
# Extract the actual base64 content
if ',' in audio_data:
audio_data = audio_data.split(',')[1]
# Decode base64 audio data
try:
binary_data = base64.b64decode(audio_data)
logger.debug(f"Decoded base64 data: {len(binary_data)} bytes")
# Check if we have enough data for a valid WAV
if len(binary_data) < 44: # WAV header is 44 bytes
logger.warning("Data too small to be a valid WAV file")
return torch.zeros(generator.sample_rate // 2)
except Exception as e:
logger.error(f"Base64 decoding error: {e}")
return torch.zeros(generator.sample_rate // 2)
# Multiple approaches to handle audio data
audio_tensor = None
sample_rate = None
# Approach 1: Direct loading with torchaudio
try:
with BytesIO(binary_data) as temp_file:
temp_file.seek(0)
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
logger.debug(f"Loaded audio: shape={audio_tensor.shape}, rate={sample_rate}Hz")
# Validate tensor
if audio_tensor.numel() == 0 or torch.isnan(audio_tensor).any():
raise ValueError("Invalid audio tensor")
except Exception as e:
logger.warning(f"Direct loading failed: {e}")
# Approach 2: Using wave module and numpy
try:
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
with open(temp_path, 'wb') as f:
f.write(binary_data)
import wave
with wave.open(temp_path, 'rb') as wf:
n_channels = wf.getnchannels()
sample_width = wf.getsampwidth()
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
frames = wf.readframes(n_frames)
# Convert to numpy array
if sample_width == 2: # 16-bit audio
data = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
elif sample_width == 1: # 8-bit audio
data = np.frombuffer(frames, dtype=np.uint8).astype(np.float32) / 128.0 - 1.0
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
# Convert to mono if needed
if n_channels > 1:
data = data.reshape(-1, n_channels)
data = data.mean(axis=1)
# Convert to torch tensor
audio_tensor = torch.from_numpy(data)
logger.info(f"Loaded audio using wave: shape={audio_tensor.shape}")
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e2:
logger.error(f"All audio loading methods failed: {e2}")
return torch.zeros(generator.sample_rate // 2)
# Format corrections
if audio_tensor is None:
return torch.zeros(generator.sample_rate // 2)
# Ensure audio is mono
if len(audio_tensor.shape) > 1 and audio_tensor.shape[0] > 1:
audio_tensor = torch.mean(audio_tensor, dim=0)
# Ensure 1D tensor
audio_tensor = audio_tensor.squeeze()
# Resample if needed
if sample_rate != generator.sample_rate:
try:
logger.debug(f"Resampling from {sample_rate}Hz to {generator.sample_rate}Hz")
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
audio_tensor = resampler(audio_tensor)
except Exception as e:
logger.warning(f"Resampling error: {e}")
# Normalize audio to avoid issues
if torch.abs(audio_tensor).max() > 0:
audio_tensor = audio_tensor / torch.abs(audio_tensor).max()
return audio_tensor
except Exception as e:
logger.error(f"Unhandled error in decode_audio_data: {e}")
return torch.zeros(generator.sample_rate // 2)
def encode_audio_data(audio_tensor: torch.Tensor) -> str:
"""Encode torch tensor audio to base64 string"""
try:
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
except Exception as e:
logger.error(f"Error encoding audio: {e}")
# Return a minimal silent audio file
silence = torch.zeros(generator.sample_rate // 2).unsqueeze(0)
buf = BytesIO()
torchaudio.save(buf, silence, generator.sample_rate, format="wav")
buf.seek(0)
return f"data:audio/wav;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
def process_speech(audio_tensor: torch.Tensor, client_id: str) -> str:
"""Process speech with speech recognition"""
if not speech_recognizer:
# Fallback to basic detection if model failed to load
return detect_speech_energy(audio_tensor)
try:
# Save audio to temp file for Whisper
temp_path = os.path.join(base_dir, f"temp_{time.time()}.wav")
torchaudio.save(temp_path, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate)
# Perform speech recognition - handle the warning differently
# Just pass the path without any additional parameters
try:
# First try - use default parameters
result = speech_recognizer(temp_path)
transcription = result["text"]
except Exception as whisper_error:
logger.warning(f"First transcription attempt failed: {whisper_error}")
# Try with explicit parameters for older versions of transformers
import numpy as np
import soundfile as sf
# Load audio as numpy array
audio_np, sr = sf.read(temp_path)
if sr != 16000:
# Whisper expects 16kHz audio
from scipy import signal
audio_np = signal.resample(audio_np, int(len(audio_np) * 16000 / sr))
# Try with numpy array directly
result = speech_recognizer(audio_np)
transcription = result["text"]
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
# Return empty string if no speech detected
if not transcription or transcription.isspace():
return "I didn't detect any speech. Could you please try again?"
logger.info(f"Transcription successful: '{transcription}'")
return transcription
except Exception as e:
logger.error(f"Speech recognition error: {e}")
return "Sorry, I couldn't understand what you said. Could you try again?"
def detect_speech_energy(audio_tensor: torch.Tensor) -> str:
"""Basic speech detection based on audio energy levels"""
# Calculate audio energy
energy = torch.mean(torch.abs(audio_tensor)).item()
logger.debug(f"Audio energy detected: {energy:.6f}")
# Generate response based on energy level
if energy > 0.1: # Louder speech
return "I heard you speaking clearly. How can I help you today?"
elif energy > 0.05: # Moderate speech
return "I heard you say something. Could you please repeat that?"
elif energy > 0.02: # Soft speech
return "I detected some speech, but it was quite soft. Could you speak up a bit?"
else: # Very soft or no speech
return "I didn't detect any speech. Could you please try again?"
def generate_response(text: str, conversation_history: List[Segment]) -> str:
"""Generate a contextual response based on the transcribed text using Llama 3.2"""
# If LLM is not available, use simple responses
if llm is None:
return generate_simple_response(text)
try:
# Create a conversational prompt based on history
# Format: recent conversation turns (up to 4) + current user input
history_str = ""
# Add up to 4 recent conversation turns (excluding the current one)
recent_segments = [
seg for seg in conversation_history[-8:]
if seg.text and not seg.text.isspace()
]
for i, segment in enumerate(recent_segments):
speaker_name = "User" if segment.speaker == 0 else "Assistant"
history_str += f"{speaker_name}: {segment.text}\n"
# Construct the prompt for Llama 3.2
prompt = f"""<|system|>
You are Sesame, a helpful, friendly and concise voice assistant.
Keep your responses conversational, natural, and to the point.
Respond to the user's latest message in the context of the conversation.
<|end|>
{history_str}
User: {text}
Assistant:"""
logger.debug(f"LLM Prompt: {prompt}")
# Generate response with the LLM
result = llm(
prompt,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
# Extract the generated text
response = result[0]["generated_text"]
# Extract just the Assistant's response (after the prompt)
response = response.split("Assistant:")[-1].strip()
# Clean up and ensure it's not too long for TTS
response = response.split("\n")[0].strip()
if len(response) > 200:
response = response[:197] + "..."
logger.info(f"LLM response: {response}")
return response
except Exception as e:
logger.error(f"Error generating LLM response: {e}")
# Fall back to simple responses
return generate_simple_response(text)
def generate_simple_response(text: str) -> str:
"""Generate a simple rule-based response as fallback"""
responses = {
"hello": "Hello there! How can I help you today?",
"hi": "Hi there! What can I do for you?",
"how are you": "I'm doing well, thanks for asking! How about you?",
"what is your name": "I'm Sesame, your voice assistant. How can I help you?",
"who are you": "I'm Sesame, an AI voice assistant. I'm here to chat with you!",
"bye": "Goodbye! It was nice chatting with you.",
"thank you": "You're welcome! Is there anything else I can help with?",
"weather": "I don't have real-time weather data, but I hope it's nice where you are!",
"help": "I can chat with you using natural voice. Just speak normally and I'll respond.",
"what can you do": "I can have a conversation with you, answer questions, and provide assistance with various topics.",
}
text_lower = text.lower()
# Check for matching keywords
for key, response in responses.items():
if key in text_lower:
return response
# Default responses based on text length
if not text:
return "I didn't catch that. Could you please repeat?"
elif len(text) < 10:
return "Thanks for your message. Could you elaborate a bit more?"
else:
return f"I heard you say something about that. Can you tell me more?"
# Flask Routes
@app.route('/')
def index():
return send_from_directory(base_dir, 'index.html')
@app.route('/favicon.ico')
def favicon():
if os.path.exists(os.path.join(static_dir, 'favicon.ico')):
return send_from_directory(static_dir, 'favicon.ico')
return Response(status=204)
@app.route('/voice-chat.js')
def voice_chat_js():
return send_from_directory(base_dir, 'voice-chat.js')
@app.route('/static/<path:path>')
def serve_static(path):
return send_from_directory(static_dir, path)
# Socket.IO Event Handlers
@socketio.on('connect')
def handle_connect():
client_id = request.sid
logger.info(f"Client connected: {client_id}")
# Initialize client context
active_clients[client_id] = {
'context_segments': [],
'streaming_buffer': [],
'is_streaming': False,
'is_silence': False,
'last_active_time': time.time(),
'energy_window': deque(maxlen=10)
}
emit('status', {'type': 'connected', 'message': 'Connected to server'})
@socketio.on('disconnect')
def handle_disconnect():
client_id = request.sid
if client_id in active_clients:
del active_clients[client_id]
logger.info(f"Client disconnected: {client_id}")
@socketio.on('generate')
def handle_generate(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
try:
text = data.get('text', '')
speaker_id = data.get('speaker', 0)
logger.info(f"Generating audio for: '{text}' with speaker {speaker_id}")
# Generate audio response
audio_tensor = generator.generate(
text=text,
speaker=speaker_id,
context=active_clients[client_id]['context_segments'],
max_audio_length_ms=10_000,
)
# Add to conversation context
active_clients[client_id]['context_segments'].append(
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
)
# Convert audio to base64 and send back to client
audio_base64 = encode_audio_data(audio_tensor)
emit('audio_response', {
'type': 'audio_response',
'audio': audio_base64,
'text': text
})
except Exception as e:
logger.error(f"Error generating audio: {e}")
emit('error', {
'type': 'error',
'message': f"Error generating audio: {str(e)}"
})
@socketio.on('add_to_context')
def handle_add_to_context(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
try:
text = data.get('text', '')
speaker_id = data.get('speaker', 0)
audio_data = data.get('audio', '')
# Convert received audio to tensor
audio_tensor = decode_audio_data(audio_data)
# Add to conversation context
active_clients[client_id]['context_segments'].append(
Segment(text=text, speaker=speaker_id, audio=audio_tensor)
)
emit('context_updated', {
'type': 'context_updated',
'message': 'Audio added to context'
})
except Exception as e:
logger.error(f"Error adding to context: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing audio: {str(e)}"
})
@socketio.on('clear_context')
def handle_clear_context():
client_id = request.sid
if client_id in active_clients:
active_clients[client_id]['context_segments'] = []
emit('context_updated', {
'type': 'context_updated',
'message': 'Context cleared'
})
@socketio.on('stream_audio')
def handle_stream_audio(data):
client_id = request.sid
if client_id not in active_clients:
emit('error', {'message': 'Client not registered'})
return
client = active_clients[client_id]
try:
speaker_id = data.get('speaker', 0)
audio_data = data.get('audio', '')
# Skip if no audio data (might be just a connection test)
if not audio_data:
logger.debug("Empty audio data received, ignoring")
return
# Convert received audio to tensor
audio_chunk = decode_audio_data(audio_data)
# Start streaming mode if not already started
if not client['is_streaming']:
client['is_streaming'] = True
client['streaming_buffer'] = []
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
logger.info(f"[{client_id[:8]}] Streaming started with speaker ID: {speaker_id}")
emit('streaming_status', {
'type': 'streaming_status',
'status': 'started'
})
# Calculate audio energy for silence detection
chunk_energy = torch.mean(torch.abs(audio_chunk)).item()
client['energy_window'].append(chunk_energy)
avg_energy = sum(client['energy_window']) / len(client['energy_window'])
# Check if audio is silent
current_silence = avg_energy < SILENCE_THRESHOLD
# Track silence transition
if not client['is_silence'] and current_silence:
# Transition to silence
client['is_silence'] = True
client['last_active_time'] = time.time()
elif client['is_silence'] and not current_silence:
# User started talking again
client['is_silence'] = False
# Add chunk to buffer regardless of silence state
client['streaming_buffer'].append(audio_chunk)
# Check if silence has persisted long enough to consider "stopped talking"
silence_elapsed = time.time() - client['last_active_time']
if client['is_silence'] and silence_elapsed >= SILENCE_DURATION_SEC and len(client['streaming_buffer']) > 0:
# User has stopped talking - process the collected audio
logger.info(f"[{client_id[:8]}] Processing audio after {silence_elapsed:.2f}s of silence")
process_complete_utterance(client_id, client, speaker_id)
# If buffer gets too large without silence, process it anyway
elif len(client['streaming_buffer']) >= MAX_BUFFER_SIZE:
logger.info(f"[{client_id[:8]}] Processing long audio segment without silence")
process_complete_utterance(client_id, client, speaker_id, is_incomplete=True)
# Keep half of the buffer for context (sliding window approach)
half_point = len(client['streaming_buffer']) // 2
client['streaming_buffer'] = client['streaming_buffer'][half_point:]
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Error processing streaming audio: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing streaming audio: {str(e)}"
})
def process_complete_utterance(client_id, client, speaker_id, is_incomplete=False):
"""Process a complete utterance (after silence or buffer limit)"""
try:
# Combine audio chunks
full_audio = torch.cat(client['streaming_buffer'], dim=0)
# Process audio to generate a response (using speech recognition)
generated_text = process_speech(full_audio, client_id)
# Add suffix for incomplete utterances
if is_incomplete:
generated_text += " (processing continued speech...)"
# Log the generated text
logger.info(f"[{client_id[:8]}] Generated text: '{generated_text}'")
# Handle the result
if generated_text:
# Add user message to context
user_segment = Segment(text=generated_text, speaker=speaker_id, audio=full_audio)
client['context_segments'].append(user_segment)
# Send the text to client
emit('transcription', {
'type': 'transcription',
'text': generated_text
}, room=client_id)
# Only generate a response if this is a complete utterance
if not is_incomplete:
# Generate a contextual response
response_text = generate_response(generated_text, client['context_segments'])
logger.info(f"[{client_id[:8]}] Generating response: '{response_text}'")
# Let the client know we're processing
emit('processing_status', {
'type': 'processing_status',
'status': 'generating_audio',
'message': 'Generating audio response...'
}, room=client_id)
# Generate audio for the response
try:
# Use a different speaker than the user
ai_speaker_id = 1 if speaker_id == 0 else 0
# Generate the full response
audio_tensor = generator.generate(
text=response_text,
speaker=ai_speaker_id,
context=client['context_segments'],
max_audio_length_ms=10_000,
)
# Add response to context
ai_segment = Segment(
text=response_text,
speaker=ai_speaker_id,
audio=audio_tensor
)
client['context_segments'].append(ai_segment)
# CHANGE HERE: Use the streaming function instead of sending all at once
# Check if the audio is short enough to send at once or if it should be streamed
if audio_tensor.size(0) < generator.sample_rate * 2: # Less than 2 seconds
# For short responses, just send in one go for better responsiveness
audio_base64 = encode_audio_data(audio_tensor)
emit('audio_response', {
'type': 'audio_response',
'text': response_text,
'audio': audio_base64
}, room=client_id)
logger.info(f"[{client_id[:8]}] Short audio response sent in one piece")
else:
# For longer responses, use streaming
logger.info(f"[{client_id[:8]}] Using streaming for audio response")
# Start a new thread for streaming to avoid blocking the main thread
import threading
stream_thread = threading.Thread(
target=stream_audio_to_client,
args=(client_id, audio_tensor, response_text, ai_speaker_id)
)
stream_thread.start()
except Exception as e:
logger.error(f"Error generating audio response: {e}")
emit('error', {
'type': 'error',
'message': "Sorry, there was an error generating the audio response."
}, room=client_id)
else:
# If processing failed, send a notification
emit('error', {
'type': 'error',
'message': "Sorry, I couldn't understand what you said. Could you try again?"
}, room=client_id)
# Only clear buffer for complete utterances
if not is_incomplete:
# Reset state
client['streaming_buffer'] = []
client['energy_window'].clear()
client['is_silence'] = False
client['last_active_time'] = time.time()
except Exception as e:
logger.error(f"Error processing utterance: {e}")
emit('error', {
'type': 'error',
'message': f"Error processing audio: {str(e)}"
}, room=client_id)
@socketio.on('stop_streaming')
def handle_stop_streaming(data):
client_id = request.sid
if client_id not in active_clients:
return
client = active_clients[client_id]
client['is_streaming'] = False
if client['streaming_buffer'] and len(client['streaming_buffer']) > 5:
# Process any remaining audio in the buffer
logger.info(f"[{client_id[:8]}] Processing final audio buffer on stop")
process_complete_utterance(client_id, client, data.get("speaker", 0))
client['streaming_buffer'] = []
emit('streaming_status', {
'type': 'streaming_status',
'status': 'stopped'
})
def stream_audio_to_client(client_id, audio_tensor, text, speaker_id, chunk_size_ms=CHUNK_SIZE_MS):
"""Stream audio to client in chunks to simulate real-time generation"""
try:
if client_id not in active_clients:
logger.warning(f"Client {client_id} not found for streaming")
return
# Calculate chunk size in samples
chunk_size = int(generator.sample_rate * chunk_size_ms / 1000)
total_chunks = math.ceil(audio_tensor.size(0) / chunk_size)
logger.info(f"Streaming audio in {total_chunks} chunks of {chunk_size_ms}ms each")
# Send initial response with text but no audio yet
socketio.emit('audio_response_start', {
'type': 'audio_response_start',
'text': text,
'total_chunks': total_chunks
}, room=client_id)
# Stream each chunk
for i in range(total_chunks):
start_idx = i * chunk_size
end_idx = min(start_idx + chunk_size, audio_tensor.size(0))
# Extract chunk
chunk = audio_tensor[start_idx:end_idx]
# Encode chunk
chunk_base64 = encode_audio_data(chunk)
# Send chunk
socketio.emit('audio_response_chunk', {
'type': 'audio_response_chunk',
'chunk_index': i,
'total_chunks': total_chunks,
'audio': chunk_base64,
'is_last': i == total_chunks - 1
}, room=client_id)
# Brief pause between chunks to simulate streaming
time.sleep(0.1)
# Send completion message
socketio.emit('audio_response_complete', {
'type': 'audio_response_complete',
'text': text
}, room=client_id)
logger.info(f"Audio streaming complete: {total_chunks} chunks sent")
except Exception as e:
logger.error(f"Error streaming audio to client: {e}")
import traceback
traceback.print_exc()
# Main server start
if __name__ == "__main__":
print(f"\n{'='*60}")
print(f"🔊 Sesame AI Voice Chat Server")
print(f"{'='*60}")
print(f"📡 Server Information:")
print(f" - Local URL: http://localhost:5000")
print(f" - Network URL: http://<your-ip-address>:5000")
print(f"{'='*60}")
print(f"🌐 Device: {device.upper()}")
print(f"🧠 Models: Sesame CSM (TTS only)")
print(f"🔧 Serving from: {os.path.join(base_dir, 'index.html')}")
print(f"{'='*60}")
print(f"Ready to receive connections! Press Ctrl+C to stop the server.\n")
socketio.run(app, host="0.0.0.0", port=5000, debug=False)

View File

@@ -1,13 +0,0 @@
from setuptools import setup, find_packages
import os
# Read requirements from requirements.txt
with open('requirements.txt') as f:
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
setup(
name='csm',
version='0.1.0',
packages=find_packages(),
install_requires=requirements,
)

View File

@@ -0,0 +1,28 @@
from scipy.io import wavfile
import numpy as np
import torchaudio
def load_audio(file_path):
sample_rate, audio_data = wavfile.read(file_path)
return sample_rate, audio_data
def normalize_audio(audio_data):
audio_data = audio_data.astype(np.float32)
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data /= max_val
return audio_data
def reduce_noise(audio_data, noise_factor=0.1):
noise = np.random.randn(len(audio_data))
noisy_audio = audio_data + noise_factor * noise
return noisy_audio
def save_audio(file_path, sample_rate, audio_data):
torchaudio.save(file_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
def process_audio(file_path, output_path):
sample_rate, audio_data = load_audio(file_path)
normalized_audio = normalize_audio(audio_data)
denoised_audio = reduce_noise(normalized_audio)
save_audio(output_path, sample_rate, denoised_audio)

View File

@@ -0,0 +1,35 @@
from flask import Blueprint, request
from flask_socketio import SocketIO, emit
from src.audio.processor import process_audio
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
streaming_bp = Blueprint('streaming', __name__)
socketio = SocketIO()
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
@socketio.on('audio_stream')
def handle_audio_stream(data):
audio_chunk = data['audio']
speaker_id = data['speaker']
# Process the audio chunk
processed_audio = process_audio(audio_chunk)
# Transcribe the audio to text
transcription = transcription_service.transcribe(processed_audio)
# Generate a response using the LLM
response_text = generate_response(transcription, speaker_id)
# Convert the response text back to audio
response_audio = tts_service.convert_text_to_speech(response_text, speaker_id)
# Emit the response audio back to the client
emit('audio_response', {'audio': response_audio})
def generate_response(transcription, speaker_id):
# Placeholder for the actual response generation logic
return f"Response to: {transcription}"

View File

@@ -15,14 +15,10 @@ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
class Segment: class Segment:
speaker: int speaker: int
text: str text: str
# (num_samples,), sample_rate = 24_000
audio: torch.Tensor audio: torch.Tensor
def load_llama3_tokenizer(): def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B" tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token bos = tokenizer.bos_token
@@ -78,10 +74,8 @@ class Generator:
frame_tokens = [] frame_tokens = []
frame_masks = [] frame_masks = []
# (K, T)
audio = audio.to(self.device) audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
@@ -96,10 +90,6 @@ class Generator:
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
(seq_len, 33), (seq_len, 33)
"""
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio) audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
@@ -146,7 +136,7 @@ class Generator:
for _ in range(max_generation_len): for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
if torch.all(sample == 0): if torch.all(sample == 0):
break # eos break
samples.append(sample) samples.append(sample)
@@ -158,10 +148,6 @@ class Generator:
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
# This applies an imperceptible watermark to identify audio as AI-generated.
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
# Please be a responsible AI citizen and keep the watermarking in place.
# If using CSM 1B in another application, use your own private key and keep it secret.
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)

View File

@@ -0,0 +1,14 @@
from transformers import AutoTokenizer
def load_llama3_tokenizer():
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return tokenizer
def tokenize_text(text: str, tokenizer) -> list:
tokens = tokenizer.encode(text, return_tensors='pt')
return tokens
def decode_tokens(tokens: list, tokenizer) -> str:
text = tokenizer.decode(tokens, skip_special_tokens=True)
return text

View File

@@ -0,0 +1,28 @@
from dataclasses import dataclass
import torch
@dataclass
class AudioModel:
model: torch.nn.Module
sample_rate: int
def __post_init__(self):
self.model.eval()
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
processed_audio = self.model(audio_tensor)
return processed_audio
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
if self.sample_rate != target_sample_rate:
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
return resampled_audio
return audio_tensor
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
def load_model(self, path: str):
self.model.load_state_dict(torch.load(path))
self.model.eval()

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class Conversation:
context: List[str] = field(default_factory=list)
current_speaker: Optional[int] = None
def add_message(self, message: str, speaker: int):
self.context.append(f"Speaker {speaker}: {message}")
self.current_speaker = speaker
def get_context(self) -> List[str]:
return self.context
def clear_context(self):
self.context.clear()
self.current_speaker = None
def get_last_message(self) -> Optional[str]:
if self.context:
return self.context[-1]
return None

View File

@@ -0,0 +1,25 @@
from typing import List
import torchaudio
import torch
from generator import load_csm_1b, Segment
class TranscriptionService:
def __init__(self, model_device: str = "cpu"):
self.generator = load_csm_1b(device=model_device)
def transcribe_audio(self, audio_path: str) -> str:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = self._resample_audio(audio_tensor, sample_rate)
transcription = self.generator.generate_transcription(audio_tensor)
return transcription
def _resample_audio(self, audio_tensor: torch.Tensor, orig_freq: int) -> torch.Tensor:
target_sample_rate = self.generator.sample_rate
if orig_freq != target_sample_rate:
audio_tensor = torchaudio.functional.resample(audio_tensor.squeeze(0), orig_freq=orig_freq, new_freq=target_sample_rate)
return audio_tensor
def transcribe_audio_stream(self, audio_chunks: List[torch.Tensor]) -> str:
combined_audio = torch.cat(audio_chunks, dim=1)
transcription = self.generator.generate_transcription(combined_audio)
return transcription

View File

@@ -0,0 +1,24 @@
from dataclasses import dataclass
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from src.llm.generator import load_csm_1b
@dataclass
class TextToSpeechService:
generator: any
def __init__(self, device: str = "cuda"):
self.generator = load_csm_1b(device=device)
def text_to_speech(self, text: str, speaker: int = 0) -> torch.Tensor:
audio = self.generator.generate(
text=text,
speaker=speaker,
context=[],
max_audio_length_ms=10000,
)
return audio
def save_audio(self, audio: torch.Tensor, file_path: str):
torchaudio.save(file_path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)

View File

@@ -0,0 +1,23 @@
# filepath: /csm-conversation-bot/csm-conversation-bot/src/utils/config.py
import os
class Config:
# General configuration
DEBUG = os.getenv('DEBUG', 'False') == 'True'
SECRET_KEY = os.getenv('SECRET_KEY', 'your_secret_key_here')
# API configuration
API_URL = os.getenv('API_URL', 'http://localhost:5000')
# Model configuration
LLM_MODEL_PATH = os.getenv('LLM_MODEL_PATH', 'path/to/llm/model')
AUDIO_MODEL_PATH = os.getenv('AUDIO_MODEL_PATH', 'path/to/audio/model')
# Socket.IO configuration
SOCKETIO_MESSAGE_QUEUE = os.getenv('SOCKETIO_MESSAGE_QUEUE', 'redis://localhost:6379/0')
# Logging configuration
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
# Other configurations can be added as needed

View File

@@ -0,0 +1,14 @@
import logging
def setup_logger(name, log_file, level=logging.INFO):
handler = logging.FileHandler(log_file)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(handler)
return logger
# Example usage:
# logger = setup_logger('my_logger', 'app.log')

View File

@@ -0,0 +1,105 @@
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
color: #333;
margin: 0;
padding: 0;
}
header {
background: #4c84ff;
color: #fff;
padding: 10px 0;
text-align: center;
}
h1 {
margin: 0;
font-size: 2.5rem;
}
.container {
width: 80%;
margin: auto;
overflow: hidden;
}
.conversation {
background: #fff;
padding: 20px;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
max-height: 400px;
overflow-y: auto;
}
.message {
padding: 10px;
margin: 10px 0;
border-radius: 5px;
}
.user {
background: #e3f2fd;
text-align: right;
}
.ai {
background: #f1f1f1;
text-align: left;
}
.controls {
display: flex;
justify-content: space-between;
margin-top: 20px;
}
button {
padding: 10px 15px;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
}
button:hover {
background: #3367d6;
color: #fff;
}
.visualizer-container {
height: 150px;
background: #000;
border-radius: 5px;
margin-top: 20px;
}
.visualizer-label {
color: rgba(255, 255, 255, 0.7);
text-align: center;
padding: 10px;
}
.status-indicator {
display: flex;
align-items: center;
margin-top: 10px;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
background-color: #ccc;
margin-right: 10px;
}
.status-dot.active {
background-color: #4CAF50;
}
.status-text {
font-size: 0.9em;
color: #666;
}

31
Backend/static/index.html Normal file
View File

@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CSM Conversation Bot</title>
<link rel="stylesheet" href="css/styles.css">
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<script src="js/client.js" defer></script>
</head>
<body>
<header>
<h1>CSM Conversation Bot</h1>
<p>Talk to the AI and get responses in real-time!</p>
</header>
<main>
<div id="conversation" class="conversation"></div>
<div class="controls">
<button id="startButton">Start Conversation</button>
<button id="stopButton">Stop Conversation</button>
</div>
<div class="status-indicator">
<div id="statusDot" class="status-dot"></div>
<div id="statusText">Disconnected</div>
</div>
</main>
<footer>
<p>Powered by CSM and Llama 3.2</p>
</footer>
</body>
</html>

131
Backend/static/js/client.js Normal file
View File

@@ -0,0 +1,131 @@
// This file contains the client-side JavaScript code that handles audio streaming and communication with the server.
const SERVER_URL = window.location.hostname === 'localhost' ?
'http://localhost:5000' : window.location.origin;
const elements = {
conversation: document.getElementById('conversation'),
streamButton: document.getElementById('streamButton'),
clearButton: document.getElementById('clearButton'),
speakerSelection: document.getElementById('speakerSelect'),
statusDot: document.getElementById('statusDot'),
statusText: document.getElementById('statusText'),
};
const state = {
socket: null,
isStreaming: false,
currentSpeaker: 0,
};
// Initialize the application
function initializeApp() {
setupSocketConnection();
setupEventListeners();
}
// Setup Socket.IO connection
function setupSocketConnection() {
state.socket = io(SERVER_URL);
state.socket.on('connect', () => {
updateConnectionStatus(true);
});
state.socket.on('disconnect', () => {
updateConnectionStatus(false);
});
state.socket.on('audio_response', handleAudioResponse);
state.socket.on('transcription', handleTranscription);
}
// Setup event listeners
function setupEventListeners() {
elements.streamButton.addEventListener('click', toggleStreaming);
elements.clearButton.addEventListener('click', clearConversation);
elements.speakerSelection.addEventListener('change', (event) => {
state.currentSpeaker = event.target.value;
});
}
// Update connection status UI
function updateConnectionStatus(isConnected) {
elements.statusDot.classList.toggle('active', isConnected);
elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected';
}
// Toggle streaming state
function toggleStreaming() {
if (state.isStreaming) {
stopStreaming();
} else {
startStreaming();
}
}
// Start streaming audio to the server
function startStreaming() {
if (state.isStreaming) return;
navigator.mediaDevices.getUserMedia({ audio: true })
.then(stream => {
const mediaRecorder = new MediaRecorder(stream);
mediaRecorder.start();
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
sendAudioChunk(event.data);
}
};
mediaRecorder.onstop = () => {
state.isStreaming = false;
elements.streamButton.innerHTML = 'Start Conversation';
};
state.isStreaming = true;
elements.streamButton.innerHTML = 'Stop Conversation';
})
.catch(err => {
console.error('Error accessing microphone:', err);
});
}
// Stop streaming audio
function stopStreaming() {
if (!state.isStreaming) return;
// Logic to stop the media recorder would go here
}
// Send audio chunk to server
function sendAudioChunk(audioData) {
const reader = new FileReader();
reader.onloadend = () => {
const arrayBuffer = reader.result;
state.socket.emit('audio_chunk', { audio: arrayBuffer, speaker: state.currentSpeaker });
};
reader.readAsArrayBuffer(audioData);
}
// Handle audio response from server
function handleAudioResponse(data) {
const audioElement = new Audio(URL.createObjectURL(new Blob([data.audio])));
audioElement.play();
}
// Handle transcription response from server
function handleTranscription(data) {
const messageElement = document.createElement('div');
messageElement.textContent = `AI: ${data.transcription}`;
elements.conversation.appendChild(messageElement);
}
// Clear conversation history
function clearConversation() {
elements.conversation.innerHTML = '';
}
// Initialize the application when DOM is fully loaded
document.addEventListener('DOMContentLoaded', initializeApp);

View File

@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CSM Conversation Bot</title>
<link rel="stylesheet" href="../static/css/styles.css">
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<script src="../static/js/client.js" defer></script>
</head>
<body>
<header>
<h1>CSM Conversation Bot</h1>
<p>Talk to the AI and get responses in real-time!</p>
</header>
<main>
<div class="chat-container">
<div class="conversation" id="conversation"></div>
<input type="text" id="userInput" placeholder="Type your message..." />
<button id="sendButton">Send</button>
</div>
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<div class="status-text" id="statusText">Not connected</div>
</div>
</main>
<footer>
<p>Powered by CSM and Llama 3.2</p>
</footer>
</body>
</html>

View File

@@ -1,50 +0,0 @@
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
generator = load_csm_1b(device=device)
speakers = [0, 1, 0, 0]
transcripts = [
"Hey how are you doing.",
"Pretty good, pretty good.",
"I'm great.",
"So happy to be speaking to you.",
]
audio_paths = [
"utterance_0.wav",
"utterance_1.wav",
"utterance_2.wav",
"utterance_3.wav",
]
def load_audio(audio_path):
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
)
return audio_tensor
segments = [
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
]
audio = generator.generate(
text="Me too, this is some cool stuff huh?",
speaker=1,
context=segments,
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)

File diff suppressed because it is too large Load Diff

View File

@@ -1,79 +0,0 @@
import argparse
import silentcipher
import torch
import torchaudio
# This watermark key is public, it is not secure.
# If using CSM 1B in another application, use a new private key and keep it secret.
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
def cli_check_audio() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--audio_path", type=str, required=True)
args = parser.parse_args()
check_audio_from_file(args.audio_path)
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
model = silentcipher.get_model(
model_type="44.1k",
device=device,
)
return model
@torch.inference_mode()
def watermark(
watermarker: silentcipher.server.Model,
audio_array: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> tuple[torch.Tensor, int]:
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
output_sample_rate = min(44100, sample_rate)
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
return encoded, output_sample_rate
@torch.inference_mode()
def verify(
watermarker: silentcipher.server.Model,
watermarked_audio: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> bool:
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
is_watermarked = result["status"]
if is_watermarked:
is_csm_watermarked = result["messages"][0] == watermark_key
else:
is_csm_watermarked = False
return is_watermarked and is_csm_watermarked
def check_audio_from_file(audio_path: str) -> None:
watermarker = load_watermarker(device="cuda")
audio_array, sample_rate = load_audio(audio_path)
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
outcome = "Watermarked" if is_watermarked else "Not watermarked"
print(f"{outcome}: {audio_path}")
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
audio_array, sample_rate = torchaudio.load(audio_path)
audio_array = audio_array.mean(dim=0)
return audio_array, int(sample_rate)
if __name__ == "__main__":
cli_check_audio()