This commit is contained in:
Surya Vemulapalli
2025-03-29 21:37:00 -04:00
11 changed files with 1328 additions and 0 deletions

46
Backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,46 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
.env
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# Project specific
.python-version
*.wav
output_*/
basic_audio.wav
full_conversation.wav
context_audio.wav
# Model files
*.pt
*.ckpt

154
Backend/README.md Normal file
View File

@@ -0,0 +1,154 @@
# CSM
**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b).
---
CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation.
## Requirements
* A CUDA-compatible GPU
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
* Similarly, Python 3.10 is recommended, but newer versions may be fine
* For some audio operations, `ffmpeg` may be required
* Access to the following Hugging Face models:
* [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
* [CSM-1B](https://huggingface.co/sesame/csm-1b)
### Setup
```bash
git clone git@github.com:SesameAILabs/csm.git
cd csm
python3.10 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
# Disable lazy compilation in Mimi
export NO_TORCH_COMPILE=1
# You will need access to CSM-1B and Llama-3.2-1B
huggingface-cli login
```
### Windows Setup
The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
## Quickstart
This script will generate a conversation between 2 characters, using a prompt for each character.
```bash
python run_csm.py
```
## Usage
If you want to write your own applications with CSM, the following examples show basic usage.
#### Generate a sentence
This will use a random speaker identity, as no prompt or context is provided.
```python
from generator import load_csm_1b
import torchaudio
import torch
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
generator = load_csm_1b(device=device)
audio = generator.generate(
text="Hello from Sesame.",
speaker=0,
context=[],
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
#### Generate with context
CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance.
NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM.
```python
from generator import Segment
speakers = [0, 1, 0, 0]
transcripts = [
"Hey how are you doing.",
"Pretty good, pretty good.",
"I'm great.",
"So happy to be speaking to you.",
]
audio_paths = [
"utterance_0.wav",
"utterance_1.wav",
"utterance_2.wav",
"utterance_3.wav",
]
def load_audio(audio_path):
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
)
return audio_tensor
segments = [
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
]
audio = generator.generate(
text="Me too, this is some cool stuff huh?",
speaker=1,
context=segments,
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
```
## FAQ
**Does this model come with any voices?**
The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
**Can I converse with the model?**
CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
**Does it support other languages?**
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
## Misuse and abuse ⚠️
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
---
## Authors
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.

176
Backend/generator.py Normal file
View File

@@ -0,0 +1,176 @@
from dataclasses import dataclass
from typing import List, Tuple
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from models import Model
from moshi.models import loaders
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
@dataclass
class Segment:
speaker: int
text: str
# (num_samples,), sample_rate = 24_000
audio: torch.Tensor
def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
class Generator:
def __init__(
self,
model: Model,
):
self._model = model
self._model.setup_caches(1)
self._text_tokenizer = load_llama3_tokenizer()
device = next(model.parameters()).device
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(32)
self._audio_tokenizer = mimi
self._watermarker = load_watermarker(device=device)
self.sample_rate = mimi.sample_rate
self.device = device
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
frame_tokens = []
frame_masks = []
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), 33).long()
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True
frame_tokens.append(text_frame.to(self.device))
frame_masks.append(text_frame_mask.to(self.device))
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert audio.ndim == 1, "Audio must be single channel"
frame_tokens = []
frame_masks = []
# (K, T)
audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True
frame_tokens.append(audio_frame)
frame_masks.append(audio_frame_mask)
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
(seq_len, 33), (seq_len, 33)
"""
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
@torch.inference_mode()
def generate(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.9,
topk: int = 50,
) -> torch.Tensor:
self._model.reset_caches()
max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
samples = []
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
max_seq_len = 2048
max_context_len = max_seq_len - max_generation_len
if curr_tokens.size(1) >= max_context_len:
raise ValueError(
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
)
for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
if torch.all(sample == 0):
break # eos
samples.append(sample)
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
# This applies an imperceptible watermark to identify audio as AI-generated.
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
# Please be a responsible AI citizen and keep the watermarking in place.
# If using CSM 1B in another application, use your own private key and keep it secret.
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
return audio
def load_csm_1b(device: str = "cuda") -> Generator:
model = Model.from_pretrained("sesame/csm-1b")
model.to(device=device, dtype=torch.bfloat16)
generator = Generator(model)
return generator

304
Backend/index.html Normal file
View File

@@ -0,0 +1,304 @@
/index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Sesame AI Voice Chat</title>
<style>
body {
font-family: 'Arial', sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
.conversation {
border: 1px solid #ccc;
border-radius: 8px;
padding: 15px;
height: 300px;
overflow-y: auto;
margin-bottom: 15px;
}
.message {
margin-bottom: 10px;
padding: 8px;
border-radius: 8px;
}
.user {
background-color: #e3f2fd;
text-align: right;
}
.ai {
background-color: #f1f1f1;
}
.controls {
display: flex;
flex-direction: column;
gap: 10px;
}
.input-row {
display: flex;
gap: 10px;
}
input[type="text"] {
flex-grow: 1;
padding: 8px;
border-radius: 4px;
border: 1px solid #ccc;
}
button {
padding: 8px 16px;
border-radius: 4px;
border: none;
background-color: #4CAF50;
color: white;
cursor: pointer;
}
button:hover {
background-color: #45a049;
}
.recording {
background-color: #f44336;
}
select {
padding: 8px;
border-radius: 4px;
border: 1px solid #ccc;
}
</style>
</head>
<body>
<h1>Sesame AI Voice Chat</h1>
<div class="conversation" id="conversation"></div>
<div class="controls">
<div class="input-row">
<input type="text" id="textInput" placeholder="Type your message...">
<select id="speakerSelect">
<option value="0">Speaker 0</option>
<option value="1">Speaker 1</option>
</select>
<button id="sendText">Send</button>
</div>
<div class="input-row">
<button id="recordAudio">Record Audio</button>
<button id="clearContext">Clear Context</button>
</div>
</div>
<script>
let ws;
let mediaRecorder;
let audioChunks = [];
let isRecording = false;
// DOM elements
const conversationEl = document.getElementById('conversation');
const textInputEl = document.getElementById('textInput');
const speakerSelectEl = document.getElementById('speakerSelect');
const sendTextBtn = document.getElementById('sendText');
const recordAudioBtn = document.getElementById('recordAudio');
const clearContextBtn = document.getElementById('clearContext');
// Connect to WebSocket
function connectWebSocket() {
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
ws = new WebSocket(wsUrl);
ws.onopen = () => {
console.log('WebSocket connected');
addSystemMessage('Connected to server');
};
ws.onmessage = (event) => {
const response = JSON.parse(event.data);
console.log('Received:', response);
if (response.type === 'audio_response') {
// Play audio response
const audio = new Audio(response.audio);
audio.play();
// Add message to conversation
addAIMessage(response.audio);
} else if (response.type === 'error') {
addSystemMessage(`Error: ${response.message}`);
} else if (response.type === 'context_updated') {
addSystemMessage(response.message);
}
};
ws.onclose = () => {
console.log('WebSocket disconnected');
addSystemMessage('Disconnected from server. Reconnecting...');
setTimeout(connectWebSocket, 3000);
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
addSystemMessage('Connection error');
};
}
// Add message to conversation
function addUserMessage(text) {
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'user');
messageEl.textContent = text;
conversationEl.appendChild(messageEl);
conversationEl.scrollTop = conversationEl.scrollHeight;
}
function addAIMessage(audioSrc) {
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'ai');
const audioEl = document.createElement('audio');
audioEl.controls = true;
audioEl.src = audioSrc;
messageEl.appendChild(audioEl);
conversationEl.appendChild(messageEl);
conversationEl.scrollTop = conversationEl.scrollHeight;
}
function addSystemMessage(text) {
const messageEl = document.createElement('div');
messageEl.classList.add('message');
messageEl.textContent = text;
conversationEl.appendChild(messageEl);
conversationEl.scrollTop = conversationEl.scrollHeight;
}
// Send text for audio generation
function sendTextForGeneration() {
const text = textInputEl.value.trim();
const speaker = parseInt(speakerSelectEl.value);
if (!text) return;
addUserMessage(text);
textInputEl.value = '';
const request = {
action: 'generate',
text: text,
speaker: speaker
};
ws.send(JSON.stringify(request));
}
// Audio recording functions
async function setupRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
audioChunks.push(event.data);
}
};
mediaRecorder.onstop = async () => {
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
const audioUrl = URL.createObjectURL(audioBlob);
// Add audio to conversation
addUserMessage('Recorded audio:');
const messageEl = document.createElement('div');
messageEl.classList.add('message', 'user');
const audioEl = document.createElement('audio');
audioEl.controls = true;
audioEl.src = audioUrl;
messageEl.appendChild(audioEl);
conversationEl.appendChild(messageEl);
// Convert to base64
const reader = new FileReader();
reader.readAsDataURL(audioBlob);
reader.onloadend = () => {
const base64Audio = reader.result;
const text = textInputEl.value.trim() || "Recorded audio";
const speaker = parseInt(speakerSelectEl.value);
// Send to server
const request = {
action: 'add_to_context',
text: text,
speaker: speaker,
audio: base64Audio
};
ws.send(JSON.stringify(request));
textInputEl.value = '';
};
audioChunks = [];
recordAudioBtn.textContent = 'Record Audio';
recordAudioBtn.classList.remove('recording');
};
console.log('Recording setup completed');
return true;
} catch (err) {
console.error('Error setting up recording:', err);
addSystemMessage(`Microphone access error: ${err.message}`);
return false;
}
}
function toggleRecording() {
if (isRecording) {
mediaRecorder.stop();
isRecording = false;
} else {
if (!mediaRecorder) {
setupRecording().then(success => {
if (success) startRecording();
});
} else {
startRecording();
}
}
}
function startRecording() {
audioChunks = [];
mediaRecorder.start();
isRecording = true;
recordAudioBtn.textContent = 'Stop Recording';
recordAudioBtn.classList.add('recording');
}
// Event listeners
sendTextBtn.addEventListener('click', sendTextForGeneration);
textInputEl.addEventListener('keypress', (e) => {
if (e.key === 'Enter') sendTextForGeneration();
});
recordAudioBtn.addEventListener('click', toggleRecording);
clearContextBtn.addEventListener('click', () => {
ws.send(JSON.stringify({
action: 'clear_context'
}));
});
// Initialize
window.addEventListener('load', () => {
connectWebSocket();
setupRecording();
});
</script>
</body>
</html>

203
Backend/models.py Normal file
View File

@@ -0,0 +1,203 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
import torchtune
from huggingface_hub import PyTorchModelHubMixin
from torchtune.models import llama3_2
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=4,
num_heads=8,
num_kv_heads=2,
embed_dim=1024,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
FLAVORS = {
"llama-1B": llama3_2_1B,
"llama-100M": llama3_2_100M,
}
def _prepare_transformer(model):
embed_dim = model.tok_embeddings.embedding_dim
model.tok_embeddings = nn.Identity()
model.output = nn.Identity()
return model, embed_dim
def _create_causal_mask(seq_len: int, device: torch.device):
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
"""
Args:
mask: (max_seq_len, max_seq_len)
input_pos: (batch_size, seq_len)
Returns:
(batch_size, seq_len, max_seq_len)
"""
r = mask[input_pos, :]
return r
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
logits = logits / temperature
filter_value: float = -float("Inf")
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
@dataclass
class ModelArgs:
backbone_flavor: str
decoder_flavor: str
text_vocab_size: int
audio_vocab_size: int
audio_num_codebooks: int
class Model(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/SesameAILabs/csm",
pipeline_tag="text-to-speech",
license="apache-2.0",
):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
def generate_frame(
self,
tokens: torch.Tensor,
tokens_mask: torch.Tensor,
input_pos: torch.Tensor,
temperature: float,
topk: int,
) -> torch.Tensor:
"""
Args:
tokens: (batch_size, seq_len, audio_num_codebooks+1)
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
input_pos: (batch_size, seq_len) positions for each token
mask: (batch_size, seq_len, max_seq_len
Returns:
(batch_size, audio_num_codebooks) sampled tokens
"""
dtype = next(self.parameters()).dtype
b, s, _ = tokens.size()
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
c0_sample = sample_topk(c0_logits, topk, temperature)
c0_embed = self._embed_audio(0, c0_sample)
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
dtype=dtype
)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, topk, temperature)
ci_embed = self._embed_audio(i, ci_sample)
curr_h = ci_embed
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
curr_pos = curr_pos[:, -1:] + 1
return curr_sample
def reset_caches(self):
self.backbone.reset_caches()
self.decoder.reset_caches()
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
)
return torch.cat([audio_embeds, text_embeds], dim=-2)

9
Backend/requirements.txt Normal file
View File

@@ -0,0 +1,9 @@
torch==2.4.0
torchaudio==2.4.0
tokenizers==0.21.0
transformers==4.49.0
huggingface_hub==0.28.1
moshi==0.2.2
torchtune==0.4.0
torchao==0.9.0
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master

117
Backend/run_csm.py Normal file
View File

@@ -0,0 +1,117 @@
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Disable Triton compilation
os.environ["NO_TORCH_COMPILE"] = "1"
# Default prompts are available at https://hf.co/sesame/csm-1b
prompt_filepath_conversational_a = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_a.wav"
)
prompt_filepath_conversational_b = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_b.wav"
)
SPEAKER_PROMPTS = {
"conversational_a": {
"text": (
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
"start really early I'd be like okay I'm gonna start revising now and then like "
"you're revising for ages and then I just like start losing steam I didn't do that "
"for the exam we had recently to be fair that was a more of a last minute scenario "
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
"sort of start the day with this not like a panic but like a"
),
"audio": prompt_filepath_conversational_a
},
"conversational_b": {
"text": (
"like a super Mario level. Like it's very like high detail. And like, once you get "
"into the park, it just like, everything looks like a computer game and they have all "
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
"will have like a question block. And if you like, you know, punch it, a coin will "
"come out. So like everyone, when they come into the park, they get like this little "
"bracelet and then you can go punching question blocks around."
),
"audio": prompt_filepath_conversational_b
}
}
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.squeeze(0)
# Resample is lazy so we can always call it
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
)
return audio_tensor
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
audio_tensor = load_prompt_audio(audio_path, sample_rate)
return Segment(text=text, speaker=speaker, audio=audio_tensor)
def main():
# Select the best available device, skipping MPS due to float64 limitations
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Load model
generator = load_csm_1b(device)
# Prepare prompts
prompt_a = prepare_prompt(
SPEAKER_PROMPTS["conversational_a"]["text"],
0,
SPEAKER_PROMPTS["conversational_a"]["audio"],
generator.sample_rate
)
prompt_b = prepare_prompt(
SPEAKER_PROMPTS["conversational_b"]["text"],
1,
SPEAKER_PROMPTS["conversational_b"]["audio"],
generator.sample_rate
)
# Generate conversation
conversation = [
{"text": "Hey how are you doing?", "speaker_id": 0},
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
]
# Generate each utterance
generated_segments = []
prompt_segments = [prompt_a, prompt_b]
for utterance in conversation:
print(f"Generating: {utterance['text']}")
audio_tensor = generator.generate(
text=utterance['text'],
speaker=utterance['speaker_id'],
context=prompt_segments + generated_segments,
max_audio_length_ms=10_000,
)
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
# Concatenate all generations
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
torchaudio.save(
"full_conversation.wav",
all_audio.unsqueeze(0).cpu(),
generator.sample_rate
)
print("Successfully generated full_conversation.wav")
if __name__ == "__main__":
main()

177
Backend/server.py Normal file
View File

@@ -0,0 +1,177 @@
import os
import base64
import json
import asyncio
import torch
import torchaudio
import numpy as np
from io import BytesIO
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from generator import load_csm_1b, Segment
import uvicorn
# Select device
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device)
app = FastAPI()
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins in development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Connection manager to handle multiple clients
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
manager = ConnectionManager()
# Helper function to convert audio data
async def decode_audio_data(audio_data: str) -> torch.Tensor:
"""Decode base64 audio data to a torch tensor"""
try:
# Decode base64 audio data
binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
# Save to a temporary WAV file first
temp_file = BytesIO(binary_data)
# Load audio from binary data, explicitly specifying the format
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
# Resample if needed
if sample_rate != generator.sample_rate:
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0),
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
else:
audio_tensor = audio_tensor.squeeze(0)
return audio_tensor
except Exception as e:
print(f"Error decoding audio: {str(e)}")
# Return a small silent audio segment as fallback
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
"""Encode torch tensor audio to base64 string"""
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
context_segments = [] # Store conversation context
try:
while True:
# Receive JSON data from client
data = await websocket.receive_text()
request = json.loads(data)
action = request.get("action")
if action == "generate":
try:
text = request.get("text", "")
speaker_id = request.get("speaker", 0)
# Generate audio response
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
audio_tensor = generator.generate(
text=text,
speaker=speaker_id,
context=context_segments,
max_audio_length_ms=10_000,
)
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"audio": audio_base64
})
except Exception as e:
print(f"Error generating audio: {str(e)}")
await websocket.send_json({
"type": "error",
"message": f"Error generating audio: {str(e)}"
})
elif action == "add_to_context":
try:
text = request.get("text", "")
speaker_id = request.get("speaker", 0)
audio_data = request.get("audio", "")
# Convert received audio to tensor
audio_tensor = await decode_audio_data(audio_data)
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
await websocket.send_json({
"type": "context_updated",
"message": "Audio added to context"
})
except Exception as e:
print(f"Error adding to context: {str(e)}")
await websocket.send_json({
"type": "error",
"message": f"Error processing audio: {str(e)}"
})
elif action == "clear_context":
context_segments = []
await websocket.send_json({
"type": "context_updated",
"message": "Context cleared"
})
except WebSocketDisconnect:
manager.disconnect(websocket)
print("Client disconnected")
except Exception as e:
print(f"Error: {str(e)}")
await websocket.send_json({
"type": "error",
"message": str(e)
})
manager.disconnect(websocket)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)

13
Backend/setup.py Normal file
View File

@@ -0,0 +1,13 @@
from setuptools import setup, find_packages
import os
# Read requirements from requirements.txt
with open('requirements.txt') as f:
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
setup(
name='csm',
version='0.1.0',
packages=find_packages(),
install_requires=requirements,
)

50
Backend/test.py Normal file
View File

@@ -0,0 +1,50 @@
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
generator = load_csm_1b(device=device)
speakers = [0, 1, 0, 0]
transcripts = [
"Hey how are you doing.",
"Pretty good, pretty good.",
"I'm great.",
"So happy to be speaking to you.",
]
audio_paths = [
"utterance_0.wav",
"utterance_1.wav",
"utterance_2.wav",
"utterance_3.wav",
]
def load_audio(audio_path):
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
)
return audio_tensor
segments = [
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
]
audio = generator.generate(
text="Me too, this is some cool stuff huh?",
speaker=1,
context=segments,
max_audio_length_ms=10_000,
)
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)

79
Backend/watermarking.py Normal file
View File

@@ -0,0 +1,79 @@
import argparse
import silentcipher
import torch
import torchaudio
# This watermark key is public, it is not secure.
# If using CSM 1B in another application, use a new private key and keep it secret.
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
def cli_check_audio() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--audio_path", type=str, required=True)
args = parser.parse_args()
check_audio_from_file(args.audio_path)
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
model = silentcipher.get_model(
model_type="44.1k",
device=device,
)
return model
@torch.inference_mode()
def watermark(
watermarker: silentcipher.server.Model,
audio_array: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> tuple[torch.Tensor, int]:
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
output_sample_rate = min(44100, sample_rate)
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
return encoded, output_sample_rate
@torch.inference_mode()
def verify(
watermarker: silentcipher.server.Model,
watermarked_audio: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> bool:
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
is_watermarked = result["status"]
if is_watermarked:
is_csm_watermarked = result["messages"][0] == watermark_key
else:
is_csm_watermarked = False
return is_watermarked and is_csm_watermarked
def check_audio_from_file(audio_path: str) -> None:
watermarker = load_watermarker(device="cuda")
audio_array, sample_rate = load_audio(audio_path)
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
outcome = "Watermarked" if is_watermarked else "Not watermarked"
print(f"{outcome}: {audio_path}")
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
audio_array, sample_rate = torchaudio.load(audio_path)
audio_array = audio_array.mean(dim=0)
return audio_array, int(sample_rate)
if __name__ == "__main__":
cli_check_audio()