Frontend Fixed
This commit is contained in:
136
Backend/api/app.py
Normal file
136
Backend/api/app.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask
|
||||
from flask_socketio import SocketIO
|
||||
from flask_cors import CORS
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure device
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
logger.info(f"Using device: {DEVICE}")
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__, static_folder='../', static_url_path='')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||
|
||||
# Global variables for conversation state
|
||||
active_conversations = {}
|
||||
user_queues = {}
|
||||
processing_threads = {}
|
||||
|
||||
# Model storage
|
||||
@dataclass
|
||||
class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
last_language = None
|
||||
|
||||
models = AppModels()
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
from generator import load_csm_1b
|
||||
import whisperx
|
||||
import gc
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
global models
|
||||
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||
|
||||
# CSM 1B loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||
models.generator = load_csm_1b(device=DEVICE)
|
||||
logger.info("CSM 1B model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# WhisperX loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||
# Use WhisperX for better transcription with timestamps
|
||||
# Use compute_type based on device
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
|
||||
# Load the WhisperX model (smaller model for faster processing)
|
||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Llama loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
device_map=DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
|
||||
# Configure all special tokens
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||
|
||||
# Inform the model about the pad token
|
||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||
|
||||
logger.info("Llama 3.2 model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Load models in a background thread
|
||||
threading.Thread(target=load_models, daemon=True).start()
|
||||
|
||||
# Import routes and socket handlers
|
||||
from api.routes import register_routes
|
||||
from api.socket_handlers import register_handlers
|
||||
|
||||
# Register routes and socket handlers
|
||||
register_routes(app)
|
||||
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
||||
|
||||
# Run server if executed directly
|
||||
if __name__ == '__main__':
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||
176
Backend/api/generator.py
Normal file
176
Backend/api/generator.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from models import Model
|
||||
from moshi.models import loaders
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from transformers import AutoTokenizer
|
||||
from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
speaker: int
|
||||
text: str
|
||||
# (num_samples,), sample_rate = 24_000
|
||||
audio: torch.Tensor
|
||||
|
||||
|
||||
def load_llama3_tokenizer():
|
||||
"""
|
||||
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
||||
"""
|
||||
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
bos = tokenizer.bos_token
|
||||
eos = tokenizer.eos_token
|
||||
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
||||
single=f"{bos}:0 $A:0 {eos}:0",
|
||||
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
||||
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
):
|
||||
self._model = model
|
||||
self._model.setup_caches(1)
|
||||
|
||||
self._text_tokenizer = load_llama3_tokenizer()
|
||||
|
||||
device = next(model.parameters()).device
|
||||
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
||||
mimi = loaders.get_mimi(mimi_weight, device=device)
|
||||
mimi.set_num_codebooks(32)
|
||||
self._audio_tokenizer = mimi
|
||||
|
||||
self._watermarker = load_watermarker(device=device)
|
||||
|
||||
self.sample_rate = mimi.sample_rate
|
||||
self.device = device
|
||||
|
||||
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
frame_tokens = []
|
||||
frame_masks = []
|
||||
|
||||
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
||||
text_frame = torch.zeros(len(text_tokens), 33).long()
|
||||
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
||||
text_frame[:, -1] = torch.tensor(text_tokens)
|
||||
text_frame_mask[:, -1] = True
|
||||
|
||||
frame_tokens.append(text_frame.to(self.device))
|
||||
frame_masks.append(text_frame_mask.to(self.device))
|
||||
|
||||
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||
|
||||
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert audio.ndim == 1, "Audio must be single channel"
|
||||
|
||||
frame_tokens = []
|
||||
frame_masks = []
|
||||
|
||||
# (K, T)
|
||||
audio = audio.to(self.device)
|
||||
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
||||
# add EOS frame
|
||||
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
||||
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
||||
|
||||
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
||||
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
||||
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
||||
audio_frame_mask[:, :-1] = True
|
||||
|
||||
frame_tokens.append(audio_frame)
|
||||
frame_masks.append(audio_frame_mask)
|
||||
|
||||
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||
|
||||
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(seq_len, 33), (seq_len, 33)
|
||||
"""
|
||||
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
||||
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
||||
|
||||
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
speaker: int,
|
||||
context: List[Segment],
|
||||
max_audio_length_ms: float = 90_000,
|
||||
temperature: float = 0.9,
|
||||
topk: int = 50,
|
||||
) -> torch.Tensor:
|
||||
self._model.reset_caches()
|
||||
|
||||
max_generation_len = int(max_audio_length_ms / 80)
|
||||
tokens, tokens_mask = [], []
|
||||
for segment in context:
|
||||
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
||||
tokens.append(segment_tokens)
|
||||
tokens_mask.append(segment_tokens_mask)
|
||||
|
||||
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
|
||||
tokens.append(gen_segment_tokens)
|
||||
tokens_mask.append(gen_segment_tokens_mask)
|
||||
|
||||
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
||||
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
||||
|
||||
samples = []
|
||||
curr_tokens = prompt_tokens.unsqueeze(0)
|
||||
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
||||
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
||||
|
||||
max_seq_len = 2048
|
||||
max_context_len = max_seq_len - max_generation_len
|
||||
if curr_tokens.size(1) >= max_context_len:
|
||||
raise ValueError(
|
||||
f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
|
||||
)
|
||||
|
||||
for _ in range(max_generation_len):
|
||||
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
||||
if torch.all(sample == 0):
|
||||
break # eos
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
||||
curr_tokens_mask = torch.cat(
|
||||
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
||||
).unsqueeze(1)
|
||||
curr_pos = curr_pos[:, -1:] + 1
|
||||
|
||||
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
||||
|
||||
# This applies an imperceptible watermark to identify audio as AI-generated.
|
||||
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
||||
# Please be a responsible AI citizen and keep the watermarking in place.
|
||||
# If using CSM 1B in another application, use your own private key and keep it secret.
|
||||
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
||||
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def load_csm_1b(device: str = "cuda") -> Generator:
|
||||
model = Model.from_pretrained("sesame/csm-1b")
|
||||
model.to(device=device, dtype=torch.bfloat16)
|
||||
|
||||
generator = Generator(model)
|
||||
return generator
|
||||
203
Backend/api/models.py
Normal file
203
Backend/api/models.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchtune
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torchtune.models import llama3_2
|
||||
|
||||
|
||||
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=16,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
embed_dim=2048,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=4,
|
||||
num_heads=8,
|
||||
num_kv_heads=2,
|
||||
embed_dim=1024,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
FLAVORS = {
|
||||
"llama-1B": llama3_2_1B,
|
||||
"llama-100M": llama3_2_100M,
|
||||
}
|
||||
|
||||
|
||||
def _prepare_transformer(model):
|
||||
embed_dim = model.tok_embeddings.embedding_dim
|
||||
model.tok_embeddings = nn.Identity()
|
||||
model.output = nn.Identity()
|
||||
return model, embed_dim
|
||||
|
||||
|
||||
def _create_causal_mask(seq_len: int, device: torch.device):
|
||||
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
||||
|
||||
|
||||
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
mask: (max_seq_len, max_seq_len)
|
||||
input_pos: (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
(batch_size, seq_len, max_seq_len)
|
||||
"""
|
||||
r = mask[input_pos, :]
|
||||
return r
|
||||
|
||||
|
||||
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs).exponential_(1)
|
||||
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
||||
logits = logits / temperature
|
||||
|
||||
filter_value: float = -float("Inf")
|
||||
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
||||
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
||||
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
||||
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
||||
|
||||
sample_token = _multinomial_sample_one_no_sync(probs)
|
||||
return sample_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
backbone_flavor: str
|
||||
decoder_flavor: str
|
||||
text_vocab_size: int
|
||||
audio_vocab_size: int
|
||||
audio_num_codebooks: int
|
||||
|
||||
|
||||
class Model(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/SesameAILabs/csm",
|
||||
pipeline_tag="text-to-speech",
|
||||
license="apache-2.0",
|
||||
):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
||||
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
||||
|
||||
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
||||
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
||||
|
||||
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
||||
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
||||
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
|
||||
|
||||
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
||||
"""Setup KV caches and return a causal mask."""
|
||||
dtype = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
|
||||
with device:
|
||||
self.backbone.setup_caches(max_batch_size, dtype)
|
||||
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
||||
|
||||
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
||||
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
||||
|
||||
def generate_frame(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
tokens_mask: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
temperature: float,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
input_pos: (batch_size, seq_len) positions for each token
|
||||
mask: (batch_size, seq_len, max_seq_len
|
||||
|
||||
Returns:
|
||||
(batch_size, audio_num_codebooks) sampled tokens
|
||||
"""
|
||||
dtype = next(self.parameters()).dtype
|
||||
b, s, _ = tokens.size()
|
||||
|
||||
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
||||
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
||||
embeds = self._embed_tokens(tokens)
|
||||
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
||||
h = masked_embeds.sum(dim=2)
|
||||
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
||||
|
||||
last_h = h[:, -1, :]
|
||||
c0_logits = self.codebook0_head(last_h)
|
||||
c0_sample = sample_topk(c0_logits, topk, temperature)
|
||||
c0_embed = self._embed_audio(0, c0_sample)
|
||||
|
||||
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
||||
curr_sample = c0_sample.clone()
|
||||
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
||||
|
||||
# Decoder caches must be reset every frame.
|
||||
self.decoder.reset_caches()
|
||||
for i in range(1, self.config.audio_num_codebooks):
|
||||
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
||||
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
||||
dtype=dtype
|
||||
)
|
||||
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
||||
ci_sample = sample_topk(ci_logits, topk, temperature)
|
||||
ci_embed = self._embed_audio(i, ci_sample)
|
||||
|
||||
curr_h = ci_embed
|
||||
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
||||
curr_pos = curr_pos[:, -1:] + 1
|
||||
|
||||
return curr_sample
|
||||
|
||||
def reset_caches(self):
|
||||
self.backbone.reset_caches()
|
||||
self.decoder.reset_caches()
|
||||
|
||||
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
||||
|
||||
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
||||
|
||||
audio_tokens = tokens[:, :, :-1] + (
|
||||
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
||||
)
|
||||
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
||||
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
||||
)
|
||||
|
||||
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
||||
74
Backend/api/routes.py
Normal file
74
Backend/api/routes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import torch
|
||||
import psutil
|
||||
from flask import send_from_directory, jsonify, request
|
||||
|
||||
def register_routes(app):
|
||||
"""Register HTTP routes for the application"""
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""Serve the main application page"""
|
||||
return send_from_directory(app.static_folder, 'index.html')
|
||||
|
||||
@app.route('/voice-chat.js')
|
||||
def serve_js():
|
||||
"""Serve the JavaScript file"""
|
||||
return send_from_directory(app.static_folder, 'voice-chat.js')
|
||||
|
||||
@app.route('/api/status')
|
||||
def system_status():
|
||||
"""Return the system status"""
|
||||
# Import here to avoid circular imports
|
||||
from api.app import models, DEVICE
|
||||
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"generator": models.generator is not None,
|
||||
"asr": models.whisperx_model is not None,
|
||||
"llm": models.llm is not None
|
||||
},
|
||||
"versions": {
|
||||
"transformers": "4.49.0", # Replace with actual version
|
||||
"torch": torch.__version__
|
||||
}
|
||||
})
|
||||
|
||||
@app.route('/api/system_resources')
|
||||
def system_resources():
|
||||
"""Return system resource usage"""
|
||||
# Import here to avoid circular imports
|
||||
from api.app import active_conversations, DEVICE
|
||||
|
||||
# Get CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
# Get memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
memory_used_gb = memory.used / (1024 ** 3)
|
||||
memory_total_gb = memory.total / (1024 ** 3)
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Get GPU memory if available
|
||||
gpu_memory = {}
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_memory[f"gpu_{i}"] = {
|
||||
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
||||
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory": {
|
||||
"used_gb": memory_used_gb,
|
||||
"total_gb": memory_total_gb,
|
||||
"percent": memory_percent
|
||||
},
|
||||
"gpu_memory": gpu_memory,
|
||||
"active_sessions": len(active_conversations)
|
||||
})
|
||||
392
Backend/api/socket_handlers.py
Normal file
392
Backend/api/socket_handlers.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
import gc
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from flask import request
|
||||
from flask_socketio import emit
|
||||
|
||||
# Import conversation model
|
||||
from generator import Segment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Conversation data structure
|
||||
class Conversation:
|
||||
def __init__(self, session_id):
|
||||
self.session_id = session_id
|
||||
self.segments: List[Segment] = []
|
||||
self.current_speaker = 0
|
||||
self.ai_speaker_id = 1 # Default AI speaker ID
|
||||
self.last_activity = time.time()
|
||||
self.is_processing = False
|
||||
|
||||
def add_segment(self, text, speaker, audio):
|
||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||
self.segments.append(segment)
|
||||
self.last_activity = time.time()
|
||||
return segment
|
||||
|
||||
def get_context(self, max_segments=10):
|
||||
"""Return the most recent segments for context"""
|
||||
return self.segments[-max_segments:] if self.segments else []
|
||||
|
||||
def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
|
||||
"""Register Socket.IO event handlers"""
|
||||
|
||||
@socketio.on('connect')
|
||||
def handle_connect(auth=None):
|
||||
"""Handle client connection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client connected: {session_id}")
|
||||
|
||||
# Initialize conversation data
|
||||
if session_id not in active_conversations:
|
||||
active_conversations[session_id] = Conversation(session_id)
|
||||
user_queues[session_id] = queue.Queue()
|
||||
processing_threads[session_id] = threading.Thread(
|
||||
target=process_audio_queue,
|
||||
args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
|
||||
daemon=True
|
||||
)
|
||||
processing_threads[session_id].start()
|
||||
|
||||
emit('connection_status', {'status': 'connected'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect(reason=None):
|
||||
"""Handle client disconnection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
||||
|
||||
# Cleanup
|
||||
if session_id in active_conversations:
|
||||
# Mark for deletion rather than immediately removing
|
||||
# as the processing thread might still be accessing it
|
||||
active_conversations[session_id].is_processing = False
|
||||
user_queues[session_id].put(None) # Signal thread to terminate
|
||||
|
||||
@socketio.on('audio_data')
|
||||
def handle_audio_data(data):
|
||||
"""Handle incoming audio data"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Received audio data from {session_id}")
|
||||
|
||||
# Check if the models are loaded
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
emit('error', {'message': 'Models still loading, please wait'})
|
||||
return
|
||||
|
||||
# Check if we're already processing for this session
|
||||
if session_id in active_conversations and active_conversations[session_id].is_processing:
|
||||
emit('error', {'message': 'Still processing previous audio, please wait'})
|
||||
return
|
||||
|
||||
# Add to processing queue
|
||||
if session_id in user_queues:
|
||||
user_queues[session_id].put(data)
|
||||
else:
|
||||
emit('error', {'message': 'Session not initialized, please refresh the page'})
|
||||
|
||||
def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Background thread to process audio chunks for a session"""
|
||||
logger.info(f"Started processing thread for session: {session_id}")
|
||||
|
||||
try:
|
||||
while session_id in active_conversations:
|
||||
try:
|
||||
# Get the next audio chunk with a timeout
|
||||
data = q.get(timeout=120)
|
||||
if data is None: # Termination signal
|
||||
break
|
||||
|
||||
# Process the audio and generate a response
|
||||
process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
|
||||
|
||||
except queue.Empty:
|
||||
# Timeout, check if session is still valid
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||
# Create an app context for the socket emit
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||
finally:
|
||||
logger.info(f"Ending processing thread for session: {session_id}")
|
||||
# Clean up when thread is done
|
||||
with app.app_context():
|
||||
if session_id in active_conversations:
|
||||
del active_conversations[session_id]
|
||||
if session_id in user_queues:
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Processing audio for session {session_id}")
|
||||
conversation = active_conversations[session_id]
|
||||
|
||||
try:
|
||||
# Set processing flag
|
||||
conversation.is_processing = True
|
||||
|
||||
# Process base64 audio data
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
logger.info(f"Received audio from speaker {speaker_id}")
|
||||
|
||||
# Convert from base64 to WAV
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||
raise
|
||||
|
||||
# Save to temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||
temp_file.write(audio_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Notify client that transcription is starting
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio using WhisperX
|
||||
import whisperx
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
|
||||
# Check audio length and add a warning for short clips
|
||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||
if audio_length < 1.0:
|
||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # adjust based on your GPU memory
|
||||
logger.info("Running WhisperX transcription...")
|
||||
|
||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Check if alignment model needs to be loaded or updated
|
||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||
# Clean up old models if they exist
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
if DEVICE == "cuda":
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load new alignment model for the detected language
|
||||
logger.info(f"Loading alignment model for language: {language_code}")
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript to get word-level timestamps
|
||||
if result["segments"] and len(result["segments"]) > 0:
|
||||
logger.info("Aligning transcript...")
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
)
|
||||
|
||||
# Process the segments for better output
|
||||
for segment in result["segments"]:
|
||||
# Round timestamps for better display
|
||||
segment["start"] = round(segment["start"], 2)
|
||||
segment["end"] = round(segment["end"], 2)
|
||||
# Add a confidence score if not present
|
||||
if "confidence" not in segment:
|
||||
segment["confidence"] = 1.0 # Default confidence
|
||||
|
||||
# Extract the full text from all segments
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Transcription: {user_text}")
|
||||
|
||||
# Load audio for CSM input
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample to the CSM sample rate if needed
|
||||
if sample_rate != models.generator.sample_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
speaker=speaker_id,
|
||||
audio=waveform.squeeze()
|
||||
)
|
||||
|
||||
# Send transcription to client with detailed segments
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||
|
||||
# Create prompt from conversation history
|
||||
conversation_history = ""
|
||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||
role = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{role}: {segment.text}\n"
|
||||
|
||||
# Add final prompt
|
||||
prompt = f"{conversation_history}Assistant: "
|
||||
|
||||
# Generate response with Llama
|
||||
try:
|
||||
# Ensure pad token is set
|
||||
if models.tokenizer.pad_token is None:
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
|
||||
input_tokens = models.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=models.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the response
|
||||
response_text = models.tokenizer.decode(
|
||||
generated_ids[0][input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||
|
||||
# Synthesize speech
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||
|
||||
# Start sending the audio response
|
||||
socketio.emit('audio_response_start', {
|
||||
'text': response_text,
|
||||
'total_chunks': 1,
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Define AI speaker ID
|
||||
ai_speaker_id = conversation.ai_speaker_id
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
# Convert audio to WAV format
|
||||
with io.BytesIO() as wav_io:
|
||||
torchaudio.save(
|
||||
wav_io,
|
||||
audio_tensor.unsqueeze(0).cpu(),
|
||||
models.generator.sample_rate,
|
||||
format="wav"
|
||||
)
|
||||
wav_io.seek(0)
|
||||
wav_data = wav_io.read()
|
||||
|
||||
# Convert WAV data to base64
|
||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||
|
||||
# Send audio chunk to client
|
||||
with app.app_context():
|
||||
socketio.emit('audio_response_chunk', {
|
||||
'chunk': audio_base64,
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'is_last': True
|
||||
}, room=session_id)
|
||||
|
||||
# Signal completion
|
||||
socketio.emit('audio_response_complete', {
|
||||
'text': response_text
|
||||
}, room=session_id)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
# Reset processing flag
|
||||
conversation.is_processing = False
|
||||
79
Backend/api/watermarking.py
Normal file
79
Backend/api/watermarking.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import argparse
|
||||
|
||||
import silentcipher
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# This watermark key is public, it is not secure.
|
||||
# If using CSM 1B in another application, use a new private key and keep it secret.
|
||||
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
|
||||
|
||||
|
||||
def cli_check_audio() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--audio_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_audio_from_file(args.audio_path)
|
||||
|
||||
|
||||
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
|
||||
model = silentcipher.get_model(
|
||||
model_type="44.1k",
|
||||
device=device,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def watermark(
|
||||
watermarker: silentcipher.server.Model,
|
||||
audio_array: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
|
||||
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
|
||||
|
||||
output_sample_rate = min(44100, sample_rate)
|
||||
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
|
||||
return encoded, output_sample_rate
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def verify(
|
||||
watermarker: silentcipher.server.Model,
|
||||
watermarked_audio: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> bool:
|
||||
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
|
||||
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
|
||||
|
||||
is_watermarked = result["status"]
|
||||
if is_watermarked:
|
||||
is_csm_watermarked = result["messages"][0] == watermark_key
|
||||
else:
|
||||
is_csm_watermarked = False
|
||||
|
||||
return is_watermarked and is_csm_watermarked
|
||||
|
||||
|
||||
def check_audio_from_file(audio_path: str) -> None:
|
||||
watermarker = load_watermarker(device="cuda")
|
||||
|
||||
audio_array, sample_rate = load_audio(audio_path)
|
||||
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
|
||||
|
||||
outcome = "Watermarked" if is_watermarked else "Not watermarked"
|
||||
print(f"{outcome}: {audio_path}")
|
||||
|
||||
|
||||
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
|
||||
audio_array, sample_rate = torchaudio.load(audio_path)
|
||||
audio_array = audio_array.mean(dim=0)
|
||||
return audio_array, int(sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_check_audio()
|
||||
Reference in New Issue
Block a user