Merge branch 'main' of https://github.com/GamerBoss101/HooHacks-12
This commit is contained in:
46
Backend/.gitignore
vendored
46
Backend/.gitignore
vendored
@@ -1,46 +0,0 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Project specific
|
||||
.python-version
|
||||
*.wav
|
||||
output_*/
|
||||
basic_audio.wav
|
||||
full_conversation.wav
|
||||
context_audio.wav
|
||||
|
||||
# Model files
|
||||
*.pt
|
||||
*.ckpt
|
||||
@@ -1,71 +0,0 @@
|
||||
# csm-conversation-bot
|
||||
|
||||
## Overview
|
||||
The CSM Conversation Bot is an application that utilizes advanced audio processing and language model technologies to facilitate real-time voice conversations with an AI assistant. The bot processes audio streams, converts spoken input into text, generates responses using the Llama 3.2 model, and converts the text back into audio for seamless interaction.
|
||||
|
||||
## Project Structure
|
||||
```
|
||||
csm-conversation-bot
|
||||
├── api
|
||||
│ ├── app.py # Main entry point for the API
|
||||
│ ├── routes.py # Defines API routes
|
||||
│ └── socket_handlers.py # Manages Socket.IO events
|
||||
├── src
|
||||
│ ├── audio
|
||||
│ │ ├── processor.py # Audio processing functions
|
||||
│ │ └── streaming.py # Audio streaming management
|
||||
│ ├── llm
|
||||
│ │ ├── generator.py # Response generation using Llama 3.2
|
||||
│ │ └── tokenizer.py # Text tokenization functions
|
||||
│ ├── models
|
||||
│ │ ├── audio_model.py # Audio processing model
|
||||
│ │ └── conversation.py # Conversation state management
|
||||
│ ├── services
|
||||
│ │ ├── transcription_service.py # Audio to text conversion
|
||||
│ │ └── tts_service.py # Text to speech conversion
|
||||
│ └── utils
|
||||
│ ├── config.py # Configuration settings
|
||||
│ └── logger.py # Logging utilities
|
||||
├── static
|
||||
│ ├── css
|
||||
│ │ └── styles.css # CSS styles for the web interface
|
||||
│ ├── js
|
||||
│ │ └── client.js # Client-side JavaScript
|
||||
│ └── index.html # Main HTML file for the web interface
|
||||
├── templates
|
||||
│ └── index.html # Template for rendering the main HTML page
|
||||
├── config.py # Main configuration settings
|
||||
├── requirements.txt # Python dependencies
|
||||
├── server.py # Entry point for running the application
|
||||
└── README.md # Documentation for the project
|
||||
```
|
||||
|
||||
## Installation
|
||||
1. Clone the repository:
|
||||
```
|
||||
git clone https://github.com/yourusername/csm-conversation-bot.git
|
||||
cd csm-conversation-bot
|
||||
```
|
||||
|
||||
2. Install the required dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Configure the application settings in `config.py` as needed.
|
||||
|
||||
## Usage
|
||||
1. Start the server:
|
||||
```
|
||||
python server.py
|
||||
```
|
||||
|
||||
2. Open your web browser and navigate to `http://localhost:5000` to access the application.
|
||||
|
||||
3. Use the interface to start a conversation with the AI assistant.
|
||||
|
||||
## Contributing
|
||||
Contributions are welcome! Please submit a pull request or open an issue for any enhancements or bug fixes.
|
||||
|
||||
## License
|
||||
This project is licensed under the MIT License. See the LICENSE file for more details.
|
||||
@@ -1,22 +0,0 @@
|
||||
from flask import Flask
|
||||
from flask_socketio import SocketIO
|
||||
from src.utils.config import Config
|
||||
from src.utils.logger import setup_logger
|
||||
from api.routes import setup_routes
|
||||
from api.socket_handlers import setup_socket_handlers
|
||||
|
||||
def create_app():
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(Config)
|
||||
|
||||
setup_logger(app)
|
||||
setup_routes(app)
|
||||
setup_socket_handlers(app)
|
||||
|
||||
return app
|
||||
|
||||
app = create_app()
|
||||
socketio = SocketIO(app)
|
||||
|
||||
if __name__ == "__main__":
|
||||
socketio.run(app, host='0.0.0.0', port=5000)
|
||||
@@ -1,29 +0,0 @@
|
||||
from flask import Blueprint, request, jsonify
|
||||
from src.services.transcription_service import TranscriptionService
|
||||
from src.services.tts_service import TextToSpeechService
|
||||
|
||||
api = Blueprint('api', __name__)
|
||||
|
||||
transcription_service = TranscriptionService()
|
||||
tts_service = TextToSpeechService()
|
||||
|
||||
@api.route('/transcribe', methods=['POST'])
|
||||
def transcribe_audio():
|
||||
audio_data = request.files.get('audio')
|
||||
if not audio_data:
|
||||
return jsonify({'error': 'No audio file provided'}), 400
|
||||
|
||||
text = transcription_service.transcribe(audio_data)
|
||||
return jsonify({'transcription': text})
|
||||
|
||||
@api.route('/generate-response', methods=['POST'])
|
||||
def generate_response():
|
||||
data = request.json
|
||||
user_input = data.get('input')
|
||||
if not user_input:
|
||||
return jsonify({'error': 'No input provided'}), 400
|
||||
|
||||
response_text = tts_service.generate_response(user_input)
|
||||
audio_data = tts_service.text_to_speech(response_text)
|
||||
|
||||
return jsonify({'response': response_text, 'audio': audio_data})
|
||||
@@ -1,32 +0,0 @@
|
||||
from flask import request
|
||||
from flask_socketio import SocketIO, emit
|
||||
from src.audio.processor import process_audio
|
||||
from src.services.transcription_service import TranscriptionService
|
||||
from src.services.tts_service import TextToSpeechService
|
||||
from src.llm.generator import load_csm_1b
|
||||
|
||||
socketio = SocketIO()
|
||||
|
||||
transcription_service = TranscriptionService()
|
||||
tts_service = TextToSpeechService()
|
||||
generator = load_csm_1b()
|
||||
|
||||
@socketio.on('audio_stream')
|
||||
def handle_audio_stream(data):
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
|
||||
# Process the incoming audio
|
||||
processed_audio = process_audio(audio_data)
|
||||
|
||||
# Transcribe the audio to text
|
||||
transcription = transcription_service.transcribe(processed_audio)
|
||||
|
||||
# Generate a response using the LLM
|
||||
response_text = generator.generate(text=transcription, speaker=speaker_id)
|
||||
|
||||
# Convert the response text back to audio
|
||||
response_audio = tts_service.convert_text_to_speech(response_text)
|
||||
|
||||
# Emit the response audio back to the client
|
||||
emit('audio_response', {'audio': response_audio, 'speaker': speaker_id})
|
||||
@@ -1,13 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.MODEL_PATH = Path("path/to/your/model")
|
||||
self.AUDIO_MODEL_PATH = Path("path/to/your/audio/model")
|
||||
self.WATERMARK_KEY = "your_watermark_key"
|
||||
self.SOCKETIO_CORS = "*"
|
||||
self.API_KEY = "your_api_key"
|
||||
self.DEBUG = True
|
||||
self.LOGGING_LEVEL = "INFO"
|
||||
self.TTS_SERVICE_URL = "http://localhost:5001/tts"
|
||||
self.TRANSCRIPTION_SERVICE_URL = "http://localhost:5002/transcribe"
|
||||
@@ -15,10 +15,14 @@ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
||||
class Segment:
|
||||
speaker: int
|
||||
text: str
|
||||
# (num_samples,), sample_rate = 24_000
|
||||
audio: torch.Tensor
|
||||
|
||||
|
||||
def load_llama3_tokenizer():
|
||||
"""
|
||||
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
||||
"""
|
||||
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
bos = tokenizer.bos_token
|
||||
@@ -74,8 +78,10 @@ class Generator:
|
||||
frame_tokens = []
|
||||
frame_masks = []
|
||||
|
||||
# (K, T)
|
||||
audio = audio.to(self.device)
|
||||
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
||||
# add EOS frame
|
||||
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
||||
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
||||
|
||||
@@ -90,6 +96,10 @@ class Generator:
|
||||
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
||||
|
||||
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(seq_len, 33), (seq_len, 33)
|
||||
"""
|
||||
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
||||
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
||||
|
||||
@@ -136,7 +146,7 @@ class Generator:
|
||||
for _ in range(max_generation_len):
|
||||
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
||||
if torch.all(sample == 0):
|
||||
break
|
||||
break # eos
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
@@ -148,6 +158,10 @@ class Generator:
|
||||
|
||||
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
||||
|
||||
# This applies an imperceptible watermark to identify audio as AI-generated.
|
||||
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
||||
# Please be a responsible AI citizen and keep the watermarking in place.
|
||||
# If using CSM 1B in another application, use your own private key and keep it secret.
|
||||
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
||||
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
||||
|
||||
711
Backend/index.html
Normal file
711
Backend/index.html
Normal file
@@ -0,0 +1,711 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CSM Voice Chat</title>
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<!-- Socket.IO client library -->
|
||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #4c84ff;
|
||||
--secondary-color: #3367d6;
|
||||
--text-color: #333;
|
||||
--background-color: #f9f9f9;
|
||||
--card-background: #ffffff;
|
||||
--accent-color: #ff5252;
|
||||
--success-color: #4CAF50;
|
||||
--border-color: #e0e0e0;
|
||||
--shadow-color: rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background-color: var(--background-color);
|
||||
color: var(--text-color);
|
||||
line-height: 1.6;
|
||||
max-width: 1000px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: var(--primary-color);
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #666;
|
||||
font-weight: 300;
|
||||
}
|
||||
|
||||
.app-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
.app-container {
|
||||
grid-template-columns: 2fr 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.chat-container, .control-panel {
|
||||
background-color: var(--card-background);
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 12px var(--shadow-color);
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.chat-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 15px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.conversation {
|
||||
height: 400px;
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
border-radius: 8px;
|
||||
background-color: #f7f9fc;
|
||||
margin-bottom: 20px;
|
||||
scroll-behavior: smooth;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin-bottom: 15px;
|
||||
padding: 12px 15px;
|
||||
border-radius: 12px;
|
||||
max-width: 85%;
|
||||
position: relative;
|
||||
animation: fade-in 0.3s ease-out forwards;
|
||||
}
|
||||
|
||||
@keyframes fade-in {
|
||||
from { opacity: 0; transform: translateY(10px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
.user {
|
||||
background-color: #e3f2fd;
|
||||
color: #0d47a1;
|
||||
margin-left: auto;
|
||||
border-bottom-right-radius: 4px;
|
||||
}
|
||||
|
||||
.ai {
|
||||
background-color: #f1f1f1;
|
||||
color: #37474f;
|
||||
margin-right: auto;
|
||||
border-bottom-left-radius: 4px;
|
||||
}
|
||||
|
||||
.system {
|
||||
background-color: #f8f9fa;
|
||||
font-style: italic;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
max-width: 90%;
|
||||
margin: 10px auto;
|
||||
font-size: 0.9em;
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.audio-player {
|
||||
width: 100%;
|
||||
margin-top: 8px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 12px 20px;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
background-color: var(--primary-color);
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
}
|
||||
|
||||
button.recording {
|
||||
background-color: var(--accent-color);
|
||||
animation: pulse 1.5s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { opacity: 1; }
|
||||
50% { opacity: 0.7; }
|
||||
100% { opacity: 1; }
|
||||
}
|
||||
|
||||
.status-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
font-size: 0.9em;
|
||||
color: #555;
|
||||
}
|
||||
|
||||
.status-dot {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background-color: #ccc;
|
||||
}
|
||||
|
||||
.status-dot.active {
|
||||
background-color: var(--success-color);
|
||||
}
|
||||
|
||||
footer {
|
||||
margin-top: 30px;
|
||||
text-align: center;
|
||||
font-size: 0.8em;
|
||||
color: #777;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>CSM Voice Chat</h1>
|
||||
<p class="subtitle">Talk naturally with the AI using your voice</p>
|
||||
</header>
|
||||
|
||||
<div class="app-container">
|
||||
<div class="chat-container">
|
||||
<div class="chat-header">
|
||||
<h2>Conversation</h2>
|
||||
<div class="status-indicator">
|
||||
<div id="statusDot" class="status-dot"></div>
|
||||
<span id="statusText">Disconnected</span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="conversation" class="conversation"></div>
|
||||
</div>
|
||||
|
||||
<div class="control-panel">
|
||||
<div>
|
||||
<h3>Controls</h3>
|
||||
<p>Click the button below to start and stop recording.</p>
|
||||
<div class="button-row">
|
||||
<button id="streamButton">
|
||||
<i class="fas fa-microphone"></i>
|
||||
Start Conversation
|
||||
</button>
|
||||
<button id="clearButton">
|
||||
<i class="fas fa-trash"></i>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="settings-panel">
|
||||
<h3>Settings</h3>
|
||||
<div class="settings-toggles">
|
||||
<div class="toggle-switch">
|
||||
<input type="checkbox" id="autoPlayResponses" checked>
|
||||
<label for="autoPlayResponses">Autoplay Responses</label>
|
||||
</div>
|
||||
<div>
|
||||
<label for="speakerSelect">Speaker Voice:</label>
|
||||
<select id="speakerSelect">
|
||||
<option value="0">Speaker 0 (You)</option>
|
||||
<option value="1">Speaker 1 (AI)</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<p>Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition</p>
|
||||
</footer>
|
||||
|
||||
<script>
|
||||
// Configuration constants
|
||||
const SERVER_URL = window.location.hostname === 'localhost' ?
|
||||
'http://localhost:5000' : window.location.origin;
|
||||
const ENERGY_WINDOW_SIZE = 15;
|
||||
const CLIENT_SILENCE_DURATION_MS = 750;
|
||||
|
||||
// DOM elements
|
||||
const elements = {
|
||||
conversation: document.getElementById('conversation'),
|
||||
streamButton: document.getElementById('streamButton'),
|
||||
clearButton: document.getElementById('clearButton'),
|
||||
statusDot: document.getElementById('statusDot'),
|
||||
statusText: document.getElementById('statusText'),
|
||||
speakerSelection: document.getElementById('speakerSelect'),
|
||||
autoPlayResponses: document.getElementById('autoPlayResponses')
|
||||
};
|
||||
|
||||
// Application state
|
||||
const state = {
|
||||
socket: null,
|
||||
audioContext: null,
|
||||
analyser: null,
|
||||
microphone: null,
|
||||
streamProcessor: null,
|
||||
isStreaming: false,
|
||||
isSpeaking: false,
|
||||
silenceThreshold: 0.01,
|
||||
energyWindow: [],
|
||||
silenceTimer: null,
|
||||
currentSpeaker: 0
|
||||
};
|
||||
|
||||
// Initialize the application
|
||||
function initializeApp() {
|
||||
// Initialize socket.io connection
|
||||
setupSocketConnection();
|
||||
|
||||
// Setup event listeners
|
||||
setupEventListeners();
|
||||
|
||||
// Show welcome message
|
||||
addSystemMessage('Welcome to CSM Voice Chat! Click "Start Conversation" to begin.');
|
||||
}
|
||||
|
||||
// Setup Socket.IO connection
|
||||
function setupSocketConnection() {
|
||||
state.socket = io(SERVER_URL);
|
||||
|
||||
// Connection events
|
||||
state.socket.on('connect', () => {
|
||||
updateConnectionStatus(true);
|
||||
addSystemMessage('Connected to server.');
|
||||
});
|
||||
|
||||
state.socket.on('disconnect', () => {
|
||||
updateConnectionStatus(false);
|
||||
addSystemMessage('Disconnected from server.');
|
||||
stopStreaming(false);
|
||||
});
|
||||
|
||||
state.socket.on('error', (data) => {
|
||||
addSystemMessage(`Error: ${data.message}`);
|
||||
console.error('Server error:', data.message);
|
||||
});
|
||||
|
||||
// Register message handlers
|
||||
state.socket.on('transcription', handleTranscription);
|
||||
state.socket.on('context_updated', handleContextUpdate);
|
||||
state.socket.on('streaming_status', handleStreamingStatus);
|
||||
state.socket.on('processing_status', handleProcessingStatus);
|
||||
|
||||
// Handlers for incremental audio streaming
|
||||
state.socket.on('audio_response_start', handleAudioResponseStart);
|
||||
state.socket.on('audio_response_chunk', handleAudioResponseChunk);
|
||||
state.socket.on('audio_response_complete', handleAudioResponseComplete);
|
||||
}
|
||||
|
||||
// Setup event listeners
|
||||
function setupEventListeners() {
|
||||
// Stream button
|
||||
elements.streamButton.addEventListener('click', toggleStreaming);
|
||||
|
||||
// Clear button
|
||||
elements.clearButton.addEventListener('click', clearConversation);
|
||||
|
||||
// Speaker selection
|
||||
elements.speakerSelection.addEventListener('change', () => {
|
||||
state.currentSpeaker = parseInt(elements.speakerSelection.value);
|
||||
});
|
||||
}
|
||||
|
||||
// Update connection status UI
|
||||
function updateConnectionStatus(isConnected) {
|
||||
if (isConnected) {
|
||||
elements.statusDot.classList.add('active');
|
||||
elements.statusText.textContent = 'Connected';
|
||||
} else {
|
||||
elements.statusDot.classList.remove('active');
|
||||
elements.statusText.textContent = 'Disconnected';
|
||||
}
|
||||
}
|
||||
|
||||
// Toggle streaming state
|
||||
function toggleStreaming() {
|
||||
if (state.isStreaming) {
|
||||
stopStreaming();
|
||||
} else {
|
||||
startStreaming();
|
||||
}
|
||||
}
|
||||
|
||||
// Start streaming audio to the server
|
||||
function startStreaming() {
|
||||
if (!state.socket || !state.socket.connected) {
|
||||
addSystemMessage('Not connected to server. Please refresh the page.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Request microphone access
|
||||
navigator.mediaDevices.getUserMedia({ audio: true, video: false })
|
||||
.then(stream => {
|
||||
state.isStreaming = true;
|
||||
elements.streamButton.classList.add('recording');
|
||||
elements.streamButton.innerHTML = '<i class="fas fa-stop"></i> Stop Recording';
|
||||
|
||||
// Initialize Web Audio API
|
||||
state.audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
|
||||
state.microphone = state.audioContext.createMediaStreamSource(stream);
|
||||
state.analyser = state.audioContext.createAnalyser();
|
||||
state.analyser.fftSize = 1024;
|
||||
|
||||
state.microphone.connect(state.analyser);
|
||||
|
||||
// Create processor node for audio data
|
||||
const processorNode = state.audioContext.createScriptProcessor(4096, 1, 1);
|
||||
processorNode.onaudioprocess = handleAudioProcess;
|
||||
state.analyser.connect(processorNode);
|
||||
processorNode.connect(state.audioContext.destination);
|
||||
state.streamProcessor = processorNode;
|
||||
|
||||
state.silenceTimer = null;
|
||||
state.energyWindow = [];
|
||||
state.isSpeaking = false;
|
||||
|
||||
// Notify server
|
||||
state.socket.emit('start_stream');
|
||||
addSystemMessage('Recording started. Speak now...');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error accessing microphone:', error);
|
||||
addSystemMessage('Could not access microphone. Please check permissions.');
|
||||
});
|
||||
}
|
||||
|
||||
// Stop streaming audio
|
||||
function stopStreaming(notifyServer = true) {
|
||||
if (state.isStreaming) {
|
||||
state.isStreaming = false;
|
||||
elements.streamButton.classList.remove('recording');
|
||||
elements.streamButton.innerHTML = '<i class="fas fa-microphone"></i> Start Conversation';
|
||||
|
||||
// Clean up audio resources
|
||||
if (state.streamProcessor) {
|
||||
state.streamProcessor.disconnect();
|
||||
state.streamProcessor = null;
|
||||
}
|
||||
|
||||
if (state.analyser) {
|
||||
state.analyser.disconnect();
|
||||
state.analyser = null;
|
||||
}
|
||||
|
||||
if (state.microphone) {
|
||||
state.microphone.disconnect();
|
||||
state.microphone = null;
|
||||
}
|
||||
|
||||
if (state.audioContext) {
|
||||
state.audioContext.close();
|
||||
state.audioContext = null;
|
||||
}
|
||||
|
||||
// Clear any pending silence timer
|
||||
if (state.silenceTimer) {
|
||||
clearTimeout(state.silenceTimer);
|
||||
state.silenceTimer = null;
|
||||
}
|
||||
|
||||
// Notify server if needed
|
||||
if (notifyServer && state.socket && state.socket.connected) {
|
||||
state.socket.emit('stop_stream');
|
||||
}
|
||||
|
||||
addSystemMessage('Recording stopped.');
|
||||
}
|
||||
}
|
||||
|
||||
// Handle audio processing
|
||||
function handleAudioProcess(event) {
|
||||
if (!state.isStreaming) return;
|
||||
|
||||
const inputData = event.inputBuffer.getChannelData(0);
|
||||
const energy = calculateAudioEnergy(inputData);
|
||||
updateEnergyWindow(energy);
|
||||
|
||||
const averageEnergy = calculateAverageEnergy();
|
||||
const isSilent = averageEnergy < state.silenceThreshold;
|
||||
|
||||
handleSpeechState(isSilent);
|
||||
}
|
||||
|
||||
// Calculate audio energy (volume)
|
||||
function calculateAudioEnergy(buffer) {
|
||||
let sum = 0;
|
||||
for (let i = 0; i < buffer.length; i++) {
|
||||
sum += buffer[i] * buffer[i];
|
||||
}
|
||||
return Math.sqrt(sum / buffer.length);
|
||||
}
|
||||
|
||||
// Update energy window for averaging
|
||||
function updateEnergyWindow(energy) {
|
||||
state.energyWindow.push(energy);
|
||||
if (state.energyWindow.length > ENERGY_WINDOW_SIZE) {
|
||||
state.energyWindow.shift();
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate average energy from window
|
||||
function calculateAverageEnergy() {
|
||||
if (state.energyWindow.length === 0) return 0;
|
||||
|
||||
const sum = state.energyWindow.reduce((acc, val) => acc + val, 0);
|
||||
return sum / state.energyWindow.length;
|
||||
}
|
||||
|
||||
// Handle speech/silence state transitions
|
||||
function handleSpeechState(isSilent) {
|
||||
if (state.isSpeaking) {
|
||||
if (isSilent) {
|
||||
// User was speaking but now is silent
|
||||
if (!state.silenceTimer) {
|
||||
state.silenceTimer = setTimeout(() => {
|
||||
// Silence lasted long enough, consider speech done
|
||||
if (state.isSpeaking) {
|
||||
state.isSpeaking = false;
|
||||
|
||||
// Get the current audio data and send it
|
||||
const audioBuffer = new Float32Array(state.audioContext.sampleRate * 5); // 5 seconds max
|
||||
state.analyser.getFloatTimeDomainData(audioBuffer);
|
||||
|
||||
// Create WAV blob
|
||||
const wavBlob = createWavBlob(audioBuffer, state.audioContext.sampleRate);
|
||||
|
||||
// Convert to base64
|
||||
const reader = new FileReader();
|
||||
reader.onloadend = function() {
|
||||
sendAudioChunk(reader.result, state.currentSpeaker);
|
||||
};
|
||||
reader.readAsDataURL(wavBlob);
|
||||
|
||||
addSystemMessage('Processing your message...');
|
||||
}
|
||||
}, CLIENT_SILENCE_DURATION_MS);
|
||||
}
|
||||
} else {
|
||||
// User is still speaking, reset silence timer
|
||||
if (state.silenceTimer) {
|
||||
clearTimeout(state.silenceTimer);
|
||||
state.silenceTimer = null;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!isSilent) {
|
||||
// User started speaking
|
||||
state.isSpeaking = true;
|
||||
if (state.silenceTimer) {
|
||||
clearTimeout(state.silenceTimer);
|
||||
state.silenceTimer = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send audio chunk to server
|
||||
function sendAudioChunk(audioData, speaker) {
|
||||
if (state.socket && state.socket.connected) {
|
||||
state.socket.emit('audio_chunk', {
|
||||
audio: audioData,
|
||||
speaker: speaker
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Create WAV blob from audio data
|
||||
function createWavBlob(audioData, sampleRate) {
|
||||
const numChannels = 1;
|
||||
const bitsPerSample = 16;
|
||||
const bytesPerSample = bitsPerSample / 8;
|
||||
|
||||
// Create buffer for WAV file
|
||||
const buffer = new ArrayBuffer(44 + audioData.length * bytesPerSample);
|
||||
const view = new DataView(buffer);
|
||||
|
||||
// Write WAV header
|
||||
// "RIFF" chunk descriptor
|
||||
writeString(view, 0, 'RIFF');
|
||||
view.setUint32(4, 36 + audioData.length * bytesPerSample, true);
|
||||
writeString(view, 8, 'WAVE');
|
||||
|
||||
// "fmt " sub-chunk
|
||||
writeString(view, 12, 'fmt ');
|
||||
view.setUint32(16, 16, true); // subchunk1size
|
||||
view.setUint16(20, 1, true); // audio format (PCM)
|
||||
view.setUint16(22, numChannels, true);
|
||||
view.setUint32(24, sampleRate, true);
|
||||
view.setUint32(28, sampleRate * numChannels * bytesPerSample, true); // byte rate
|
||||
view.setUint16(32, numChannels * bytesPerSample, true); // block align
|
||||
view.setUint16(34, bitsPerSample, true);
|
||||
|
||||
// "data" sub-chunk
|
||||
writeString(view, 36, 'data');
|
||||
view.setUint32(40, audioData.length * bytesPerSample, true);
|
||||
|
||||
// Write audio data
|
||||
const audioDataStart = 44;
|
||||
for (let i = 0; i < audioData.length; i++) {
|
||||
const sample = Math.max(-1, Math.min(1, audioData[i]));
|
||||
const value = sample < 0 ? sample * 0x8000 : sample * 0x7FFF;
|
||||
view.setInt16(audioDataStart + i * bytesPerSample, value, true);
|
||||
}
|
||||
|
||||
return new Blob([buffer], { type: 'audio/wav' });
|
||||
}
|
||||
|
||||
// Helper function to write strings to DataView
|
||||
function writeString(view, offset, string) {
|
||||
for (let i = 0; i < string.length; i++) {
|
||||
view.setUint8(offset + i, string.charCodeAt(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Clear conversation history
|
||||
function clearConversation() {
|
||||
elements.conversation.innerHTML = '';
|
||||
if (state.socket && state.socket.connected) {
|
||||
state.socket.emit('clear_context');
|
||||
}
|
||||
addSystemMessage('Conversation cleared.');
|
||||
}
|
||||
|
||||
// Handle transcription response from server
|
||||
function handleTranscription(data) {
|
||||
const speaker = data.speaker === 0 ? 'user' : 'ai';
|
||||
addMessage(data.text, speaker);
|
||||
}
|
||||
|
||||
// Handle context update from server
|
||||
function handleContextUpdate(data) {
|
||||
if (data.status === 'cleared') {
|
||||
elements.conversation.innerHTML = '';
|
||||
addSystemMessage('Conversation context cleared.');
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming status updates from server
|
||||
function handleStreamingStatus(data) {
|
||||
if (data.status === 'active') {
|
||||
console.log('Server acknowledged streaming is active');
|
||||
} else if (data.status === 'inactive') {
|
||||
console.log('Server acknowledged streaming is inactive');
|
||||
}
|
||||
}
|
||||
|
||||
// Handle processing status updates
|
||||
function handleProcessingStatus(data) {
|
||||
switch (data.status) {
|
||||
case 'transcribing':
|
||||
addSystemMessage('Transcribing your message...');
|
||||
break;
|
||||
case 'generating':
|
||||
addSystemMessage('Generating response...');
|
||||
break;
|
||||
case 'synthesizing':
|
||||
addSystemMessage('Synthesizing voice...');
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the start of an audio streaming response
|
||||
function handleAudioResponseStart(data) {
|
||||
// Prepare for receiving chunked audio
|
||||
console.log(`Expecting ${data.total_chunks} audio chunks`);
|
||||
}
|
||||
|
||||
// Handle an incoming audio chunk
|
||||
function handleAudioResponseChunk(data) {
|
||||
// Create audio element for the response
|
||||
const audioElement = document.createElement('audio');
|
||||
if (elements.autoPlayResponses.checked) {
|
||||
audioElement.autoplay = true;
|
||||
}
|
||||
audioElement.controls = true;
|
||||
audioElement.className = 'audio-player';
|
||||
audioElement.src = data.chunk;
|
||||
|
||||
// Add to the most recent AI message if it exists
|
||||
const messages = elements.conversation.querySelectorAll('.message.ai');
|
||||
if (messages.length > 0) {
|
||||
const lastAiMessage = messages[messages.length - 1];
|
||||
lastAiMessage.appendChild(audioElement);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle completion of audio streaming
|
||||
function handleAudioResponseComplete(data) {
|
||||
// Update the AI message with the full text
|
||||
addMessage(data.text, 'ai');
|
||||
}
|
||||
|
||||
// Add a message to the conversation
|
||||
function addMessage(text, sender) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${sender}`;
|
||||
messageDiv.textContent = text;
|
||||
|
||||
const timeSpan = document.createElement('span');
|
||||
timeSpan.className = 'message-time';
|
||||
const now = new Date();
|
||||
timeSpan.textContent = `${now.getHours().toString().padStart(2, '0')}:${now.getMinutes().toString().padStart(2, '0')}`;
|
||||
messageDiv.appendChild(timeSpan);
|
||||
|
||||
elements.conversation.appendChild(messageDiv);
|
||||
elements.conversation.scrollTop = elements.conversation.scrollHeight;
|
||||
}
|
||||
|
||||
// Add a system message to the conversation
|
||||
function addSystemMessage(message) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = 'message system';
|
||||
messageDiv.textContent = message;
|
||||
elements.conversation.appendChild(messageDiv);
|
||||
elements.conversation.scrollTop = elements.conversation.scrollHeight;
|
||||
}
|
||||
|
||||
// Initialize the application when DOM is fully loaded
|
||||
document.addEventListener('DOMContentLoaded', initializeApp);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
203
Backend/models.py
Normal file
203
Backend/models.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchtune
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torchtune.models import llama3_2
|
||||
|
||||
|
||||
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=16,
|
||||
num_heads=32,
|
||||
num_kv_heads=8,
|
||||
embed_dim=2048,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
|
||||
return llama3_2.llama3_2(
|
||||
vocab_size=128_256,
|
||||
num_layers=4,
|
||||
num_heads=8,
|
||||
num_kv_heads=2,
|
||||
embed_dim=1024,
|
||||
max_seq_len=2048,
|
||||
intermediate_dim=8192,
|
||||
attn_dropout=0.0,
|
||||
norm_eps=1e-5,
|
||||
rope_base=500_000,
|
||||
scale_factor=32,
|
||||
)
|
||||
|
||||
|
||||
FLAVORS = {
|
||||
"llama-1B": llama3_2_1B,
|
||||
"llama-100M": llama3_2_100M,
|
||||
}
|
||||
|
||||
|
||||
def _prepare_transformer(model):
|
||||
embed_dim = model.tok_embeddings.embedding_dim
|
||||
model.tok_embeddings = nn.Identity()
|
||||
model.output = nn.Identity()
|
||||
return model, embed_dim
|
||||
|
||||
|
||||
def _create_causal_mask(seq_len: int, device: torch.device):
|
||||
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
||||
|
||||
|
||||
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
mask: (max_seq_len, max_seq_len)
|
||||
input_pos: (batch_size, seq_len)
|
||||
|
||||
Returns:
|
||||
(batch_size, seq_len, max_seq_len)
|
||||
"""
|
||||
r = mask[input_pos, :]
|
||||
return r
|
||||
|
||||
|
||||
def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs).exponential_(1)
|
||||
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
||||
logits = logits / temperature
|
||||
|
||||
filter_value: float = -float("Inf")
|
||||
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
||||
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
||||
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
||||
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
||||
|
||||
sample_token = _multinomial_sample_one_no_sync(probs)
|
||||
return sample_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
backbone_flavor: str
|
||||
decoder_flavor: str
|
||||
text_vocab_size: int
|
||||
audio_vocab_size: int
|
||||
audio_num_codebooks: int
|
||||
|
||||
|
||||
class Model(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
repo_url="https://github.com/SesameAILabs/csm",
|
||||
pipeline_tag="text-to-speech",
|
||||
license="apache-2.0",
|
||||
):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
|
||||
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
|
||||
|
||||
self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
|
||||
self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
|
||||
|
||||
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
||||
self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
|
||||
self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
|
||||
|
||||
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
||||
"""Setup KV caches and return a causal mask."""
|
||||
dtype = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
|
||||
with device:
|
||||
self.backbone.setup_caches(max_batch_size, dtype)
|
||||
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
|
||||
|
||||
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
||||
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
|
||||
|
||||
def generate_frame(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
tokens_mask: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
temperature: float,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
||||
input_pos: (batch_size, seq_len) positions for each token
|
||||
mask: (batch_size, seq_len, max_seq_len
|
||||
|
||||
Returns:
|
||||
(batch_size, audio_num_codebooks) sampled tokens
|
||||
"""
|
||||
dtype = next(self.parameters()).dtype
|
||||
b, s, _ = tokens.size()
|
||||
|
||||
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
||||
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
||||
embeds = self._embed_tokens(tokens)
|
||||
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
||||
h = masked_embeds.sum(dim=2)
|
||||
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
||||
|
||||
last_h = h[:, -1, :]
|
||||
c0_logits = self.codebook0_head(last_h)
|
||||
c0_sample = sample_topk(c0_logits, topk, temperature)
|
||||
c0_embed = self._embed_audio(0, c0_sample)
|
||||
|
||||
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
||||
curr_sample = c0_sample.clone()
|
||||
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
||||
|
||||
# Decoder caches must be reset every frame.
|
||||
self.decoder.reset_caches()
|
||||
for i in range(1, self.config.audio_num_codebooks):
|
||||
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
||||
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
||||
dtype=dtype
|
||||
)
|
||||
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
||||
ci_sample = sample_topk(ci_logits, topk, temperature)
|
||||
ci_embed = self._embed_audio(i, ci_sample)
|
||||
|
||||
curr_h = ci_embed
|
||||
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
||||
curr_pos = curr_pos[:, -1:] + 1
|
||||
|
||||
return curr_sample
|
||||
|
||||
def reset_caches(self):
|
||||
self.backbone.reset_caches()
|
||||
self.decoder.reset_caches()
|
||||
|
||||
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
|
||||
|
||||
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
||||
|
||||
audio_tokens = tokens[:, :, :-1] + (
|
||||
self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
|
||||
)
|
||||
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
||||
tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
|
||||
)
|
||||
|
||||
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
||||
@@ -1,16 +1,9 @@
|
||||
Flask==2.2.2
|
||||
Flask-SocketIO==5.3.2
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
transformers>=4.30.0
|
||||
huggingface-hub>=0.14.0
|
||||
python-dotenv==0.19.2
|
||||
numpy>=1.21.6
|
||||
scipy>=1.7.3
|
||||
soundfile==0.10.3.post1
|
||||
requests==2.28.1
|
||||
pydub==0.25.1
|
||||
python-socketio==5.7.2
|
||||
eventlet==0.33.3
|
||||
whisper>=20230314
|
||||
ffmpeg-python>=0.2.0
|
||||
torch==2.4.0
|
||||
torchaudio==2.4.0
|
||||
tokenizers==0.21.0
|
||||
transformers==4.49.0
|
||||
huggingface_hub==0.28.1
|
||||
moshi==0.2.2
|
||||
torchtune==0.4.0
|
||||
torchao==0.9.0
|
||||
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
||||
117
Backend/run_csm.py
Normal file
117
Backend/run_csm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Disable Triton compilation
|
||||
os.environ["NO_TORCH_COMPILE"] = "1"
|
||||
|
||||
# Default prompts are available at https://hf.co/sesame/csm-1b
|
||||
prompt_filepath_conversational_a = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_a.wav"
|
||||
)
|
||||
prompt_filepath_conversational_b = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_b.wav"
|
||||
)
|
||||
|
||||
SPEAKER_PROMPTS = {
|
||||
"conversational_a": {
|
||||
"text": (
|
||||
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
||||
"start really early I'd be like okay I'm gonna start revising now and then like "
|
||||
"you're revising for ages and then I just like start losing steam I didn't do that "
|
||||
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
||||
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
||||
"sort of start the day with this not like a panic but like a"
|
||||
),
|
||||
"audio": prompt_filepath_conversational_a
|
||||
},
|
||||
"conversational_b": {
|
||||
"text": (
|
||||
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
||||
"into the park, it just like, everything looks like a computer game and they have all "
|
||||
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
||||
"will have like a question block. And if you like, you know, punch it, a coin will "
|
||||
"come out. So like everyone, when they come into the park, they get like this little "
|
||||
"bracelet and then you can go punching question blocks around."
|
||||
),
|
||||
"audio": prompt_filepath_conversational_b
|
||||
}
|
||||
}
|
||||
|
||||
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
# Resample is lazy so we can always call it
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
||||
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
||||
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
||||
|
||||
def main():
|
||||
# Select the best available device, skipping MPS due to float64 limitations
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
generator = load_csm_1b(device)
|
||||
|
||||
# Prepare prompts
|
||||
prompt_a = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_a"]["text"],
|
||||
0,
|
||||
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
prompt_b = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_b"]["text"],
|
||||
1,
|
||||
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
# Generate conversation
|
||||
conversation = [
|
||||
{"text": "Hey how are you doing?", "speaker_id": 0},
|
||||
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
||||
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
||||
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
||||
]
|
||||
|
||||
# Generate each utterance
|
||||
generated_segments = []
|
||||
prompt_segments = [prompt_a, prompt_b]
|
||||
|
||||
for utterance in conversation:
|
||||
print(f"Generating: {utterance['text']}")
|
||||
audio_tensor = generator.generate(
|
||||
text=utterance['text'],
|
||||
speaker=utterance['speaker_id'],
|
||||
context=prompt_segments + generated_segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
||||
|
||||
# Concatenate all generations
|
||||
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
||||
torchaudio.save(
|
||||
"full_conversation.wav",
|
||||
all_audio.unsqueeze(0).cpu(),
|
||||
generator.sample_rate
|
||||
)
|
||||
print("Successfully generated full_conversation.wav")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,53 +1,426 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import eventlet
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from flask import Flask, render_template, request, jsonify
|
||||
from flask_socketio import SocketIO, emit
|
||||
import whisper
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from src.models.conversation import Segment
|
||||
from src.services.tts_service import load_csm_1b
|
||||
from src.llm.generator import generate_llm_response
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from src.audio.streaming import AudioStreamer
|
||||
from src.services.transcription_service import TranscriptionService
|
||||
from src.services.tts_service import TextToSpeechService
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify, send_from_directory
|
||||
from flask_socketio import SocketIO, emit
|
||||
from flask_cors import CORS
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__, static_folder='static', template_folder='templates')
|
||||
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'your-secret-key')
|
||||
socketio = SocketIO(app)
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__, static_folder='.')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||
|
||||
# Initialize services
|
||||
transcription_service = TranscriptionService()
|
||||
tts_service = TextToSpeechService()
|
||||
audio_streamer = AudioStreamer()
|
||||
# Configure device
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
@socketio.on('audio_input')
|
||||
def handle_audio_input(data):
|
||||
audio_chunk = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
logger.info(f"Using device: {DEVICE}")
|
||||
|
||||
# Global variables
|
||||
active_conversations = {}
|
||||
user_queues = {}
|
||||
processing_threads = {}
|
||||
|
||||
# Load models
|
||||
@dataclass
|
||||
class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
asr = None
|
||||
|
||||
models = AppModels()
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
global models
|
||||
|
||||
# Process audio and convert to text
|
||||
text = transcription_service.transcribe(audio_chunk)
|
||||
logging.info(f"Transcribed text: {text}")
|
||||
|
||||
# Generate response using Llama 3.2
|
||||
response_text = tts_service.generate_response(text, speaker_id)
|
||||
logging.info(f"Generated response: {response_text}")
|
||||
|
||||
# Convert response text to audio
|
||||
audio_response = tts_service.text_to_speech(response_text, speaker_id)
|
||||
logger.info("Loading CSM 1B model...")
|
||||
models.generator = load_csm_1b(device=DEVICE)
|
||||
|
||||
# Stream audio response back to client
|
||||
socketio.emit('audio_response', {'audio': audio_response})
|
||||
logger.info("Loading ASR pipeline...")
|
||||
models.asr = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-small",
|
||||
device=DEVICE
|
||||
)
|
||||
|
||||
logger.info("Loading Llama 3.2 model...")
|
||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
device_map=DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
|
||||
# Load models in a background thread
|
||||
threading.Thread(target=load_models, daemon=True).start()
|
||||
|
||||
# Conversation data structure
|
||||
class Conversation:
|
||||
def __init__(self, session_id):
|
||||
self.session_id = session_id
|
||||
self.segments: List[Segment] = []
|
||||
self.current_speaker = 0
|
||||
self.last_activity = time.time()
|
||||
self.is_processing = False
|
||||
|
||||
def add_segment(self, text, speaker, audio):
|
||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||
self.segments.append(segment)
|
||||
self.last_activity = time.time()
|
||||
return segment
|
||||
|
||||
def get_context(self, max_segments=10):
|
||||
"""Return the most recent segments for context"""
|
||||
return self.segments[-max_segments:] if self.segments else []
|
||||
|
||||
# Routes
|
||||
@app.route('/')
|
||||
def index():
|
||||
return send_from_directory('.', 'index.html')
|
||||
|
||||
@app.route('/api/health')
|
||||
def health_check():
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"models_loaded": models.generator is not None and models.llm is not None
|
||||
})
|
||||
|
||||
# Socket event handlers
|
||||
@socketio.on('connect')
|
||||
def handle_connect():
|
||||
session_id = request.sid
|
||||
logger.info(f"Client connected: {session_id}")
|
||||
|
||||
# Initialize conversation data
|
||||
if session_id not in active_conversations:
|
||||
active_conversations[session_id] = Conversation(session_id)
|
||||
user_queues[session_id] = queue.Queue()
|
||||
processing_threads[session_id] = threading.Thread(
|
||||
target=process_audio_queue,
|
||||
args=(session_id, user_queues[session_id]),
|
||||
daemon=True
|
||||
)
|
||||
processing_threads[session_id].start()
|
||||
|
||||
emit('connection_status', {'status': 'connected'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
session_id = request.sid
|
||||
logger.info(f"Client disconnected: {session_id}")
|
||||
|
||||
# Cleanup
|
||||
if session_id in active_conversations:
|
||||
# Mark for deletion rather than immediately removing
|
||||
# as the processing thread might still be accessing it
|
||||
active_conversations[session_id].is_processing = False
|
||||
user_queues[session_id].put(None) # Signal thread to terminate
|
||||
|
||||
@socketio.on('start_stream')
|
||||
def handle_start_stream():
|
||||
session_id = request.sid
|
||||
logger.info(f"Starting stream for client: {session_id}")
|
||||
emit('streaming_status', {'status': 'active'})
|
||||
|
||||
@socketio.on('stop_stream')
|
||||
def handle_stop_stream():
|
||||
session_id = request.sid
|
||||
logger.info(f"Stopping stream for client: {session_id}")
|
||||
emit('streaming_status', {'status': 'inactive'})
|
||||
|
||||
@socketio.on('clear_context')
|
||||
def handle_clear_context():
|
||||
session_id = request.sid
|
||||
logger.info(f"Clearing context for client: {session_id}")
|
||||
|
||||
if session_id in active_conversations:
|
||||
active_conversations[session_id].segments = []
|
||||
emit('context_updated', {'status': 'cleared'})
|
||||
|
||||
@socketio.on('audio_chunk')
|
||||
def handle_audio_chunk(data):
|
||||
session_id = request.sid
|
||||
audio_data = data.get('audio', '')
|
||||
speaker_id = int(data.get('speaker', 0))
|
||||
|
||||
if not audio_data or not session_id in active_conversations:
|
||||
return
|
||||
|
||||
# Update the current speaker
|
||||
active_conversations[session_id].current_speaker = speaker_id
|
||||
|
||||
# Queue audio for processing
|
||||
user_queues[session_id].put({
|
||||
'audio': audio_data,
|
||||
'speaker': speaker_id
|
||||
})
|
||||
|
||||
emit('processing_status', {'status': 'transcribing'})
|
||||
|
||||
def process_audio_queue(session_id, q):
|
||||
"""Background thread to process audio chunks for a session"""
|
||||
logger.info(f"Started processing thread for session: {session_id}")
|
||||
|
||||
try:
|
||||
while session_id in active_conversations:
|
||||
try:
|
||||
# Get the next audio chunk with a timeout
|
||||
data = q.get(timeout=120)
|
||||
if data is None: # Termination signal
|
||||
break
|
||||
|
||||
# Process the audio and generate a response
|
||||
process_audio_and_respond(session_id, data)
|
||||
|
||||
except queue.Empty:
|
||||
# Timeout, check if session is still valid
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||
finally:
|
||||
logger.info(f"Ending processing thread for session: {session_id}")
|
||||
# Clean up when thread is done
|
||||
with app.app_context():
|
||||
if session_id in active_conversations:
|
||||
del active_conversations[session_id]
|
||||
if session_id in user_queues:
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data):
|
||||
"""Process audio data and generate a response"""
|
||||
if models.generator is None or models.asr is None or models.llm is None:
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
return
|
||||
|
||||
conversation = active_conversations[session_id]
|
||||
|
||||
try:
|
||||
# Set processing flag
|
||||
conversation.is_processing = True
|
||||
|
||||
# Process base64 audio data
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
|
||||
# Convert from base64 to WAV
|
||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||
|
||||
# Save to temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||
temp_file.write(audio_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Load audio file
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample to the CSM sample rate if needed
|
||||
if sample_rate != models.generator.sample_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Transcribe audio
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Use the ASR pipeline to transcribe
|
||||
transcription_result = models.asr(
|
||||
{"array": waveform.squeeze().cpu().numpy(), "sampling_rate": models.generator.sample_rate},
|
||||
return_timestamps=False
|
||||
)
|
||||
user_text = transcription_result['text'].strip()
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text:
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
speaker=speaker_id,
|
||||
audio=waveform.squeeze()
|
||||
)
|
||||
|
||||
# Send transcription to client
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||
|
||||
# Create prompt from conversation history
|
||||
conversation_history = ""
|
||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||
role = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{role}: {segment.text}\n"
|
||||
|
||||
# Add final prompt
|
||||
prompt = f"{conversation_history}Assistant: "
|
||||
|
||||
# Generate response with Llama
|
||||
input_ids = models.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=models.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the response
|
||||
response_text = models.tokenizer.decode(
|
||||
generated_ids[0][input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
).strip()
|
||||
|
||||
# Synthesize speech
|
||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||
|
||||
# Generate audio with CSM
|
||||
ai_speaker_id = 1 # Use speaker 1 for AI responses
|
||||
|
||||
# Start sending the audio response
|
||||
socketio.emit('audio_response_start', {
|
||||
'text': response_text,
|
||||
'total_chunks': 1,
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
# Convert audio to WAV format
|
||||
with io.BytesIO() as wav_io:
|
||||
torchaudio.save(
|
||||
wav_io,
|
||||
audio_tensor.unsqueeze(0).cpu(),
|
||||
models.generator.sample_rate,
|
||||
format="wav"
|
||||
)
|
||||
wav_io.seek(0)
|
||||
wav_data = wav_io.read()
|
||||
|
||||
# Convert WAV data to base64
|
||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||
|
||||
# Send audio chunk to client
|
||||
socketio.emit('audio_response_chunk', {
|
||||
'chunk': audio_base64,
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'is_last': True
|
||||
}, room=session_id)
|
||||
|
||||
# Signal completion
|
||||
socketio.emit('audio_response_complete', {
|
||||
'text': response_text
|
||||
}, room=session_id)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
# Reset processing flag
|
||||
conversation.is_processing = False
|
||||
|
||||
# Error handler
|
||||
@socketio.on_error()
|
||||
def error_handler(e):
|
||||
logger.error(f"SocketIO error: {str(e)}")
|
||||
|
||||
# Periodic cleanup of inactive sessions
|
||||
def cleanup_inactive_sessions():
|
||||
"""Remove sessions that have been inactive for too long"""
|
||||
current_time = time.time()
|
||||
inactive_timeout = 3600 # 1 hour
|
||||
|
||||
for session_id in list(active_conversations.keys()):
|
||||
conversation = active_conversations[session_id]
|
||||
if (current_time - conversation.last_activity > inactive_timeout and
|
||||
not conversation.is_processing):
|
||||
|
||||
logger.info(f"Cleaning up inactive session: {session_id}")
|
||||
|
||||
# Signal processing thread to terminate
|
||||
if session_id in user_queues:
|
||||
user_queues[session_id].put(None)
|
||||
|
||||
# Remove from active conversations
|
||||
del active_conversations[session_id]
|
||||
|
||||
# Start cleanup thread
|
||||
def start_cleanup_thread():
|
||||
while True:
|
||||
try:
|
||||
cleanup_inactive_sessions()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup thread: {str(e)}")
|
||||
time.sleep(300) # Run every 5 minutes
|
||||
|
||||
cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
# Start the server
|
||||
if __name__ == '__main__':
|
||||
socketio.run(app, host='0.0.0.0', port=5000)
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
logger.info(f"Starting server on port {port}")
|
||||
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True)
|
||||
13
Backend/setup.py
Normal file
13
Backend/setup.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
|
||||
# Read requirements from requirements.txt
|
||||
with open('requirements.txt') as f:
|
||||
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
setup(
|
||||
name='csm',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
from scipy.io import wavfile
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
|
||||
def load_audio(file_path):
|
||||
sample_rate, audio_data = wavfile.read(file_path)
|
||||
return sample_rate, audio_data
|
||||
|
||||
def normalize_audio(audio_data):
|
||||
audio_data = audio_data.astype(np.float32)
|
||||
max_val = np.max(np.abs(audio_data))
|
||||
if max_val > 0:
|
||||
audio_data /= max_val
|
||||
return audio_data
|
||||
|
||||
def reduce_noise(audio_data, noise_factor=0.1):
|
||||
noise = np.random.randn(len(audio_data))
|
||||
noisy_audio = audio_data + noise_factor * noise
|
||||
return noisy_audio
|
||||
|
||||
def save_audio(file_path, sample_rate, audio_data):
|
||||
torchaudio.save(file_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
|
||||
|
||||
def process_audio(file_path, output_path):
|
||||
sample_rate, audio_data = load_audio(file_path)
|
||||
normalized_audio = normalize_audio(audio_data)
|
||||
denoised_audio = reduce_noise(normalized_audio)
|
||||
save_audio(output_path, sample_rate, denoised_audio)
|
||||
@@ -1,35 +0,0 @@
|
||||
from flask import Blueprint, request
|
||||
from flask_socketio import SocketIO, emit
|
||||
from src.audio.processor import process_audio
|
||||
from src.services.transcription_service import TranscriptionService
|
||||
from src.services.tts_service import TextToSpeechService
|
||||
|
||||
streaming_bp = Blueprint('streaming', __name__)
|
||||
socketio = SocketIO()
|
||||
|
||||
transcription_service = TranscriptionService()
|
||||
tts_service = TextToSpeechService()
|
||||
|
||||
@socketio.on('audio_stream')
|
||||
def handle_audio_stream(data):
|
||||
audio_chunk = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
|
||||
# Process the audio chunk
|
||||
processed_audio = process_audio(audio_chunk)
|
||||
|
||||
# Transcribe the audio to text
|
||||
transcription = transcription_service.transcribe(processed_audio)
|
||||
|
||||
# Generate a response using the LLM
|
||||
response_text = generate_response(transcription, speaker_id)
|
||||
|
||||
# Convert the response text back to audio
|
||||
response_audio = tts_service.convert_text_to_speech(response_text, speaker_id)
|
||||
|
||||
# Emit the response audio back to the client
|
||||
emit('audio_response', {'audio': response_audio})
|
||||
|
||||
def generate_response(transcription, speaker_id):
|
||||
# Placeholder for the actual response generation logic
|
||||
return f"Response to: {transcription}"
|
||||
@@ -1,14 +0,0 @@
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
def load_llama3_tokenizer():
|
||||
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
return tokenizer
|
||||
|
||||
def tokenize_text(text: str, tokenizer) -> list:
|
||||
tokens = tokenizer.encode(text, return_tensors='pt')
|
||||
return tokens
|
||||
|
||||
def decode_tokens(tokens: list, tokenizer) -> str:
|
||||
text = tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
return text
|
||||
@@ -1,28 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class AudioModel:
|
||||
model: torch.nn.Module
|
||||
sample_rate: int
|
||||
|
||||
def __post_init__(self):
|
||||
self.model.eval()
|
||||
|
||||
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
processed_audio = self.model(audio_tensor)
|
||||
return processed_audio
|
||||
|
||||
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
|
||||
if self.sample_rate != target_sample_rate:
|
||||
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
|
||||
return resampled_audio
|
||||
return audio_tensor
|
||||
|
||||
def save_model(self, path: str):
|
||||
torch.save(self.model.state_dict(), path)
|
||||
|
||||
def load_model(self, path: str):
|
||||
self.model.load_state_dict(torch.load(path))
|
||||
self.model.eval()
|
||||
@@ -1,51 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
speaker: int
|
||||
text: str
|
||||
# (num_samples,), sample_rate = 24_000
|
||||
audio: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Ensure audio is a tensor if provided
|
||||
if self.audio is not None and not isinstance(self.audio, torch.Tensor):
|
||||
self.audio = torch.tensor(self.audio, dtype=torch.float32)
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
context: List[str] = field(default_factory=list)
|
||||
segments: List[Segment] = field(default_factory=list)
|
||||
current_speaker: Optional[int] = None
|
||||
|
||||
def add_message(self, message: str, speaker: int):
|
||||
self.context.append(f"Speaker {speaker}: {message}")
|
||||
self.current_speaker = speaker
|
||||
|
||||
def add_segment(self, segment: Segment):
|
||||
self.segments.append(segment)
|
||||
self.context.append(f"Speaker {segment.speaker}: {segment.text}")
|
||||
self.current_speaker = segment.speaker
|
||||
|
||||
def get_context(self) -> List[str]:
|
||||
return self.context
|
||||
|
||||
def get_segments(self) -> List[Segment]:
|
||||
return self.segments
|
||||
|
||||
def clear_context(self):
|
||||
self.context.clear()
|
||||
self.segments.clear()
|
||||
self.current_speaker = None
|
||||
|
||||
def get_last_message(self) -> Optional[str]:
|
||||
if self.context:
|
||||
return self.context[-1]
|
||||
return None
|
||||
|
||||
def get_last_segment(self) -> Optional[Segment]:
|
||||
if self.segments:
|
||||
return self.segments[-1]
|
||||
return None
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import List
|
||||
import torchaudio
|
||||
import torch
|
||||
from generator import load_csm_1b, Segment
|
||||
|
||||
class TranscriptionService:
|
||||
def __init__(self, model_device: str = "cpu"):
|
||||
self.generator = load_csm_1b(device=model_device)
|
||||
|
||||
def transcribe_audio(self, audio_path: str) -> str:
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = self._resample_audio(audio_tensor, sample_rate)
|
||||
transcription = self.generator.generate_transcription(audio_tensor)
|
||||
return transcription
|
||||
|
||||
def _resample_audio(self, audio_tensor: torch.Tensor, orig_freq: int) -> torch.Tensor:
|
||||
target_sample_rate = self.generator.sample_rate
|
||||
if orig_freq != target_sample_rate:
|
||||
audio_tensor = torchaudio.functional.resample(audio_tensor.squeeze(0), orig_freq=orig_freq, new_freq=target_sample_rate)
|
||||
return audio_tensor
|
||||
|
||||
def transcribe_audio_stream(self, audio_chunks: List[torch.Tensor]) -> str:
|
||||
combined_audio = torch.cat(audio_chunks, dim=1)
|
||||
transcription = self.generator.generate_transcription(combined_audio)
|
||||
return transcription
|
||||
@@ -1,24 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from src.llm.generator import load_csm_1b
|
||||
|
||||
@dataclass
|
||||
class TextToSpeechService:
|
||||
generator: any
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self.generator = load_csm_1b(device=device)
|
||||
|
||||
def text_to_speech(self, text: str, speaker: int = 0) -> torch.Tensor:
|
||||
audio = self.generator.generate(
|
||||
text=text,
|
||||
speaker=speaker,
|
||||
context=[],
|
||||
max_audio_length_ms=10000,
|
||||
)
|
||||
return audio
|
||||
|
||||
def save_audio(self, audio: torch.Tensor, file_path: str):
|
||||
torchaudio.save(file_path, audio.unsqueeze(0).cpu(), self.generator.sample_rate)
|
||||
@@ -1,23 +0,0 @@
|
||||
# filepath: /csm-conversation-bot/csm-conversation-bot/src/utils/config.py
|
||||
|
||||
import os
|
||||
|
||||
class Config:
|
||||
# General configuration
|
||||
DEBUG = os.getenv('DEBUG', 'False') == 'True'
|
||||
SECRET_KEY = os.getenv('SECRET_KEY', 'your_secret_key_here')
|
||||
|
||||
# API configuration
|
||||
API_URL = os.getenv('API_URL', 'http://localhost:5000')
|
||||
|
||||
# Model configuration
|
||||
LLM_MODEL_PATH = os.getenv('LLM_MODEL_PATH', 'path/to/llm/model')
|
||||
AUDIO_MODEL_PATH = os.getenv('AUDIO_MODEL_PATH', 'path/to/audio/model')
|
||||
|
||||
# Socket.IO configuration
|
||||
SOCKETIO_MESSAGE_QUEUE = os.getenv('SOCKETIO_MESSAGE_QUEUE', 'redis://localhost:6379/0')
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
|
||||
|
||||
# Other configurations can be added as needed
|
||||
@@ -1,14 +0,0 @@
|
||||
import logging
|
||||
|
||||
def setup_logger(name, log_file, level=logging.INFO):
|
||||
handler = logging.FileHandler(log_file)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
# Example usage:
|
||||
# logger = setup_logger('my_logger', 'app.log')
|
||||
@@ -1,105 +0,0 @@
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
background-color: #f4f4f4;
|
||||
color: #333;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
header {
|
||||
background: #4c84ff;
|
||||
color: #fff;
|
||||
padding: 10px 0;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
h1 {
|
||||
margin: 0;
|
||||
font-size: 2.5rem;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 80%;
|
||||
margin: auto;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.conversation {
|
||||
background: #fff;
|
||||
padding: 20px;
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.message {
|
||||
padding: 10px;
|
||||
margin: 10px 0;
|
||||
border-radius: 5px;
|
||||
}
|
||||
|
||||
.user {
|
||||
background: #e3f2fd;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.ai {
|
||||
background: #f1f1f1;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.controls {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 10px 15px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
transition: background 0.3s;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background: #3367d6;
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
.visualizer-container {
|
||||
height: 150px;
|
||||
background: #000;
|
||||
border-radius: 5px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.visualizer-label {
|
||||
color: rgba(255, 255, 255, 0.7);
|
||||
text-align: center;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.status-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.status-dot {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background-color: #ccc;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
.status-dot.active {
|
||||
background-color: #4CAF50;
|
||||
}
|
||||
|
||||
.status-text {
|
||||
font-size: 0.9em;
|
||||
color: #666;
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CSM Conversation Bot</title>
|
||||
<link rel="stylesheet" href="css/styles.css">
|
||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||
<script src="js/client.js" defer></script>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>CSM Conversation Bot</h1>
|
||||
<p>Talk to the AI and get responses in real-time!</p>
|
||||
</header>
|
||||
<main>
|
||||
<div id="conversation" class="conversation"></div>
|
||||
<div class="controls">
|
||||
<button id="startButton">Start Conversation</button>
|
||||
<button id="stopButton">Stop Conversation</button>
|
||||
</div>
|
||||
<div class="status-indicator">
|
||||
<div id="statusDot" class="status-dot"></div>
|
||||
<div id="statusText">Disconnected</div>
|
||||
</div>
|
||||
</main>
|
||||
<footer>
|
||||
<p>Powered by CSM and Llama 3.2</p>
|
||||
</footer>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,131 +0,0 @@
|
||||
// This file contains the client-side JavaScript code that handles audio streaming and communication with the server.
|
||||
|
||||
const SERVER_URL = window.location.hostname === 'localhost' ?
|
||||
'http://localhost:5000' : window.location.origin;
|
||||
|
||||
const elements = {
|
||||
conversation: document.getElementById('conversation'),
|
||||
streamButton: document.getElementById('streamButton'),
|
||||
clearButton: document.getElementById('clearButton'),
|
||||
speakerSelection: document.getElementById('speakerSelect'),
|
||||
statusDot: document.getElementById('statusDot'),
|
||||
statusText: document.getElementById('statusText'),
|
||||
};
|
||||
|
||||
const state = {
|
||||
socket: null,
|
||||
isStreaming: false,
|
||||
currentSpeaker: 0,
|
||||
};
|
||||
|
||||
// Initialize the application
|
||||
function initializeApp() {
|
||||
setupSocketConnection();
|
||||
setupEventListeners();
|
||||
}
|
||||
|
||||
// Setup Socket.IO connection
|
||||
function setupSocketConnection() {
|
||||
state.socket = io(SERVER_URL);
|
||||
|
||||
state.socket.on('connect', () => {
|
||||
updateConnectionStatus(true);
|
||||
});
|
||||
|
||||
state.socket.on('disconnect', () => {
|
||||
updateConnectionStatus(false);
|
||||
});
|
||||
|
||||
state.socket.on('audio_response', handleAudioResponse);
|
||||
state.socket.on('transcription', handleTranscription);
|
||||
}
|
||||
|
||||
// Setup event listeners
|
||||
function setupEventListeners() {
|
||||
elements.streamButton.addEventListener('click', toggleStreaming);
|
||||
elements.clearButton.addEventListener('click', clearConversation);
|
||||
elements.speakerSelection.addEventListener('change', (event) => {
|
||||
state.currentSpeaker = event.target.value;
|
||||
});
|
||||
}
|
||||
|
||||
// Update connection status UI
|
||||
function updateConnectionStatus(isConnected) {
|
||||
elements.statusDot.classList.toggle('active', isConnected);
|
||||
elements.statusText.textContent = isConnected ? 'Connected' : 'Disconnected';
|
||||
}
|
||||
|
||||
// Toggle streaming state
|
||||
function toggleStreaming() {
|
||||
if (state.isStreaming) {
|
||||
stopStreaming();
|
||||
} else {
|
||||
startStreaming();
|
||||
}
|
||||
}
|
||||
|
||||
// Start streaming audio to the server
|
||||
function startStreaming() {
|
||||
if (state.isStreaming) return;
|
||||
|
||||
navigator.mediaDevices.getUserMedia({ audio: true })
|
||||
.then(stream => {
|
||||
const mediaRecorder = new MediaRecorder(stream);
|
||||
mediaRecorder.start();
|
||||
|
||||
mediaRecorder.ondataavailable = (event) => {
|
||||
if (event.data.size > 0) {
|
||||
sendAudioChunk(event.data);
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = () => {
|
||||
state.isStreaming = false;
|
||||
elements.streamButton.innerHTML = 'Start Conversation';
|
||||
};
|
||||
|
||||
state.isStreaming = true;
|
||||
elements.streamButton.innerHTML = 'Stop Conversation';
|
||||
})
|
||||
.catch(err => {
|
||||
console.error('Error accessing microphone:', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Stop streaming audio
|
||||
function stopStreaming() {
|
||||
if (!state.isStreaming) return;
|
||||
|
||||
// Logic to stop the media recorder would go here
|
||||
}
|
||||
|
||||
// Send audio chunk to server
|
||||
function sendAudioChunk(audioData) {
|
||||
const reader = new FileReader();
|
||||
reader.onloadend = () => {
|
||||
const arrayBuffer = reader.result;
|
||||
state.socket.emit('audio_chunk', { audio: arrayBuffer, speaker: state.currentSpeaker });
|
||||
};
|
||||
reader.readAsArrayBuffer(audioData);
|
||||
}
|
||||
|
||||
// Handle audio response from server
|
||||
function handleAudioResponse(data) {
|
||||
const audioElement = new Audio(URL.createObjectURL(new Blob([data.audio])));
|
||||
audioElement.play();
|
||||
}
|
||||
|
||||
// Handle transcription response from server
|
||||
function handleTranscription(data) {
|
||||
const messageElement = document.createElement('div');
|
||||
messageElement.textContent = `AI: ${data.transcription}`;
|
||||
elements.conversation.appendChild(messageElement);
|
||||
}
|
||||
|
||||
// Clear conversation history
|
||||
function clearConversation() {
|
||||
elements.conversation.innerHTML = '';
|
||||
}
|
||||
|
||||
// Initialize the application when DOM is fully loaded
|
||||
document.addEventListener('DOMContentLoaded', initializeApp);
|
||||
@@ -1,31 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CSM Conversation Bot</title>
|
||||
<link rel="stylesheet" href="../static/css/styles.css">
|
||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||
<script src="../static/js/client.js" defer></script>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>CSM Conversation Bot</h1>
|
||||
<p>Talk to the AI and get responses in real-time!</p>
|
||||
</header>
|
||||
<main>
|
||||
<div class="chat-container">
|
||||
<div class="conversation" id="conversation"></div>
|
||||
<input type="text" id="userInput" placeholder="Type your message..." />
|
||||
<button id="sendButton">Send</button>
|
||||
</div>
|
||||
<div class="status-indicator">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<div class="status-text" id="statusText">Not connected</div>
|
||||
</div>
|
||||
</main>
|
||||
<footer>
|
||||
<p>Powered by CSM and Llama 3.2</p>
|
||||
</footer>
|
||||
</body>
|
||||
</html>
|
||||
1071
Backend/voice-chat.js
Normal file
1071
Backend/voice-chat.js
Normal file
File diff suppressed because it is too large
Load Diff
79
Backend/watermarking.py
Normal file
79
Backend/watermarking.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import argparse
|
||||
|
||||
import silentcipher
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# This watermark key is public, it is not secure.
|
||||
# If using CSM 1B in another application, use a new private key and keep it secret.
|
||||
CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201]
|
||||
|
||||
|
||||
def cli_check_audio() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--audio_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_audio_from_file(args.audio_path)
|
||||
|
||||
|
||||
def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
|
||||
model = silentcipher.get_model(
|
||||
model_type="44.1k",
|
||||
device=device,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def watermark(
|
||||
watermarker: silentcipher.server.Model,
|
||||
audio_array: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
|
||||
encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
|
||||
|
||||
output_sample_rate = min(44100, sample_rate)
|
||||
encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
|
||||
return encoded, output_sample_rate
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def verify(
|
||||
watermarker: silentcipher.server.Model,
|
||||
watermarked_audio: torch.Tensor,
|
||||
sample_rate: int,
|
||||
watermark_key: list[int],
|
||||
) -> bool:
|
||||
watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
|
||||
result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
|
||||
|
||||
is_watermarked = result["status"]
|
||||
if is_watermarked:
|
||||
is_csm_watermarked = result["messages"][0] == watermark_key
|
||||
else:
|
||||
is_csm_watermarked = False
|
||||
|
||||
return is_watermarked and is_csm_watermarked
|
||||
|
||||
|
||||
def check_audio_from_file(audio_path: str) -> None:
|
||||
watermarker = load_watermarker(device="cuda")
|
||||
|
||||
audio_array, sample_rate = load_audio(audio_path)
|
||||
is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
|
||||
|
||||
outcome = "Watermarked" if is_watermarked else "Not watermarked"
|
||||
print(f"{outcome}: {audio_path}")
|
||||
|
||||
|
||||
def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
|
||||
audio_array, sample_rate = torchaudio.load(audio_path)
|
||||
audio_array = audio_array.mean(dim=0)
|
||||
return audio_array, int(sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_check_audio()
|
||||
Reference in New Issue
Block a user