Backend Server Code
This commit is contained in:
46
Backend/.gitignore
vendored
Normal file
46
Backend/.gitignore
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual Environment
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
.python-version
|
||||||
|
*.wav
|
||||||
|
output_*/
|
||||||
|
basic_audio.wav
|
||||||
|
full_conversation.wav
|
||||||
|
context_audio.wav
|
||||||
|
|
||||||
|
# Model files
|
||||||
|
*.pt
|
||||||
|
*.ckpt
|
||||||
154
Backend/README.md
Normal file
154
Backend/README.md
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# CSM
|
||||||
|
|
||||||
|
**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
|
||||||
|
|
||||||
|
A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
|
||||||
|
|
||||||
|
A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
* A CUDA-compatible GPU
|
||||||
|
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
|
||||||
|
* Similarly, Python 3.10 is recommended, but newer versions may be fine
|
||||||
|
* For some audio operations, `ffmpeg` may be required
|
||||||
|
* Access to the following Hugging Face models:
|
||||||
|
* [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||||||
|
* [CSM-1B](https://huggingface.co/sesame/csm-1b)
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:SesameAILabs/csm.git
|
||||||
|
cd csm
|
||||||
|
python3.10 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Disable lazy compilation in Mimi
|
||||||
|
export NO_TORCH_COMPILE=1
|
||||||
|
|
||||||
|
# You will need access to CSM-1B and Llama-3.2-1B
|
||||||
|
huggingface-cli login
|
||||||
|
```
|
||||||
|
|
||||||
|
### Windows Setup
|
||||||
|
|
||||||
|
The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
This script will generate a conversation between 2 characters, using a prompt for each character.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_csm.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
If you want to write your own applications with CSM, the following examples show basic usage.
|
||||||
|
|
||||||
|
#### Generate a sentence
|
||||||
|
|
||||||
|
This will use a random speaker identity, as no prompt or context is provided.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from generator import load_csm_1b
|
||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
audio = generator.generate(
|
||||||
|
text="Hello from Sesame.",
|
||||||
|
speaker=0,
|
||||||
|
context=[],
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Generate with context
|
||||||
|
|
||||||
|
CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance.
|
||||||
|
|
||||||
|
NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from generator import Segment
|
||||||
|
|
||||||
|
speakers = [0, 1, 0, 0]
|
||||||
|
transcripts = [
|
||||||
|
"Hey how are you doing.",
|
||||||
|
"Pretty good, pretty good.",
|
||||||
|
"I'm great.",
|
||||||
|
"So happy to be speaking to you.",
|
||||||
|
]
|
||||||
|
audio_paths = [
|
||||||
|
"utterance_0.wav",
|
||||||
|
"utterance_1.wav",
|
||||||
|
"utterance_2.wav",
|
||||||
|
"utterance_3.wav",
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_audio(audio_path):
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
||||||
|
)
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
segments = [
|
||||||
|
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
||||||
|
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
||||||
|
]
|
||||||
|
audio = generator.generate(
|
||||||
|
text="Me too, this is some cool stuff huh?",
|
||||||
|
speaker=1,
|
||||||
|
context=segments,
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
|
```
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
**Does this model come with any voices?**
|
||||||
|
|
||||||
|
The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
|
||||||
|
|
||||||
|
**Can I converse with the model?**
|
||||||
|
|
||||||
|
CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
|
||||||
|
|
||||||
|
**Does it support other languages?**
|
||||||
|
|
||||||
|
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
|
||||||
|
|
||||||
|
## Misuse and abuse ⚠️
|
||||||
|
|
||||||
|
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
|
||||||
|
|
||||||
|
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
|
||||||
|
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
|
||||||
|
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
|
||||||
|
|
||||||
|
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Authors
|
||||||
|
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.
|
||||||
176
Backend/generator.py
Normal file
176
Backend/generator.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from models import Model
|
||||||
|
from moshi.models import loaders
|
||||||
|
from tokenizers.processors import TemplateProcessing
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
|
speaker: int
|
||||||
|
text: str
|
||||||
|
# (num_samples,), sample_rate = 24_000
|
||||||
|
audio: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def load_llama3_tokenizer():
|
||||||
|
"""
|
||||||
|
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
||||||
|
"""
|
||||||
|
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
bos = tokenizer.bos_token
|
||||||
|
eos = tokenizer.eos_token
|
||||||
|
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
||||||
|
single=f"{bos}:0 $A:0 {eos}:0",
|
||||||
|
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
||||||
|
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Generator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
):
|
||||||
|
self._model = model
|
||||||
|
self._model.setup_caches(1)
|
||||||
|
|
||||||
|
self._text_tokenizer = load_llama3_tokenizer()
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
||||||
|
mimi = loaders.get_mimi(mimi_weight, device=device)
|
||||||
|
mimi.set_num_codebooks(32)
|
||||||
|
self._audio_tokenizer = mimi
|
||||||
|
|
||||||
|
self._watermarker = load_watermarker(device=device)
|
||||||
|
|
||||||
|
self.sample_rate = mimi.sample_rate
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
frame_tokens = []
|
||||||
|
frame_masks = []
|
||||||
|
|
||||||
|
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
||||||
|
text_frame = torch.zeros(len(text_tokens), 33).long()
|
||||||
|
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
||||||
|
text_frame[:, -1] = torch.tensor(text_tokens)
|
||||||
|
text_frame_mask[:, -1] = True
|
||||||
|
|
||||||
|
frame_tokens.append(text_frame.to(self.device))
|
||||||
|
frame_masks.append(text_frame_mask.to(self.device))
|
||||||
|
|
||||||
|
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||||
|
|
||||||
|
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert audio.ndim == 1, "Audio must be single channel"
|
||||||
|
|
||||||
|
frame_tokens = []
|
||||||
|
frame_masks = []
|
||||||
|
|
||||||
|
# (K, T)
|
||||||
|
audio = audio.to(self.device)
|
||||||
|
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
||||||
|
# add EOS frame
|
||||||
|
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
||||||
|
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
||||||
|
|
||||||
|
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
||||||
|
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
||||||
|
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
||||||
|
audio_frame_mask[:, :-1] = True
|
||||||
|
|
||||||
|
frame_tokens.append(audio_frame)
|
||||||
|
frame_masks.append(audio_frame_mask)
|
||||||
|
|
||||||
|
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||||
|
|
||||||
|
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
(seq_len, 33), (seq_len, 33)
|
||||||
|
"""
|
||||||
|
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
||||||
|
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
||||||
|
|
||||||
|
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
speaker: int,
|
||||||
|
context: List[Segment],
|
||||||
|
max_audio_length_ms: float = 90_000,
|
||||||
|
temperature: float = 0.9,
|
||||||
|
topk: int = 50,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self._model.reset_caches()
|
||||||
|
|
||||||
|
max_generation_len = int(max_audio_length_ms / 80)
|
||||||
|
tokens, tokens_mask = [], []
|
||||||
|
for segment in context:
|
||||||
|
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
||||||
|
tokens.append(segment_tokens)
|
||||||
|
tokens_mask.append(segment_tokens_mask)
|
||||||
|
|
||||||
|
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
|
||||||
|
tokens.append(gen_segment_tokens)
|
||||||
|
tokens_mask.append(gen_segment_tokens_mask)
|
||||||
|
|
||||||
|
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
||||||
|
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
curr_tokens = prompt_tokens.unsqueeze(0)
|
||||||
|
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
||||||
|
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
||||||
|
|
||||||
|
max_seq_len = 2048
|
||||||
|
max_context_len = max_seq_len - max_generation_len
|
||||||
|
if curr_tokens.size(1) >= max_context_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(max_generation_len):
|
||||||
|
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
||||||
|
if torch.all(sample == 0):
|
||||||
|
break # eos
|
||||||
|
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
|
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
||||||
|
curr_tokens_mask = torch.cat(
|
||||||
|
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
||||||
|
).unsqueeze(1)
|
||||||
|
curr_pos = curr_pos[:, -1:] + 1
|
||||||
|
|
||||||
|
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
||||||
|
|
||||||
|
# This applies an imperceptible watermark to identify audio as AI-generated.
|
||||||
|
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
||||||
|
# Please be a responsible AI citizen and keep the watermarking in place.
|
||||||
|
# If using CSM 1B in another application, use your own private key and keep it secret.
|
||||||
|
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
||||||
|
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def load_csm_1b(device: str = "cuda") -> Generator:
|
||||||
|
model = Model.from_pretrained("sesame/csm-1b")
|
||||||
|
model.to(device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
generator = Generator(model)
|
||||||
|
return generator
|
||||||
304
Backend/index.html
Normal file
304
Backend/index.html
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
/index.html
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Sesame AI Voice Chat</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
.conversation {
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 15px;
|
||||||
|
height: 300px;
|
||||||
|
overflow-y: auto;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.message {
|
||||||
|
margin-bottom: 10px;
|
||||||
|
padding: 8px;
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
.user {
|
||||||
|
background-color: #e3f2fd;
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
.ai {
|
||||||
|
background-color: #f1f1f1;
|
||||||
|
}
|
||||||
|
.controls {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
.input-row {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
input[type="text"] {
|
||||||
|
flex-grow: 1;
|
||||||
|
padding: 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
padding: 8px 16px;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: none;
|
||||||
|
background-color: #4CAF50;
|
||||||
|
color: white;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background-color: #45a049;
|
||||||
|
}
|
||||||
|
.recording {
|
||||||
|
background-color: #f44336;
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
padding: 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Sesame AI Voice Chat</h1>
|
||||||
|
<div class="conversation" id="conversation"></div>
|
||||||
|
|
||||||
|
<div class="controls">
|
||||||
|
<div class="input-row">
|
||||||
|
<input type="text" id="textInput" placeholder="Type your message...">
|
||||||
|
<select id="speakerSelect">
|
||||||
|
<option value="0">Speaker 0</option>
|
||||||
|
<option value="1">Speaker 1</option>
|
||||||
|
</select>
|
||||||
|
<button id="sendText">Send</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="input-row">
|
||||||
|
<button id="recordAudio">Record Audio</button>
|
||||||
|
<button id="clearContext">Clear Context</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let ws;
|
||||||
|
let mediaRecorder;
|
||||||
|
let audioChunks = [];
|
||||||
|
let isRecording = false;
|
||||||
|
|
||||||
|
// DOM elements
|
||||||
|
const conversationEl = document.getElementById('conversation');
|
||||||
|
const textInputEl = document.getElementById('textInput');
|
||||||
|
const speakerSelectEl = document.getElementById('speakerSelect');
|
||||||
|
const sendTextBtn = document.getElementById('sendText');
|
||||||
|
const recordAudioBtn = document.getElementById('recordAudio');
|
||||||
|
const clearContextBtn = document.getElementById('clearContext');
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
function connectWebSocket() {
|
||||||
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||||
|
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
|
||||||
|
|
||||||
|
ws = new WebSocket(wsUrl);
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
console.log('WebSocket connected');
|
||||||
|
addSystemMessage('Connected to server');
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
const response = JSON.parse(event.data);
|
||||||
|
console.log('Received:', response);
|
||||||
|
|
||||||
|
if (response.type === 'audio_response') {
|
||||||
|
// Play audio response
|
||||||
|
const audio = new Audio(response.audio);
|
||||||
|
audio.play();
|
||||||
|
|
||||||
|
// Add message to conversation
|
||||||
|
addAIMessage(response.audio);
|
||||||
|
} else if (response.type === 'error') {
|
||||||
|
addSystemMessage(`Error: ${response.message}`);
|
||||||
|
} else if (response.type === 'context_updated') {
|
||||||
|
addSystemMessage(response.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
console.log('WebSocket disconnected');
|
||||||
|
addSystemMessage('Disconnected from server. Reconnecting...');
|
||||||
|
setTimeout(connectWebSocket, 3000);
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
console.error('WebSocket error:', error);
|
||||||
|
addSystemMessage('Connection error');
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add message to conversation
|
||||||
|
function addUserMessage(text) {
|
||||||
|
const messageEl = document.createElement('div');
|
||||||
|
messageEl.classList.add('message', 'user');
|
||||||
|
messageEl.textContent = text;
|
||||||
|
conversationEl.appendChild(messageEl);
|
||||||
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function addAIMessage(audioSrc) {
|
||||||
|
const messageEl = document.createElement('div');
|
||||||
|
messageEl.classList.add('message', 'ai');
|
||||||
|
|
||||||
|
const audioEl = document.createElement('audio');
|
||||||
|
audioEl.controls = true;
|
||||||
|
audioEl.src = audioSrc;
|
||||||
|
|
||||||
|
messageEl.appendChild(audioEl);
|
||||||
|
conversationEl.appendChild(messageEl);
|
||||||
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function addSystemMessage(text) {
|
||||||
|
const messageEl = document.createElement('div');
|
||||||
|
messageEl.classList.add('message');
|
||||||
|
messageEl.textContent = text;
|
||||||
|
conversationEl.appendChild(messageEl);
|
||||||
|
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send text for audio generation
|
||||||
|
function sendTextForGeneration() {
|
||||||
|
const text = textInputEl.value.trim();
|
||||||
|
const speaker = parseInt(speakerSelectEl.value);
|
||||||
|
|
||||||
|
if (!text) return;
|
||||||
|
|
||||||
|
addUserMessage(text);
|
||||||
|
textInputEl.value = '';
|
||||||
|
|
||||||
|
const request = {
|
||||||
|
action: 'generate',
|
||||||
|
text: text,
|
||||||
|
speaker: speaker
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.send(JSON.stringify(request));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio recording functions
|
||||||
|
async function setupRecording() {
|
||||||
|
try {
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
|
||||||
|
mediaRecorder = new MediaRecorder(stream);
|
||||||
|
|
||||||
|
mediaRecorder.ondataavailable = (event) => {
|
||||||
|
if (event.data.size > 0) {
|
||||||
|
audioChunks.push(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mediaRecorder.onstop = async () => {
|
||||||
|
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
||||||
|
const audioUrl = URL.createObjectURL(audioBlob);
|
||||||
|
|
||||||
|
// Add audio to conversation
|
||||||
|
addUserMessage('Recorded audio:');
|
||||||
|
const messageEl = document.createElement('div');
|
||||||
|
messageEl.classList.add('message', 'user');
|
||||||
|
|
||||||
|
const audioEl = document.createElement('audio');
|
||||||
|
audioEl.controls = true;
|
||||||
|
audioEl.src = audioUrl;
|
||||||
|
|
||||||
|
messageEl.appendChild(audioEl);
|
||||||
|
conversationEl.appendChild(messageEl);
|
||||||
|
|
||||||
|
// Convert to base64
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.readAsDataURL(audioBlob);
|
||||||
|
reader.onloadend = () => {
|
||||||
|
const base64Audio = reader.result;
|
||||||
|
const text = textInputEl.value.trim() || "Recorded audio";
|
||||||
|
const speaker = parseInt(speakerSelectEl.value);
|
||||||
|
|
||||||
|
// Send to server
|
||||||
|
const request = {
|
||||||
|
action: 'add_to_context',
|
||||||
|
text: text,
|
||||||
|
speaker: speaker,
|
||||||
|
audio: base64Audio
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.send(JSON.stringify(request));
|
||||||
|
textInputEl.value = '';
|
||||||
|
};
|
||||||
|
|
||||||
|
audioChunks = [];
|
||||||
|
recordAudioBtn.textContent = 'Record Audio';
|
||||||
|
recordAudioBtn.classList.remove('recording');
|
||||||
|
};
|
||||||
|
|
||||||
|
console.log('Recording setup completed');
|
||||||
|
return true;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error setting up recording:', err);
|
||||||
|
addSystemMessage(`Microphone access error: ${err.message}`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleRecording() {
|
||||||
|
if (isRecording) {
|
||||||
|
mediaRecorder.stop();
|
||||||
|
isRecording = false;
|
||||||
|
} else {
|
||||||
|
if (!mediaRecorder) {
|
||||||
|
setupRecording().then(success => {
|
||||||
|
if (success) startRecording();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
startRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function startRecording() {
|
||||||
|
audioChunks = [];
|
||||||
|
mediaRecorder.start();
|
||||||
|
isRecording = true;
|
||||||
|
recordAudioBtn.textContent = 'Stop Recording';
|
||||||
|
recordAudioBtn.classList.add('recording');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event listeners
|
||||||
|
sendTextBtn.addEventListener('click', sendTextForGeneration);
|
||||||
|
|
||||||
|
textInputEl.addEventListener('keypress', (e) => {
|
||||||
|
if (e.key === 'Enter') sendTextForGeneration();
|
||||||
|
});
|
||||||
|
|
||||||
|
recordAudioBtn.addEventListener('click', toggleRecording);
|
||||||
|
|
||||||
|
clearContextBtn.addEventListener('click', () => {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
action: 'clear_context'
|
||||||
|
}));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
window.addEventListener('load', () => {
|
||||||
|
connectWebSocket();
|
||||||
|
setupRecording();
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
203
Backend/models.py
Normal file
203
Backend/models.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchtune
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
from torchtune.models import llama3_2
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
|
||||||
|
return llama3_2.llama3_2(
|
||||||
|
vocab_size=128_256,
|
||||||
|
num_layers=16,
|
||||||
|
num_heads=32,
|
||||||
|
num_kv_heads=8,
|
||||||
|
embed_dim=2048,
|
||||||
|
max_seq_len=2048,
|
||||||
|
intermediate_dim=8192,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
rope_base=500_000,
|
||||||
|
scale_factor=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
|
||||||
|
return llama3_2.llama3_2(
|
||||||
|
vocab_size=128_256,
|
||||||
|
num_layers=4,
|
||||||
|
num_heads=8,
|
||||||
|
num_kv_heads=2,
|
||||||
|
embed_dim=1024,
|
||||||
|
max_seq_len=2048,
|
||||||
|
intermediate_dim=8192,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
rope_base=500_000,
|
||||||
|
scale_factor=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAVORS = {
|
||||||
|
"llama-1B": llama3_2_1B,
|
||||||
|
"llama-100M": llama3_2_100M,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_transformer(model):
|
||||||
|
embed_dim = model.tok_embeddings.embedding_dim
|
||||||
|
model.tok_embeddings = nn.Identity()
|
||||||
|
model.output = nn.Identity()
|
||||||
|
return model, embed_dim
|
||||||
|
|
||||||
|
|
||||||
|
def _create_causal_mask(seq_len: int, device: torch.device):
|
||||||
|
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
||||||
|
|
||||||
|
|
||||||
|
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mask: (max_seq_len, max_seq_len)
|
||||||
|
input_pos: (batch_size, seq_len)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(batch_size, seq_len, max_seq_len)
|
||||||
|
"""
|
||||||
|
r = mask[input_pos, :]
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
|
||||||
|
q = torch.empty_like(probs).exponential_(1)
|
||||||
|
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
filter_value: float = -float("Inf")
|
||||||
|
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
||||||
|
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
||||||
|
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
||||||
|
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
||||||
|
|
||||||
|
sample_token = _multinomial_sample_one_no_sync(probs)
|
||||||
|
return sample_token
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
backbone_flavor: str
|
||||||
|
decoder_flavor: str
|
||||||
|
text_vocab_size: int
|
||||||
|
audio_vocab_size: int
|
||||||
|
audio_num_codebooks: int
|
||||||
|
|
||||||
|
|
||||||
|
class Model(
|
||||||
|
nn.Module,
|
||||||
|
PyTorchModelHubMixin,
|
||||||
|
repo_url="https://github.com/SesameAILabs/csm",
|
||||||
|
pipeline_tag="text-to-speech",
|
||||||
|
license="apache-2.0",
|
||||||
|
):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
||||||
|
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
||||||
|
|
||||||
|
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
||||||
|
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
||||||
|
|
||||||
|
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
||||||
|
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
||||||
|
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
|
||||||
|
|
||||||
|
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
||||||
|
"""Setup KV caches and return a causal mask."""
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
|
with device:
|
||||||
|
self.backbone.setup_caches(max_batch_size, dtype)
|
||||||
|
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
||||||
|
|
||||||
|
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
||||||
|
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
||||||
|
|
||||||
|
def generate_frame(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
tokens_mask: torch.Tensor,
|
||||||
|
input_pos: torch.Tensor,
|
||||||
|
temperature: float,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
||||||
|
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
||||||
|
input_pos: (batch_size, seq_len) positions for each token
|
||||||
|
mask: (batch_size, seq_len, max_seq_len
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(batch_size, audio_num_codebooks) sampled tokens
|
||||||
|
"""
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
b, s, _ = tokens.size()
|
||||||
|
|
||||||
|
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
||||||
|
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
||||||
|
embeds = self._embed_tokens(tokens)
|
||||||
|
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
||||||
|
h = masked_embeds.sum(dim=2)
|
||||||
|
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
||||||
|
|
||||||
|
last_h = h[:, -1, :]
|
||||||
|
c0_logits = self.codebook0_head(last_h)
|
||||||
|
c0_sample = sample_topk(c0_logits, topk, temperature)
|
||||||
|
c0_embed = self._embed_audio(0, c0_sample)
|
||||||
|
|
||||||
|
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
||||||
|
curr_sample = c0_sample.clone()
|
||||||
|
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
||||||
|
|
||||||
|
# Decoder caches must be reset every frame.
|
||||||
|
self.decoder.reset_caches()
|
||||||
|
for i in range(1, self.config.audio_num_codebooks):
|
||||||
|
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
||||||
|
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
||||||
|
ci_sample = sample_topk(ci_logits, topk, temperature)
|
||||||
|
ci_embed = self._embed_audio(i, ci_sample)
|
||||||
|
|
||||||
|
curr_h = ci_embed
|
||||||
|
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
||||||
|
curr_pos = curr_pos[:, -1:] + 1
|
||||||
|
|
||||||
|
return curr_sample
|
||||||
|
|
||||||
|
def reset_caches(self):
|
||||||
|
self.backbone.reset_caches()
|
||||||
|
self.decoder.reset_caches()
|
||||||
|
|
||||||
|
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
||||||
|
|
||||||
|
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
||||||
|
|
||||||
|
audio_tokens = tokens[:, :, :-1] + (
|
||||||
|
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
||||||
|
)
|
||||||
|
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
||||||
|
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
||||||
9
Backend/requirements.txt
Normal file
9
Backend/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
torch==2.4.0
|
||||||
|
torchaudio==2.4.0
|
||||||
|
tokenizers==0.21.0
|
||||||
|
transformers==4.49.0
|
||||||
|
huggingface_hub==0.28.1
|
||||||
|
moshi==0.2.2
|
||||||
|
torchtune==0.4.0
|
||||||
|
torchao==0.9.0
|
||||||
|
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
||||||
117
Backend/run_csm.py
Normal file
117
Backend/run_csm.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Disable Triton compilation
|
||||||
|
os.environ["NO_TORCH_COMPILE"] = "1"
|
||||||
|
|
||||||
|
# Default prompts are available at https://hf.co/sesame/csm-1b
|
||||||
|
prompt_filepath_conversational_a = hf_hub_download(
|
||||||
|
repo_id="sesame/csm-1b",
|
||||||
|
filename="prompts/conversational_a.wav"
|
||||||
|
)
|
||||||
|
prompt_filepath_conversational_b = hf_hub_download(
|
||||||
|
repo_id="sesame/csm-1b",
|
||||||
|
filename="prompts/conversational_b.wav"
|
||||||
|
)
|
||||||
|
|
||||||
|
SPEAKER_PROMPTS = {
|
||||||
|
"conversational_a": {
|
||||||
|
"text": (
|
||||||
|
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
||||||
|
"start really early I'd be like okay I'm gonna start revising now and then like "
|
||||||
|
"you're revising for ages and then I just like start losing steam I didn't do that "
|
||||||
|
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
||||||
|
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
||||||
|
"sort of start the day with this not like a panic but like a"
|
||||||
|
),
|
||||||
|
"audio": prompt_filepath_conversational_a
|
||||||
|
},
|
||||||
|
"conversational_b": {
|
||||||
|
"text": (
|
||||||
|
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
||||||
|
"into the park, it just like, everything looks like a computer game and they have all "
|
||||||
|
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
||||||
|
"will have like a question block. And if you like, you know, punch it, a coin will "
|
||||||
|
"come out. So like everyone, when they come into the park, they get like this little "
|
||||||
|
"bracelet and then you can go punching question blocks around."
|
||||||
|
),
|
||||||
|
"audio": prompt_filepath_conversational_b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_tensor = audio_tensor.squeeze(0)
|
||||||
|
# Resample is lazy so we can always call it
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
||||||
|
)
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
||||||
|
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
||||||
|
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Select the best available device, skipping MPS due to float64 limitations
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
generator = load_csm_1b(device)
|
||||||
|
|
||||||
|
# Prepare prompts
|
||||||
|
prompt_a = prepare_prompt(
|
||||||
|
SPEAKER_PROMPTS["conversational_a"]["text"],
|
||||||
|
0,
|
||||||
|
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_b = prepare_prompt(
|
||||||
|
SPEAKER_PROMPTS["conversational_b"]["text"],
|
||||||
|
1,
|
||||||
|
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate conversation
|
||||||
|
conversation = [
|
||||||
|
{"text": "Hey how are you doing?", "speaker_id": 0},
|
||||||
|
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
||||||
|
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
||||||
|
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate each utterance
|
||||||
|
generated_segments = []
|
||||||
|
prompt_segments = [prompt_a, prompt_b]
|
||||||
|
|
||||||
|
for utterance in conversation:
|
||||||
|
print(f"Generating: {utterance['text']}")
|
||||||
|
audio_tensor = generator.generate(
|
||||||
|
text=utterance['text'],
|
||||||
|
speaker=utterance['speaker_id'],
|
||||||
|
context=prompt_segments + generated_segments,
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
||||||
|
|
||||||
|
# Concatenate all generations
|
||||||
|
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
||||||
|
torchaudio.save(
|
||||||
|
"full_conversation.wav",
|
||||||
|
all_audio.unsqueeze(0).cpu(),
|
||||||
|
generator.sample_rate
|
||||||
|
)
|
||||||
|
print("Successfully generated full_conversation.wav")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
156
Backend/server.py
Normal file
156
Backend/server.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Select device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Add CORS middleware to allow cross-origin requests
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # Allow all origins in development
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection manager to handle multiple clients
|
||||||
|
class ConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: List[WebSocket] = []
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections.append(websocket)
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket):
|
||||||
|
self.active_connections.remove(websocket)
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to convert audio data
|
||||||
|
async def decode_audio_data(audio_data: str) -> torch.Tensor:
|
||||||
|
"""Decode base64 audio data to a torch tensor"""
|
||||||
|
# Decode base64 audio data
|
||||||
|
binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
|
||||||
|
|
||||||
|
# Load audio from binary data
|
||||||
|
buf = BytesIO(binary_data)
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(buf)
|
||||||
|
|
||||||
|
# Resample if needed
|
||||||
|
if sample_rate != generator.sample_rate:
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor.squeeze(0),
|
||||||
|
orig_freq=sample_rate,
|
||||||
|
new_freq=generator.sample_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio_tensor = audio_tensor.squeeze(0)
|
||||||
|
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
|
||||||
|
async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||||
|
"""Encode torch tensor audio to base64 string"""
|
||||||
|
buf = BytesIO()
|
||||||
|
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
||||||
|
buf.seek(0)
|
||||||
|
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||||
|
return f"data:audio/wav;base64,{audio_base64}"
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await manager.connect(websocket)
|
||||||
|
context_segments = [] # Store conversation context
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Receive JSON data from client
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
request = json.loads(data)
|
||||||
|
|
||||||
|
action = request.get("action")
|
||||||
|
|
||||||
|
if action == "generate":
|
||||||
|
text = request.get("text", "")
|
||||||
|
speaker_id = request.get("speaker", 0)
|
||||||
|
|
||||||
|
# Generate audio response
|
||||||
|
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
||||||
|
audio_tensor = generator.generate(
|
||||||
|
text=text,
|
||||||
|
speaker=speaker_id,
|
||||||
|
context=context_segments,
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to conversation context
|
||||||
|
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
||||||
|
|
||||||
|
# Convert audio to base64 and send back to client
|
||||||
|
audio_base64 = await encode_audio_data(audio_tensor)
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "audio_response",
|
||||||
|
"audio": audio_base64
|
||||||
|
})
|
||||||
|
|
||||||
|
elif action == "add_to_context":
|
||||||
|
text = request.get("text", "")
|
||||||
|
speaker_id = request.get("speaker", 0)
|
||||||
|
audio_data = request.get("audio", "")
|
||||||
|
|
||||||
|
# Convert received audio to tensor
|
||||||
|
audio_tensor = await decode_audio_data(audio_data)
|
||||||
|
|
||||||
|
# Add to conversation context
|
||||||
|
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
||||||
|
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "context_updated",
|
||||||
|
"message": "Audio added to context"
|
||||||
|
})
|
||||||
|
|
||||||
|
elif action == "clear_context":
|
||||||
|
context_segments = []
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "context_updated",
|
||||||
|
"message": "Context cleared"
|
||||||
|
})
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
print("Client disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {str(e)}")
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e)
|
||||||
|
})
|
||||||
|
manager.disconnect(websocket)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="localhost", port=8000)
|
||||||
13
Backend/setup.py
Normal file
13
Backend/setup.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Read requirements from requirements.txt
|
||||||
|
with open('requirements.txt') as f:
|
||||||
|
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='csm',
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=requirements,
|
||||||
|
)
|
||||||
50
Backend/test.py
Normal file
50
Backend/test.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from generator import load_csm_1b, Segment
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
generator = load_csm_1b(device=device)
|
||||||
|
|
||||||
|
speakers = [0, 1, 0, 0]
|
||||||
|
transcripts = [
|
||||||
|
"Hey how are you doing.",
|
||||||
|
"Pretty good, pretty good.",
|
||||||
|
"I'm great.",
|
||||||
|
"So happy to be speaking to you.",
|
||||||
|
]
|
||||||
|
audio_paths = [
|
||||||
|
"utterance_0.wav",
|
||||||
|
"utterance_1.wav",
|
||||||
|
"utterance_2.wav",
|
||||||
|
"utterance_3.wav",
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_audio(audio_path):
|
||||||
|
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_tensor = torchaudio.functional.resample(
|
||||||
|
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
||||||
|
)
|
||||||
|
return audio_tensor
|
||||||
|
|
||||||
|
segments = [
|
||||||
|
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
||||||
|
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
audio = generator.generate(
|
||||||
|
text="Me too, this is some cool stuff huh?",
|
||||||
|
speaker=1,
|
||||||
|
context=segments,
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||||
79
Backend/watermarking.py
Normal file
79
Backend/watermarking.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import silentcipher
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
# This watermark key is public, it is not secure.
|
||||||
|
# If using CSM 1B in another application, use a new private key and keep it secret.
|
||||||
|
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
|
||||||
|
|
||||||
|
|
||||||
|
def cli_check_audio() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--audio_path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
check_audio_from_file(args.audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
|
||||||
|
model = silentcipher.get_model(
|
||||||
|
model_type="44.1k",
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def watermark(
|
||||||
|
watermarker: silentcipher.server.Model,
|
||||||
|
audio_array: torch.Tensor,
|
||||||
|
sample_rate: int,
|
||||||
|
watermark_key: list[int],
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
|
||||||
|
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
|
||||||
|
|
||||||
|
output_sample_rate = min(44100, sample_rate)
|
||||||
|
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
|
||||||
|
return encoded, output_sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def verify(
|
||||||
|
watermarker: silentcipher.server.Model,
|
||||||
|
watermarked_audio: torch.Tensor,
|
||||||
|
sample_rate: int,
|
||||||
|
watermark_key: list[int],
|
||||||
|
) -> bool:
|
||||||
|
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
|
||||||
|
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
|
||||||
|
|
||||||
|
is_watermarked = result["status"]
|
||||||
|
if is_watermarked:
|
||||||
|
is_csm_watermarked = result["messages"][0] == watermark_key
|
||||||
|
else:
|
||||||
|
is_csm_watermarked = False
|
||||||
|
|
||||||
|
return is_watermarked and is_csm_watermarked
|
||||||
|
|
||||||
|
|
||||||
|
def check_audio_from_file(audio_path: str) -> None:
|
||||||
|
watermarker = load_watermarker(device="cuda")
|
||||||
|
|
||||||
|
audio_array, sample_rate = load_audio(audio_path)
|
||||||
|
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
|
||||||
|
|
||||||
|
outcome = "Watermarked" if is_watermarked else "Not watermarked"
|
||||||
|
print(f"{outcome}: {audio_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
|
||||||
|
audio_array, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_array = audio_array.mean(dim=0)
|
||||||
|
return audio_array, int(sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_check_audio()
|
||||||
Reference in New Issue
Block a user