Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
46
Backend/.gitignore
vendored
Normal file
46
Backend/.gitignore
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Project specific
|
||||
.python-version
|
||||
*.wav
|
||||
output_*/
|
||||
basic_audio.wav
|
||||
full_conversation.wav
|
||||
context_audio.wav
|
||||
|
||||
# Model files
|
||||
*.pt
|
||||
*.ckpt
|
||||
154
Backend/README.md
Normal file
154
Backend/README.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# CSM
|
||||
|
||||
**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b).
|
||||
|
||||
---
|
||||
|
||||
CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
|
||||
|
||||
A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
|
||||
|
||||
A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation.
|
||||
|
||||
## Requirements
|
||||
|
||||
* A CUDA-compatible GPU
|
||||
* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
|
||||
* Similarly, Python 3.10 is recommended, but newer versions may be fine
|
||||
* For some audio operations, `ffmpeg` may be required
|
||||
* Access to the following Hugging Face models:
|
||||
* [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||||
* [CSM-1B](https://huggingface.co/sesame/csm-1b)
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
git clone git@github.com:SesameAILabs/csm.git
|
||||
cd csm
|
||||
python3.10 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Disable lazy compilation in Mimi
|
||||
export NO_TORCH_COMPILE=1
|
||||
|
||||
# You will need access to CSM-1B and Llama-3.2-1B
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
### Windows Setup
|
||||
|
||||
The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
|
||||
|
||||
## Quickstart
|
||||
|
||||
This script will generate a conversation between 2 characters, using a prompt for each character.
|
||||
|
||||
```bash
|
||||
python run_csm.py
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
If you want to write your own applications with CSM, the following examples show basic usage.
|
||||
|
||||
#### Generate a sentence
|
||||
|
||||
This will use a random speaker identity, as no prompt or context is provided.
|
||||
|
||||
```python
|
||||
from generator import load_csm_1b
|
||||
import torchaudio
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
generator = load_csm_1b(device=device)
|
||||
|
||||
audio = generator.generate(
|
||||
text="Hello from Sesame.",
|
||||
speaker=0,
|
||||
context=[],
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||
```
|
||||
|
||||
#### Generate with context
|
||||
|
||||
CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance.
|
||||
|
||||
NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM.
|
||||
|
||||
```python
|
||||
from generator import Segment
|
||||
|
||||
speakers = [0, 1, 0, 0]
|
||||
transcripts = [
|
||||
"Hey how are you doing.",
|
||||
"Pretty good, pretty good.",
|
||||
"I'm great.",
|
||||
"So happy to be speaking to you.",
|
||||
]
|
||||
audio_paths = [
|
||||
"utterance_0.wav",
|
||||
"utterance_1.wav",
|
||||
"utterance_2.wav",
|
||||
"utterance_3.wav",
|
||||
]
|
||||
|
||||
def load_audio(audio_path):
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
segments = [
|
||||
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
||||
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
||||
]
|
||||
audio = generator.generate(
|
||||
text="Me too, this is some cool stuff huh?",
|
||||
speaker=1,
|
||||
context=segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
**Does this model come with any voices?**
|
||||
|
||||
The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
|
||||
|
||||
**Can I converse with the model?**
|
||||
|
||||
CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
|
||||
|
||||
**Does it support other languages?**
|
||||
|
||||
The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
|
||||
|
||||
## Misuse and abuse ⚠️
|
||||
|
||||
This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
|
||||
|
||||
- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
|
||||
- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
|
||||
- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
|
||||
|
||||
By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
|
||||
|
||||
---
|
||||
|
||||
## Authors
|
||||
Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.
|
||||
176
Backend/generator.py
Normal file
176
Backend/generator.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from models import Model
|
||||
from moshi.models import loaders
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from transformers import AutoTokenizer
|
||||
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
speaker: int
|
||||
text: str
|
||||
# (num_samples,), sample_rate = 24_000
|
||||
audio: torch.Tensor
|
||||
|
||||
|
||||
def load_llama3_tokenizer():
|
||||
"""
|
||||
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
||||
"""
|
||||
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
bos = tokenizer.bos_token
|
||||
eos = tokenizer.eos_token
|
||||
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
||||
single=f"{bos}:0 $A:0 {eos}:0",
|
||||
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
||||
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
):
|
||||
self._model = model
|
||||
self._model.setup_caches(1)
|
||||
|
||||
self._text_tokenizer = load_llama3_tokenizer()
|
||||
|
||||
device = next(model.parameters()).device
|
||||
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
||||
mimi = loaders.get_mimi(mimi_weight, device=device)
|
||||
mimi.set_num_codebooks(32)
|
||||
self._audio_tokenizer = mimi
|
||||
|
||||
self._watermarker = load_watermarker(device=device)
|
||||
|
||||
self.sample_rate = mimi.sample_rate
|
||||
self.device = device
|
||||
|
||||
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
frame_tokens = []
|
||||
frame_masks = []
|
||||
|
||||
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
||||
text_frame = torch.zeros(len(text_tokens), 33).long()
|
||||
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
||||
text_frame[:, -1] = torch.tensor(text_tokens)
|
||||
text_frame_mask[:, -1] = True
|
||||
|
||||
frame_tokens.append(text_frame.to(self.device))
|
||||
frame_masks.append(text_frame_mask.to(self.device))
|
||||
|
||||
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||
|
||||
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert audio.ndim == 1, "Audio must be single channel"
|
||||
|
||||
frame_tokens = []
|
||||
frame_masks = []
|
||||
|
||||
# (K, T)
|
||||
audio = audio.to(self.device)
|
||||
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
||||
# add EOS frame
|
||||
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
||||
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
||||
|
||||
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
||||
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
||||
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
||||
audio_frame_mask[:, :-1] = True
|
||||
|
||||
frame_tokens.append(audio_frame)
|
||||
frame_masks.append(audio_frame_mask)
|
||||
|
||||
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||
|
||||
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(seq_len, 33), (seq_len, 33)
|
||||
"""
|
||||
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
||||
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
||||
|
||||
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
speaker: int,
|
||||
context: List[Segment],
|
||||
max_audio_length_ms: float = 90_000,
|
||||
temperature: float = 0.9,
|
||||
topk: int = 50,
|
||||
) -> torch.Tensor:
|
||||
self._model.reset_caches()
|
||||
|
||||
max_generation_len = int(max_audio_length_ms / 80)
|
||||
tokens, tokens_mask = [], []
|
||||
for segment in context:
|
||||
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
||||
tokens.append(segment_tokens)
|
||||
tokens_mask.append(segment_tokens_mask)
|
||||
|
||||
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
|
||||
tokens.append(gen_segment_tokens)
|
||||
tokens_mask.append(gen_segment_tokens_mask)
|
||||
|
||||
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
||||
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
||||
|
||||
samples = []
|
||||
curr_tokens = prompt_tokens.unsqueeze(0)
|
||||
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
||||
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
||||
|
||||
max_seq_len = 2048
|
||||
max_context_len = max_seq_len - max_generation_len
|
||||
if curr_tokens.size(1) >= max_context_len:
|
||||
raise ValueError(
|
||||
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
|
||||
)
|
||||
|
||||
for _ in range(max_generation_len):
|
||||
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
||||
if torch.all(sample == 0):
|
||||
break # eos
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
||||
curr_tokens_mask = torch.cat(
|
||||
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
||||
).unsqueeze(1)
|
||||
curr_pos = curr_pos[:, -1:] + 1
|
||||
|
||||
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
||||
|
||||
# This applies an imperceptible watermark to identify audio as AI-generated.
|
||||
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
||||
# Please be a responsible AI citizen and keep the watermarking in place.
|
||||
# If using CSM 1B in another application, use your own private key and keep it secret.
|
||||
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
||||
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def load_csm_1b(device: str = "cuda") -> Generator:
|
||||
model = Model.from_pretrained("sesame/csm-1b")
|
||||
model.to(device=device, dtype=torch.bfloat16)
|
||||
|
||||
generator = Generator(model)
|
||||
return generator
|
||||
304
Backend/index.html
Normal file
304
Backend/index.html
Normal file
@@ -0,0 +1,304 @@
|
||||
/index.html
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Sesame AI Voice Chat</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}
|
||||
.conversation {
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.message {
|
||||
margin-bottom: 10px;
|
||||
padding: 8px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
.user {
|
||||
background-color: #e3f2fd;
|
||||
text-align: right;
|
||||
}
|
||||
.ai {
|
||||
background-color: #f1f1f1;
|
||||
}
|
||||
.controls {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
.input-row {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
}
|
||||
input[type="text"] {
|
||||
flex-grow: 1;
|
||||
padding: 8px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #ccc;
|
||||
}
|
||||
button {
|
||||
padding: 8px 16px;
|
||||
border-radius: 4px;
|
||||
border: none;
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
cursor: pointer;
|
||||
}
|
||||
button:hover {
|
||||
background-color: #45a049;
|
||||
}
|
||||
.recording {
|
||||
background-color: #f44336;
|
||||
}
|
||||
select {
|
||||
padding: 8px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #ccc;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Sesame AI Voice Chat</h1>
|
||||
<div class="conversation" id="conversation"></div>
|
||||
|
||||
<div class="controls">
|
||||
<div class="input-row">
|
||||
<input type="text" id="textInput" placeholder="Type your message...">
|
||||
<select id="speakerSelect">
|
||||
<option value="0">Speaker 0</option>
|
||||
<option value="1">Speaker 1</option>
|
||||
</select>
|
||||
<button id="sendText">Send</button>
|
||||
</div>
|
||||
|
||||
<div class="input-row">
|
||||
<button id="recordAudio">Record Audio</button>
|
||||
<button id="clearContext">Clear Context</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
let mediaRecorder;
|
||||
let audioChunks = [];
|
||||
let isRecording = false;
|
||||
|
||||
// DOM elements
|
||||
const conversationEl = document.getElementById('conversation');
|
||||
const textInputEl = document.getElementById('textInput');
|
||||
const speakerSelectEl = document.getElementById('speakerSelect');
|
||||
const sendTextBtn = document.getElementById('sendText');
|
||||
const recordAudioBtn = document.getElementById('recordAudio');
|
||||
const clearContextBtn = document.getElementById('clearContext');
|
||||
|
||||
// Connect to WebSocket
|
||||
function connectWebSocket() {
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${wsProtocol}//${window.location.hostname}:8000/ws`;
|
||||
|
||||
ws = new WebSocket(wsUrl);
|
||||
|
||||
ws.onopen = () => {
|
||||
console.log('WebSocket connected');
|
||||
addSystemMessage('Connected to server');
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const response = JSON.parse(event.data);
|
||||
console.log('Received:', response);
|
||||
|
||||
if (response.type === 'audio_response') {
|
||||
// Play audio response
|
||||
const audio = new Audio(response.audio);
|
||||
audio.play();
|
||||
|
||||
// Add message to conversation
|
||||
addAIMessage(response.audio);
|
||||
} else if (response.type === 'error') {
|
||||
addSystemMessage(`Error: ${response.message}`);
|
||||
} else if (response.type === 'context_updated') {
|
||||
addSystemMessage(response.message);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
console.log('WebSocket disconnected');
|
||||
addSystemMessage('Disconnected from server. Reconnecting...');
|
||||
setTimeout(connectWebSocket, 3000);
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
addSystemMessage('Connection error');
|
||||
};
|
||||
}
|
||||
|
||||
// Add message to conversation
|
||||
function addUserMessage(text) {
|
||||
const messageEl = document.createElement('div');
|
||||
messageEl.classList.add('message', 'user');
|
||||
messageEl.textContent = text;
|
||||
conversationEl.appendChild(messageEl);
|
||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||
}
|
||||
|
||||
function addAIMessage(audioSrc) {
|
||||
const messageEl = document.createElement('div');
|
||||
messageEl.classList.add('message', 'ai');
|
||||
|
||||
const audioEl = document.createElement('audio');
|
||||
audioEl.controls = true;
|
||||
audioEl.src = audioSrc;
|
||||
|
||||
messageEl.appendChild(audioEl);
|
||||
conversationEl.appendChild(messageEl);
|
||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||
}
|
||||
|
||||
function addSystemMessage(text) {
|
||||
const messageEl = document.createElement('div');
|
||||
messageEl.classList.add('message');
|
||||
messageEl.textContent = text;
|
||||
conversationEl.appendChild(messageEl);
|
||||
conversationEl.scrollTop = conversationEl.scrollHeight;
|
||||
}
|
||||
|
||||
// Send text for audio generation
|
||||
function sendTextForGeneration() {
|
||||
const text = textInputEl.value.trim();
|
||||
const speaker = parseInt(speakerSelectEl.value);
|
||||
|
||||
if (!text) return;
|
||||
|
||||
addUserMessage(text);
|
||||
textInputEl.value = '';
|
||||
|
||||
const request = {
|
||||
action: 'generate',
|
||||
text: text,
|
||||
speaker: speaker
|
||||
};
|
||||
|
||||
ws.send(JSON.stringify(request));
|
||||
}
|
||||
|
||||
// Audio recording functions
|
||||
async function setupRecording() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
|
||||
mediaRecorder = new MediaRecorder(stream);
|
||||
|
||||
mediaRecorder.ondataavailable = (event) => {
|
||||
if (event.data.size > 0) {
|
||||
audioChunks.push(event.data);
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = async () => {
|
||||
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
||||
const audioUrl = URL.createObjectURL(audioBlob);
|
||||
|
||||
// Add audio to conversation
|
||||
addUserMessage('Recorded audio:');
|
||||
const messageEl = document.createElement('div');
|
||||
messageEl.classList.add('message', 'user');
|
||||
|
||||
const audioEl = document.createElement('audio');
|
||||
audioEl.controls = true;
|
||||
audioEl.src = audioUrl;
|
||||
|
||||
messageEl.appendChild(audioEl);
|
||||
conversationEl.appendChild(messageEl);
|
||||
|
||||
// Convert to base64
|
||||
const reader = new FileReader();
|
||||
reader.readAsDataURL(audioBlob);
|
||||
reader.onloadend = () => {
|
||||
const base64Audio = reader.result;
|
||||
const text = textInputEl.value.trim() || "Recorded audio";
|
||||
const speaker = parseInt(speakerSelectEl.value);
|
||||
|
||||
// Send to server
|
||||
const request = {
|
||||
action: 'add_to_context',
|
||||
text: text,
|
||||
speaker: speaker,
|
||||
audio: base64Audio
|
||||
};
|
||||
|
||||
ws.send(JSON.stringify(request));
|
||||
textInputEl.value = '';
|
||||
};
|
||||
|
||||
audioChunks = [];
|
||||
recordAudioBtn.textContent = 'Record Audio';
|
||||
recordAudioBtn.classList.remove('recording');
|
||||
};
|
||||
|
||||
console.log('Recording setup completed');
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.error('Error setting up recording:', err);
|
||||
addSystemMessage(`Microphone access error: ${err.message}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function toggleRecording() {
|
||||
if (isRecording) {
|
||||
mediaRecorder.stop();
|
||||
isRecording = false;
|
||||
} else {
|
||||
if (!mediaRecorder) {
|
||||
setupRecording().then(success => {
|
||||
if (success) startRecording();
|
||||
});
|
||||
} else {
|
||||
startRecording();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function startRecording() {
|
||||
audioChunks = [];
|
||||
mediaRecorder.start();
|
||||
isRecording = true;
|
||||
recordAudioBtn.textContent = 'Stop Recording';
|
||||
recordAudioBtn.classList.add('recording');
|
||||
}
|
||||
|
||||
// Event listeners
|
||||
sendTextBtn.addEventListener('click', sendTextForGeneration);
|
||||
|
||||
textInputEl.addEventListener('keypress', (e) => {
|
||||
if (e.key === 'Enter') sendTextForGeneration();
|
||||
});
|
||||
|
||||
recordAudioBtn.addEventListener('click', toggleRecording);
|
||||
|
||||
clearContextBtn.addEventListener('click', () => {
|
||||
ws.send(JSON.stringify({
|
||||
action: 'clear_context'
|
||||
}));
|
||||
});
|
||||
|
||||
// Initialize
|
||||
window.addEventListener('load', () => {
|
||||
connectWebSocket();
|
||||
setupRecording();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
203
Backend/models.py
Normal file
203
Backend/models.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchtune
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torchtune.models import llama3_2
|
||||
|
||||
|
||||
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=16,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
embed_dim=2048,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=4,
|
||||
num_heads=8,
|
||||
num_kv_heads=2,
|
||||
embed_dim=1024,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
FLAVORS = {
|
||||
"llama-1B": llama3_2_1B,
|
||||
"llama-100M": llama3_2_100M,
|
||||
}
|
||||
|
||||
|
||||
def _prepare_transformer(model):
|
||||
embed_dim = model.tok_embeddings.embedding_dim
|
||||
model.tok_embeddings = nn.Identity()
|
||||
model.output = nn.Identity()
|
||||
return model, embed_dim
|
||||
|
||||
|
||||
def _create_causal_mask(seq_len: int, device: torch.device):
|
||||
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
||||
|
||||
|
||||
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
mask: (max_seq_len, max_seq_len)
|
||||
input_pos: (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
(batch_size, seq_len, max_seq_len)
|
||||
"""
|
||||
r = mask[input_pos, :]
|
||||
return r
|
||||
|
||||
|
||||
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs).exponential_(1)
|
||||
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
||||
logits = logits / temperature
|
||||
|
||||
filter_value: float = -float("Inf")
|
||||
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
||||
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
||||
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
||||
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
||||
|
||||
sample_token = _multinomial_sample_one_no_sync(probs)
|
||||
return sample_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
backbone_flavor: str
|
||||
decoder_flavor: str
|
||||
text_vocab_size: int
|
||||
audio_vocab_size: int
|
||||
audio_num_codebooks: int
|
||||
|
||||
|
||||
class Model(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/SesameAILabs/csm",
|
||||
pipeline_tag="text-to-speech",
|
||||
license="apache-2.0",
|
||||
):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
||||
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
||||
|
||||
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
||||
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
||||
|
||||
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
||||
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
||||
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
|
||||
|
||||
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
||||
"""Setup KV caches and return a causal mask."""
|
||||
dtype = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
|
||||
with device:
|
||||
self.backbone.setup_caches(max_batch_size, dtype)
|
||||
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
||||
|
||||
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
||||
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
||||
|
||||
def generate_frame(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
tokens_mask: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
temperature: float,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
input_pos: (batch_size, seq_len) positions for each token
|
||||
mask: (batch_size, seq_len, max_seq_len
|
||||
|
||||
Returns:
|
||||
(batch_size, audio_num_codebooks) sampled tokens
|
||||
"""
|
||||
dtype = next(self.parameters()).dtype
|
||||
b, s, _ = tokens.size()
|
||||
|
||||
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
||||
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
||||
embeds = self._embed_tokens(tokens)
|
||||
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
||||
h = masked_embeds.sum(dim=2)
|
||||
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
||||
|
||||
last_h = h[:, -1, :]
|
||||
c0_logits = self.codebook0_head(last_h)
|
||||
c0_sample = sample_topk(c0_logits, topk, temperature)
|
||||
c0_embed = self._embed_audio(0, c0_sample)
|
||||
|
||||
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
||||
curr_sample = c0_sample.clone()
|
||||
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
||||
|
||||
# Decoder caches must be reset every frame.
|
||||
self.decoder.reset_caches()
|
||||
for i in range(1, self.config.audio_num_codebooks):
|
||||
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
||||
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
||||
dtype=dtype
|
||||
)
|
||||
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
||||
ci_sample = sample_topk(ci_logits, topk, temperature)
|
||||
ci_embed = self._embed_audio(i, ci_sample)
|
||||
|
||||
curr_h = ci_embed
|
||||
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
||||
curr_pos = curr_pos[:, -1:] + 1
|
||||
|
||||
return curr_sample
|
||||
|
||||
def reset_caches(self):
|
||||
self.backbone.reset_caches()
|
||||
self.decoder.reset_caches()
|
||||
|
||||
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
||||
|
||||
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
||||
|
||||
audio_tokens = tokens[:, :, :-1] + (
|
||||
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
||||
)
|
||||
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
||||
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
||||
)
|
||||
|
||||
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
||||
9
Backend/requirements.txt
Normal file
9
Backend/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.4.0
|
||||
torchaudio==2.4.0
|
||||
tokenizers==0.21.0
|
||||
transformers==4.49.0
|
||||
huggingface_hub==0.28.1
|
||||
moshi==0.2.2
|
||||
torchtune==0.4.0
|
||||
torchao==0.9.0
|
||||
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
||||
117
Backend/run_csm.py
Normal file
117
Backend/run_csm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Disable Triton compilation
|
||||
os.environ["NO_TORCH_COMPILE"] = "1"
|
||||
|
||||
# Default prompts are available at https://hf.co/sesame/csm-1b
|
||||
prompt_filepath_conversational_a = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_a.wav"
|
||||
)
|
||||
prompt_filepath_conversational_b = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_b.wav"
|
||||
)
|
||||
|
||||
SPEAKER_PROMPTS = {
|
||||
"conversational_a": {
|
||||
"text": (
|
||||
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
||||
"start really early I'd be like okay I'm gonna start revising now and then like "
|
||||
"you're revising for ages and then I just like start losing steam I didn't do that "
|
||||
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
||||
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
||||
"sort of start the day with this not like a panic but like a"
|
||||
),
|
||||
"audio": prompt_filepath_conversational_a
|
||||
},
|
||||
"conversational_b": {
|
||||
"text": (
|
||||
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
||||
"into the park, it just like, everything looks like a computer game and they have all "
|
||||
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
||||
"will have like a question block. And if you like, you know, punch it, a coin will "
|
||||
"come out. So like everyone, when they come into the park, they get like this little "
|
||||
"bracelet and then you can go punching question blocks around."
|
||||
),
|
||||
"audio": prompt_filepath_conversational_b
|
||||
}
|
||||
}
|
||||
|
||||
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
# Resample is lazy so we can always call it
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
||||
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
||||
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
||||
|
||||
def main():
|
||||
# Select the best available device, skipping MPS due to float64 limitations
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
generator = load_csm_1b(device)
|
||||
|
||||
# Prepare prompts
|
||||
prompt_a = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_a"]["text"],
|
||||
0,
|
||||
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
prompt_b = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_b"]["text"],
|
||||
1,
|
||||
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
# Generate conversation
|
||||
conversation = [
|
||||
{"text": "Hey how are you doing?", "speaker_id": 0},
|
||||
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
||||
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
||||
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
||||
]
|
||||
|
||||
# Generate each utterance
|
||||
generated_segments = []
|
||||
prompt_segments = [prompt_a, prompt_b]
|
||||
|
||||
for utterance in conversation:
|
||||
print(f"Generating: {utterance['text']}")
|
||||
audio_tensor = generator.generate(
|
||||
text=utterance['text'],
|
||||
speaker=utterance['speaker_id'],
|
||||
context=prompt_segments + generated_segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
||||
|
||||
# Concatenate all generations
|
||||
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
||||
torchaudio.save(
|
||||
"full_conversation.wav",
|
||||
all_audio.unsqueeze(0).cpu(),
|
||||
generator.sample_rate
|
||||
)
|
||||
print("Successfully generated full_conversation.wav")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
177
Backend/server.py
Normal file
177
Backend/server.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import asyncio
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from generator import load_csm_1b, Segment
|
||||
import uvicorn
|
||||
|
||||
# Select device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Initialize the model
|
||||
generator = load_csm_1b(device=device)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow cross-origin requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins in development
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Connection manager to handle multiple clients
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
# Helper function to convert audio data
|
||||
async def decode_audio_data(audio_data: str) -> torch.Tensor:
|
||||
"""Decode base64 audio data to a torch tensor"""
|
||||
try:
|
||||
# Decode base64 audio data
|
||||
binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
|
||||
|
||||
# Save to a temporary WAV file first
|
||||
temp_file = BytesIO(binary_data)
|
||||
|
||||
# Load audio from binary data, explicitly specifying the format
|
||||
audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
|
||||
|
||||
# Resample if needed
|
||||
if sample_rate != generator.sample_rate:
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor.squeeze(0),
|
||||
orig_freq=sample_rate,
|
||||
new_freq=generator.sample_rate
|
||||
)
|
||||
else:
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
|
||||
return audio_tensor
|
||||
except Exception as e:
|
||||
print(f"Error decoding audio: {str(e)}")
|
||||
# Return a small silent audio segment as fallback
|
||||
return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
|
||||
|
||||
|
||||
async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
|
||||
"""Encode torch tensor audio to base64 string"""
|
||||
buf = BytesIO()
|
||||
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
|
||||
buf.seek(0)
|
||||
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
return f"data:audio/wav;base64,{audio_base64}"
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
context_segments = [] # Store conversation context
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive JSON data from client
|
||||
data = await websocket.receive_text()
|
||||
request = json.loads(data)
|
||||
|
||||
action = request.get("action")
|
||||
|
||||
if action == "generate":
|
||||
try:
|
||||
text = request.get("text", "")
|
||||
speaker_id = request.get("speaker", 0)
|
||||
|
||||
# Generate audio response
|
||||
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
|
||||
audio_tensor = generator.generate(
|
||||
text=text,
|
||||
speaker=speaker_id,
|
||||
context=context_segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
# Add to conversation context
|
||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
||||
|
||||
# Convert audio to base64 and send back to client
|
||||
audio_base64 = await encode_audio_data(audio_tensor)
|
||||
await websocket.send_json({
|
||||
"type": "audio_response",
|
||||
"audio": audio_base64
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error generating audio: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Error generating audio: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "add_to_context":
|
||||
try:
|
||||
text = request.get("text", "")
|
||||
speaker_id = request.get("speaker", 0)
|
||||
audio_data = request.get("audio", "")
|
||||
|
||||
# Convert received audio to tensor
|
||||
audio_tensor = await decode_audio_data(audio_data)
|
||||
|
||||
# Add to conversation context
|
||||
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "context_updated",
|
||||
"message": "Audio added to context"
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Error adding to context: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Error processing audio: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "clear_context":
|
||||
context_segments = []
|
||||
await websocket.send_json({
|
||||
"type": "context_updated",
|
||||
"message": "Context cleared"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
print("Client disconnected")
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
})
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="localhost", port=8000)
|
||||
13
Backend/setup.py
Normal file
13
Backend/setup.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
|
||||
# Read requirements from requirements.txt
|
||||
with open('requirements.txt') as f:
|
||||
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
setup(
|
||||
name='csm',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
)
|
||||
50
Backend/test.py
Normal file
50
Backend/test.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
generator = load_csm_1b(device=device)
|
||||
|
||||
speakers = [0, 1, 0, 0]
|
||||
transcripts = [
|
||||
"Hey how are you doing.",
|
||||
"Pretty good, pretty good.",
|
||||
"I'm great.",
|
||||
"So happy to be speaking to you.",
|
||||
]
|
||||
audio_paths = [
|
||||
"utterance_0.wav",
|
||||
"utterance_1.wav",
|
||||
"utterance_2.wav",
|
||||
"utterance_3.wav",
|
||||
]
|
||||
|
||||
def load_audio(audio_path):
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
segments = [
|
||||
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
||||
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
||||
]
|
||||
|
||||
audio = generator.generate(
|
||||
text="Me too, this is some cool stuff huh?",
|
||||
speaker=1,
|
||||
context=segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||
79
Backend/watermarking.py
Normal file
79
Backend/watermarking.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import argparse
|
||||
|
||||
import silentcipher
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# This watermark key is public, it is not secure.
|
||||
# If using CSM 1B in another application, use a new private key and keep it secret.
|
||||
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
|
||||
|
||||
|
||||
def cli_check_audio() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--audio_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_audio_from_file(args.audio_path)
|
||||
|
||||
|
||||
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
|
||||
model = silentcipher.get_model(
|
||||
model_type="44.1k",
|
||||
device=device,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def watermark(
|
||||
watermarker: silentcipher.server.Model,
|
||||
audio_array: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
|
||||
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
|
||||
|
||||
output_sample_rate = min(44100, sample_rate)
|
||||
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
|
||||
return encoded, output_sample_rate
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def verify(
|
||||
watermarker: silentcipher.server.Model,
|
||||
watermarked_audio: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> bool:
|
||||
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
|
||||
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
|
||||
|
||||
is_watermarked = result["status"]
|
||||
if is_watermarked:
|
||||
is_csm_watermarked = result["messages"][0] == watermark_key
|
||||
else:
|
||||
is_csm_watermarked = False
|
||||
|
||||
return is_watermarked and is_csm_watermarked
|
||||
|
||||
|
||||
def check_audio_from_file(audio_path: str) -> None:
|
||||
watermarker = load_watermarker(device="cuda")
|
||||
|
||||
audio_array, sample_rate = load_audio(audio_path)
|
||||
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
|
||||
|
||||
outcome = "Watermarked" if is_watermarked else "Not watermarked"
|
||||
print(f"{outcome}: {audio_path}")
|
||||
|
||||
|
||||
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
|
||||
audio_array, sample_rate = torchaudio.load(audio_path)
|
||||
audio_array = audio_array.mean(dim=0)
|
||||
return audio_array, int(sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_check_audio()
|
||||
Reference in New Issue
Block a user