Frontend Fixed
This commit is contained in:
136
Backend/api/app.py
Normal file
136
Backend/api/app.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from flask import Flask
|
||||||
|
from flask_socketio import SocketIO
|
||||||
|
from flask_cors import CORS
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configure device
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
DEVICE = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
DEVICE = "mps"
|
||||||
|
else:
|
||||||
|
DEVICE = "cpu"
|
||||||
|
|
||||||
|
logger.info(f"Using device: {DEVICE}")
|
||||||
|
|
||||||
|
# Initialize Flask app
|
||||||
|
app = Flask(__name__, static_folder='../', static_url_path='')
|
||||||
|
CORS(app)
|
||||||
|
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||||
|
|
||||||
|
# Global variables for conversation state
|
||||||
|
active_conversations = {}
|
||||||
|
user_queues = {}
|
||||||
|
processing_threads = {}
|
||||||
|
|
||||||
|
# Model storage
|
||||||
|
@dataclass
|
||||||
|
class AppModels:
|
||||||
|
generator = None
|
||||||
|
tokenizer = None
|
||||||
|
llm = None
|
||||||
|
whisperx_model = None
|
||||||
|
whisperx_align_model = None
|
||||||
|
whisperx_align_metadata = None
|
||||||
|
last_language = None
|
||||||
|
|
||||||
|
models = AppModels()
|
||||||
|
|
||||||
|
def load_models():
|
||||||
|
"""Load all required models"""
|
||||||
|
from generator import load_csm_1b
|
||||||
|
import whisperx
|
||||||
|
import gc
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
global models
|
||||||
|
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||||
|
|
||||||
|
# CSM 1B loading
|
||||||
|
try:
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||||
|
models.generator = load_csm_1b(device=DEVICE)
|
||||||
|
logger.info("CSM 1B model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||||
|
if DEVICE == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||||
|
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
|
# WhisperX loading
|
||||||
|
try:
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||||
|
# Use WhisperX for better transcription with timestamps
|
||||||
|
# Use compute_type based on device
|
||||||
|
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||||
|
|
||||||
|
# Load the WhisperX model (smaller model for faster processing)
|
||||||
|
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||||
|
|
||||||
|
logger.info("WhisperX model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||||
|
if DEVICE == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_details = traceback.format_exc()
|
||||||
|
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||||
|
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
|
# Llama loading
|
||||||
|
try:
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||||
|
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Llama-3.2-1B",
|
||||||
|
device_map=DEVICE,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||||
|
|
||||||
|
# Configure all special tokens
|
||||||
|
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||||
|
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||||
|
|
||||||
|
# Inform the model about the pad token
|
||||||
|
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||||
|
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
logger.info("Llama 3.2 model loaded successfully")
|
||||||
|
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||||
|
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||||
|
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||||
|
|
||||||
|
# Load models in a background thread
|
||||||
|
threading.Thread(target=load_models, daemon=True).start()
|
||||||
|
|
||||||
|
# Import routes and socket handlers
|
||||||
|
from api.routes import register_routes
|
||||||
|
from api.socket_handlers import register_handlers
|
||||||
|
|
||||||
|
# Register routes and socket handlers
|
||||||
|
register_routes(app)
|
||||||
|
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
||||||
|
|
||||||
|
# Run server if executed directly
|
||||||
|
if __name__ == '__main__':
|
||||||
|
port = int(os.environ.get('PORT', 5000))
|
||||||
|
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||||
|
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||||
|
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||||
74
Backend/api/routes.py
Normal file
74
Backend/api/routes.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import psutil
|
||||||
|
from flask import send_from_directory, jsonify, request
|
||||||
|
|
||||||
|
def register_routes(app):
|
||||||
|
"""Register HTTP routes for the application"""
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""Serve the main application page"""
|
||||||
|
return send_from_directory(app.static_folder, 'index.html')
|
||||||
|
|
||||||
|
@app.route('/voice-chat.js')
|
||||||
|
def serve_js():
|
||||||
|
"""Serve the JavaScript file"""
|
||||||
|
return send_from_directory(app.static_folder, 'voice-chat.js')
|
||||||
|
|
||||||
|
@app.route('/api/status')
|
||||||
|
def system_status():
|
||||||
|
"""Return the system status"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from api.app import models, DEVICE
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"status": "ok",
|
||||||
|
"cuda_available": torch.cuda.is_available(),
|
||||||
|
"device": DEVICE,
|
||||||
|
"models": {
|
||||||
|
"generator": models.generator is not None,
|
||||||
|
"asr": models.whisperx_model is not None,
|
||||||
|
"llm": models.llm is not None
|
||||||
|
},
|
||||||
|
"versions": {
|
||||||
|
"transformers": "4.49.0", # Replace with actual version
|
||||||
|
"torch": torch.__version__
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@app.route('/api/system_resources')
|
||||||
|
def system_resources():
|
||||||
|
"""Return system resource usage"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from api.app import active_conversations, DEVICE
|
||||||
|
|
||||||
|
# Get CPU usage
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||||
|
|
||||||
|
# Get memory usage
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
memory_used_gb = memory.used / (1024 ** 3)
|
||||||
|
memory_total_gb = memory.total / (1024 ** 3)
|
||||||
|
memory_percent = memory.percent
|
||||||
|
|
||||||
|
# Get GPU memory if available
|
||||||
|
gpu_memory = {}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
gpu_memory[f"gpu_{i}"] = {
|
||||||
|
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||||
|
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
||||||
|
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"cpu_percent": cpu_percent,
|
||||||
|
"memory": {
|
||||||
|
"used_gb": memory_used_gb,
|
||||||
|
"total_gb": memory_total_gb,
|
||||||
|
"percent": memory_percent
|
||||||
|
},
|
||||||
|
"gpu_memory": gpu_memory,
|
||||||
|
"active_sessions": len(active_conversations)
|
||||||
|
})
|
||||||
392
Backend/api/socket_handlers.py
Normal file
392
Backend/api/socket_handlers.py
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import tempfile
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from flask import request
|
||||||
|
from flask_socketio import emit
|
||||||
|
|
||||||
|
# Import conversation model
|
||||||
|
from generator import Segment
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Conversation data structure
|
||||||
|
class Conversation:
|
||||||
|
def __init__(self, session_id):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.segments: List[Segment] = []
|
||||||
|
self.current_speaker = 0
|
||||||
|
self.ai_speaker_id = 1 # Default AI speaker ID
|
||||||
|
self.last_activity = time.time()
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
|
def add_segment(self, text, speaker, audio):
|
||||||
|
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||||
|
self.segments.append(segment)
|
||||||
|
self.last_activity = time.time()
|
||||||
|
return segment
|
||||||
|
|
||||||
|
def get_context(self, max_segments=10):
|
||||||
|
"""Return the most recent segments for context"""
|
||||||
|
return self.segments[-max_segments:] if self.segments else []
|
||||||
|
|
||||||
|
def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
|
||||||
|
"""Register Socket.IO event handlers"""
|
||||||
|
|
||||||
|
@socketio.on('connect')
|
||||||
|
def handle_connect(auth=None):
|
||||||
|
"""Handle client connection"""
|
||||||
|
session_id = request.sid
|
||||||
|
logger.info(f"Client connected: {session_id}")
|
||||||
|
|
||||||
|
# Initialize conversation data
|
||||||
|
if session_id not in active_conversations:
|
||||||
|
active_conversations[session_id] = Conversation(session_id)
|
||||||
|
user_queues[session_id] = queue.Queue()
|
||||||
|
processing_threads[session_id] = threading.Thread(
|
||||||
|
target=process_audio_queue,
|
||||||
|
args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
processing_threads[session_id].start()
|
||||||
|
|
||||||
|
emit('connection_status', {'status': 'connected'})
|
||||||
|
|
||||||
|
@socketio.on('disconnect')
|
||||||
|
def handle_disconnect(reason=None):
|
||||||
|
"""Handle client disconnection"""
|
||||||
|
session_id = request.sid
|
||||||
|
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if session_id in active_conversations:
|
||||||
|
# Mark for deletion rather than immediately removing
|
||||||
|
# as the processing thread might still be accessing it
|
||||||
|
active_conversations[session_id].is_processing = False
|
||||||
|
user_queues[session_id].put(None) # Signal thread to terminate
|
||||||
|
|
||||||
|
@socketio.on('audio_data')
|
||||||
|
def handle_audio_data(data):
|
||||||
|
"""Handle incoming audio data"""
|
||||||
|
session_id = request.sid
|
||||||
|
logger.info(f"Received audio data from {session_id}")
|
||||||
|
|
||||||
|
# Check if the models are loaded
|
||||||
|
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||||
|
emit('error', {'message': 'Models still loading, please wait'})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we're already processing for this session
|
||||||
|
if session_id in active_conversations and active_conversations[session_id].is_processing:
|
||||||
|
emit('error', {'message': 'Still processing previous audio, please wait'})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add to processing queue
|
||||||
|
if session_id in user_queues:
|
||||||
|
user_queues[session_id].put(data)
|
||||||
|
else:
|
||||||
|
emit('error', {'message': 'Session not initialized, please refresh the page'})
|
||||||
|
|
||||||
|
def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
|
||||||
|
"""Background thread to process audio chunks for a session"""
|
||||||
|
logger.info(f"Started processing thread for session: {session_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while session_id in active_conversations:
|
||||||
|
try:
|
||||||
|
# Get the next audio chunk with a timeout
|
||||||
|
data = q.get(timeout=120)
|
||||||
|
if data is None: # Termination signal
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process the audio and generate a response
|
||||||
|
process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
# Timeout, check if session is still valid
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||||
|
# Create an app context for the socket emit
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||||
|
finally:
|
||||||
|
logger.info(f"Ending processing thread for session: {session_id}")
|
||||||
|
# Clean up when thread is done
|
||||||
|
with app.app_context():
|
||||||
|
if session_id in active_conversations:
|
||||||
|
del active_conversations[session_id]
|
||||||
|
if session_id in user_queues:
|
||||||
|
del user_queues[session_id]
|
||||||
|
|
||||||
|
def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
|
||||||
|
"""Process audio data and generate a response using WhisperX"""
|
||||||
|
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||||
|
logger.warning("Models not yet loaded!")
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing audio for session {session_id}")
|
||||||
|
conversation = active_conversations[session_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set processing flag
|
||||||
|
conversation.is_processing = True
|
||||||
|
|
||||||
|
# Process base64 audio data
|
||||||
|
audio_data = data['audio']
|
||||||
|
speaker_id = data['speaker']
|
||||||
|
logger.info(f"Received audio from speaker {speaker_id}")
|
||||||
|
|
||||||
|
# Convert from base64 to WAV
|
||||||
|
try:
|
||||||
|
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||||
|
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Save to temporary file for processing
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||||
|
temp_file.write(audio_bytes)
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Notify client that transcription is starting
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||||
|
|
||||||
|
# Load audio using WhisperX
|
||||||
|
import whisperx
|
||||||
|
audio = whisperx.load_audio(temp_path)
|
||||||
|
|
||||||
|
# Check audio length and add a warning for short clips
|
||||||
|
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||||
|
if audio_length < 1.0:
|
||||||
|
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||||
|
|
||||||
|
# Transcribe using WhisperX
|
||||||
|
batch_size = 16 # adjust based on your GPU memory
|
||||||
|
logger.info("Running WhisperX transcription...")
|
||||||
|
|
||||||
|
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||||
|
import warnings
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||||
|
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||||
|
|
||||||
|
# Get the detected language
|
||||||
|
language_code = result["language"]
|
||||||
|
logger.info(f"Detected language: {language_code}")
|
||||||
|
|
||||||
|
# Check if alignment model needs to be loaded or updated
|
||||||
|
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||||
|
# Clean up old models if they exist
|
||||||
|
if models.whisperx_align_model is not None:
|
||||||
|
del models.whisperx_align_model
|
||||||
|
del models.whisperx_align_metadata
|
||||||
|
if DEVICE == "cuda":
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Load new alignment model for the detected language
|
||||||
|
logger.info(f"Loading alignment model for language: {language_code}")
|
||||||
|
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||||
|
language_code=language_code, device=DEVICE
|
||||||
|
)
|
||||||
|
models.last_language = language_code
|
||||||
|
|
||||||
|
# Align the transcript to get word-level timestamps
|
||||||
|
if result["segments"] and len(result["segments"]) > 0:
|
||||||
|
logger.info("Aligning transcript...")
|
||||||
|
result = whisperx.align(
|
||||||
|
result["segments"],
|
||||||
|
models.whisperx_align_model,
|
||||||
|
models.whisperx_align_metadata,
|
||||||
|
audio,
|
||||||
|
DEVICE,
|
||||||
|
return_char_alignments=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the segments for better output
|
||||||
|
for segment in result["segments"]:
|
||||||
|
# Round timestamps for better display
|
||||||
|
segment["start"] = round(segment["start"], 2)
|
||||||
|
segment["end"] = round(segment["end"], 2)
|
||||||
|
# Add a confidence score if not present
|
||||||
|
if "confidence" not in segment:
|
||||||
|
segment["confidence"] = 1.0 # Default confidence
|
||||||
|
|
||||||
|
# Extract the full text from all segments
|
||||||
|
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||||
|
|
||||||
|
# If no text was recognized, don't process further
|
||||||
|
if not user_text or len(user_text.strip()) == 0:
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Transcription: {user_text}")
|
||||||
|
|
||||||
|
# Load audio for CSM input
|
||||||
|
waveform, sample_rate = torchaudio.load(temp_path)
|
||||||
|
|
||||||
|
# Normalize to mono if needed
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# Resample to the CSM sample rate if needed
|
||||||
|
if sample_rate != models.generator.sample_rate:
|
||||||
|
waveform = torchaudio.functional.resample(
|
||||||
|
waveform,
|
||||||
|
orig_freq=sample_rate,
|
||||||
|
new_freq=models.generator.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the user's message to conversation history
|
||||||
|
user_segment = conversation.add_segment(
|
||||||
|
text=user_text,
|
||||||
|
speaker=speaker_id,
|
||||||
|
audio=waveform.squeeze()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send transcription to client with detailed segments
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('transcription', {
|
||||||
|
'text': user_text,
|
||||||
|
'speaker': speaker_id,
|
||||||
|
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
# Generate AI response using Llama
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||||
|
|
||||||
|
# Create prompt from conversation history
|
||||||
|
conversation_history = ""
|
||||||
|
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||||
|
role = "User" if segment.speaker == 0 else "Assistant"
|
||||||
|
conversation_history += f"{role}: {segment.text}\n"
|
||||||
|
|
||||||
|
# Add final prompt
|
||||||
|
prompt = f"{conversation_history}Assistant: "
|
||||||
|
|
||||||
|
# Generate response with Llama
|
||||||
|
try:
|
||||||
|
# Ensure pad token is set
|
||||||
|
if models.tokenizer.pad_token is None:
|
||||||
|
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||||
|
|
||||||
|
input_tokens = models.tokenizer(
|
||||||
|
prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_attention_mask=True
|
||||||
|
)
|
||||||
|
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||||
|
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = models.llm.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_new_tokens=100,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=models.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode the response
|
||||||
|
response_text = models.tokenizer.decode(
|
||||||
|
generated_ids[0][input_ids.shape[1]:],
|
||||||
|
skip_special_tokens=True
|
||||||
|
).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating response: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||||
|
|
||||||
|
# Synthesize speech
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||||
|
|
||||||
|
# Start sending the audio response
|
||||||
|
socketio.emit('audio_response_start', {
|
||||||
|
'text': response_text,
|
||||||
|
'total_chunks': 1,
|
||||||
|
'chunk_index': 0
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
# Define AI speaker ID
|
||||||
|
ai_speaker_id = conversation.ai_speaker_id
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
audio_tensor = models.generator.generate(
|
||||||
|
text=response_text,
|
||||||
|
speaker=ai_speaker_id,
|
||||||
|
context=conversation.get_context(),
|
||||||
|
max_audio_length_ms=10_000,
|
||||||
|
temperature=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add AI response to conversation history
|
||||||
|
ai_segment = conversation.add_segment(
|
||||||
|
text=response_text,
|
||||||
|
speaker=ai_speaker_id,
|
||||||
|
audio=audio_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert audio to WAV format
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
torchaudio.save(
|
||||||
|
wav_io,
|
||||||
|
audio_tensor.unsqueeze(0).cpu(),
|
||||||
|
models.generator.sample_rate,
|
||||||
|
format="wav"
|
||||||
|
)
|
||||||
|
wav_io.seek(0)
|
||||||
|
wav_data = wav_io.read()
|
||||||
|
|
||||||
|
# Convert WAV data to base64
|
||||||
|
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||||
|
|
||||||
|
# Send audio chunk to client
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('audio_response_chunk', {
|
||||||
|
'chunk': audio_base64,
|
||||||
|
'chunk_index': 0,
|
||||||
|
'total_chunks': 1,
|
||||||
|
'is_last': True
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
# Signal completion
|
||||||
|
socketio.emit('audio_response_complete', {
|
||||||
|
'text': response_text
|
||||||
|
}, room=session_id)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temp file
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing audio: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
with app.app_context():
|
||||||
|
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||||
|
finally:
|
||||||
|
# Reset processing flag
|
||||||
|
conversation.is_processing = False
|
||||||
@@ -1,419 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>CSM Voice Chat</title>
|
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
|
||||||
<!-- Socket.IO client library -->
|
|
||||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
|
||||||
<style>
|
|
||||||
:root {
|
|
||||||
--primary-color: #4c84ff;
|
|
||||||
--secondary-color: #3367d6;
|
|
||||||
--text-color: #333;
|
|
||||||
--background-color: #f9f9f9;
|
|
||||||
--card-background: #ffffff;
|
|
||||||
--accent-color: #ff5252;
|
|
||||||
--success-color: #4CAF50;
|
|
||||||
--border-color: #e0e0e0;
|
|
||||||
--shadow-color: rgba(0, 0, 0, 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
* {
|
|
||||||
box-sizing: border-box;
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
body {
|
|
||||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
||||||
background-color: var(--background-color);
|
|
||||||
color: var(--text-color);
|
|
||||||
line-height: 1.6;
|
|
||||||
max-width: 1000px;
|
|
||||||
margin: 0 auto;
|
|
||||||
padding: 20px;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
header {
|
|
||||||
text-align: center;
|
|
||||||
margin-bottom: 30px;
|
|
||||||
}
|
|
||||||
|
|
||||||
h1 {
|
|
||||||
color: var(--primary-color);
|
|
||||||
font-size: 2.5rem;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.subtitle {
|
|
||||||
color: #666;
|
|
||||||
font-weight: 300;
|
|
||||||
}
|
|
||||||
|
|
||||||
.app-container {
|
|
||||||
display: grid;
|
|
||||||
grid-template-columns: 1fr;
|
|
||||||
gap: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
@media (min-width: 768px) {
|
|
||||||
.app-container {
|
|
||||||
grid-template-columns: 2fr 1fr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-container, .control-panel {
|
|
||||||
background-color: var(--card-background);
|
|
||||||
border-radius: 12px;
|
|
||||||
box-shadow: 0 4px 12px var(--shadow-color);
|
|
||||||
padding: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.control-panel {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-header {
|
|
||||||
display: flex;
|
|
||||||
justify-content: space-between;
|
|
||||||
align-items: center;
|
|
||||||
margin-bottom: 15px;
|
|
||||||
padding-bottom: 10px;
|
|
||||||
border-bottom: 1px solid var(--border-color);
|
|
||||||
}
|
|
||||||
|
|
||||||
.conversation {
|
|
||||||
height: 400px;
|
|
||||||
overflow-y: auto;
|
|
||||||
padding: 10px;
|
|
||||||
border-radius: 8px;
|
|
||||||
background-color: #f7f9fc;
|
|
||||||
margin-bottom: 20px;
|
|
||||||
scroll-behavior: smooth;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message {
|
|
||||||
margin-bottom: 15px;
|
|
||||||
padding: 12px 15px;
|
|
||||||
border-radius: 12px;
|
|
||||||
max-width: 85%;
|
|
||||||
position: relative;
|
|
||||||
animation: fade-in 0.3s ease-out forwards;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes fade-in {
|
|
||||||
from { opacity: 0; transform: translateY(10px); }
|
|
||||||
to { opacity: 1; transform: translateY(0); }
|
|
||||||
}
|
|
||||||
|
|
||||||
.user {
|
|
||||||
background-color: #e3f2fd;
|
|
||||||
color: #0d47a1;
|
|
||||||
margin-left: auto;
|
|
||||||
border-bottom-right-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ai {
|
|
||||||
background-color: #f1f1f1;
|
|
||||||
color: #37474f;
|
|
||||||
margin-right: auto;
|
|
||||||
border-bottom-left-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.system {
|
|
||||||
background-color: #f8f9fa;
|
|
||||||
font-style: italic;
|
|
||||||
color: #666;
|
|
||||||
text-align: center;
|
|
||||||
max-width: 90%;
|
|
||||||
margin: 10px auto;
|
|
||||||
font-size: 0.9em;
|
|
||||||
padding: 8px 12px;
|
|
||||||
border-radius: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.audio-player {
|
|
||||||
width: 100%;
|
|
||||||
margin-top: 8px;
|
|
||||||
border-radius: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
button {
|
|
||||||
padding: 12px 20px;
|
|
||||||
border-radius: 8px;
|
|
||||||
border: none;
|
|
||||||
background-color: var(--primary-color);
|
|
||||||
color: white;
|
|
||||||
font-weight: 600;
|
|
||||||
cursor: pointer;
|
|
||||||
transition: all 0.2s ease;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
gap: 8px;
|
|
||||||
flex: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
button:hover {
|
|
||||||
background-color: var(--secondary-color);
|
|
||||||
}
|
|
||||||
|
|
||||||
button.recording {
|
|
||||||
background-color: var(--accent-color);
|
|
||||||
animation: pulse 1.5s infinite;
|
|
||||||
}
|
|
||||||
|
|
||||||
button.processing {
|
|
||||||
background-color: #ffa000;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes pulse {
|
|
||||||
0% { opacity: 1; }
|
|
||||||
50% { opacity: 0.7; }
|
|
||||||
100% { opacity: 1; }
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-indicator {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 10px;
|
|
||||||
font-size: 0.9em;
|
|
||||||
color: #555;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-dot {
|
|
||||||
width: 12px;
|
|
||||||
height: 12px;
|
|
||||||
border-radius: 50%;
|
|
||||||
background-color: #ccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-dot.active {
|
|
||||||
background-color: var(--success-color);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Audio visualizer styles */
|
|
||||||
.visualizer-container {
|
|
||||||
margin-top: 15px;
|
|
||||||
position: relative;
|
|
||||||
width: 100%;
|
|
||||||
height: 100px;
|
|
||||||
background-color: #000;
|
|
||||||
border-radius: 8px;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
#audioVisualizer {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
transition: opacity 0.3s;
|
|
||||||
}
|
|
||||||
|
|
||||||
#visualizerLabel {
|
|
||||||
position: absolute;
|
|
||||||
top: 50%;
|
|
||||||
left: 50%;
|
|
||||||
transform: translate(-50%, -50%);
|
|
||||||
color: rgba(255, 255, 255, 0.7);
|
|
||||||
font-size: 0.9em;
|
|
||||||
pointer-events: none;
|
|
||||||
transition: opacity 0.3s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.volume-meter {
|
|
||||||
height: 8px;
|
|
||||||
width: 100%;
|
|
||||||
background-color: #eee;
|
|
||||||
border-radius: 4px;
|
|
||||||
margin-top: 8px;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
#volumeLevel {
|
|
||||||
height: 100%;
|
|
||||||
width: 0%;
|
|
||||||
background-color: var(--primary-color);
|
|
||||||
border-radius: 4px;
|
|
||||||
transition: width 0.1s ease, background-color 0.2s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.settings-toggles {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 12px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.toggle-switch {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
footer {
|
|
||||||
margin-top: 30px;
|
|
||||||
text-align: center;
|
|
||||||
font-size: 0.8em;
|
|
||||||
color: #777;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Model status indicators */
|
|
||||||
.model-status {
|
|
||||||
display: flex;
|
|
||||||
gap: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.model-indicator {
|
|
||||||
padding: 3px 6px;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-size: 0.7em;
|
|
||||||
font-weight: bold;
|
|
||||||
}
|
|
||||||
|
|
||||||
.model-indicator.loading {
|
|
||||||
background-color: #ffd54f;
|
|
||||||
color: #000;
|
|
||||||
}
|
|
||||||
|
|
||||||
.model-indicator.loaded {
|
|
||||||
background-color: #4CAF50;
|
|
||||||
color: white;
|
|
||||||
}
|
|
||||||
|
|
||||||
.model-indicator.error {
|
|
||||||
background-color: #f44336;
|
|
||||||
color: white;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-timestamp {
|
|
||||||
font-size: 0.7em;
|
|
||||||
color: #888;
|
|
||||||
margin-top: 4px;
|
|
||||||
text-align: right;
|
|
||||||
}
|
|
||||||
|
|
||||||
.simple-timestamp {
|
|
||||||
font-size: 0.8em;
|
|
||||||
color: #888;
|
|
||||||
margin-top: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Add this to your existing styles */
|
|
||||||
.loading-progress {
|
|
||||||
width: 100%;
|
|
||||||
max-width: 150px;
|
|
||||||
margin-right: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
progress {
|
|
||||||
width: 100%;
|
|
||||||
height: 8px;
|
|
||||||
border-radius: 4px;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
progress::-webkit-progress-bar {
|
|
||||||
background-color: #eee;
|
|
||||||
border-radius: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
progress::-webkit-progress-value {
|
|
||||||
background-color: var(--primary-color);
|
|
||||||
border-radius: 4px;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<header>
|
|
||||||
<h1>CSM Voice Chat</h1>
|
|
||||||
<p class="subtitle">Talk naturally with the AI using your voice</p>
|
|
||||||
</header>
|
|
||||||
|
|
||||||
<div class="app-container">
|
|
||||||
<div class="chat-container">
|
|
||||||
<div class="chat-header">
|
|
||||||
<h2>Conversation</h2>
|
|
||||||
<div class="status-indicator">
|
|
||||||
<div id="statusDot" class="status-dot"></div>
|
|
||||||
<span id="statusText">Disconnected</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Add this above the model status indicators in the chat-header div -->
|
|
||||||
<div class="loading-progress">
|
|
||||||
<progress id="modelLoadingProgress" max="100" value="0">0%</progress>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Add this model status panel -->
|
|
||||||
<div class="model-status">
|
|
||||||
<div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div>
|
|
||||||
<div id="asrStatus" class="model-indicator loading" title="Loading ASR model...">ASR</div>
|
|
||||||
<div id="llmStatus" class="model-indicator loading" title="Loading LLM model...">LLM</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div id="conversation" class="conversation"></div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="control-panel">
|
|
||||||
<div>
|
|
||||||
<h3>Controls</h3>
|
|
||||||
<p>Click the button below to start and stop recording.</p>
|
|
||||||
<div class="button-row">
|
|
||||||
<button id="streamButton">
|
|
||||||
<i class="fas fa-microphone"></i>
|
|
||||||
Start Conversation
|
|
||||||
</button>
|
|
||||||
<button id="clearButton">
|
|
||||||
<i class="fas fa-trash"></i>
|
|
||||||
Clear
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Audio visualizer and volume meter -->
|
|
||||||
<div class="visualizer-container">
|
|
||||||
<canvas id="audioVisualizer"></canvas>
|
|
||||||
<div id="visualizerLabel">Start speaking to see audio visualization</div>
|
|
||||||
</div>
|
|
||||||
<div class="volume-meter">
|
|
||||||
<div id="volumeLevel"></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="settings-panel">
|
|
||||||
<h3>Settings</h3>
|
|
||||||
<div class="settings-toggles">
|
|
||||||
<div class="toggle-switch">
|
|
||||||
<input type="checkbox" id="autoPlayResponses" checked>
|
|
||||||
<label for="autoPlayResponses">Autoplay Responses</label>
|
|
||||||
</div>
|
|
||||||
<div class="toggle-switch">
|
|
||||||
<input type="checkbox" id="showVisualizer" checked>
|
|
||||||
<label for="showVisualizer">Show Audio Visualizer</label>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="speakerSelect">Speaker Voice:</label>
|
|
||||||
<select id="speakerSelect">
|
|
||||||
<option value="0">Speaker 0 (You)</option>
|
|
||||||
<option value="1">Speaker 1 (AI)</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="thresholdSlider">Silence Threshold: <span id="thresholdValue">0.010</span></label>
|
|
||||||
<input type="range" id="thresholdSlider" min="0.001" max="0.05" step="0.001" value="0.01">
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<footer>
|
|
||||||
<p>Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition</p>
|
|
||||||
</footer>
|
|
||||||
|
|
||||||
<!-- Load external JavaScript file -->
|
|
||||||
<script src="voice-chat.js"></script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
flask==2.2.5
|
|
||||||
flask-socketio==5.3.6
|
|
||||||
flask-cors==4.0.0
|
|
||||||
torch==2.4.0
|
|
||||||
torchaudio==2.4.0
|
|
||||||
tokenizers==0.21.0
|
|
||||||
transformers==4.49.0
|
|
||||||
librosa==0.10.1
|
|
||||||
huggingface_hub==0.28.1
|
|
||||||
moshi==0.2.2
|
|
||||||
torchtune==0.4.0
|
|
||||||
torchao==0.9.0
|
|
||||||
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from generator import load_csm_1b, Segment
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
# Disable Triton compilation
|
|
||||||
os.environ["NO_TORCH_COMPILE"] = "1"
|
|
||||||
|
|
||||||
# Default prompts are available at https://hf.co/sesame/csm-1b
|
|
||||||
prompt_filepath_conversational_a = hf_hub_download(
|
|
||||||
repo_id="sesame/csm-1b",
|
|
||||||
filename="prompts/conversational_a.wav"
|
|
||||||
)
|
|
||||||
prompt_filepath_conversational_b = hf_hub_download(
|
|
||||||
repo_id="sesame/csm-1b",
|
|
||||||
filename="prompts/conversational_b.wav"
|
|
||||||
)
|
|
||||||
|
|
||||||
SPEAKER_PROMPTS = {
|
|
||||||
"conversational_a": {
|
|
||||||
"text": (
|
|
||||||
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
|
||||||
"start really early I'd be like okay I'm gonna start revising now and then like "
|
|
||||||
"you're revising for ages and then I just like start losing steam I didn't do that "
|
|
||||||
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
|
||||||
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
|
||||||
"sort of start the day with this not like a panic but like a"
|
|
||||||
),
|
|
||||||
"audio": prompt_filepath_conversational_a
|
|
||||||
},
|
|
||||||
"conversational_b": {
|
|
||||||
"text": (
|
|
||||||
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
|
||||||
"into the park, it just like, everything looks like a computer game and they have all "
|
|
||||||
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
|
||||||
"will have like a question block. And if you like, you know, punch it, a coin will "
|
|
||||||
"come out. So like everyone, when they come into the park, they get like this little "
|
|
||||||
"bracelet and then you can go punching question blocks around."
|
|
||||||
),
|
|
||||||
"audio": prompt_filepath_conversational_b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
|
||||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
|
||||||
audio_tensor = audio_tensor.squeeze(0)
|
|
||||||
# Resample is lazy so we can always call it
|
|
||||||
audio_tensor = torchaudio.functional.resample(
|
|
||||||
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
|
||||||
)
|
|
||||||
return audio_tensor
|
|
||||||
|
|
||||||
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
|
||||||
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
|
||||||
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Select the best available device, skipping MPS due to float64 limitations
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
print(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
generator = load_csm_1b(device)
|
|
||||||
|
|
||||||
# Prepare prompts
|
|
||||||
prompt_a = prepare_prompt(
|
|
||||||
SPEAKER_PROMPTS["conversational_a"]["text"],
|
|
||||||
0,
|
|
||||||
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
|
||||||
generator.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_b = prepare_prompt(
|
|
||||||
SPEAKER_PROMPTS["conversational_b"]["text"],
|
|
||||||
1,
|
|
||||||
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
|
||||||
generator.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate conversation
|
|
||||||
conversation = [
|
|
||||||
{"text": "Hey how are you doing?", "speaker_id": 0},
|
|
||||||
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
|
||||||
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
|
||||||
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Generate each utterance
|
|
||||||
generated_segments = []
|
|
||||||
prompt_segments = [prompt_a, prompt_b]
|
|
||||||
|
|
||||||
for utterance in conversation:
|
|
||||||
print(f"Generating: {utterance['text']}")
|
|
||||||
audio_tensor = generator.generate(
|
|
||||||
text=utterance['text'],
|
|
||||||
speaker=utterance['speaker_id'],
|
|
||||||
context=prompt_segments + generated_segments,
|
|
||||||
max_audio_length_ms=10_000,
|
|
||||||
)
|
|
||||||
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
|
||||||
|
|
||||||
# Concatenate all generations
|
|
||||||
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
|
||||||
torchaudio.save(
|
|
||||||
"full_conversation.wav",
|
|
||||||
all_audio.unsqueeze(0).cpu(),
|
|
||||||
generator.sample_rate
|
|
||||||
)
|
|
||||||
print("Successfully generated full_conversation.wav")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,639 +1,19 @@
|
|||||||
import os
|
"""
|
||||||
import io
|
CSM Voice Chat Server
|
||||||
import base64
|
A voice chat application that uses CSM 1B for voice synthesis,
|
||||||
import time
|
WhisperX for speech recognition, and Llama 3.2 for language generation.
|
||||||
import json
|
"""
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
import tempfile
|
|
||||||
import gc
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
# Start the Flask application
|
||||||
import torchaudio
|
from api.app import app, socketio
|
||||||
import numpy as np
|
|
||||||
from flask import Flask, request, jsonify, send_from_directory
|
|
||||||
from flask_socketio import SocketIO, emit
|
|
||||||
from flask_cors import CORS
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
||||||
|
|
||||||
# Import WhisperX for better transcription
|
|
||||||
import whisperx
|
|
||||||
|
|
||||||
from generator import load_csm_1b, Segment
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
# Add these imports at the top
|
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Initialize Flask app
|
|
||||||
app = Flask(__name__, static_folder='.')
|
|
||||||
CORS(app)
|
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
|
||||||
|
|
||||||
# Configure device
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
DEVICE = "cuda"
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
DEVICE = "mps"
|
|
||||||
else:
|
|
||||||
DEVICE = "cpu"
|
|
||||||
|
|
||||||
logger.info(f"Using device: {DEVICE}")
|
|
||||||
|
|
||||||
# Global variables
|
|
||||||
active_conversations = {}
|
|
||||||
user_queues = {}
|
|
||||||
processing_threads = {}
|
|
||||||
|
|
||||||
# Load models
|
|
||||||
@dataclass
|
|
||||||
class AppModels:
|
|
||||||
generator = None
|
|
||||||
tokenizer = None
|
|
||||||
llm = None
|
|
||||||
whisperx_model = None
|
|
||||||
whisperx_align_model = None
|
|
||||||
whisperx_align_metadata = None
|
|
||||||
diarize_model = None
|
|
||||||
last_language = None
|
|
||||||
|
|
||||||
# Initialize the models object
|
|
||||||
models = AppModels()
|
|
||||||
|
|
||||||
def load_models():
|
|
||||||
"""Load all required models"""
|
|
||||||
global models
|
|
||||||
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
|
||||||
|
|
||||||
# CSM 1B loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
|
||||||
models.generator = load_csm_1b(device=DEVICE)
|
|
||||||
logger.info("CSM 1B model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
error_details = traceback.format_exc()
|
|
||||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
|
||||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# WhisperX loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
|
||||||
# Use WhisperX for better transcription with timestamps
|
|
||||||
import whisperx
|
|
||||||
|
|
||||||
# Use compute_type based on device
|
|
||||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
|
||||||
|
|
||||||
# Load the WhisperX model (smaller model for faster processing)
|
|
||||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
|
||||||
|
|
||||||
logger.info("WhisperX model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
error_details = traceback.format_exc()
|
|
||||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
|
||||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# Llama loading
|
|
||||||
try:
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
|
||||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
"meta-llama/Llama-3.2-1B",
|
|
||||||
device_map=DEVICE,
|
|
||||||
torch_dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
|
||||||
|
|
||||||
# Configure all special tokens
|
|
||||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
|
||||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
|
||||||
|
|
||||||
# Inform the model about the pad token
|
|
||||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
|
||||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
|
||||||
|
|
||||||
logger.info("Llama 3.2 model loaded successfully")
|
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
|
||||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
|
||||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
|
||||||
|
|
||||||
# Load models in a background thread
|
|
||||||
threading.Thread(target=load_models, daemon=True).start()
|
|
||||||
|
|
||||||
# Conversation data structure
|
|
||||||
class Conversation:
|
|
||||||
def __init__(self, session_id):
|
|
||||||
self.session_id = session_id
|
|
||||||
self.segments: List[Segment] = []
|
|
||||||
self.current_speaker = 0
|
|
||||||
self.ai_speaker_id = 1 # Add this property
|
|
||||||
self.last_activity = time.time()
|
|
||||||
self.is_processing = False
|
|
||||||
|
|
||||||
def add_segment(self, text, speaker, audio):
|
|
||||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
|
||||||
self.segments.append(segment)
|
|
||||||
self.last_activity = time.time()
|
|
||||||
return segment
|
|
||||||
|
|
||||||
def get_context(self, max_segments=10):
|
|
||||||
"""Return the most recent segments for context"""
|
|
||||||
return self.segments[-max_segments:] if self.segments else []
|
|
||||||
|
|
||||||
# Routes
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
return send_from_directory('.', 'index.html')
|
|
||||||
|
|
||||||
@app.route('/voice-chat.js')
|
|
||||||
def voice_chat_js():
|
|
||||||
return send_from_directory('.', 'voice-chat.js')
|
|
||||||
|
|
||||||
@app.route('/api/health')
|
|
||||||
def health_check():
|
|
||||||
return jsonify({
|
|
||||||
"status": "ok",
|
|
||||||
"cuda_available": torch.cuda.is_available(),
|
|
||||||
"models_loaded": models.generator is not None and models.llm is not None
|
|
||||||
})
|
|
||||||
|
|
||||||
# Fix the system_status function:
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
|
||||||
def system_status():
|
|
||||||
return jsonify({
|
|
||||||
"status": "ok",
|
|
||||||
"cuda_available": torch.cuda.is_available(),
|
|
||||||
"device": DEVICE,
|
|
||||||
"models": {
|
|
||||||
"generator": models.generator is not None,
|
|
||||||
"asr": models.whisperx_model is not None, # Use the correct model name
|
|
||||||
"llm": models.llm is not None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add a new endpoint to check system resources
|
|
||||||
@app.route('/api/system_resources')
|
|
||||||
def system_resources():
|
|
||||||
# Get CPU usage
|
|
||||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
|
||||||
|
|
||||||
# Get memory usage
|
|
||||||
memory = psutil.virtual_memory()
|
|
||||||
memory_used_gb = memory.used / (1024 ** 3)
|
|
||||||
memory_total_gb = memory.total / (1024 ** 3)
|
|
||||||
memory_percent = memory.percent
|
|
||||||
|
|
||||||
# Get GPU memory if available
|
|
||||||
gpu_memory = {}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
gpu_memory[f"gpu_{i}"] = {
|
|
||||||
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
|
||||||
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
|
||||||
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"cpu_percent": cpu_percent,
|
|
||||||
"memory": {
|
|
||||||
"used_gb": memory_used_gb,
|
|
||||||
"total_gb": memory_total_gb,
|
|
||||||
"percent": memory_percent
|
|
||||||
},
|
|
||||||
"gpu_memory": gpu_memory,
|
|
||||||
"active_sessions": len(active_conversations)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Socket event handlers
|
|
||||||
@socketio.on('connect')
|
|
||||||
def handle_connect(auth=None):
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Client connected: {session_id}")
|
|
||||||
|
|
||||||
# Initialize conversation data
|
|
||||||
if session_id not in active_conversations:
|
|
||||||
active_conversations[session_id] = Conversation(session_id)
|
|
||||||
user_queues[session_id] = queue.Queue()
|
|
||||||
processing_threads[session_id] = threading.Thread(
|
|
||||||
target=process_audio_queue,
|
|
||||||
args=(session_id, user_queues[session_id]),
|
|
||||||
daemon=True
|
|
||||||
)
|
|
||||||
processing_threads[session_id].start()
|
|
||||||
|
|
||||||
emit('connection_status', {'status': 'connected'})
|
|
||||||
|
|
||||||
@socketio.on('disconnect')
|
|
||||||
def handle_disconnect(reason=None):
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
if session_id in active_conversations:
|
|
||||||
# Mark for deletion rather than immediately removing
|
|
||||||
# as the processing thread might still be accessing it
|
|
||||||
active_conversations[session_id].is_processing = False
|
|
||||||
user_queues[session_id].put(None) # Signal thread to terminate
|
|
||||||
|
|
||||||
@socketio.on('start_stream')
|
|
||||||
def handle_start_stream():
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Starting stream for client: {session_id}")
|
|
||||||
emit('streaming_status', {'status': 'active'})
|
|
||||||
|
|
||||||
@socketio.on('stop_stream')
|
|
||||||
def handle_stop_stream():
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Stopping stream for client: {session_id}")
|
|
||||||
emit('streaming_status', {'status': 'inactive'})
|
|
||||||
|
|
||||||
@socketio.on('clear_context')
|
|
||||||
def handle_clear_context():
|
|
||||||
session_id = request.sid
|
|
||||||
logger.info(f"Clearing context for client: {session_id}")
|
|
||||||
|
|
||||||
if session_id in active_conversations:
|
|
||||||
active_conversations[session_id].segments = []
|
|
||||||
emit('context_updated', {'status': 'cleared'})
|
|
||||||
|
|
||||||
@socketio.on('audio_chunk')
|
|
||||||
def handle_audio_chunk(data):
|
|
||||||
session_id = request.sid
|
|
||||||
audio_data = data.get('audio', '')
|
|
||||||
speaker_id = int(data.get('speaker', 0))
|
|
||||||
|
|
||||||
if not audio_data or not session_id in active_conversations:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update the current speaker
|
|
||||||
active_conversations[session_id].current_speaker = speaker_id
|
|
||||||
|
|
||||||
# Queue audio for processing
|
|
||||||
user_queues[session_id].put({
|
|
||||||
'audio': audio_data,
|
|
||||||
'speaker': speaker_id
|
|
||||||
})
|
|
||||||
|
|
||||||
emit('processing_status', {'status': 'transcribing'})
|
|
||||||
|
|
||||||
def process_audio_queue(session_id, q):
|
|
||||||
"""Background thread to process audio chunks for a session"""
|
|
||||||
logger.info(f"Started processing thread for session: {session_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while session_id in active_conversations:
|
|
||||||
try:
|
|
||||||
# Get the next audio chunk with a timeout
|
|
||||||
data = q.get(timeout=120)
|
|
||||||
if data is None: # Termination signal
|
|
||||||
break
|
|
||||||
|
|
||||||
# Process the audio and generate a response
|
|
||||||
process_audio_and_respond(session_id, data)
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
# Timeout, check if session is still valid
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
|
||||||
# Create an app context for the socket emit
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
|
||||||
finally:
|
|
||||||
logger.info(f"Ending processing thread for session: {session_id}")
|
|
||||||
# Clean up when thread is done
|
|
||||||
with app.app_context():
|
|
||||||
if session_id in active_conversations:
|
|
||||||
del active_conversations[session_id]
|
|
||||||
if session_id in user_queues:
|
|
||||||
del user_queues[session_id]
|
|
||||||
|
|
||||||
def process_audio_and_respond(session_id, data):
|
|
||||||
"""Process audio data and generate a response using WhisperX"""
|
|
||||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
|
||||||
logger.warning("Models not yet loaded!")
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Processing audio for session {session_id}")
|
|
||||||
conversation = active_conversations[session_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set processing flag
|
|
||||||
conversation.is_processing = True
|
|
||||||
|
|
||||||
# Process base64 audio data
|
|
||||||
audio_data = data['audio']
|
|
||||||
speaker_id = data['speaker']
|
|
||||||
logger.info(f"Received audio from speaker {speaker_id}")
|
|
||||||
|
|
||||||
# Convert from base64 to WAV
|
|
||||||
try:
|
|
||||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
|
||||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Save to temporary file for processing
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
|
||||||
temp_file.write(audio_bytes)
|
|
||||||
temp_path = temp_file.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Notify client that transcription is starting
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
|
||||||
|
|
||||||
# Load audio using WhisperX
|
|
||||||
import whisperx
|
|
||||||
audio = whisperx.load_audio(temp_path)
|
|
||||||
|
|
||||||
# Check audio length and add a warning for short clips
|
|
||||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
|
||||||
if audio_length < 1.0:
|
|
||||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
|
||||||
|
|
||||||
# Transcribe using WhisperX
|
|
||||||
batch_size = 16 # adjust based on your GPU memory
|
|
||||||
logger.info("Running WhisperX transcription...")
|
|
||||||
|
|
||||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
|
||||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
|
||||||
|
|
||||||
# Get the detected language
|
|
||||||
language_code = result["language"]
|
|
||||||
logger.info(f"Detected language: {language_code}")
|
|
||||||
|
|
||||||
# Check if alignment model needs to be loaded or updated
|
|
||||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
|
||||||
# Clean up old models if they exist
|
|
||||||
if models.whisperx_align_model is not None:
|
|
||||||
del models.whisperx_align_model
|
|
||||||
del models.whisperx_align_metadata
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Load new alignment model for the detected language
|
|
||||||
logger.info(f"Loading alignment model for language: {language_code}")
|
|
||||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
|
||||||
language_code=language_code, device=DEVICE
|
|
||||||
)
|
|
||||||
models.last_language = language_code
|
|
||||||
|
|
||||||
# Align the transcript to get word-level timestamps
|
|
||||||
if result["segments"] and len(result["segments"]) > 0:
|
|
||||||
logger.info("Aligning transcript...")
|
|
||||||
result = whisperx.align(
|
|
||||||
result["segments"],
|
|
||||||
models.whisperx_align_model,
|
|
||||||
models.whisperx_align_metadata,
|
|
||||||
audio,
|
|
||||||
DEVICE,
|
|
||||||
return_char_alignments=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process the segments for better output
|
|
||||||
for segment in result["segments"]:
|
|
||||||
# Round timestamps for better display
|
|
||||||
segment["start"] = round(segment["start"], 2)
|
|
||||||
segment["end"] = round(segment["end"], 2)
|
|
||||||
# Add a confidence score if not present
|
|
||||||
if "confidence" not in segment:
|
|
||||||
segment["confidence"] = 1.0 # Default confidence
|
|
||||||
|
|
||||||
# Extract the full text from all segments
|
|
||||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
|
||||||
|
|
||||||
# If no text was recognized, don't process further
|
|
||||||
if not user_text or len(user_text.strip()) == 0:
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Transcription: {user_text}")
|
|
||||||
|
|
||||||
# Load audio for CSM input
|
|
||||||
waveform, sample_rate = torchaudio.load(temp_path)
|
|
||||||
|
|
||||||
# Normalize to mono if needed
|
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
||||||
|
|
||||||
# Resample to the CSM sample rate if needed
|
|
||||||
if sample_rate != models.generator.sample_rate:
|
|
||||||
waveform = torchaudio.functional.resample(
|
|
||||||
waveform,
|
|
||||||
orig_freq=sample_rate,
|
|
||||||
new_freq=models.generator.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the user's message to conversation history
|
|
||||||
user_segment = conversation.add_segment(
|
|
||||||
text=user_text,
|
|
||||||
speaker=speaker_id,
|
|
||||||
audio=waveform.squeeze()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send transcription to client with detailed segments
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('transcription', {
|
|
||||||
'text': user_text,
|
|
||||||
'speaker': speaker_id,
|
|
||||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Generate AI response using Llama
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
|
||||||
|
|
||||||
# Create prompt from conversation history
|
|
||||||
conversation_history = ""
|
|
||||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
|
||||||
role = "User" if segment.speaker == 0 else "Assistant"
|
|
||||||
conversation_history += f"{role}: {segment.text}\n"
|
|
||||||
|
|
||||||
# Add final prompt
|
|
||||||
prompt = f"{conversation_history}Assistant: "
|
|
||||||
|
|
||||||
# Generate response with Llama
|
|
||||||
try:
|
|
||||||
# Ensure pad token is set
|
|
||||||
if models.tokenizer.pad_token is None:
|
|
||||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
|
||||||
|
|
||||||
input_tokens = models.tokenizer(
|
|
||||||
prompt,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
return_attention_mask=True
|
|
||||||
)
|
|
||||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
|
||||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
generated_ids = models.llm.generate(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_new_tokens=100,
|
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.9,
|
|
||||||
do_sample=True,
|
|
||||||
pad_token_id=models.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode the response
|
|
||||||
response_text = models.tokenizer.decode(
|
|
||||||
generated_ids[0][input_ids.shape[1]:],
|
|
||||||
skip_special_tokens=True
|
|
||||||
).strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating response: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
|
||||||
|
|
||||||
# Synthesize speech
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
|
||||||
|
|
||||||
# Start sending the audio response
|
|
||||||
socketio.emit('audio_response_start', {
|
|
||||||
'text': response_text,
|
|
||||||
'total_chunks': 1,
|
|
||||||
'chunk_index': 0
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Define AI speaker ID
|
|
||||||
ai_speaker_id = conversation.ai_speaker_id
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
audio_tensor = models.generator.generate(
|
|
||||||
text=response_text,
|
|
||||||
speaker=ai_speaker_id,
|
|
||||||
context=conversation.get_context(),
|
|
||||||
max_audio_length_ms=10_000,
|
|
||||||
temperature=0.9
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add AI response to conversation history
|
|
||||||
ai_segment = conversation.add_segment(
|
|
||||||
text=response_text,
|
|
||||||
speaker=ai_speaker_id,
|
|
||||||
audio=audio_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert audio to WAV format
|
|
||||||
with io.BytesIO() as wav_io:
|
|
||||||
torchaudio.save(
|
|
||||||
wav_io,
|
|
||||||
audio_tensor.unsqueeze(0).cpu(),
|
|
||||||
models.generator.sample_rate,
|
|
||||||
format="wav"
|
|
||||||
)
|
|
||||||
wav_io.seek(0)
|
|
||||||
wav_data = wav_io.read()
|
|
||||||
|
|
||||||
# Convert WAV data to base64
|
|
||||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
|
||||||
|
|
||||||
# Send audio chunk to client
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('audio_response_chunk', {
|
|
||||||
'chunk': audio_base64,
|
|
||||||
'chunk_index': 0,
|
|
||||||
'total_chunks': 1,
|
|
||||||
'is_last': True
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
# Signal completion
|
|
||||||
socketio.emit('audio_response_complete', {
|
|
||||||
'text': response_text
|
|
||||||
}, room=session_id)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temp file
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
os.unlink(temp_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing audio: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
with app.app_context():
|
|
||||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
|
||||||
finally:
|
|
||||||
# Reset processing flag
|
|
||||||
conversation.is_processing = False
|
|
||||||
|
|
||||||
# Error handler
|
|
||||||
@socketio.on_error()
|
|
||||||
def error_handler(e):
|
|
||||||
logger.error(f"SocketIO error: {str(e)}")
|
|
||||||
|
|
||||||
# Periodic cleanup of inactive sessions
|
|
||||||
def cleanup_inactive_sessions():
|
|
||||||
"""Remove sessions that have been inactive for too long"""
|
|
||||||
current_time = time.time()
|
|
||||||
inactive_timeout = 3600 # 1 hour
|
|
||||||
|
|
||||||
for session_id in list(active_conversations.keys()):
|
|
||||||
conversation = active_conversations[session_id]
|
|
||||||
if (current_time - conversation.last_activity > inactive_timeout and
|
|
||||||
not conversation.is_processing):
|
|
||||||
|
|
||||||
logger.info(f"Cleaning up inactive session: {session_id}")
|
|
||||||
|
|
||||||
# Signal processing thread to terminate
|
|
||||||
if session_id in user_queues:
|
|
||||||
user_queues[session_id].put(None)
|
|
||||||
|
|
||||||
# Remove from active conversations
|
|
||||||
del active_conversations[session_id]
|
|
||||||
|
|
||||||
# Start cleanup thread
|
|
||||||
def start_cleanup_thread():
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
cleanup_inactive_sessions()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in cleanup thread: {str(e)}")
|
|
||||||
time.sleep(300) # Run every 5 minutes
|
|
||||||
|
|
||||||
cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True)
|
|
||||||
cleanup_thread.start()
|
|
||||||
|
|
||||||
# Start the server
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
import os
|
||||||
|
|
||||||
port = int(os.environ.get('PORT', 5000))
|
port = int(os.environ.get('PORT', 5000))
|
||||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
|
||||||
|
print(f"Starting server on port {port} (debug={debug_mode})")
|
||||||
|
print("Visit http://localhost:5000 to access the application")
|
||||||
|
|
||||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from setuptools import setup, find_packages
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Read requirements from requirements.txt
|
|
||||||
with open('requirements.txt') as f:
|
|
||||||
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='csm',
|
|
||||||
version='0.1.0',
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=requirements,
|
|
||||||
)
|
|
||||||
File diff suppressed because it is too large
Load Diff
12
React/src/app/auth/session/route.ts
Normal file
12
React/src/app/auth/session/route.ts
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import { NextResponse } from "next/server";
|
||||||
|
import { auth0 } from "../../../lib/auth0";
|
||||||
|
|
||||||
|
export async function GET() {
|
||||||
|
try {
|
||||||
|
const session = await auth0.getSession();
|
||||||
|
return NextResponse.json({ session });
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error getting session:", error);
|
||||||
|
return NextResponse.json({ session: null }, { status: 500 });
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,67 +1,94 @@
|
|||||||
import { useState } from "react";
|
"use client";
|
||||||
import { auth0 } from "../lib/auth0";
|
import { useState, useEffect } from "react";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
|
||||||
|
export default function Home() {
|
||||||
export default async function Home() {
|
|
||||||
|
|
||||||
const [contacts, setContacts] = useState<string[]>([]);
|
const [contacts, setContacts] = useState<string[]>([]);
|
||||||
const [codeword, setCodeword] = useState("");
|
const [codeword, setCodeword] = useState("");
|
||||||
|
const [session, setSession] = useState<any>(null);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
const session = await auth0.getSession();
|
useEffect(() => {
|
||||||
|
// Fetch session data from an API route
|
||||||
console.log("Session:", session?.user);
|
fetch("/auth/session")
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
setSession(data.session);
|
||||||
|
setLoading(false);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error("Failed to fetch session:", error);
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
function saveToDB() {
|
function saveToDB() {
|
||||||
//e.preventDefault();
|
|
||||||
alert("Saving contacts...");
|
alert("Saving contacts...");
|
||||||
// const contactInputs = document.querySelectorAll(".text-input") as NodeListOf<HTMLInputElement>;
|
const contactInputs = document.querySelectorAll(
|
||||||
// const contactValues = Array.from(contactInputs).map(input => input.value);
|
".text-input"
|
||||||
// console.log("Contact values:", contactValues);
|
) as NodeListOf<HTMLInputElement>;
|
||||||
// // save codeword and contacts to database
|
const contactValues = Array.from(contactInputs).map((input) => input.value);
|
||||||
// fetch("/api/databaseStorage", {
|
|
||||||
// method: "POST",
|
fetch("/api/databaseStorage", {
|
||||||
// headers: {
|
method: "POST",
|
||||||
// "Content-Type": "application/json",
|
headers: {
|
||||||
// },
|
"Content-Type": "application/json",
|
||||||
// body: JSON.stringify({
|
},
|
||||||
// email: session?.user?.email || "",
|
body: JSON.stringify({
|
||||||
// codeword: (document.getElementById("codeword") as HTMLInputElement)?.value,
|
email: session?.user?.email || "",
|
||||||
// contacts: contactValues,
|
codeword: codeword,
|
||||||
// }),
|
contacts: contactValues,
|
||||||
// })
|
}),
|
||||||
// .then((response) => {
|
})
|
||||||
// if (response.ok) {
|
.then((response) => {
|
||||||
// alert("Contacts saved successfully!");
|
if (response.ok) {
|
||||||
// } else {
|
alert("Contacts saved successfully!");
|
||||||
// alert("Error saving contacts.");
|
} else {
|
||||||
// }
|
alert("Error saving contacts.");
|
||||||
// })
|
}
|
||||||
// .catch((error) => {
|
})
|
||||||
// console.error("Error:", error);
|
.catch((error) => {
|
||||||
// alert("Error saving contacts.");
|
console.error("Error:", error);
|
||||||
// });
|
alert("Error saving contacts.");
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return <div>Loading...</div>;
|
||||||
|
}
|
||||||
|
|
||||||
// If no session, show sign-up and login buttons
|
// If no session, show sign-up and login buttons
|
||||||
if (!session) {
|
if (!session) {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-7 bg-indigo-800 items-center justify-items-center min-h-screen p-8 pb-20 gap-16 sm:p-20 font-[family-name:var(--font-geist-sans)]">
|
<div className="space-y-7 bg-indigo-800 items-center justify-items-center min-h-screen p-8 pb-20 gap-16 sm:p-20 font-[family-name:var(--font-geist-sans)]">
|
||||||
<main className="space-x-2 flex flex-row gap-[32px] row-start-2 items-center sm:items-start">
|
<main className="space-x-2 flex flex-row gap-[32px] row-start-2 items-center sm:items-start">
|
||||||
<a href="/auth/login?screen_hint=signup">
|
<a href="/auth/login?screen_hint=signup">
|
||||||
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-300">Sign up</button>
|
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-300">
|
||||||
|
Sign up
|
||||||
|
</button>
|
||||||
</a>
|
</a>
|
||||||
<a href="/auth/login">
|
<a href="/auth/login">
|
||||||
<button className = "box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-400">Log in</button>
|
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-400">
|
||||||
|
Log in
|
||||||
|
</button>
|
||||||
</a>
|
</a>
|
||||||
</main>
|
</main>
|
||||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">Fauxcall</h1>
|
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">
|
||||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">Set emergency contacts</h2>
|
Fauxcall
|
||||||
<p>If you stop speaking or say the codeword, these contacts will be notified</p>
|
</h1>
|
||||||
|
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">
|
||||||
|
Set emergency contacts
|
||||||
|
</h2>
|
||||||
|
<p>
|
||||||
|
If you stop speaking or say the codeword, these contacts will be
|
||||||
|
notified
|
||||||
|
</p>
|
||||||
{/* form for setting codeword */}
|
{/* form for setting codeword */}
|
||||||
<form className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
<form
|
||||||
|
className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||||
|
onSubmit={(e) => e.preventDefault()}
|
||||||
|
>
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
value={codeword}
|
value={codeword}
|
||||||
@@ -70,11 +97,17 @@ export default async function Home() {
|
|||||||
className="border border-gray-300 rounded-md p-2"
|
className="border border-gray-300 rounded-md p-2"
|
||||||
/>
|
/>
|
||||||
<button
|
<button
|
||||||
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
||||||
type="submit">Set codeword</button>
|
type="submit"
|
||||||
|
>
|
||||||
|
Set codeword
|
||||||
|
</button>
|
||||||
</form>
|
</form>
|
||||||
{/* form for adding contacts */}
|
{/* form for adding contacts */}
|
||||||
<form className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
<form
|
||||||
|
className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||||
|
onSubmit={(e) => e.preventDefault()}
|
||||||
|
>
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
value={contacts}
|
value={contacts}
|
||||||
@@ -97,7 +130,12 @@ export default async function Home() {
|
|||||||
className="border border-gray-300 rounded-md p-2"
|
className="border border-gray-300 rounded-md p-2"
|
||||||
/>
|
/>
|
||||||
<button type="button">Add</button>
|
<button type="button">Add</button>
|
||||||
<button className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2" type="submit">Set contacts</button>
|
<button
|
||||||
|
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2"
|
||||||
|
type="submit"
|
||||||
|
>
|
||||||
|
Set contacts
|
||||||
|
</button>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -107,25 +145,42 @@ export default async function Home() {
|
|||||||
<div className="grid grid-rows-[20px_1fr_20px] items-center justify-items-center min-h-screen p-8 pb-20 gap-16 sm:p-20 font-[family-name:var(--font-geist-sans)]">
|
<div className="grid grid-rows-[20px_1fr_20px] items-center justify-items-center min-h-screen p-8 pb-20 gap-16 sm:p-20 font-[family-name:var(--font-geist-sans)]">
|
||||||
<main className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start">
|
<main className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start">
|
||||||
<h1>Welcome, {session.user.name}!</h1>
|
<h1>Welcome, {session.user.name}!</h1>
|
||||||
|
|
||||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">Fauxcall</h1>
|
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">
|
||||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">Set emergency contacts</h2>
|
Fauxcall
|
||||||
<p>If you stop speaking or say the codeword, these contacts will be notified</p>
|
</h1>
|
||||||
{/* form for setting codeword */}
|
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">
|
||||||
<form className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
Set emergency contacts
|
||||||
<input
|
</h2>
|
||||||
type="text"
|
<p>
|
||||||
value={codeword}
|
If you stop speaking or say the codeword, these contacts will be
|
||||||
onChange={(e) => setCodeword(e.target.value)}
|
notified
|
||||||
placeholder="Codeword"
|
</p>
|
||||||
className="border border-gray-300 rounded-md p-2"
|
{/* form for setting codeword */}
|
||||||
/>
|
<form
|
||||||
<button
|
className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||||
|
onSubmit={(e) => e.preventDefault()}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={codeword}
|
||||||
|
onChange={(e) => setCodeword(e.target.value)}
|
||||||
|
placeholder="Codeword"
|
||||||
|
className="border border-gray-300 rounded-md p-2"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
||||||
type="submit">Set codeword</button>
|
type="submit"
|
||||||
</form>
|
>
|
||||||
{/* form for adding contacts */}
|
Set codeword
|
||||||
<form id="Contacts" className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
</button>
|
||||||
|
</form>
|
||||||
|
{/* form for adding contacts */}
|
||||||
|
<form
|
||||||
|
id="Contacts"
|
||||||
|
className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||||
|
onSubmit={(e) => e.preventDefault()}
|
||||||
|
>
|
||||||
<input
|
<input
|
||||||
type="text"
|
type="text"
|
||||||
value={contacts}
|
value={contacts}
|
||||||
@@ -154,26 +209,37 @@ export default async function Home() {
|
|||||||
placeholder="Write down an emergency contact"
|
placeholder="Write down an emergency contact"
|
||||||
className="text-input border border-gray-300 rounded-md p-2"
|
className="text-input border border-gray-300 rounded-md p-2"
|
||||||
/>
|
/>
|
||||||
<button onClick={() => {
|
<button
|
||||||
alert("Adding contact...");
|
onClick={() => {
|
||||||
let elem = document.getElementsByClassName("text-input")[0] as HTMLElement;
|
alert("Adding contact...");
|
||||||
console.log("Element:", elem);
|
let elem = document.getElementsByClassName(
|
||||||
let d = elem.cloneNode(true) as HTMLElement;
|
"text-input"
|
||||||
document.getElementById("Contacts")?.appendChild(d);
|
)[0] as HTMLElement;
|
||||||
}}
|
console.log("Element:", elem);
|
||||||
className="bg-emerald-500 text-fuchsia-300"
|
let d = elem.cloneNode(true) as HTMLElement;
|
||||||
type="button">Add</button>
|
document.getElementById("Contacts")?.appendChild(d);
|
||||||
|
}}
|
||||||
<button
|
className="bg-emerald-500 text-fuchsia-300"
|
||||||
|
type="button"
|
||||||
|
>
|
||||||
|
Add
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={saveToDB}
|
onClick={saveToDB}
|
||||||
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2">Save</button>
|
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2"
|
||||||
</form>
|
>
|
||||||
<div>
|
Save
|
||||||
<a href="/call">
|
</button>
|
||||||
<button className="bg-zinc-700 text-lime-300 font-semibold font-lg rounded-md p-2">Call</button>
|
</form>
|
||||||
</a>
|
<div>
|
||||||
</div>
|
<a href="/call">
|
||||||
|
<button className="bg-zinc-700 text-lime-300 font-semibold font-lg rounded-md p-2">
|
||||||
|
Call
|
||||||
|
</button>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
<p>
|
<p>
|
||||||
<a href="/auth/logout">
|
<a href="/auth/logout">
|
||||||
<button>Log out</button>
|
<button>Log out</button>
|
||||||
@@ -182,6 +248,4 @@ export default async function Home() {
|
|||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user