This commit is contained in:
BGV
2025-03-30 01:46:36 -04:00
28 changed files with 2630 additions and 809 deletions

46
Backend/.gitignore vendored
View File

@@ -1,46 +0,0 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual Environment
.env
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# Project specific
.python-version
*.wav
output_*/
basic_audio.wav
full_conversation.wav
context_audio.wav
# Model files
*.pt
*.ckpt

View File

@@ -1,71 +0,0 @@
# csm-conversation-bot
## Overview
The CSM Conversation Bot is an application that utilizes advanced audio processing and language model technologies to facilitate real-time voice conversations with an AI assistant. The bot processes audio streams, converts spoken input into text, generates responses using the Llama 3.2 model, and converts the text back into audio for seamless interaction.
## Project Structure
```
csm-conversation-bot
├── api
│ ├── app.py # Main entry point for the API
│ ├── routes.py # Defines API routes
│ └── socket_handlers.py # Manages Socket.IO events
├── src
│ ├── audio
│ │ ├── processor.py # Audio processing functions
│ │ └── streaming.py # Audio streaming management
│ ├── llm
│ │ ├── generator.py # Response generation using Llama 3.2
│ │ └── tokenizer.py # Text tokenization functions
│ ├── models
│ │ ├── audio_model.py # Audio processing model
│ │ └── conversation.py # Conversation state management
│ ├── services
│ │ ├── transcription_service.py # Audio to text conversion
│ │ └── tts_service.py # Text to speech conversion
│ └── utils
│ ├── config.py # Configuration settings
│ └── logger.py # Logging utilities
├── static
│ ├── css
│ │ └── styles.css # CSS styles for the web interface
│ ├── js
│ │ └── client.js # Client-side JavaScript
│ └── index.html # Main HTML file for the web interface
├── templates
│ └── index.html # Template for rendering the main HTML page
├── config.py # Main configuration settings
├── requirements.txt # Python dependencies
├── server.py # Entry point for running the application
└── README.md # Documentation for the project
```
## Installation
1. Clone the repository:
```
git clone https://github.com/yourusername/csm-conversation-bot.git
cd csm-conversation-bot
```
2. Install the required dependencies:
```
pip install -r requirements.txt
```
3. Configure the application settings in `config.py` as needed.
## Usage
1. Start the server:
```
python server.py
```
2. Open your web browser and navigate to `http://localhost:5000` to access the application.
3. Use the interface to start a conversation with the AI assistant.
## Contributing
Contributions are welcome! Please submit a pull request or open an issue for any enhancements or bug fixes.
## License
This project is licensed under the MIT License. See the LICENSE file for more details.

View File

@@ -1,22 +0,0 @@
from flask import Flask
from flask_socketio import SocketIO
from src.utils.config import Config
from src.utils.logger import setup_logger
from api.routes import setup_routes
from api.socket_handlers import setup_socket_handlers
def create_app():
app = Flask(__name__)
app.config.from_object(Config)
setup_logger(app)
setup_routes(app)
setup_socket_handlers(app)
return app
app = create_app()
socketio = SocketIO(app)
if __name__ == "__main__":
socketio.run(app, host='0.0.0.0', port=5000)

View File

@@ -1,29 +0,0 @@
from flask import Blueprint, request, jsonify
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
api = Blueprint('api', __name__)
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
@api.route('/transcribe', methods=['POST'])
def transcribe_audio():
audio_data = request.files.get('audio')
if not audio_data:
return jsonify({'error': 'No audio file provided'}), 400
text = transcription_service.transcribe(audio_data)
return jsonify({'transcription': text})
@api.route('/generate-response', methods=['POST'])
def generate_response():
data = request.json
user_input = data.get('input')
if not user_input:
return jsonify({'error': 'No input provided'}), 400
response_text = tts_service.generate_response(user_input)
audio_data = tts_service.text_to_speech(response_text)
return jsonify({'response': response_text, 'audio': audio_data})

View File

@@ -1,32 +0,0 @@
from flask import request
from flask_socketio import SocketIO, emit
from src.audio.processor import process_audio
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
from src.llm.generator import load_csm_1b
socketio = SocketIO()
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
generator = load_csm_1b()
@socketio.on('audio_stream')
def handle_audio_stream(data):
audio_data = data['audio']
speaker_id = data['speaker']
# Process the incoming audio
processed_audio = process_audio(audio_data)
# Transcribe the audio to text
transcription = transcription_service.transcribe(processed_audio)
# Generate a response using the LLM
response_text = generator.generate(text=transcription, speaker=speaker_id)
# Convert the response text back to audio
response_audio = tts_service.convert_text_to_speech(response_text)
# Emit the response audio back to the client
emit('audio_response', {'audio': response_audio, 'speaker': speaker_id})

View File

@@ -1,13 +0,0 @@
from pathlib import Path
class Config:
def __init__(self):
self.MODEL_PATH = Path("path/to/your/model")
self.AUDIO_MODEL_PATH = Path("path/to/your/audio/model")
self.WATERMARK_KEY = "your_watermark_key"
self.SOCKETIO_CORS = "*"
self.API_KEY = "your_api_key"
self.DEBUG = True
self.LOGGING_LEVEL = "INFO"
self.TTS_SERVICE_URL = "http://localhost:5001/tts"
self.TRANSCRIPTION_SERVICE_URL = "http://localhost:5002/transcribe"

View File

@@ -15,10 +15,14 @@ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
class Segment: class Segment:
speaker: int speaker: int
text: str text: str
# (num_samples,), sample_rate = 24_000
audio: torch.Tensor audio: torch.Tensor
def load_llama3_tokenizer(): def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B" tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token bos = tokenizer.bos_token
@@ -74,8 +78,10 @@ class Generator:
frame_tokens = [] frame_tokens = []
frame_masks = [] frame_masks = []
# (K, T)
audio = audio.to(self.device) audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
# add EOS frame
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
@@ -90,6 +96,10 @@ class Generator:
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
(seq_len, 33), (seq_len, 33)
"""
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio) audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
@@ -136,7 +146,7 @@ class Generator:
for _ in range(max_generation_len): for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
if torch.all(sample == 0): if torch.all(sample == 0):
break break # eos
samples.append(sample) samples.append(sample)
@@ -148,6 +158,10 @@ class Generator:
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0) audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
# This applies an imperceptible watermark to identify audio as AI-generated.
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
# Please be a responsible AI citizen and keep the watermarking in place.
# If using CSM 1B in another application, use your own private key and keep it secret.
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)

711
Backend/index.html Normal file
View File

@@ -0,0 +1,711 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CSM Voice Chat</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
<!-- Socket.IO client library -->
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<style>
:root {
--primary-color: #4c84ff;
--secondary-color: #3367d6;
--text-color: #333;
--background-color: #f9f9f9;
--card-background: #ffffff;
--accent-color: #ff5252;
--success-color: #4CAF50;
--border-color: #e0e0e0;
--shadow-color: rgba(0, 0, 0, 0.1);
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: var(--background-color);
color: var(--text-color);
line-height: 1.6;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
transition: all 0.3s ease;
}
header {
text-align: center;
margin-bottom: 30px;
}
h1 {
color: var(--primary-color);
font-size: 2.5rem;
margin-bottom: 10px;
}
.subtitle {
color: #666;
font-weight: 300;
}
.app-container {
display: grid;
grid-template-columns: 1fr;
gap: 20px;
}
@media (min-width: 768px) {
.app-container {
grid-template-columns: 2fr 1fr;
}
}
.chat-container, .control-panel {
background-color: var(--card-background);
border-radius: 12px;
box-shadow: 0 4px 12px var(--shadow-color);
padding: 20px;
}
.control-panel {
display: flex;
flex-direction: column;
gap: 20px;
}
.chat-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 1px solid var(--border-color);
}
.conversation {
height: 400px;
overflow-y: auto;
padding: 10px;
border-radius: 8px;
background-color: #f7f9fc;
margin-bottom: 20px;
scroll-behavior: smooth;
}
.message {
margin-bottom: 15px;
padding: 12px 15px;
border-radius: 12px;
max-width: 85%;
position: relative;
animation: fade-in 0.3s ease-out forwards;
}
@keyframes fade-in {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
}
.user {
background-color: #e3f2fd;
color: #0d47a1;
margin-left: auto;
border-bottom-right-radius: 4px;
}
.ai {
background-color: #f1f1f1;
color: #37474f;
margin-right: auto;
border-bottom-left-radius: 4px;
}
.system {
background-color: #f8f9fa;
font-style: italic;
color: #666;
text-align: center;
max-width: 90%;
margin: 10px auto;
font-size: 0.9em;
padding: 8px 12px;
border-radius: 8px;
}
.audio-player {
width: 100%;
margin-top: 8px;
border-radius: 8px;
}
button {
padding: 12px 20px;
border-radius: 8px;
border: none;
background-color: var(--primary-color);
color: white;
font-weight: 600;
cursor: pointer;
transition: all 0.2s ease;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
flex: 1;
}
button:hover {
background-color: var(--secondary-color);
}
button.recording {
background-color: var(--accent-color);
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.7; }
100% { opacity: 1; }
}
.status-indicator {
display: flex;
align-items: center;
gap: 10px;
font-size: 0.9em;
color: #555;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
background-color: #ccc;
}
.status-dot.active {
background-color: var(--success-color);
}
footer {
margin-top: 30px;
text-align: center;
font-size: 0.8em;
color: #777;
}
</style>
</head>
<body>
<header>
<h1>CSM Voice Chat</h1>
<p class="subtitle">Talk naturally with the AI using your voice</p>
</header>
<div class="app-container">
<div class="chat-container">
<div class="chat-header">
<h2>Conversation</h2>
<div class="status-indicator">
<div id="statusDot" class="status-dot"></div>
<span id="statusText">Disconnected</span>
</div>
</div>
<div id="conversation" class="conversation"></div>
</div>
<div class="control-panel">
<div>
<h3>Controls</h3>
<p>Click the button below to start and stop recording.</p>
<div class="button-row">
<button id="streamButton">
<i class="fas fa-microphone"></i>
Start Conversation
</button>
<button id="clearButton">
<i class="fas fa-trash"></i>
Clear
</button>
</div>
</div>
<div class="settings-panel">
<h3>Settings</h3>
<div class="settings-toggles">
<div class="toggle-switch">
<input type="checkbox" id="autoPlayResponses" checked>
<label for="autoPlayResponses">Autoplay Responses</label>
</div>
<div>
<label for="speakerSelect">Speaker Voice:</label>
<select id="speakerSelect">
<option value="0">Speaker 0 (You)</option>
<option value="1">Speaker 1 (AI)</option>
</select>
</div>
</div>
</div>
</div>
</div>
<footer>
<p>Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition</p>
</footer>
<script>
// Configuration constants
const SERVER_URL = window.location.hostname === 'localhost' ?
'http://localhost:5000' : window.location.origin;
const ENERGY_WINDOW_SIZE = 15;
const CLIENT_SILENCE_DURATION_MS = 750;
// DOM elements
const elements = {
conversation: document.getElementById('conversation'),
streamButton: document.getElementById('streamButton'),
clearButton: document.getElementById('clearButton'),
statusDot: document.getElementById('statusDot'),
statusText: document.getElementById('statusText'),
speakerSelection: document.getElementById('speakerSelect'),
autoPlayResponses: document.getElementById('autoPlayResponses')
};
// Application state
const state = {
socket: null,
audioContext: null,
analyser: null,
microphone: null,
streamProcessor: null,
isStreaming: false,
isSpeaking: false,
silenceThreshold: 0.01,
energyWindow: [],
silenceTimer: null,
currentSpeaker: 0
};
// Initialize the application
function initializeApp() {
// Initialize socket.io connection
setupSocketConnection();
// Setup event listeners
setupEventListeners();
// Show welcome message
addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.');
}
// Setup Socket.IO connection
function setupSocketConnection() {
state.socket = io(SERVER_URL);
// Connection events
state.socket.on('connect', () => {
updateConnectionStatus(true);
addSystemMessage('Connected to server.');
});
state.socket.on('disconnect', () => {
updateConnectionStatus(false);
addSystemMessage('Disconnected from server.');
stopStreaming(false);
});
state.socket.on('error', (data) => {
addSystemMessage(`Error: ${data.message}`);
console.error('Server error:', data.message);
});
// Register message handlers
state.socket.on('transcription', handleTranscription);
state.socket.on('context_updated', handleContextUpdate);
state.socket.on('streaming_status', handleStreamingStatus);
state.socket.on('processing_status', handleProcessingStatus);
// Handlers for incremental audio streaming
state.socket.on('audio_response_start', handleAudioResponseStart);
state.socket.on('audio_response_chunk', handleAudioResponseChunk);
state.socket.on('audio_response_complete', handleAudioResponseComplete);
}
// Setup event listeners
function setupEventListeners() {
// Stream button
elements.streamButton.addEventListener('click', toggleStreaming);
// Clear button
elements.clearButton.addEventListener('click', clearConversation);
// Speaker selection
elements.speakerSelection.addEventListener('change', () => {
state.currentSpeaker = parseInt(elements.speakerSelection.value);
});
}
// Update connection status UI
function updateConnectionStatus(isConnected) {
if (isConnected) {
elements.statusDot.classList.add('active');
elements.statusText.textContent = 'Connected';
} else {
elements.statusDot.classList.remove('active');
elements.statusText.textContent = 'Disconnected';
}
}
// Toggle streaming state
function toggleStreaming() {
if (state.isStreaming) {
stopStreaming();
} else {
startStreaming();
}
}
// Start streaming audio to the server
function startStreaming() {
if (!state.socket || !state.socket.connected) {
addSystemMessage('Not connected to server. Please refresh the page.');
return;
}
// Request microphone access
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
.then(stream => {
state.isStreaming = true;
elements.streamButton.classList.add('recording');
elements.streamButton.innerHTML = '<i class="fas fa-stop"></i> Stop Recording';
// Initialize Web Audio API
state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
state.microphone = state.audioContext.createMediaStreamSource(stream);
state.analyser = state.audioContext.createAnalyser();
state.analyser.fftSize = 1024;
state.microphone.connect(state.analyser);
// Create processor node for audio data
const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1);
processorNode.onaudioprocess = handleAudioProcess;
state.analyser.connect(processorNode);
processorNode.connect(state.audioContext.destination);
state.streamProcessor = processorNode;
state.silenceTimer = null;
state.energyWindow = [];
state.isSpeaking = false;
// Notify server
state.socket.emit('start_stream');
addSystemMessage('Recording started. Speak now...');
})
.catch(error => {
console.error('Error accessing microphone:', error);
addSystemMessage('Could not access microphone. Please check permissions.');
});
}
// Stop streaming audio
function stopStreaming(notifyServer = true) {
if (state.isStreaming) {
state.isStreaming = false;
elements.streamButton.classList.remove('recording');
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
// Clean up audio resources
if (state.streamProcessor) {
state.streamProcessor.disconnect();
state.streamProcessor = null;
}
if (state.analyser) {
state.analyser.disconnect();
state.analyser = null;
}
if (state.microphone) {
state.microphone.disconnect();
state.microphone = null;
}
if (state.audioContext) {
state.audioContext.close();
state.audioContext = null;
}
// Clear any pending silence timer
if (state.silenceTimer) {
clearTimeout(state.silenceTimer);
state.silenceTimer = null;
}
// Notify server if needed
if (notifyServer && state.socket && state.socket.connected) {
state.socket.emit('stop_stream');
}
addSystemMessage('Recording stopped.');
}
}
// Handle audio processing
function handleAudioProcess(event) {
if (!state.isStreaming) return;
const inputData = event.inputBuffer.getChannelData(0);
const energy = calculateAudioEnergy(inputData);
updateEnergyWindow(energy);
const averageEnergy = calculateAverageEnergy();
const isSilent = averageEnergy < state.silenceThreshold;
handleSpeechState(isSilent);
}
// Calculate audio energy (volume)
function calculateAudioEnergy(buffer) {
let sum = 0;
for (let i = 0; i < buffer.length; i++) {
sum += buffer[i] * buffer[i];
}
return Math.sqrt(sum / buffer.length);
}
// Update energy window for averaging
function updateEnergyWindow(energy) {
state.energyWindow.push(energy);
if (state.energyWindow.length > ENERGY_WINDOW_SIZE) {
state.energyWindow.shift();
}
}
// Calculate average energy from window
function calculateAverageEnergy() {
if (state.energyWindow.length === 0) return 0;
const sum = state.energyWindow.reduce((acc, val) => acc + val, 0);
return sum / state.energyWindow.length;
}
// Handle speech/silence state transitions
function handleSpeechState(isSilent) {
if (state.isSpeaking) {
if (isSilent) {
// User was speaking but now is silent
if (!state.silenceTimer) {
state.silenceTimer = setTimeout(() => {
// Silence lasted long enough, consider speech done
if (state.isSpeaking) {
state.isSpeaking = false;
// Get the current audio data and send it
const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
state.analyser.getFloatTimeDomainData(audioBuffer);
// Create WAV blob
const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
// Convert to base64
const reader = new FileReader();
reader.onloadend = function() {
sendAudioChunk(reader.result, state.currentSpeaker);
};
reader.readAsDataURL(wavBlob);
addSystemMessage('Processing your message...');
}
}, CLIENT_SILENCE_DURATION_MS);
}
} else {
// User is still speaking, reset silence timer
if (state.silenceTimer) {
clearTimeout(state.silenceTimer);
state.silenceTimer = null;
}
}
} else {
if (!isSilent) {
// User started speaking
state.isSpeaking = true;
if (state.silenceTimer) {
clearTimeout(state.silenceTimer);
state.silenceTimer = null;
}
}
}
}
// Send audio chunk to server
function sendAudioChunk(audioData, speaker) {
if (state.socket && state.socket.connected) {
state.socket.emit('audio_chunk', {
audio: audioData,
speaker: speaker
});
}
}
// Create WAV blob from audio data
function createWavBlob(audioData, sampleRate) {
const numChannels = 1;
const bitsPerSample = 16;
const bytesPerSample = bitsPerSample / 8;
// Create buffer for WAV file
const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample);
const view = new DataView(buffer);
// Write WAV header
// "RIFF" chunk descriptor
writeString(view, 0, 'RIFF');
view.setUint32(4, 36 + audioData.length * bytesPerSample, true);
writeString(view, 8, 'WAVE');
// "fmt " sub-chunk
writeString(view, 12, 'fmt ');
view.setUint32(16, 16, true); // subchunk1size
view.setUint16(20, 1, true); // audio format (PCM)
view.setUint16(22, numChannels, true);
view.setUint32(24, sampleRate, true);
view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate
view.setUint16(32, numChannels * bytesPerSample, true); // block align
view.setUint16(34, bitsPerSample, true);
// "data" sub-chunk
writeString(view, 36, 'data');
view.setUint32(40, audioData.length * bytesPerSample, true);
// Write audio data
const audioDataStart = 44;
for (let i = 0; i < audioData.length; i++) {
const sample = Math.max(-1, Math.min(1, audioData[i]));
const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF;
view.setInt16(audioDataStart + i * bytesPerSample, value, true);
}
return new Blob([buffer], { type: 'audio/wav' });
}
// Helper function to write strings to DataView
function writeString(view, offset, string) {
for (let i = 0; i < string.length; i++) {
view.setUint8(offset + i, string.charCodeAt(i));
}
}
// Clear conversation history
function clearConversation() {
elements.conversation.innerHTML = '';
if (state.socket && state.socket.connected) {
state.socket.emit('clear_context');
}
addSystemMessage('Conversation cleared.');
}
// Handle transcription response from server
function handleTranscription(data) {
const speaker = data.speaker === 0 ? 'user' : 'ai';
addMessage(data.text, speaker);
}
// Handle context update from server
function handleContextUpdate(data) {
if (data.status === 'cleared') {
elements.conversation.innerHTML = '';
addSystemMessage('Conversation context cleared.');
}
}
// Handle streaming status updates from server
function handleStreamingStatus(data) {
if (data.status === 'active') {
console.log('Server acknowledged streaming is active');
} else if (data.status === 'inactive') {
console.log('Server acknowledged streaming is inactive');
}
}
// Handle processing status updates
function handleProcessingStatus(data) {
switch (data.status) {
case 'transcribing':
addSystemMessage('Transcribing your message...');
break;
case 'generating':
addSystemMessage('Generating response...');
break;
case 'synthesizing':
addSystemMessage('Synthesizing voice...');
break;
}
}
// Handle the start of an audio streaming response
function handleAudioResponseStart(data) {
// Prepare for receiving chunked audio
console.log(`Expecting ${data.total_chunks} audio chunks`);
}
// Handle an incoming audio chunk
function handleAudioResponseChunk(data) {
// Create audio element for the response
const audioElement = document.createElement('audio');
if (elements.autoPlayResponses.checked) {
audioElement.autoplay = true;
}
audioElement.controls = true;
audioElement.className = 'audio-player';
audioElement.src = data.chunk;
// Add to the most recent AI message if it exists
const messages = elements.conversation.querySelectorAll('.message.ai');
if (messages.length > 0) {
const lastAiMessage = messages[messages.length - 1];
lastAiMessage.appendChild(audioElement);
}
}
// Handle completion of audio streaming
function handleAudioResponseComplete(data) {
// Update the AI message with the full text
addMessage(data.text, 'ai');
}
// Add a message to the conversation
function addMessage(text, sender) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${sender}`;
messageDiv.textContent = text;
const timeSpan = document.createElement('span');
timeSpan.className = 'message-time';
const now = new Date();
timeSpan.textContent = `${now.getHours().toString().padStart(2, '0')}:${now.getMinutes().toString().padStart(2, '0')}`;
messageDiv.appendChild(timeSpan);
elements.conversation.appendChild(messageDiv);
elements.conversation.scrollTop = elements.conversation.scrollHeight;
}
// Add a system message to the conversation
function addSystemMessage(message) {
const messageDiv = document.createElement('div');
messageDiv.className = 'message system';
messageDiv.textContent = message;
elements.conversation.appendChild(messageDiv);
elements.conversation.scrollTop = elements.conversation.scrollHeight;
}
// Initialize the application when DOM is fully loaded
document.addEventListener('DOMContentLoaded', initializeApp);
</script>
</body>
</html>

203
Backend/models.py Normal file
View File

@@ -0,0 +1,203 @@
from dataclasses import dataclass
import torch
import torch.nn as nn
import torchtune
from huggingface_hub import PyTorchModelHubMixin
from torchtune.models import llama3_2
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=16,
num_heads=32,
num_kv_heads=8,
embed_dim=2048,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
return llama3_2.llama3_2(
vocab_size=128_256,
num_layers=4,
num_heads=8,
num_kv_heads=2,
embed_dim=1024,
max_seq_len=2048,
intermediate_dim=8192,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
scale_factor=32,
)
FLAVORS = {
"llama-1B": llama3_2_1B,
"llama-100M": llama3_2_100M,
}
def _prepare_transformer(model):
embed_dim = model.tok_embeddings.embedding_dim
model.tok_embeddings = nn.Identity()
model.output = nn.Identity()
return model, embed_dim
def _create_causal_mask(seq_len: int, device: torch.device):
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
"""
Args:
mask: (max_seq_len, max_seq_len)
input_pos: (batch_size, seq_len)
Returns:
(batch_size, seq_len, max_seq_len)
"""
r = mask[input_pos, :]
return r
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
logits = logits / temperature
filter_value: float = -float("Inf")
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
sample_token = _multinomial_sample_one_no_sync(probs)
return sample_token
@dataclass
class ModelArgs:
backbone_flavor: str
decoder_flavor: str
text_vocab_size: int
audio_vocab_size: int
audio_num_codebooks: int
class Model(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/SesameAILabs/csm",
pipeline_tag="text-to-speech",
license="apache-2.0",
):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
"""Setup KV caches and return a causal mask."""
dtype = next(self.parameters()).dtype
device = next(self.parameters()).device
with device:
self.backbone.setup_caches(max_batch_size, dtype)
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
def generate_frame(
self,
tokens: torch.Tensor,
tokens_mask: torch.Tensor,
input_pos: torch.Tensor,
temperature: float,
topk: int,
) -> torch.Tensor:
"""
Args:
tokens: (batch_size, seq_len, audio_num_codebooks+1)
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
input_pos: (batch_size, seq_len) positions for each token
mask: (batch_size, seq_len, max_seq_len
Returns:
(batch_size, audio_num_codebooks) sampled tokens
"""
dtype = next(self.parameters()).dtype
b, s, _ = tokens.size()
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
embeds = self._embed_tokens(tokens)
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
h = masked_embeds.sum(dim=2)
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
last_h = h[:, -1, :]
c0_logits = self.codebook0_head(last_h)
c0_sample = sample_topk(c0_logits, topk, temperature)
c0_embed = self._embed_audio(0, c0_sample)
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
curr_sample = c0_sample.clone()
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
# Decoder caches must be reset every frame.
self.decoder.reset_caches()
for i in range(1, self.config.audio_num_codebooks):
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
dtype=dtype
)
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
ci_sample = sample_topk(ci_logits, topk, temperature)
ci_embed = self._embed_audio(i, ci_sample)
curr_h = ci_embed
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
curr_pos = curr_pos[:, -1:] + 1
return curr_sample
def reset_caches(self):
self.backbone.reset_caches()
self.decoder.reset_caches()
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
audio_tokens = tokens[:, :, :-1] + (
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
)
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
)
return torch.cat([audio_embeds, text_embeds], dim=-2)

View File

@@ -1,16 +1,9 @@
Flask==2.2.2 torch==2.4.0
Flask-SocketIO==5.3.2 torchaudio==2.4.0
torch>=2.0.0 tokenizers==0.21.0
torchaudio>=2.0.0 transformers==4.49.0
transformers>=4.30.0 huggingface_hub==0.28.1
huggingface-hub>=0.14.0 moshi==0.2.2
python-dotenv==0.19.2 torchtune==0.4.0
numpy>=1.21.6 torchao==0.9.0
scipy>=1.7.3 silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
soundfile==0.10.3.post1
requests==2.28.1
pydub==0.25.1
python-socketio==5.7.2
eventlet==0.33.3
whisper>=20230314
ffmpeg-python>=0.2.0

117
Backend/run_csm.py Normal file
View File

@@ -0,0 +1,117 @@
import os
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Disable Triton compilation
os.environ["NO_TORCH_COMPILE"] = "1"
# Default prompts are available at https://hf.co/sesame/csm-1b
prompt_filepath_conversational_a = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_a.wav"
)
prompt_filepath_conversational_b = hf_hub_download(
repo_id="sesame/csm-1b",
filename="prompts/conversational_b.wav"
)
SPEAKER_PROMPTS = {
"conversational_a": {
"text": (
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
"start really early I'd be like okay I'm gonna start revising now and then like "
"you're revising for ages and then I just like start losing steam I didn't do that "
"for the exam we had recently to be fair that was a more of a last minute scenario "
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
"sort of start the day with this not like a panic but like a"
),
"audio": prompt_filepath_conversational_a
},
"conversational_b": {
"text": (
"like a super Mario level. Like it's very like high detail. And like, once you get "
"into the park, it just like, everything looks like a computer game and they have all "
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
"will have like a question block. And if you like, you know, punch it, a coin will "
"come out. So like everyone, when they come into the park, they get like this little "
"bracelet and then you can go punching question blocks around."
),
"audio": prompt_filepath_conversational_b
}
}
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = audio_tensor.squeeze(0)
# Resample is lazy so we can always call it
audio_tensor = torchaudio.functional.resample(
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
)
return audio_tensor
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
audio_tensor = load_prompt_audio(audio_path, sample_rate)
return Segment(text=text, speaker=speaker, audio=audio_tensor)
def main():
# Select the best available device, skipping MPS due to float64 limitations
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Load model
generator = load_csm_1b(device)
# Prepare prompts
prompt_a = prepare_prompt(
SPEAKER_PROMPTS["conversational_a"]["text"],
0,
SPEAKER_PROMPTS["conversational_a"]["audio"],
generator.sample_rate
)
prompt_b = prepare_prompt(
SPEAKER_PROMPTS["conversational_b"]["text"],
1,
SPEAKER_PROMPTS["conversational_b"]["audio"],
generator.sample_rate
)
# Generate conversation
conversation = [
{"text": "Hey how are you doing?", "speaker_id": 0},
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
]
# Generate each utterance
generated_segments = []
prompt_segments = [prompt_a, prompt_b]
for utterance in conversation:
print(f"Generating: {utterance['text']}")
audio_tensor = generator.generate(
text=utterance['text'],
speaker=utterance['speaker_id'],
context=prompt_segments + generated_segments,
max_audio_length_ms=10_000,
)
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
# Concatenate all generations
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
torchaudio.save(
"full_conversation.wav",
all_audio.unsqueeze(0).cpu(),
generator.sample_rate
)
print("Successfully generated full_conversation.wav")
if __name__ == "__main__":
main()

View File

@@ -1,53 +1,426 @@
import os import os
import logging import io
import torch
import eventlet
import base64 import base64
import time
import json
import uuid
import logging
import threading
import queue
import tempfile import tempfile
from io import BytesIO from typing import Dict, List, Optional, Tuple
from flask import Flask, render_template, request, jsonify
from flask_socketio import SocketIO, emit import torch
import whisper
import torchaudio import torchaudio
from src.models.conversation import Segment import numpy as np
from src.services.tts_service import load_csm_1b from flask import Flask, request, jsonify, send_from_directory
from src.llm.generator import generate_llm_response from flask_socketio import SocketIO, emit
from transformers import AutoTokenizer, AutoModelForCausalLM from flask_cors import CORS
from src.audio.streaming import AudioStreamer from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService from generator import load_csm_1b, Segment
from dataclasses import dataclass
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = Flask(__name__, static_folder='static', template_folder='templates') # Initialize Flask app
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'your-secret-key') app = Flask(__name__, static_folder='.')
socketio = SocketIO(app) CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
# Initialize services # Configure device
transcription_service = TranscriptionService() if torch.cuda.is_available():
tts_service = TextToSpeechService() DEVICE = "cuda"
audio_streamer = AudioStreamer() elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
@socketio.on('audio_input') logger.info(f"Using device: {DEVICE}")
def handle_audio_input(data):
audio_chunk = data['audio']
speaker_id = data['speaker']
# Process audio and convert to text # Global variables
text = transcription_service.transcribe(audio_chunk) active_conversations = {}
logging.info(f"Transcribed text: {text}") user_queues = {}
processing_threads = {}
# Generate response using Llama 3.2 # Load models
response_text = tts_service.generate_response(text, speaker_id) @dataclass
logging.info(f"Generated response: {response_text}") class AppModels:
generator = None
tokenizer = None
llm = None
asr = None
# Convert response text to audio models = AppModels()
audio_response = tts_service.text_to_speech(response_text, speaker_id)
# Stream audio response back to client def load_models():
socketio.emit('audio_response', {'audio': audio_response}) """Load all required models"""
global models
logger.info("Loading CSM 1B model...")
models.generator = load_csm_1b(device=DEVICE)
logger.info("Loading ASR pipeline...")
models.asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
device=DEVICE
)
logger.info("Loading Llama 3.2 model...")
models.llm = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map=DEVICE,
torch_dtype=torch.bfloat16
)
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
# Load models in a background thread
threading.Thread(target=load_models, daemon=True).start()
# Conversation data structure
class Conversation:
def __init__(self, session_id):
self.session_id = session_id
self.segments: List[Segment] = []
self.current_speaker = 0
self.last_activity = time.time()
self.is_processing = False
def add_segment(self, text, speaker, audio):
segment = Segment(text=text, speaker=speaker, audio=audio)
self.segments.append(segment)
self.last_activity = time.time()
return segment
def get_context(self, max_segments=10):
"""Return the most recent segments for context"""
return self.segments[-max_segments:] if self.segments else []
# Routes
@app.route('/')
def index():
return send_from_directory('.', 'index.html')
@app.route('/api/health')
def health_check():
return jsonify({
"status": "ok",
"cuda_available": torch.cuda.is_available(),
"models_loaded": models.generator is not None and models.llm is not None
})
# Socket event handlers
@socketio.on('connect')
def handle_connect():
session_id = request.sid
logger.info(f"Client connected: {session_id}")
# Initialize conversation data
if session_id not in active_conversations:
active_conversations[session_id] = Conversation(session_id)
user_queues[session_id] = queue.Queue()
processing_threads[session_id] = threading.Thread(
target=process_audio_queue,
args=(session_id, user_queues[session_id]),
daemon=True
)
processing_threads[session_id].start()
emit('connection_status', {'status': 'connected'})
@socketio.on('disconnect')
def handle_disconnect():
session_id = request.sid
logger.info(f"Client disconnected: {session_id}")
# Cleanup
if session_id in active_conversations:
# Mark for deletion rather than immediately removing
# as the processing thread might still be accessing it
active_conversations[session_id].is_processing = False
user_queues[session_id].put(None) # Signal thread to terminate
@socketio.on('start_stream')
def handle_start_stream():
session_id = request.sid
logger.info(f"Starting stream for client: {session_id}")
emit('streaming_status', {'status': 'active'})
@socketio.on('stop_stream')
def handle_stop_stream():
session_id = request.sid
logger.info(f"Stopping stream for client: {session_id}")
emit('streaming_status', {'status': 'inactive'})
@socketio.on('clear_context')
def handle_clear_context():
session_id = request.sid
logger.info(f"Clearing context for client: {session_id}")
if session_id in active_conversations:
active_conversations[session_id].segments = []
emit('context_updated', {'status': 'cleared'})
@socketio.on('audio_chunk')
def handle_audio_chunk(data):
session_id = request.sid
audio_data = data.get('audio', '')
speaker_id = int(data.get('speaker', 0))
if not audio_data or not session_id in active_conversations:
return
# Update the current speaker
active_conversations[session_id].current_speaker = speaker_id
# Queue audio for processing
user_queues[session_id].put({
'audio': audio_data,
'speaker': speaker_id
})
emit('processing_status', {'status': 'transcribing'})
def process_audio_queue(session_id, q):
"""Background thread to process audio chunks for a session"""
logger.info(f"Started processing thread for session: {session_id}")
try:
while session_id in active_conversations:
try:
# Get the next audio chunk with a timeout
data = q.get(timeout=120)
if data is None: # Termination signal
break
# Process the audio and generate a response
process_audio_and_respond(session_id, data)
except queue.Empty:
# Timeout, check if session is still valid
continue
except Exception as e:
logger.error(f"Error processing audio for {session_id}: {str(e)}")
socketio.emit('error', {'message': str(e)}, room=session_id)
finally:
logger.info(f"Ending processing thread for session: {session_id}")
# Clean up when thread is done
with app.app_context():
if session_id in active_conversations:
del active_conversations[session_id]
if session_id in user_queues:
del user_queues[session_id]
def process_audio_and_respond(session_id, data):
"""Process audio data and generate a response"""
if models.generator is None or models.asr is None or models.llm is None:
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
return
conversation = active_conversations[session_id]
try:
# Set processing flag
conversation.is_processing = True
# Process base64 audio data
audio_data = data['audio']
speaker_id = data['speaker']
# Convert from base64 to WAV
audio_bytes = base64.b64decode(audio_data.split(',')[1])
# Save to temporary file for processing
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
try:
# Load audio file
waveform, sample_rate = torchaudio.load(temp_path)
# Normalize to mono if needed
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to the CSM sample rate if needed
if sample_rate != models.generator.sample_rate:
waveform = torchaudio.functional.resample(
waveform,
orig_freq=sample_rate,
new_freq=models.generator.sample_rate
)
# Transcribe audio
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
# Use the ASR pipeline to transcribe
transcription_result = models.asr(
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
return_timestamps=False
)
user_text = transcription_result['text'].strip()
# If no text was recognized, don't process further
if not user_text:
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
return
# Add the user's message to conversation history
user_segment = conversation.add_segment(
text=user_text,
speaker=speaker_id,
audio=waveform.squeeze()
)
# Send transcription to client
socketio.emit('transcription', {
'text': user_text,
'speaker': speaker_id
}, room=session_id)
# Generate AI response using Llama
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
# Create prompt from conversation history
conversation_history = ""
for segment in conversation.segments[-5:]: # Last 5 segments for context
role = "User" if segment.speaker == 0 else "Assistant"
conversation_history += f"{role}: {segment.text}\n"
# Add final prompt
prompt = f"{conversation_history}Assistant: "
# Generate response with Llama
input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
with torch.no_grad():
generated_ids = models.llm.generate(
input_ids,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=models.tokenizer.eos_token_id
)
# Decode the response
response_text = models.tokenizer.decode(
generated_ids[0][input_ids.shape[1]:],
skip_special_tokens=True
).strip()
# Synthesize speech
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
# Generate audio with CSM
ai_speaker_id = 1 # Use speaker 1 for AI responses
# Start sending the audio response
socketio.emit('audio_response_start', {
'text': response_text,
'total_chunks': 1,
'chunk_index': 0
}, room=session_id)
# Generate audio
audio_tensor = models.generator.generate(
text=response_text,
speaker=ai_speaker_id,
context=conversation.get_context(),
max_audio_length_ms=10_000,
temperature=0.9
)
# Add AI response to conversation history
ai_segment = conversation.add_segment(
text=response_text,
speaker=ai_speaker_id,
audio=audio_tensor
)
# Convert audio to WAV format
with io.BytesIO() as wav_io:
torchaudio.save(
wav_io,
audio_tensor.unsqueeze(0).cpu(),
models.generator.sample_rate,
format="wav"
)
wav_io.seek(0)
wav_data = wav_io.read()
# Convert WAV data to base64
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
# Send audio chunk to client
socketio.emit('audio_response_chunk', {
'chunk': audio_base64,
'chunk_index': 0,
'total_chunks': 1,
'is_last': True
}, room=session_id)
# Signal completion
socketio.emit('audio_response_complete', {
'text': response_text
}, room=session_id)
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
finally:
# Reset processing flag
conversation.is_processing = False
# Error handler
@socketio.on_error()
def error_handler(e):
logger.error(f"SocketIO error: {str(e)}")
# Periodic cleanup of inactive sessions
def cleanup_inactive_sessions():
"""Remove sessions that have been inactive for too long"""
current_time = time.time()
inactive_timeout = 3600 # 1 hour
for session_id in list(active_conversations.keys()):
conversation = active_conversations[session_id]
if (current_time - conversation.last_activity > inactive_timeout and
not conversation.is_processing):
logger.info(f"Cleaning up inactive session: {session_id}")
# Signal processing thread to terminate
if session_id in user_queues:
user_queues[session_id].put(None)
# Remove from active conversations
del active_conversations[session_id]
# Start cleanup thread
def start_cleanup_thread():
while True:
try:
cleanup_inactive_sessions()
except Exception as e:
logger.error(f"Error in cleanup thread: {str(e)}")
time.sleep(300) # Run every 5 minutes
cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True)
cleanup_thread.start()
# Start the server
if __name__ == '__main__': if __name__ == '__main__':
socketio.run(app, host='0.0.0.0', port=5000) port = int(os.environ.get('PORT', 5000))
logger.info(f"Starting server on port {port}")
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True)

13
Backend/setup.py Normal file
View File

@@ -0,0 +1,13 @@
from setuptools import setup, find_packages
import os
# Read requirements from requirements.txt
with open('requirements.txt') as f:
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
setup(
name='csm',
version='0.1.0',
packages=find_packages(),
install_requires=requirements,
)

View File

@@ -1,28 +0,0 @@
from scipy.io import wavfile
import numpy as np
import torchaudio
def load_audio(file_path):
sample_rate, audio_data = wavfile.read(file_path)
return sample_rate, audio_data
def normalize_audio(audio_data):
audio_data = audio_data.astype(np.float32)
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data /= max_val
return audio_data
def reduce_noise(audio_data, noise_factor=0.1):
noise = np.random.randn(len(audio_data))
noisy_audio = audio_data + noise_factor * noise
return noisy_audio
def save_audio(file_path, sample_rate, audio_data):
torchaudio.save(file_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
def process_audio(file_path, output_path):
sample_rate, audio_data = load_audio(file_path)
normalized_audio = normalize_audio(audio_data)
denoised_audio = reduce_noise(normalized_audio)
save_audio(output_path, sample_rate, denoised_audio)

View File

@@ -1,35 +0,0 @@
from flask import Blueprint, request
from flask_socketio import SocketIO, emit
from src.audio.processor import process_audio
from src.services.transcription_service import TranscriptionService
from src.services.tts_service import TextToSpeechService
streaming_bp = Blueprint('streaming', __name__)
socketio = SocketIO()
transcription_service = TranscriptionService()
tts_service = TextToSpeechService()
@socketio.on('audio_stream')
def handle_audio_stream(data):
audio_chunk = data['audio']
speaker_id = data['speaker']
# Process the audio chunk
processed_audio = process_audio(audio_chunk)
# Transcribe the audio to text
transcription = transcription_service.transcribe(processed_audio)
# Generate a response using the LLM
response_text = generate_response(transcription, speaker_id)
# Convert the response text back to audio
response_audio = tts_service.convert_text_to_speech(response_text, speaker_id)
# Emit the response audio back to the client
emit('audio_response', {'audio': response_audio})
def generate_response(transcription, speaker_id):
# Placeholder for the actual response generation logic
return f"Response to: {transcription}"

View File

@@ -1,14 +0,0 @@
from transformers import AutoTokenizer
def load_llama3_tokenizer():
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return tokenizer
def tokenize_text(text: str, tokenizer) -> list:
tokens = tokenizer.encode(text, return_tensors='pt')
return tokens
def decode_tokens(tokens: list, tokenizer) -> str:
text = tokenizer.decode(tokens, skip_special_tokens=True)
return text

View File

@@ -1,28 +0,0 @@
from dataclasses import dataclass
import torch
@dataclass
class AudioModel:
model: torch.nn.Module
sample_rate: int
def __post_init__(self):
self.model.eval()
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
processed_audio = self.model(audio_tensor)
return processed_audio
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
if self.sample_rate != target_sample_rate:
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
return resampled_audio
return audio_tensor
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
def load_model(self, path: str):
self.model.load_state_dict(torch.load(path))
self.model.eval()

View File

@@ -1,51 +0,0 @@
from dataclasses import dataclass, field
from typing import List, Optional
import torch
@dataclass
class Segment:
speaker: int
text: str
# (num_samples,), sample_rate = 24_000
audio: Optional[torch.Tensor] = None
def __post_init__(self):
# Ensure audio is a tensor if provided
if self.audio is not None and not isinstance(self.audio, torch.Tensor):
self.audio = torch.tensor(self.audio, dtype=torch.float32)
@dataclass
class Conversation:
context: List[str] = field(default_factory=list)
segments: List[Segment] = field(default_factory=list)
current_speaker: Optional[int] = None
def add_message(self, message: str, speaker: int):
self.context.append(f"Speaker {speaker}: {message}")
self.current_speaker = speaker
def add_segment(self, segment: Segment):
self.segments.append(segment)
self.context.append(f"Speaker {segment.speaker}: {segment.text}")
self.current_speaker = segment.speaker
def get_context(self) -> List[str]:
return self.context
def get_segments(self) -> List[Segment]:
return self.segments
def clear_context(self):
self.context.clear()
self.segments.clear()
self.current_speaker = None
def get_last_message(self) -> Optional[str]:
if self.context:
return self.context[-1]
return None
def get_last_segment(self) -> Optional[Segment]:
if self.segments:
return self.segments[-1]
return None

View File

@@ -1,25 +0,0 @@
from typing import List
import torchaudio
import torch
from generator import load_csm_1b, Segment
class TranscriptionService:
def __init__(self, model_device: str = "cpu"):
self.generator = load_csm_1b(device=model_device)
def transcribe_audio(self, audio_path: str) -> str:
audio_tensor, sample_rate = torchaudio.load(audio_path)
audio_tensor = self._resample_audio(audio_tensor, sample_rate)
transcription = self.generator.generate_transcription(audio_tensor)
return transcription
def _resample_audio(self, audio_tensor: torch.Tensor, orig_freq: int) -> torch.Tensor:
target_sample_rate = self.generator.sample_rate
if orig_freq != target_sample_rate:
audio_tensor = torchaudio.functional.resample(audio_tensor.squeeze(0), orig_freq=orig_freq, new_freq=target_sample_rate)
return audio_tensor
def transcribe_audio_stream(self, audio_chunks: List[torch.Tensor]) -> str:
combined_audio = torch.cat(audio_chunks, dim=1)
transcription = self.generator.generate_transcription(combined_audio)
return transcription

View File

@@ -1,24 +0,0 @@
from dataclasses import dataclass
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from src.llm.generator import load_csm_1b
@dataclass
class TextToSpeechService:
generator: any
def __init__(self, device: str = "cuda"):
self.generator = load_csm_1b(device=device)
def text_to_speech(self, text: str, speaker: int = 0) -> torch.Tensor:
audio = self.generator.generate(
text=text,
speaker=speaker,
context=[],
max_audio_length_ms=10000,
)
return audio
def save_audio(self, audio: torch.Tensor, file_path: str):
torchaudio.save(file_path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)

View File

@@ -1,23 +0,0 @@
# filepath: /csm-conversation-bot/csm-conversation-bot/src/utils/config.py
import os
class Config:
# General configuration
DEBUG = os.getenv('DEBUG', 'False') == 'True'
SECRET_KEY = os.getenv('SECRET_KEY', 'your_secret_key_here')
# API configuration
API_URL = os.getenv('API_URL', 'http://localhost:5000')
# Model configuration
LLM_MODEL_PATH = os.getenv('LLM_MODEL_PATH', 'path/to/llm/model')
AUDIO_MODEL_PATH = os.getenv('AUDIO_MODEL_PATH', 'path/to/audio/model')
# Socket.IO configuration
SOCKETIO_MESSAGE_QUEUE = os.getenv('SOCKETIO_MESSAGE_QUEUE', 'redis://localhost:6379/0')
# Logging configuration
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
# Other configurations can be added as needed

View File

@@ -1,14 +0,0 @@
import logging
def setup_logger(name, log_file, level=logging.INFO):
handler = logging.FileHandler(log_file)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger = logging.getLogger(name)
logger.setLevel(level)
logger.addHandler(handler)
return logger
# Example usage:
# logger = setup_logger('my_logger', 'app.log')

View File

@@ -1,105 +0,0 @@
body {
font-family: 'Arial', sans-serif;
background-color: #f4f4f4;
color: #333;
margin: 0;
padding: 0;
}
header {
background: #4c84ff;
color: #fff;
padding: 10px 0;
text-align: center;
}
h1 {
margin: 0;
font-size: 2.5rem;
}
.container {
width: 80%;
margin: auto;
overflow: hidden;
}
.conversation {
background: #fff;
padding: 20px;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
max-height: 400px;
overflow-y: auto;
}
.message {
padding: 10px;
margin: 10px 0;
border-radius: 5px;
}
.user {
background: #e3f2fd;
text-align: right;
}
.ai {
background: #f1f1f1;
text-align: left;
}
.controls {
display: flex;
justify-content: space-between;
margin-top: 20px;
}
button {
padding: 10px 15px;
border: none;
border-radius: 5px;
cursor: pointer;
transition: background 0.3s;
}
button:hover {
background: #3367d6;
color: #fff;
}
.visualizer-container {
height: 150px;
background: #000;
border-radius: 5px;
margin-top: 20px;
}
.visualizer-label {
color: rgba(255, 255, 255, 0.7);
text-align: center;
padding: 10px;
}
.status-indicator {
display: flex;
align-items: center;
margin-top: 10px;
}
.status-dot {
width: 12px;
height: 12px;
border-radius: 50%;
background-color: #ccc;
margin-right: 10px;
}
.status-dot.active {
background-color: #4CAF50;
}
.status-text {
font-size: 0.9em;
color: #666;
}

View File

@@ -1,31 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CSM Conversation Bot</title>
<link rel="stylesheet" href="css/styles.css">
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<script src="js/client.js" defer></script>
</head>
<body>
<header>
<h1>CSM Conversation Bot</h1>
<p>Talk to the AI and get responses in real-time!</p>
</header>
<main>
<div id="conversation" class="conversation"></div>
<div class="controls">
<button id="startButton">Start Conversation</button>
<button id="stopButton">Stop Conversation</button>
</div>
<div class="status-indicator">
<div id="statusDot" class="status-dot"></div>
<div id="statusText">Disconnected</div>
</div>
</main>
<footer>
<p>Powered by CSM and Llama 3.2</p>
</footer>
</body>
</html>

View File

@@ -1,131 +0,0 @@
// This file contains the client-side JavaScript code that handles audio streaming and communication with the server.
const SERVER_URL = window.location.hostname === 'localhost' ?
'http://localhost:5000' : window.location.origin;
const elements = {
conversation: document.getElementById('conversation'),
streamButton: document.getElementById('streamButton'),
clearButton: document.getElementById('clearButton'),
speakerSelection: document.getElementById('speakerSelect'),
statusDot: document.getElementById('statusDot'),
statusText: document.getElementById('statusText'),
};
const state = {
socket: null,
isStreaming: false,
currentSpeaker: 0,
};
// Initialize the application
function initializeApp() {
setupSocketConnection();
setupEventListeners();
}
// Setup Socket.IO connection
function setupSocketConnection() {
state.socket = io(SERVER_URL);
state.socket.on('connect', () => {
updateConnectionStatus(true);
});
state.socket.on('disconnect', () => {
updateConnectionStatus(false);
});
state.socket.on('audio_response', handleAudioResponse);
state.socket.on('transcription', handleTranscription);
}
// Setup event listeners
function setupEventListeners() {
elements.streamButton.addEventListener('click', toggleStreaming);
elements.clearButton.addEventListener('click', clearConversation);
elements.speakerSelection.addEventListener('change', (event) => {
state.currentSpeaker = event.target.value;
});
}
// Update connection status UI
function updateConnectionStatus(isConnected) {
elements.statusDot.classList.toggle('active', isConnected);
elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected';
}
// Toggle streaming state
function toggleStreaming() {
if (state.isStreaming) {
stopStreaming();
} else {
startStreaming();
}
}
// Start streaming audio to the server
function startStreaming() {
if (state.isStreaming) return;
navigator.mediaDevices.getUserMedia({ audio: true })
.then(stream => {
const mediaRecorder = new MediaRecorder(stream);
mediaRecorder.start();
mediaRecorder.ondataavailable = (event) => {
if (event.data.size > 0) {
sendAudioChunk(event.data);
}
};
mediaRecorder.onstop = () => {
state.isStreaming = false;
elements.streamButton.innerHTML = 'Start Conversation';
};
state.isStreaming = true;
elements.streamButton.innerHTML = 'Stop Conversation';
})
.catch(err => {
console.error('Error accessing microphone:', err);
});
}
// Stop streaming audio
function stopStreaming() {
if (!state.isStreaming) return;
// Logic to stop the media recorder would go here
}
// Send audio chunk to server
function sendAudioChunk(audioData) {
const reader = new FileReader();
reader.onloadend = () => {
const arrayBuffer = reader.result;
state.socket.emit('audio_chunk', { audio: arrayBuffer, speaker: state.currentSpeaker });
};
reader.readAsArrayBuffer(audioData);
}
// Handle audio response from server
function handleAudioResponse(data) {
const audioElement = new Audio(URL.createObjectURL(new Blob([data.audio])));
audioElement.play();
}
// Handle transcription response from server
function handleTranscription(data) {
const messageElement = document.createElement('div');
messageElement.textContent = `AI: ${data.transcription}`;
elements.conversation.appendChild(messageElement);
}
// Clear conversation history
function clearConversation() {
elements.conversation.innerHTML = '';
}
// Initialize the application when DOM is fully loaded
document.addEventListener('DOMContentLoaded', initializeApp);

View File

@@ -1,31 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>CSM Conversation Bot</title>
<link rel="stylesheet" href="../static/css/styles.css">
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
<script src="../static/js/client.js" defer></script>
</head>
<body>
<header>
<h1>CSM Conversation Bot</h1>
<p>Talk to the AI and get responses in real-time!</p>
</header>
<main>
<div class="chat-container">
<div class="conversation" id="conversation"></div>
<input type="text" id="userInput" placeholder="Type your message..." />
<button id="sendButton">Send</button>
</div>
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<div class="status-text" id="statusText">Not connected</div>
</div>
</main>
<footer>
<p>Powered by CSM and Llama 3.2</p>
</footer>
</body>
</html>

1071
Backend/voice-chat.js Normal file

File diff suppressed because it is too large Load Diff

79
Backend/watermarking.py Normal file
View File

@@ -0,0 +1,79 @@
import argparse
import silentcipher
import torch
import torchaudio
# This watermark key is public, it is not secure.
# If using CSM 1B in another application, use a new private key and keep it secret.
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
def cli_check_audio() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--audio_path", type=str, required=True)
args = parser.parse_args()
check_audio_from_file(args.audio_path)
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
model = silentcipher.get_model(
model_type="44.1k",
device=device,
)
return model
@torch.inference_mode()
def watermark(
watermarker: silentcipher.server.Model,
audio_array: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> tuple[torch.Tensor, int]:
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
output_sample_rate = min(44100, sample_rate)
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
return encoded, output_sample_rate
@torch.inference_mode()
def verify(
watermarker: silentcipher.server.Model,
watermarked_audio: torch.Tensor,
sample_rate: int,
watermark_key: list[int],
) -> bool:
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
is_watermarked = result["status"]
if is_watermarked:
is_csm_watermarked = result["messages"][0] == watermark_key
else:
is_csm_watermarked = False
return is_watermarked and is_csm_watermarked
def check_audio_from_file(audio_path: str) -> None:
watermarker = load_watermarker(device="cuda")
audio_array, sample_rate = load_audio(audio_path)
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
outcome = "Watermarked" if is_watermarked else "Not watermarked"
print(f"{outcome}: {audio_path}")
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
audio_array, sample_rate = torchaudio.load(audio_path)
audio_array = audio_array.mean(dim=0)
return audio_array, int(sample_rate)
if __name__ == "__main__":
cli_check_audio()