diff --git a/Backend/.gitignore b/Backend/.gitignore
new file mode 100644
index 0000000..4b7fc9d
--- /dev/null
+++ b/Backend/.gitignore
@@ -0,0 +1,46 @@
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+.env
+.venv
+env/
+venv/
+ENV/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# Project specific
+.python-version
+*.wav
+output_*/
+basic_audio.wav
+full_conversation.wav
+context_audio.wav
+
+# Model files
+*.pt
+*.ckpt
\ No newline at end of file
diff --git a/Backend/README.md b/Backend/README.md
new file mode 100644
index 0000000..44cab4d
--- /dev/null
+++ b/Backend/README.md
@@ -0,0 +1,154 @@
+# CSM
+
+**2025/03/13** - We are releasing the 1B CSM variant. The checkpoint is [hosted on Hugging Face](https://huggingface.co/sesame/csm_1b).
+
+---
+
+CSM (Conversational Speech Model) is a speech generation model from [Sesame](https://www.sesame.com) that generates RVQ audio codes from text and audio inputs. The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
+
+A fine-tuned variant of CSM powers the [interactive voice demo](https://www.sesame.com/voicedemo) shown in our [blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
+
+A hosted [Hugging Face space](https://huggingface.co/spaces/sesame/csm-1b) is also available for testing audio generation.
+
+## Requirements
+
+* A CUDA-compatible GPU
+* The code has been tested on CUDA 12.4 and 12.6, but it may also work on other versions
+* Similarly, Python 3.10 is recommended, but newer versions may be fine
+* For some audio operations, `ffmpeg` may be required
+* Access to the following Hugging Face models:
+ * [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
+ * [CSM-1B](https://huggingface.co/sesame/csm-1b)
+
+### Setup
+
+```bash
+git clone git@github.com:SesameAILabs/csm.git
+cd csm
+python3.10 -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+
+# Disable lazy compilation in Mimi
+export NO_TORCH_COMPILE=1
+
+# You will need access to CSM-1B and Llama-3.2-1B
+huggingface-cli login
+```
+
+### Windows Setup
+
+The `triton` package cannot be installed in Windows. Instead use `pip install triton-windows`.
+
+## Quickstart
+
+This script will generate a conversation between 2 characters, using a prompt for each character.
+
+```bash
+python run_csm.py
+```
+
+## Usage
+
+If you want to write your own applications with CSM, the following examples show basic usage.
+
+#### Generate a sentence
+
+This will use a random speaker identity, as no prompt or context is provided.
+
+```python
+from generator import load_csm_1b
+import torchaudio
+import torch
+
+if torch.backends.mps.is_available():
+ device = "mps"
+elif torch.cuda.is_available():
+ device = "cuda"
+else:
+ device = "cpu"
+
+generator = load_csm_1b(device=device)
+
+audio = generator.generate(
+ text="Hello from Sesame.",
+ speaker=0,
+ context=[],
+ max_audio_length_ms=10_000,
+)
+
+torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
+```
+
+#### Generate with context
+
+CSM sounds best when provided with context. You can prompt or provide context to the model using a `Segment` for each speaker's utterance.
+
+NOTE: The following example is instructional and the audio files do not exist. It is intended as an example for using context with CSM.
+
+```python
+from generator import Segment
+
+speakers = [0, 1, 0, 0]
+transcripts = [
+ "Hey how are you doing.",
+ "Pretty good, pretty good.",
+ "I'm great.",
+ "So happy to be speaking to you.",
+]
+audio_paths = [
+ "utterance_0.wav",
+ "utterance_1.wav",
+ "utterance_2.wav",
+ "utterance_3.wav",
+]
+
+def load_audio(audio_path):
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
+ audio_tensor = torchaudio.functional.resample(
+ audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
+ )
+ return audio_tensor
+
+segments = [
+ Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
+ for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
+]
+audio = generator.generate(
+ text="Me too, this is some cool stuff huh?",
+ speaker=1,
+ context=segments,
+ max_audio_length_ms=10_000,
+)
+
+torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
+```
+
+## FAQ
+
+**Does this model come with any voices?**
+
+The model open-sourced here is a base generation model. It is capable of producing a variety of voices, but it has not been fine-tuned on any specific voice.
+
+**Can I converse with the model?**
+
+CSM is trained to be an audio generation model and not a general-purpose multimodal LLM. It cannot generate text. We suggest using a separate LLM for text generation.
+
+**Does it support other languages?**
+
+The model has some capacity for non-English languages due to data contamination in the training data, but it likely won't do well.
+
+## Misuse and abuse ⚠️
+
+This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
+
+- **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
+- **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
+- **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
+
+By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
+
+---
+
+## Authors
+Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.
diff --git a/Backend/generator.py b/Backend/generator.py
new file mode 100644
index 0000000..7bc3634
--- /dev/null
+++ b/Backend/generator.py
@@ -0,0 +1,176 @@
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from models import Model
+from moshi.models import loaders
+from tokenizers.processors import TemplateProcessing
+from transformers import AutoTokenizer
+from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
+
+
+@dataclass
+class Segment:
+ speaker: int
+ text: str
+ # (num_samples,), sample_rate = 24_000
+ audio: torch.Tensor
+
+
+def load_llama3_tokenizer():
+ """
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
+ """
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ bos = tokenizer.bos_token
+ eos = tokenizer.eos_token
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
+ single=f"{bos}:0 $A:0 {eos}:0",
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
+ )
+
+ return tokenizer
+
+
+class Generator:
+ def __init__(
+ self,
+ model: Model,
+ ):
+ self._model = model
+ self._model.setup_caches(1)
+
+ self._text_tokenizer = load_llama3_tokenizer()
+
+ device = next(model.parameters()).device
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
+ mimi = loaders.get_mimi(mimi_weight, device=device)
+ mimi.set_num_codebooks(32)
+ self._audio_tokenizer = mimi
+
+ self._watermarker = load_watermarker(device=device)
+
+ self.sample_rate = mimi.sample_rate
+ self.device = device
+
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ frame_tokens = []
+ frame_masks = []
+
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
+ text_frame = torch.zeros(len(text_tokens), 33).long()
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
+ text_frame[:, -1] = torch.tensor(text_tokens)
+ text_frame_mask[:, -1] = True
+
+ frame_tokens.append(text_frame.to(self.device))
+ frame_masks.append(text_frame_mask.to(self.device))
+
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
+
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert audio.ndim == 1, "Audio must be single channel"
+
+ frame_tokens = []
+ frame_masks = []
+
+ # (K, T)
+ audio = audio.to(self.device)
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
+ # add EOS frame
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
+
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
+ audio_frame_mask[:, :-1] = True
+
+ frame_tokens.append(audio_frame)
+ frame_masks.append(audio_frame_mask)
+
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
+
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Returns:
+ (seq_len, 33), (seq_len, 33)
+ """
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
+
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
+
+ @torch.inference_mode()
+ def generate(
+ self,
+ text: str,
+ speaker: int,
+ context: List[Segment],
+ max_audio_length_ms: float = 90_000,
+ temperature: float = 0.9,
+ topk: int = 50,
+ ) -> torch.Tensor:
+ self._model.reset_caches()
+
+ max_generation_len = int(max_audio_length_ms / 80)
+ tokens, tokens_mask = [], []
+ for segment in context:
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
+ tokens.append(segment_tokens)
+ tokens_mask.append(segment_tokens_mask)
+
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
+ tokens.append(gen_segment_tokens)
+ tokens_mask.append(gen_segment_tokens_mask)
+
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
+
+ samples = []
+ curr_tokens = prompt_tokens.unsqueeze(0)
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
+
+ max_seq_len = 2048
+ max_context_len = max_seq_len - max_generation_len
+ if curr_tokens.size(1) >= max_context_len:
+ raise ValueError(
+ f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
+ )
+
+ for _ in range(max_generation_len):
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
+ if torch.all(sample == 0):
+ break # eos
+
+ samples.append(sample)
+
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
+ curr_tokens_mask = torch.cat(
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
+ ).unsqueeze(1)
+ curr_pos = curr_pos[:, -1:] + 1
+
+ audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
+
+ # This applies an imperceptible watermark to identify audio as AI-generated.
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
+ # Please be a responsible AI citizen and keep the watermarking in place.
+ # If using CSM 1B in another application, use your own private key and keep it secret.
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
+
+ return audio
+
+
+def load_csm_1b(device: str = "cuda") -> Generator:
+ model = Model.from_pretrained("sesame/csm-1b")
+ model.to(device=device, dtype=torch.bfloat16)
+
+ generator = Generator(model)
+ return generator
\ No newline at end of file
diff --git a/Backend/index.html b/Backend/index.html
new file mode 100644
index 0000000..539fe09
--- /dev/null
+++ b/Backend/index.html
@@ -0,0 +1,304 @@
+/index.html
+
+
+
+
+
+ Sesame AI Voice Chat
+
+
+
+
Sesame AI Voice Chat
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Backend/models.py b/Backend/models.py
new file mode 100644
index 0000000..e9508e7
--- /dev/null
+++ b/Backend/models.py
@@ -0,0 +1,203 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torchtune
+from huggingface_hub import PyTorchModelHubMixin
+from torchtune.models import llama3_2
+
+
+def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
+ return llama3_2.llama3_2(
+ vocab_size=128_256,
+ num_layers=16,
+ num_heads=32,
+ num_kv_heads=8,
+ embed_dim=2048,
+ max_seq_len=2048,
+ intermediate_dim=8192,
+ attn_dropout=0.0,
+ norm_eps=1e-5,
+ rope_base=500_000,
+ scale_factor=32,
+ )
+
+
+def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
+ return llama3_2.llama3_2(
+ vocab_size=128_256,
+ num_layers=4,
+ num_heads=8,
+ num_kv_heads=2,
+ embed_dim=1024,
+ max_seq_len=2048,
+ intermediate_dim=8192,
+ attn_dropout=0.0,
+ norm_eps=1e-5,
+ rope_base=500_000,
+ scale_factor=32,
+ )
+
+
+FLAVORS = {
+ "llama-1B": llama3_2_1B,
+ "llama-100M": llama3_2_100M,
+}
+
+
+def _prepare_transformer(model):
+ embed_dim = model.tok_embeddings.embedding_dim
+ model.tok_embeddings = nn.Identity()
+ model.output = nn.Identity()
+ return model, embed_dim
+
+
+def _create_causal_mask(seq_len: int, device: torch.device):
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
+
+
+def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
+ """
+ Args:
+ mask: (max_seq_len, max_seq_len)
+ input_pos: (batch_size, seq_len)
+
+ Returns:
+ (batch_size, seq_len, max_seq_len)
+ """
+ r = mask[input_pos, :]
+ return r
+
+
+def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs).exponential_(1)
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
+ logits = logits / temperature
+
+ filter_value: float = -float("Inf")
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
+
+ sample_token = _multinomial_sample_one_no_sync(probs)
+ return sample_token
+
+
+@dataclass
+class ModelArgs:
+ backbone_flavor: str
+ decoder_flavor: str
+ text_vocab_size: int
+ audio_vocab_size: int
+ audio_num_codebooks: int
+
+
+class Model(
+ nn.Module,
+ PyTorchModelHubMixin,
+ repo_url="https://github.com/SesameAILabs/csm",
+ pipeline_tag="text-to-speech",
+ license="apache-2.0",
+):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
+
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
+
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
+ self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
+
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
+ """Setup KV caches and return a causal mask."""
+ dtype = next(self.parameters()).dtype
+ device = next(self.parameters()).device
+
+ with device:
+ self.backbone.setup_caches(max_batch_size, dtype)
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
+
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
+
+ def generate_frame(
+ self,
+ tokens: torch.Tensor,
+ tokens_mask: torch.Tensor,
+ input_pos: torch.Tensor,
+ temperature: float,
+ topk: int,
+ ) -> torch.Tensor:
+ """
+ Args:
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
+ input_pos: (batch_size, seq_len) positions for each token
+ mask: (batch_size, seq_len, max_seq_len
+
+ Returns:
+ (batch_size, audio_num_codebooks) sampled tokens
+ """
+ dtype = next(self.parameters()).dtype
+ b, s, _ = tokens.size()
+
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
+ embeds = self._embed_tokens(tokens)
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
+ h = masked_embeds.sum(dim=2)
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
+
+ last_h = h[:, -1, :]
+ c0_logits = self.codebook0_head(last_h)
+ c0_sample = sample_topk(c0_logits, topk, temperature)
+ c0_embed = self._embed_audio(0, c0_sample)
+
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
+ curr_sample = c0_sample.clone()
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
+
+ # Decoder caches must be reset every frame.
+ self.decoder.reset_caches()
+ for i in range(1, self.config.audio_num_codebooks):
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
+ dtype=dtype
+ )
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
+ ci_sample = sample_topk(ci_logits, topk, temperature)
+ ci_embed = self._embed_audio(i, ci_sample)
+
+ curr_h = ci_embed
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
+ curr_pos = curr_pos[:, -1:] + 1
+
+ return curr_sample
+
+ def reset_caches(self):
+ self.backbone.reset_caches()
+ self.decoder.reset_caches()
+
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
+
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
+
+ audio_tokens = tokens[:, :, :-1] + (
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
+ )
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
+ )
+
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
diff --git a/Backend/requirements.txt b/Backend/requirements.txt
new file mode 100644
index 0000000..ba8a04f
--- /dev/null
+++ b/Backend/requirements.txt
@@ -0,0 +1,9 @@
+torch==2.4.0
+torchaudio==2.4.0
+tokenizers==0.21.0
+transformers==4.49.0
+huggingface_hub==0.28.1
+moshi==0.2.2
+torchtune==0.4.0
+torchao==0.9.0
+silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
\ No newline at end of file
diff --git a/Backend/run_csm.py b/Backend/run_csm.py
new file mode 100644
index 0000000..0062973
--- /dev/null
+++ b/Backend/run_csm.py
@@ -0,0 +1,117 @@
+import os
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from generator import load_csm_1b, Segment
+from dataclasses import dataclass
+
+# Disable Triton compilation
+os.environ["NO_TORCH_COMPILE"] = "1"
+
+# Default prompts are available at https://hf.co/sesame/csm-1b
+prompt_filepath_conversational_a = hf_hub_download(
+ repo_id="sesame/csm-1b",
+ filename="prompts/conversational_a.wav"
+)
+prompt_filepath_conversational_b = hf_hub_download(
+ repo_id="sesame/csm-1b",
+ filename="prompts/conversational_b.wav"
+)
+
+SPEAKER_PROMPTS = {
+ "conversational_a": {
+ "text": (
+ "like revising for an exam I'd have to try and like keep up the momentum because I'd "
+ "start really early I'd be like okay I'm gonna start revising now and then like "
+ "you're revising for ages and then I just like start losing steam I didn't do that "
+ "for the exam we had recently to be fair that was a more of a last minute scenario "
+ "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
+ "sort of start the day with this not like a panic but like a"
+ ),
+ "audio": prompt_filepath_conversational_a
+ },
+ "conversational_b": {
+ "text": (
+ "like a super Mario level. Like it's very like high detail. And like, once you get "
+ "into the park, it just like, everything looks like a computer game and they have all "
+ "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
+ "will have like a question block. And if you like, you know, punch it, a coin will "
+ "come out. So like everyone, when they come into the park, they get like this little "
+ "bracelet and then you can go punching question blocks around."
+ ),
+ "audio": prompt_filepath_conversational_b
+ }
+}
+
+def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
+ audio_tensor = audio_tensor.squeeze(0)
+ # Resample is lazy so we can always call it
+ audio_tensor = torchaudio.functional.resample(
+ audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
+ )
+ return audio_tensor
+
+def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
+ audio_tensor = load_prompt_audio(audio_path, sample_rate)
+ return Segment(text=text, speaker=speaker, audio=audio_tensor)
+
+def main():
+ # Select the best available device, skipping MPS due to float64 limitations
+ if torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ print(f"Using device: {device}")
+
+ # Load model
+ generator = load_csm_1b(device)
+
+ # Prepare prompts
+ prompt_a = prepare_prompt(
+ SPEAKER_PROMPTS["conversational_a"]["text"],
+ 0,
+ SPEAKER_PROMPTS["conversational_a"]["audio"],
+ generator.sample_rate
+ )
+
+ prompt_b = prepare_prompt(
+ SPEAKER_PROMPTS["conversational_b"]["text"],
+ 1,
+ SPEAKER_PROMPTS["conversational_b"]["audio"],
+ generator.sample_rate
+ )
+
+ # Generate conversation
+ conversation = [
+ {"text": "Hey how are you doing?", "speaker_id": 0},
+ {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
+ {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
+ {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
+ ]
+
+ # Generate each utterance
+ generated_segments = []
+ prompt_segments = [prompt_a, prompt_b]
+
+ for utterance in conversation:
+ print(f"Generating: {utterance['text']}")
+ audio_tensor = generator.generate(
+ text=utterance['text'],
+ speaker=utterance['speaker_id'],
+ context=prompt_segments + generated_segments,
+ max_audio_length_ms=10_000,
+ )
+ generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
+
+ # Concatenate all generations
+ all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
+ torchaudio.save(
+ "full_conversation.wav",
+ all_audio.unsqueeze(0).cpu(),
+ generator.sample_rate
+ )
+ print("Successfully generated full_conversation.wav")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/Backend/server.py b/Backend/server.py
new file mode 100644
index 0000000..e8ed1ae
--- /dev/null
+++ b/Backend/server.py
@@ -0,0 +1,177 @@
+import os
+import base64
+import json
+import asyncio
+import torch
+import torchaudio
+import numpy as np
+from io import BytesIO
+from typing import List, Dict, Any, Optional
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+from generator import load_csm_1b, Segment
+import uvicorn
+
+# Select device
+if torch.cuda.is_available():
+ device = "cuda"
+else:
+ device = "cpu"
+print(f"Using device: {device}")
+
+# Initialize the model
+generator = load_csm_1b(device=device)
+
+app = FastAPI()
+
+# Add CORS middleware to allow cross-origin requests
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # Allow all origins in development
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# Connection manager to handle multiple clients
+class ConnectionManager:
+ def __init__(self):
+ self.active_connections: List[WebSocket] = []
+
+ async def connect(self, websocket: WebSocket):
+ await websocket.accept()
+ self.active_connections.append(websocket)
+
+ def disconnect(self, websocket: WebSocket):
+ self.active_connections.remove(websocket)
+
+manager = ConnectionManager()
+
+
+# Helper function to convert audio data
+async def decode_audio_data(audio_data: str) -> torch.Tensor:
+ """Decode base64 audio data to a torch tensor"""
+ try:
+ # Decode base64 audio data
+ binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
+
+ # Save to a temporary WAV file first
+ temp_file = BytesIO(binary_data)
+
+ # Load audio from binary data, explicitly specifying the format
+ audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav")
+
+ # Resample if needed
+ if sample_rate != generator.sample_rate:
+ audio_tensor = torchaudio.functional.resample(
+ audio_tensor.squeeze(0),
+ orig_freq=sample_rate,
+ new_freq=generator.sample_rate
+ )
+ else:
+ audio_tensor = audio_tensor.squeeze(0)
+
+ return audio_tensor
+ except Exception as e:
+ print(f"Error decoding audio: {str(e)}")
+ # Return a small silent audio segment as fallback
+ return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence
+
+
+async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
+ """Encode torch tensor audio to base64 string"""
+ buf = BytesIO()
+ torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
+ buf.seek(0)
+ audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
+ return f"data:audio/wav;base64,{audio_base64}"
+
+
+@app.websocket("/ws")
+async def websocket_endpoint(websocket: WebSocket):
+ await manager.connect(websocket)
+ context_segments = [] # Store conversation context
+
+ try:
+ while True:
+ # Receive JSON data from client
+ data = await websocket.receive_text()
+ request = json.loads(data)
+
+ action = request.get("action")
+
+ if action == "generate":
+ try:
+ text = request.get("text", "")
+ speaker_id = request.get("speaker", 0)
+
+ # Generate audio response
+ print(f"Generating audio for: '{text}' with speaker {speaker_id}")
+ audio_tensor = generator.generate(
+ text=text,
+ speaker=speaker_id,
+ context=context_segments,
+ max_audio_length_ms=10_000,
+ )
+
+ # Add to conversation context
+ context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
+
+ # Convert audio to base64 and send back to client
+ audio_base64 = await encode_audio_data(audio_tensor)
+ await websocket.send_json({
+ "type": "audio_response",
+ "audio": audio_base64
+ })
+ except Exception as e:
+ print(f"Error generating audio: {str(e)}")
+ await websocket.send_json({
+ "type": "error",
+ "message": f"Error generating audio: {str(e)}"
+ })
+
+ elif action == "add_to_context":
+ try:
+ text = request.get("text", "")
+ speaker_id = request.get("speaker", 0)
+ audio_data = request.get("audio", "")
+
+ # Convert received audio to tensor
+ audio_tensor = await decode_audio_data(audio_data)
+
+ # Add to conversation context
+ context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
+
+ await websocket.send_json({
+ "type": "context_updated",
+ "message": "Audio added to context"
+ })
+ except Exception as e:
+ print(f"Error adding to context: {str(e)}")
+ await websocket.send_json({
+ "type": "error",
+ "message": f"Error processing audio: {str(e)}"
+ })
+
+ elif action == "clear_context":
+ context_segments = []
+ await websocket.send_json({
+ "type": "context_updated",
+ "message": "Context cleared"
+ })
+
+ except WebSocketDisconnect:
+ manager.disconnect(websocket)
+ print("Client disconnected")
+ except Exception as e:
+ print(f"Error: {str(e)}")
+ await websocket.send_json({
+ "type": "error",
+ "message": str(e)
+ })
+ manager.disconnect(websocket)
+
+
+if __name__ == "__main__":
+ uvicorn.run(app, host="localhost", port=8000)
\ No newline at end of file
diff --git a/Backend/setup.py b/Backend/setup.py
new file mode 100644
index 0000000..8eddb95
--- /dev/null
+++ b/Backend/setup.py
@@ -0,0 +1,13 @@
+from setuptools import setup, find_packages
+import os
+
+# Read requirements from requirements.txt
+with open('requirements.txt') as f:
+ requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
+
+setup(
+ name='csm',
+ version='0.1.0',
+ packages=find_packages(),
+ install_requires=requirements,
+)
diff --git a/Backend/test.py b/Backend/test.py
new file mode 100644
index 0000000..34735b1
--- /dev/null
+++ b/Backend/test.py
@@ -0,0 +1,50 @@
+import os
+import torch
+import torchaudio
+from huggingface_hub import hf_hub_download
+from generator import load_csm_1b, Segment
+from dataclasses import dataclass
+
+if torch.backends.mps.is_available():
+ device = "mps"
+elif torch.cuda.is_available():
+ device = "cuda"
+else:
+ device = "cpu"
+
+generator = load_csm_1b(device=device)
+
+speakers = [0, 1, 0, 0]
+transcripts = [
+ "Hey how are you doing.",
+ "Pretty good, pretty good.",
+ "I'm great.",
+ "So happy to be speaking to you.",
+]
+audio_paths = [
+ "utterance_0.wav",
+ "utterance_1.wav",
+ "utterance_2.wav",
+ "utterance_3.wav",
+]
+
+def load_audio(audio_path):
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
+ audio_tensor = torchaudio.functional.resample(
+ audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
+ )
+ return audio_tensor
+
+segments = [
+ Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
+ for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
+]
+
+audio = generator.generate(
+ text="Me too, this is some cool stuff huh?",
+ speaker=1,
+ context=segments,
+ max_audio_length_ms=10_000,
+)
+
+torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
\ No newline at end of file
diff --git a/Backend/watermarking.py b/Backend/watermarking.py
new file mode 100644
index 0000000..093962f
--- /dev/null
+++ b/Backend/watermarking.py
@@ -0,0 +1,79 @@
+import argparse
+
+import silentcipher
+import torch
+import torchaudio
+
+# This watermark key is public, it is not secure.
+# If using CSM 1B in another application, use a new private key and keep it secret.
+CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
+
+
+def cli_check_audio() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--audio_path", type=str, required=True)
+ args = parser.parse_args()
+
+ check_audio_from_file(args.audio_path)
+
+
+def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
+ model = silentcipher.get_model(
+ model_type="44.1k",
+ device=device,
+ )
+ return model
+
+
+@torch.inference_mode()
+def watermark(
+ watermarker: silentcipher.server.Model,
+ audio_array: torch.Tensor,
+ sample_rate: int,
+ watermark_key: list[int],
+) -> tuple[torch.Tensor, int]:
+ audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
+ encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
+
+ output_sample_rate = min(44100, sample_rate)
+ encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
+ return encoded, output_sample_rate
+
+
+@torch.inference_mode()
+def verify(
+ watermarker: silentcipher.server.Model,
+ watermarked_audio: torch.Tensor,
+ sample_rate: int,
+ watermark_key: list[int],
+) -> bool:
+ watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
+ result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
+
+ is_watermarked = result["status"]
+ if is_watermarked:
+ is_csm_watermarked = result["messages"][0] == watermark_key
+ else:
+ is_csm_watermarked = False
+
+ return is_watermarked and is_csm_watermarked
+
+
+def check_audio_from_file(audio_path: str) -> None:
+ watermarker = load_watermarker(device="cuda")
+
+ audio_array, sample_rate = load_audio(audio_path)
+ is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
+
+ outcome = "Watermarked" if is_watermarked else "Not watermarked"
+ print(f"{outcome}: {audio_path}")
+
+
+def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
+ audio_array, sample_rate = torchaudio.load(audio_path)
+ audio_array = audio_array.mean(dim=0)
+ return audio_array, int(sample_rate)
+
+
+if __name__ == "__main__":
+ cli_check_audio()