From 5da627097d80c41d27ce139bcdc2c5a411b9f8ad Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sat, 29 Mar 2025 20:36:18 -0400 Subject: [PATCH 1/2] Backend Server Code --- Backend/.gitignore | 46 ++++++ Backend/README.md | 154 ++++++++++++++++++++ Backend/generator.py | 176 +++++++++++++++++++++++ Backend/index.html | 304 +++++++++++++++++++++++++++++++++++++++ Backend/models.py | 203 ++++++++++++++++++++++++++ Backend/requirements.txt | 9 ++ Backend/run_csm.py | 117 +++++++++++++++ Backend/server.py | 156 ++++++++++++++++++++ Backend/setup.py | 13 ++ Backend/test.py | 50 +++++++ Backend/watermarking.py | 79 ++++++++++ 11 files changed, 1307 insertions(+) create mode 100644 Backend/.gitignore create mode 100644 Backend/README.md create mode 100644 Backend/generator.py create mode 100644 Backend/index.html create mode 100644 Backend/models.py create mode 100644 Backend/requirements.txt create mode 100644 Backend/run_csm.py create mode 100644 Backend/server.py create mode 100644 Backend/setup.py create mode 100644 Backend/test.py create mode 100644 Backend/watermarking.py diff --git a/Backend/.gitignore b/Backend/.gitignore new file mode 100644 index 0000000..4b7fc9d --- /dev/null +++ b/Backend/.gitignore @@ -0,0 +1,46 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +.env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Project specific +.python-version +*.wav +output_*/ +basic_audio.wav +full_conversation.wav +context_audio.wav + +# Model files +*.pt +*.ckpt \ No newline at end of file diff --git a/Backend/README.md b/Backend/README.md new file mode 100644 index 0000000..44cab4d --- /dev/null +++ b/Backend/README.md @@ -0,0 +1,154 @@ +# CSM + +**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b). + +--- + +CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes. + +A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). + +A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation. + +## Requirements + +* A CUDA-compatible GPU +* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions +* Similarly, Python 3.10 is recommended, but newer versions may be fine +* For some audio operations, `ffmpeg` may be required +* Access to the following Hugging Face models: + * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) + * [CSM-1B](https://huggingface.co/sesame/csm-1b) + +### Setup + +```bash +git clone git@github.com:SesameAILabs/csm.git +cd csm +python3.10 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt + +# Disable lazy compilation in Mimi +export NO_TORCH_COMPILE=1 + +# You will need access to CSM-1B and Llama-3.2-1B +huggingface-cli login +``` + +### Windows Setup + +The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`. + +## Quickstart + +This script will generate a conversation between 2 characters, using a prompt for each character. + +```bash +python run_csm.py +``` + +## Usage + +If you want to write your own applications with CSM, the following examples show basic usage. + +#### Generate a sentence + +This will use a random speaker identity, as no prompt or context is provided. + +```python +from generator import load_csm_1b +import torchaudio +import torch + +if torch.backends.mps.is_available(): + device = "mps" +elif torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +generator = load_csm_1b(device=device) + +audio = generator.generate( + text="Hello from Sesame.", + speaker=0, + context=[], + max_audio_length_ms=10_000, +) + +torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) +``` + +#### Generate with context + +CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance. + +NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM. + +```python +from generator import Segment + +speakers = [0, 1, 0, 0] +transcripts = [ + "Hey how are you doing.", + "Pretty good, pretty good.", + "I'm great.", + "So happy to be speaking to you.", +] +audio_paths = [ + "utterance_0.wav", + "utterance_1.wav", + "utterance_2.wav", + "utterance_3.wav", +] + +def load_audio(audio_path): + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = torchaudio.functional.resample( + audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate + ) + return audio_tensor + +segments = [ + Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) + for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) +] +audio = generator.generate( + text="Me too, this is some cool stuff huh?", + speaker=1, + context=segments, + max_audio_length_ms=10_000, +) + +torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) +``` + +## FAQ + +**Does this model come with any voices?** + +The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice. + +**Can I converse with the model?** + +CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation. + +**Does it support other languages?** + +The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well. + +## Misuse and abuse ⚠️ + +This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following: + +- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent. +- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls. +- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes. + +By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology. + +--- + +## Authors +Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team. diff --git a/Backend/generator.py b/Backend/generator.py new file mode 100644 index 0000000..7bc3634 --- /dev/null +++ b/Backend/generator.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from models import Model +from moshi.models import loaders +from tokenizers.processors import TemplateProcessing +from transformers import AutoTokenizer +from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark + + +@dataclass +class Segment: + speaker: int + text: str + # (num_samples,), sample_rate = 24_000 + audio: torch.Tensor + + +def load_llama3_tokenizer(): + """ + https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 + """ + tokenizer_name = "meta-llama/Llama-3.2-1B" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + bos = tokenizer.bos_token + eos = tokenizer.eos_token + tokenizer._tokenizer.post_processor = TemplateProcessing( + single=f"{bos}:0 $A:0 {eos}:0", + pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", + special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], + ) + + return tokenizer + + +class Generator: + def __init__( + self, + model: Model, + ): + self._model = model + self._model.setup_caches(1) + + self._text_tokenizer = load_llama3_tokenizer() + + device = next(model.parameters()).device + mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) + mimi = loaders.get_mimi(mimi_weight, device=device) + mimi.set_num_codebooks(32) + self._audio_tokenizer = mimi + + self._watermarker = load_watermarker(device=device) + + self.sample_rate = mimi.sample_rate + self.device = device + + def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: + frame_tokens = [] + frame_masks = [] + + text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") + text_frame = torch.zeros(len(text_tokens), 33).long() + text_frame_mask = torch.zeros(len(text_tokens), 33).bool() + text_frame[:, -1] = torch.tensor(text_tokens) + text_frame_mask[:, -1] = True + + frame_tokens.append(text_frame.to(self.device)) + frame_masks.append(text_frame_mask.to(self.device)) + + return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) + + def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert audio.ndim == 1, "Audio must be single channel" + + frame_tokens = [] + frame_masks = [] + + # (K, T) + audio = audio.to(self.device) + audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] + # add EOS frame + eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) + audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) + + audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device) + audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device) + audio_frame[:, :-1] = audio_tokens.transpose(0, 1) + audio_frame_mask[:, :-1] = True + + frame_tokens.append(audio_frame) + frame_masks.append(audio_frame_mask) + + return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) + + def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + (seq_len, 33), (seq_len, 33) + """ + text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) + audio_tokens, audio_masks = self._tokenize_audio(segment.audio) + + return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) + + @torch.inference_mode() + def generate( + self, + text: str, + speaker: int, + context: List[Segment], + max_audio_length_ms: float = 90_000, + temperature: float = 0.9, + topk: int = 50, + ) -> torch.Tensor: + self._model.reset_caches() + + max_generation_len = int(max_audio_length_ms / 80) + tokens, tokens_mask = [], [] + for segment in context: + segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) + tokens.append(segment_tokens) + tokens_mask.append(segment_tokens_mask) + + gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker) + tokens.append(gen_segment_tokens) + tokens_mask.append(gen_segment_tokens_mask) + + prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) + prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) + + samples = [] + curr_tokens = prompt_tokens.unsqueeze(0) + curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) + curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) + + max_seq_len = 2048 + max_context_len = max_seq_len - max_generation_len + if curr_tokens.size(1) >= max_context_len: + raise ValueError( + f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}" + ) + + for _ in range(max_generation_len): + sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) + if torch.all(sample == 0): + break # eos + + samples.append(sample) + + curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) + curr_tokens_mask = torch.cat( + [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 + ).unsqueeze(1) + curr_pos = curr_pos[:, -1:] + 1 + + audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) + + # This applies an imperceptible watermark to identify audio as AI-generated. + # Watermarking ensures transparency, dissuades misuse, and enables traceability. + # Please be a responsible AI citizen and keep the watermarking in place. + # If using CSM 1B in another application, use your own private key and keep it secret. + audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) + audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) + + return audio + + +def load_csm_1b(device: str = "cuda") -> Generator: + model = Model.from_pretrained("sesame/csm-1b") + model.to(device=device, dtype=torch.bfloat16) + + generator = Generator(model) + return generator \ No newline at end of file diff --git a/Backend/index.html b/Backend/index.html new file mode 100644 index 0000000..539fe09 --- /dev/null +++ b/Backend/index.html @@ -0,0 +1,304 @@ +/index.html + + + + + + Sesame AI Voice Chat + + + +

Sesame AI Voice Chat

+
+ +
+
+ + + +
+ +
+ + +
+
+ + + + \ No newline at end of file diff --git a/Backend/models.py b/Backend/models.py new file mode 100644 index 0000000..e9508e7 --- /dev/null +++ b/Backend/models.py @@ -0,0 +1,203 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torchtune +from huggingface_hub import PyTorchModelHubMixin +from torchtune.models import llama3_2 + + +def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder: + return llama3_2.llama3_2( + vocab_size=128_256, + num_layers=16, + num_heads=32, + num_kv_heads=8, + embed_dim=2048, + max_seq_len=2048, + intermediate_dim=8192, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500_000, + scale_factor=32, + ) + + +def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder: + return llama3_2.llama3_2( + vocab_size=128_256, + num_layers=4, + num_heads=8, + num_kv_heads=2, + embed_dim=1024, + max_seq_len=2048, + intermediate_dim=8192, + attn_dropout=0.0, + norm_eps=1e-5, + rope_base=500_000, + scale_factor=32, + ) + + +FLAVORS = { + "llama-1B": llama3_2_1B, + "llama-100M": llama3_2_100M, +} + + +def _prepare_transformer(model): + embed_dim = model.tok_embeddings.embedding_dim + model.tok_embeddings = nn.Identity() + model.output = nn.Identity() + return model, embed_dim + + +def _create_causal_mask(seq_len: int, device: torch.device): + return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)) + + +def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor): + """ + Args: + mask: (max_seq_len, max_seq_len) + input_pos: (batch_size, seq_len) + + Returns: + (batch_size, seq_len, max_seq_len) + """ + r = mask[input_pos, :] + return r + + +def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def sample_topk(logits: torch.Tensor, topk: int, temperature: float): + logits = logits / temperature + + filter_value: float = -float("Inf") + indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] + scores_processed = logits.masked_fill(indices_to_remove, filter_value) + scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) + probs = torch.nn.functional.softmax(scores_processed, dim=-1) + + sample_token = _multinomial_sample_one_no_sync(probs) + return sample_token + + +@dataclass +class ModelArgs: + backbone_flavor: str + decoder_flavor: str + text_vocab_size: int + audio_vocab_size: int + audio_num_codebooks: int + + +class Model( + nn.Module, + PyTorchModelHubMixin, + repo_url="https://github.com/SesameAILabs/csm", + pipeline_tag="text-to-speech", + license="apache-2.0", +): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + + self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]()) + self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]()) + + self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim) + self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim) + + self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) + self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False) + self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size)) + + def setup_caches(self, max_batch_size: int) -> torch.Tensor: + """Setup KV caches and return a causal mask.""" + dtype = next(self.parameters()).dtype + device = next(self.parameters()).device + + with device: + self.backbone.setup_caches(max_batch_size, dtype) + self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks) + + self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device)) + self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device)) + + def generate_frame( + self, + tokens: torch.Tensor, + tokens_mask: torch.Tensor, + input_pos: torch.Tensor, + temperature: float, + topk: int, + ) -> torch.Tensor: + """ + Args: + tokens: (batch_size, seq_len, audio_num_codebooks+1) + tokens_mask: (batch_size, seq_len, audio_num_codebooks+1) + input_pos: (batch_size, seq_len) positions for each token + mask: (batch_size, seq_len, max_seq_len + + Returns: + (batch_size, audio_num_codebooks) sampled tokens + """ + dtype = next(self.parameters()).dtype + b, s, _ = tokens.size() + + assert self.backbone.caches_are_enabled(), "backbone caches are not enabled" + curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos) + embeds = self._embed_tokens(tokens) + masked_embeds = embeds * tokens_mask.unsqueeze(-1) + h = masked_embeds.sum(dim=2) + h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype) + + last_h = h[:, -1, :] + c0_logits = self.codebook0_head(last_h) + c0_sample = sample_topk(c0_logits, topk, temperature) + c0_embed = self._embed_audio(0, c0_sample) + + curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1) + curr_sample = c0_sample.clone() + curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1) + + # Decoder caches must be reset every frame. + self.decoder.reset_caches() + for i in range(1, self.config.audio_num_codebooks): + curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos) + decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to( + dtype=dtype + ) + ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1]) + ci_sample = sample_topk(ci_logits, topk, temperature) + ci_embed = self._embed_audio(i, ci_sample) + + curr_h = ci_embed + curr_sample = torch.cat([curr_sample, ci_sample], dim=1) + curr_pos = curr_pos[:, -1:] + 1 + + return curr_sample + + def reset_caches(self): + self.backbone.reset_caches() + self.decoder.reset_caches() + + def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor: + return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size) + + def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2) + + audio_tokens = tokens[:, :, :-1] + ( + self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device) + ) + audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape( + tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1 + ) + + return torch.cat([audio_embeds, text_embeds], dim=-2) diff --git a/Backend/requirements.txt b/Backend/requirements.txt new file mode 100644 index 0000000..ba8a04f --- /dev/null +++ b/Backend/requirements.txt @@ -0,0 +1,9 @@ +torch==2.4.0 +torchaudio==2.4.0 +tokenizers==0.21.0 +transformers==4.49.0 +huggingface_hub==0.28.1 +moshi==0.2.2 +torchtune==0.4.0 +torchao==0.9.0 +silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master \ No newline at end of file diff --git a/Backend/run_csm.py b/Backend/run_csm.py new file mode 100644 index 0000000..0062973 --- /dev/null +++ b/Backend/run_csm.py @@ -0,0 +1,117 @@ +import os +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from generator import load_csm_1b, Segment +from dataclasses import dataclass + +# Disable Triton compilation +os.environ["NO_TORCH_COMPILE"] = "1" + +# Default prompts are available at https://hf.co/sesame/csm-1b +prompt_filepath_conversational_a = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_a.wav" +) +prompt_filepath_conversational_b = hf_hub_download( + repo_id="sesame/csm-1b", + filename="prompts/conversational_b.wav" +) + +SPEAKER_PROMPTS = { + "conversational_a": { + "text": ( + "like revising for an exam I'd have to try and like keep up the momentum because I'd " + "start really early I'd be like okay I'm gonna start revising now and then like " + "you're revising for ages and then I just like start losing steam I didn't do that " + "for the exam we had recently to be fair that was a more of a last minute scenario " + "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I " + "sort of start the day with this not like a panic but like a" + ), + "audio": prompt_filepath_conversational_a + }, + "conversational_b": { + "text": ( + "like a super Mario level. Like it's very like high detail. And like, once you get " + "into the park, it just like, everything looks like a computer game and they have all " + "these, like, you know, if, if there's like a, you know, like in a Mario game, they " + "will have like a question block. And if you like, you know, punch it, a coin will " + "come out. So like everyone, when they come into the park, they get like this little " + "bracelet and then you can go punching question blocks around." + ), + "audio": prompt_filepath_conversational_b + } +} + +def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor: + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = audio_tensor.squeeze(0) + # Resample is lazy so we can always call it + audio_tensor = torchaudio.functional.resample( + audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate + ) + return audio_tensor + +def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment: + audio_tensor = load_prompt_audio(audio_path, sample_rate) + return Segment(text=text, speaker=speaker, audio=audio_tensor) + +def main(): + # Select the best available device, skipping MPS due to float64 limitations + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + print(f"Using device: {device}") + + # Load model + generator = load_csm_1b(device) + + # Prepare prompts + prompt_a = prepare_prompt( + SPEAKER_PROMPTS["conversational_a"]["text"], + 0, + SPEAKER_PROMPTS["conversational_a"]["audio"], + generator.sample_rate + ) + + prompt_b = prepare_prompt( + SPEAKER_PROMPTS["conversational_b"]["text"], + 1, + SPEAKER_PROMPTS["conversational_b"]["audio"], + generator.sample_rate + ) + + # Generate conversation + conversation = [ + {"text": "Hey how are you doing?", "speaker_id": 0}, + {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1}, + {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0}, + {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1} + ] + + # Generate each utterance + generated_segments = [] + prompt_segments = [prompt_a, prompt_b] + + for utterance in conversation: + print(f"Generating: {utterance['text']}") + audio_tensor = generator.generate( + text=utterance['text'], + speaker=utterance['speaker_id'], + context=prompt_segments + generated_segments, + max_audio_length_ms=10_000, + ) + generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor)) + + # Concatenate all generations + all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0) + torchaudio.save( + "full_conversation.wav", + all_audio.unsqueeze(0).cpu(), + generator.sample_rate + ) + print("Successfully generated full_conversation.wav") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Backend/server.py b/Backend/server.py new file mode 100644 index 0000000..92d0f1f --- /dev/null +++ b/Backend/server.py @@ -0,0 +1,156 @@ +import os +import base64 +import json +import asyncio +import torch +import torchaudio +import numpy as np +from io import BytesIO +from typing import List, Dict, Any, Optional +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from generator import load_csm_1b, Segment +import uvicorn + +# Select device +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" +print(f"Using device: {device}") + +# Initialize the model +generator = load_csm_1b(device=device) + +app = FastAPI() + +# Add CORS middleware to allow cross-origin requests +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allow all origins in development + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Connection manager to handle multiple clients +class ConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + +manager = ConnectionManager() + + +# Helper function to convert audio data +async def decode_audio_data(audio_data: str) -> torch.Tensor: + """Decode base64 audio data to a torch tensor""" + # Decode base64 audio data + binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data) + + # Load audio from binary data + buf = BytesIO(binary_data) + audio_tensor, sample_rate = torchaudio.load(buf) + + # Resample if needed + if sample_rate != generator.sample_rate: + audio_tensor = torchaudio.functional.resample( + audio_tensor.squeeze(0), + orig_freq=sample_rate, + new_freq=generator.sample_rate + ) + else: + audio_tensor = audio_tensor.squeeze(0) + + return audio_tensor + + +async def encode_audio_data(audio_tensor: torch.Tensor) -> str: + """Encode torch tensor audio to base64 string""" + buf = BytesIO() + torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav") + buf.seek(0) + audio_base64 = base64.b64encode(buf.read()).decode('utf-8') + return f"data:audio/wav;base64,{audio_base64}" + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await manager.connect(websocket) + context_segments = [] # Store conversation context + + try: + while True: + # Receive JSON data from client + data = await websocket.receive_text() + request = json.loads(data) + + action = request.get("action") + + if action == "generate": + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + + # Generate audio response + print(f"Generating audio for: '{text}' with speaker {speaker_id}") + audio_tensor = generator.generate( + text=text, + speaker=speaker_id, + context=context_segments, + max_audio_length_ms=10_000, + ) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + # Convert audio to base64 and send back to client + audio_base64 = await encode_audio_data(audio_tensor) + await websocket.send_json({ + "type": "audio_response", + "audio": audio_base64 + }) + + elif action == "add_to_context": + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + audio_data = request.get("audio", "") + + # Convert received audio to tensor + audio_tensor = await decode_audio_data(audio_data) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + await websocket.send_json({ + "type": "context_updated", + "message": "Audio added to context" + }) + + elif action == "clear_context": + context_segments = [] + await websocket.send_json({ + "type": "context_updated", + "message": "Context cleared" + }) + + except WebSocketDisconnect: + manager.disconnect(websocket) + print("Client disconnected") + except Exception as e: + print(f"Error: {str(e)}") + await websocket.send_json({ + "type": "error", + "message": str(e) + }) + manager.disconnect(websocket) + + +if __name__ == "__main__": + uvicorn.run(app, host="localhost", port=8000) \ No newline at end of file diff --git a/Backend/setup.py b/Backend/setup.py new file mode 100644 index 0000000..8eddb95 --- /dev/null +++ b/Backend/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages +import os + +# Read requirements from requirements.txt +with open('requirements.txt') as f: + requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] + +setup( + name='csm', + version='0.1.0', + packages=find_packages(), + install_requires=requirements, +) diff --git a/Backend/test.py b/Backend/test.py new file mode 100644 index 0000000..34735b1 --- /dev/null +++ b/Backend/test.py @@ -0,0 +1,50 @@ +import os +import torch +import torchaudio +from huggingface_hub import hf_hub_download +from generator import load_csm_1b, Segment +from dataclasses import dataclass + +if torch.backends.mps.is_available(): + device = "mps" +elif torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +generator = load_csm_1b(device=device) + +speakers = [0, 1, 0, 0] +transcripts = [ + "Hey how are you doing.", + "Pretty good, pretty good.", + "I'm great.", + "So happy to be speaking to you.", +] +audio_paths = [ + "utterance_0.wav", + "utterance_1.wav", + "utterance_2.wav", + "utterance_3.wav", +] + +def load_audio(audio_path): + audio_tensor, sample_rate = torchaudio.load(audio_path) + audio_tensor = torchaudio.functional.resample( + audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate + ) + return audio_tensor + +segments = [ + Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) + for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) +] + +audio = generator.generate( + text="Me too, this is some cool stuff huh?", + speaker=1, + context=segments, + max_audio_length_ms=10_000, +) + +torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) \ No newline at end of file diff --git a/Backend/watermarking.py b/Backend/watermarking.py new file mode 100644 index 0000000..093962f --- /dev/null +++ b/Backend/watermarking.py @@ -0,0 +1,79 @@ +import argparse + +import silentcipher +import torch +import torchaudio + +# This watermark key is public, it is not secure. +# If using CSM 1B in another application, use a new private key and keep it secret. +CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201] + + +def cli_check_audio() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--audio_path", type=str, required=True) + args = parser.parse_args() + + check_audio_from_file(args.audio_path) + + +def load_watermarker(device: str = "cuda") -> silentcipher.server.Model: + model = silentcipher.get_model( + model_type="44.1k", + device=device, + ) + return model + + +@torch.inference_mode() +def watermark( + watermarker: silentcipher.server.Model, + audio_array: torch.Tensor, + sample_rate: int, + watermark_key: list[int], +) -> tuple[torch.Tensor, int]: + audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100) + encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36) + + output_sample_rate = min(44100, sample_rate) + encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate) + return encoded, output_sample_rate + + +@torch.inference_mode() +def verify( + watermarker: silentcipher.server.Model, + watermarked_audio: torch.Tensor, + sample_rate: int, + watermark_key: list[int], +) -> bool: + watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100) + result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True) + + is_watermarked = result["status"] + if is_watermarked: + is_csm_watermarked = result["messages"][0] == watermark_key + else: + is_csm_watermarked = False + + return is_watermarked and is_csm_watermarked + + +def check_audio_from_file(audio_path: str) -> None: + watermarker = load_watermarker(device="cuda") + + audio_array, sample_rate = load_audio(audio_path) + is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK) + + outcome = "Watermarked" if is_watermarked else "Not watermarked" + print(f"{outcome}: {audio_path}") + + +def load_audio(audio_path: str) -> tuple[torch.Tensor, int]: + audio_array, sample_rate = torchaudio.load(audio_path) + audio_array = audio_array.mean(dim=0) + return audio_array, int(sample_rate) + + +if __name__ == "__main__": + cli_check_audio() From 99a6c7d413a9e9e136086db17560f022b845122f Mon Sep 17 00:00:00 2001 From: GamerBoss101 Date: Sat, 29 Mar 2025 21:18:19 -0400 Subject: [PATCH 2/2] Server Py update --- Backend/server.py | 125 +++++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 52 deletions(-) diff --git a/Backend/server.py b/Backend/server.py index 92d0f1f..e8ed1ae 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -52,24 +52,31 @@ manager = ConnectionManager() # Helper function to convert audio data async def decode_audio_data(audio_data: str) -> torch.Tensor: """Decode base64 audio data to a torch tensor""" - # Decode base64 audio data - binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data) - - # Load audio from binary data - buf = BytesIO(binary_data) - audio_tensor, sample_rate = torchaudio.load(buf) - - # Resample if needed - if sample_rate != generator.sample_rate: - audio_tensor = torchaudio.functional.resample( - audio_tensor.squeeze(0), - orig_freq=sample_rate, - new_freq=generator.sample_rate - ) - else: - audio_tensor = audio_tensor.squeeze(0) + try: + # Decode base64 audio data + binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data) - return audio_tensor + # Save to a temporary WAV file first + temp_file = BytesIO(binary_data) + + # Load audio from binary data, explicitly specifying the format + audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") + + # Resample if needed + if sample_rate != generator.sample_rate: + audio_tensor = torchaudio.functional.resample( + audio_tensor.squeeze(0), + orig_freq=sample_rate, + new_freq=generator.sample_rate + ) + else: + audio_tensor = audio_tensor.squeeze(0) + + return audio_tensor + except Exception as e: + print(f"Error decoding audio: {str(e)}") + # Return a small silent audio segment as fallback + return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence async def encode_audio_data(audio_tensor: torch.Tensor) -> str: @@ -95,43 +102,57 @@ async def websocket_endpoint(websocket: WebSocket): action = request.get("action") if action == "generate": - text = request.get("text", "") - speaker_id = request.get("speaker", 0) - - # Generate audio response - print(f"Generating audio for: '{text}' with speaker {speaker_id}") - audio_tensor = generator.generate( - text=text, - speaker=speaker_id, - context=context_segments, - max_audio_length_ms=10_000, - ) - - # Add to conversation context - context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) - - # Convert audio to base64 and send back to client - audio_base64 = await encode_audio_data(audio_tensor) - await websocket.send_json({ - "type": "audio_response", - "audio": audio_base64 - }) + try: + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + + # Generate audio response + print(f"Generating audio for: '{text}' with speaker {speaker_id}") + audio_tensor = generator.generate( + text=text, + speaker=speaker_id, + context=context_segments, + max_audio_length_ms=10_000, + ) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + # Convert audio to base64 and send back to client + audio_base64 = await encode_audio_data(audio_tensor) + await websocket.send_json({ + "type": "audio_response", + "audio": audio_base64 + }) + except Exception as e: + print(f"Error generating audio: {str(e)}") + await websocket.send_json({ + "type": "error", + "message": f"Error generating audio: {str(e)}" + }) elif action == "add_to_context": - text = request.get("text", "") - speaker_id = request.get("speaker", 0) - audio_data = request.get("audio", "") - - # Convert received audio to tensor - audio_tensor = await decode_audio_data(audio_data) - - # Add to conversation context - context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) - - await websocket.send_json({ - "type": "context_updated", - "message": "Audio added to context" - }) + try: + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + audio_data = request.get("audio", "") + + # Convert received audio to tensor + audio_tensor = await decode_audio_data(audio_data) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + await websocket.send_json({ + "type": "context_updated", + "message": "Audio added to context" + }) + except Exception as e: + print(f"Error adding to context: {str(e)}") + await websocket.send_json({ + "type": "error", + "message": f"Error processing audio: {str(e)}" + }) elif action == "clear_context": context_segments = []