Frontend Fixed
This commit is contained in:
136
Backend/api/app.py
Normal file
136
Backend/api/app.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from flask import Flask
|
||||
from flask_socketio import SocketIO
|
||||
from flask_cors import CORS
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure device
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
logger.info(f"Using device: {DEVICE}")
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__, static_folder='../', static_url_path='')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||
|
||||
# Global variables for conversation state
|
||||
active_conversations = {}
|
||||
user_queues = {}
|
||||
processing_threads = {}
|
||||
|
||||
# Model storage
|
||||
@dataclass
|
||||
class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
last_language = None
|
||||
|
||||
models = AppModels()
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
from generator import load_csm_1b
|
||||
import whisperx
|
||||
import gc
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
global models
|
||||
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||
|
||||
# CSM 1B loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||
models.generator = load_csm_1b(device=DEVICE)
|
||||
logger.info("CSM 1B model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# WhisperX loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||
# Use WhisperX for better transcription with timestamps
|
||||
# Use compute_type based on device
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
|
||||
# Load the WhisperX model (smaller model for faster processing)
|
||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Llama loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
device_map=DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
|
||||
# Configure all special tokens
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||
|
||||
# Inform the model about the pad token
|
||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||
|
||||
logger.info("Llama 3.2 model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Load models in a background thread
|
||||
threading.Thread(target=load_models, daemon=True).start()
|
||||
|
||||
# Import routes and socket handlers
|
||||
from api.routes import register_routes
|
||||
from api.socket_handlers import register_handlers
|
||||
|
||||
# Register routes and socket handlers
|
||||
register_routes(app)
|
||||
register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE)
|
||||
|
||||
# Run server if executed directly
|
||||
if __name__ == '__main__':
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||
74
Backend/api/routes.py
Normal file
74
Backend/api/routes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import torch
|
||||
import psutil
|
||||
from flask import send_from_directory, jsonify, request
|
||||
|
||||
def register_routes(app):
|
||||
"""Register HTTP routes for the application"""
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""Serve the main application page"""
|
||||
return send_from_directory(app.static_folder, 'index.html')
|
||||
|
||||
@app.route('/voice-chat.js')
|
||||
def serve_js():
|
||||
"""Serve the JavaScript file"""
|
||||
return send_from_directory(app.static_folder, 'voice-chat.js')
|
||||
|
||||
@app.route('/api/status')
|
||||
def system_status():
|
||||
"""Return the system status"""
|
||||
# Import here to avoid circular imports
|
||||
from api.app import models, DEVICE
|
||||
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"generator": models.generator is not None,
|
||||
"asr": models.whisperx_model is not None,
|
||||
"llm": models.llm is not None
|
||||
},
|
||||
"versions": {
|
||||
"transformers": "4.49.0", # Replace with actual version
|
||||
"torch": torch.__version__
|
||||
}
|
||||
})
|
||||
|
||||
@app.route('/api/system_resources')
|
||||
def system_resources():
|
||||
"""Return system resource usage"""
|
||||
# Import here to avoid circular imports
|
||||
from api.app import active_conversations, DEVICE
|
||||
|
||||
# Get CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
# Get memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
memory_used_gb = memory.used / (1024 ** 3)
|
||||
memory_total_gb = memory.total / (1024 ** 3)
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Get GPU memory if available
|
||||
gpu_memory = {}
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_memory[f"gpu_{i}"] = {
|
||||
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
||||
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory": {
|
||||
"used_gb": memory_used_gb,
|
||||
"total_gb": memory_total_gb,
|
||||
"percent": memory_percent
|
||||
},
|
||||
"gpu_memory": gpu_memory,
|
||||
"active_sessions": len(active_conversations)
|
||||
})
|
||||
392
Backend/api/socket_handlers.py
Normal file
392
Backend/api/socket_handlers.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
import gc
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from flask import request
|
||||
from flask_socketio import emit
|
||||
|
||||
# Import conversation model
|
||||
from generator import Segment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Conversation data structure
|
||||
class Conversation:
|
||||
def __init__(self, session_id):
|
||||
self.session_id = session_id
|
||||
self.segments: List[Segment] = []
|
||||
self.current_speaker = 0
|
||||
self.ai_speaker_id = 1 # Default AI speaker ID
|
||||
self.last_activity = time.time()
|
||||
self.is_processing = False
|
||||
|
||||
def add_segment(self, text, speaker, audio):
|
||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||
self.segments.append(segment)
|
||||
self.last_activity = time.time()
|
||||
return segment
|
||||
|
||||
def get_context(self, max_segments=10):
|
||||
"""Return the most recent segments for context"""
|
||||
return self.segments[-max_segments:] if self.segments else []
|
||||
|
||||
def register_handlers(socketio, app, models, active_conversations, user_queues, processing_threads, DEVICE):
|
||||
"""Register Socket.IO event handlers"""
|
||||
|
||||
@socketio.on('connect')
|
||||
def handle_connect(auth=None):
|
||||
"""Handle client connection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client connected: {session_id}")
|
||||
|
||||
# Initialize conversation data
|
||||
if session_id not in active_conversations:
|
||||
active_conversations[session_id] = Conversation(session_id)
|
||||
user_queues[session_id] = queue.Queue()
|
||||
processing_threads[session_id] = threading.Thread(
|
||||
target=process_audio_queue,
|
||||
args=(session_id, user_queues[session_id], app, socketio, models, active_conversations, DEVICE),
|
||||
daemon=True
|
||||
)
|
||||
processing_threads[session_id].start()
|
||||
|
||||
emit('connection_status', {'status': 'connected'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect(reason=None):
|
||||
"""Handle client disconnection"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
||||
|
||||
# Cleanup
|
||||
if session_id in active_conversations:
|
||||
# Mark for deletion rather than immediately removing
|
||||
# as the processing thread might still be accessing it
|
||||
active_conversations[session_id].is_processing = False
|
||||
user_queues[session_id].put(None) # Signal thread to terminate
|
||||
|
||||
@socketio.on('audio_data')
|
||||
def handle_audio_data(data):
|
||||
"""Handle incoming audio data"""
|
||||
session_id = request.sid
|
||||
logger.info(f"Received audio data from {session_id}")
|
||||
|
||||
# Check if the models are loaded
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
emit('error', {'message': 'Models still loading, please wait'})
|
||||
return
|
||||
|
||||
# Check if we're already processing for this session
|
||||
if session_id in active_conversations and active_conversations[session_id].is_processing:
|
||||
emit('error', {'message': 'Still processing previous audio, please wait'})
|
||||
return
|
||||
|
||||
# Add to processing queue
|
||||
if session_id in user_queues:
|
||||
user_queues[session_id].put(data)
|
||||
else:
|
||||
emit('error', {'message': 'Session not initialized, please refresh the page'})
|
||||
|
||||
def process_audio_queue(session_id, q, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Background thread to process audio chunks for a session"""
|
||||
logger.info(f"Started processing thread for session: {session_id}")
|
||||
|
||||
try:
|
||||
while session_id in active_conversations:
|
||||
try:
|
||||
# Get the next audio chunk with a timeout
|
||||
data = q.get(timeout=120)
|
||||
if data is None: # Termination signal
|
||||
break
|
||||
|
||||
# Process the audio and generate a response
|
||||
process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE)
|
||||
|
||||
except queue.Empty:
|
||||
# Timeout, check if session is still valid
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||
# Create an app context for the socket emit
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||
finally:
|
||||
logger.info(f"Ending processing thread for session: {session_id}")
|
||||
# Clean up when thread is done
|
||||
with app.app_context():
|
||||
if session_id in active_conversations:
|
||||
del active_conversations[session_id]
|
||||
if session_id in user_queues:
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data, app, socketio, models, active_conversations, DEVICE):
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Processing audio for session {session_id}")
|
||||
conversation = active_conversations[session_id]
|
||||
|
||||
try:
|
||||
# Set processing flag
|
||||
conversation.is_processing = True
|
||||
|
||||
# Process base64 audio data
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
logger.info(f"Received audio from speaker {speaker_id}")
|
||||
|
||||
# Convert from base64 to WAV
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||
raise
|
||||
|
||||
# Save to temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||
temp_file.write(audio_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Notify client that transcription is starting
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio using WhisperX
|
||||
import whisperx
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
|
||||
# Check audio length and add a warning for short clips
|
||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||
if audio_length < 1.0:
|
||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # adjust based on your GPU memory
|
||||
logger.info("Running WhisperX transcription...")
|
||||
|
||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Check if alignment model needs to be loaded or updated
|
||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||
# Clean up old models if they exist
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
if DEVICE == "cuda":
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load new alignment model for the detected language
|
||||
logger.info(f"Loading alignment model for language: {language_code}")
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript to get word-level timestamps
|
||||
if result["segments"] and len(result["segments"]) > 0:
|
||||
logger.info("Aligning transcript...")
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
)
|
||||
|
||||
# Process the segments for better output
|
||||
for segment in result["segments"]:
|
||||
# Round timestamps for better display
|
||||
segment["start"] = round(segment["start"], 2)
|
||||
segment["end"] = round(segment["end"], 2)
|
||||
# Add a confidence score if not present
|
||||
if "confidence" not in segment:
|
||||
segment["confidence"] = 1.0 # Default confidence
|
||||
|
||||
# Extract the full text from all segments
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Transcription: {user_text}")
|
||||
|
||||
# Load audio for CSM input
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample to the CSM sample rate if needed
|
||||
if sample_rate != models.generator.sample_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
speaker=speaker_id,
|
||||
audio=waveform.squeeze()
|
||||
)
|
||||
|
||||
# Send transcription to client with detailed segments
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||
|
||||
# Create prompt from conversation history
|
||||
conversation_history = ""
|
||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||
role = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{role}: {segment.text}\n"
|
||||
|
||||
# Add final prompt
|
||||
prompt = f"{conversation_history}Assistant: "
|
||||
|
||||
# Generate response with Llama
|
||||
try:
|
||||
# Ensure pad token is set
|
||||
if models.tokenizer.pad_token is None:
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
|
||||
input_tokens = models.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=models.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the response
|
||||
response_text = models.tokenizer.decode(
|
||||
generated_ids[0][input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||
|
||||
# Synthesize speech
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||
|
||||
# Start sending the audio response
|
||||
socketio.emit('audio_response_start', {
|
||||
'text': response_text,
|
||||
'total_chunks': 1,
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Define AI speaker ID
|
||||
ai_speaker_id = conversation.ai_speaker_id
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
# Convert audio to WAV format
|
||||
with io.BytesIO() as wav_io:
|
||||
torchaudio.save(
|
||||
wav_io,
|
||||
audio_tensor.unsqueeze(0).cpu(),
|
||||
models.generator.sample_rate,
|
||||
format="wav"
|
||||
)
|
||||
wav_io.seek(0)
|
||||
wav_data = wav_io.read()
|
||||
|
||||
# Convert WAV data to base64
|
||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||
|
||||
# Send audio chunk to client
|
||||
with app.app_context():
|
||||
socketio.emit('audio_response_chunk', {
|
||||
'chunk': audio_base64,
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'is_last': True
|
||||
}, room=session_id)
|
||||
|
||||
# Signal completion
|
||||
socketio.emit('audio_response_complete', {
|
||||
'text': response_text
|
||||
}, room=session_id)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
# Reset processing flag
|
||||
conversation.is_processing = False
|
||||
@@ -1,419 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CSM Voice Chat</title>
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<!-- Socket.IO client library -->
|
||||
<script src="https://cdn.socket.io/4.6.0/socket.io.min.js"></script>
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #4c84ff;
|
||||
--secondary-color: #3367d6;
|
||||
--text-color: #333;
|
||||
--background-color: #f9f9f9;
|
||||
--card-background: #ffffff;
|
||||
--accent-color: #ff5252;
|
||||
--success-color: #4CAF50;
|
||||
--border-color: #e0e0e0;
|
||||
--shadow-color: rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background-color: var(--background-color);
|
||||
color: var(--text-color);
|
||||
line-height: 1.6;
|
||||
max-width: 1000px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: var(--primary-color);
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #666;
|
||||
font-weight: 300;
|
||||
}
|
||||
|
||||
.app-container {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
.app-container {
|
||||
grid-template-columns: 2fr 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.chat-container, .control-panel {
|
||||
background-color: var(--card-background);
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 12px var(--shadow-color);
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.control-panel {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.chat-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 15px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.conversation {
|
||||
height: 400px;
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
border-radius: 8px;
|
||||
background-color: #f7f9fc;
|
||||
margin-bottom: 20px;
|
||||
scroll-behavior: smooth;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin-bottom: 15px;
|
||||
padding: 12px 15px;
|
||||
border-radius: 12px;
|
||||
max-width: 85%;
|
||||
position: relative;
|
||||
animation: fade-in 0.3s ease-out forwards;
|
||||
}
|
||||
|
||||
@keyframes fade-in {
|
||||
from { opacity: 0; transform: translateY(10px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
.user {
|
||||
background-color: #e3f2fd;
|
||||
color: #0d47a1;
|
||||
margin-left: auto;
|
||||
border-bottom-right-radius: 4px;
|
||||
}
|
||||
|
||||
.ai {
|
||||
background-color: #f1f1f1;
|
||||
color: #37474f;
|
||||
margin-right: auto;
|
||||
border-bottom-left-radius: 4px;
|
||||
}
|
||||
|
||||
.system {
|
||||
background-color: #f8f9fa;
|
||||
font-style: italic;
|
||||
color: #666;
|
||||
text-align: center;
|
||||
max-width: 90%;
|
||||
margin: 10px auto;
|
||||
font-size: 0.9em;
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.audio-player {
|
||||
width: 100%;
|
||||
margin-top: 8px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 12px 20px;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
background-color: var(--primary-color);
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
}
|
||||
|
||||
button.recording {
|
||||
background-color: var(--accent-color);
|
||||
animation: pulse 1.5s infinite;
|
||||
}
|
||||
|
||||
button.processing {
|
||||
background-color: #ffa000;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { opacity: 1; }
|
||||
50% { opacity: 0.7; }
|
||||
100% { opacity: 1; }
|
||||
}
|
||||
|
||||
.status-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
font-size: 0.9em;
|
||||
color: #555;
|
||||
}
|
||||
|
||||
.status-dot {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background-color: #ccc;
|
||||
}
|
||||
|
||||
.status-dot.active {
|
||||
background-color: var(--success-color);
|
||||
}
|
||||
|
||||
/* Audio visualizer styles */
|
||||
.visualizer-container {
|
||||
margin-top: 15px;
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100px;
|
||||
background-color: #000;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
#audioVisualizer {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
transition: opacity 0.3s;
|
||||
}
|
||||
|
||||
#visualizerLabel {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
color: rgba(255, 255, 255, 0.7);
|
||||
font-size: 0.9em;
|
||||
pointer-events: none;
|
||||
transition: opacity 0.3s;
|
||||
}
|
||||
|
||||
.volume-meter {
|
||||
height: 8px;
|
||||
width: 100%;
|
||||
background-color: #eee;
|
||||
border-radius: 4px;
|
||||
margin-top: 8px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
#volumeLevel {
|
||||
height: 100%;
|
||||
width: 0%;
|
||||
background-color: var(--primary-color);
|
||||
border-radius: 4px;
|
||||
transition: width 0.1s ease, background-color 0.2s;
|
||||
}
|
||||
|
||||
.settings-toggles {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.toggle-switch {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
footer {
|
||||
margin-top: 30px;
|
||||
text-align: center;
|
||||
font-size: 0.8em;
|
||||
color: #777;
|
||||
}
|
||||
|
||||
/* Model status indicators */
|
||||
.model-status {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.model-indicator {
|
||||
padding: 3px 6px;
|
||||
border-radius: 4px;
|
||||
font-size: 0.7em;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.model-indicator.loading {
|
||||
background-color: #ffd54f;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.model-indicator.loaded {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.model-indicator.error {
|
||||
background-color: #f44336;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.message-timestamp {
|
||||
font-size: 0.7em;
|
||||
color: #888;
|
||||
margin-top: 4px;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
.simple-timestamp {
|
||||
font-size: 0.8em;
|
||||
color: #888;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
/* Add this to your existing styles */
|
||||
.loading-progress {
|
||||
width: 100%;
|
||||
max-width: 150px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
progress {
|
||||
width: 100%;
|
||||
height: 8px;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
progress::-webkit-progress-bar {
|
||||
background-color: #eee;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
progress::-webkit-progress-value {
|
||||
background-color: var(--primary-color);
|
||||
border-radius: 4px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>CSM Voice Chat</h1>
|
||||
<p class="subtitle">Talk naturally with the AI using your voice</p>
|
||||
</header>
|
||||
|
||||
<div class="app-container">
|
||||
<div class="chat-container">
|
||||
<div class="chat-header">
|
||||
<h2>Conversation</h2>
|
||||
<div class="status-indicator">
|
||||
<div id="statusDot" class="status-dot"></div>
|
||||
<span id="statusText">Disconnected</span>
|
||||
</div>
|
||||
|
||||
<!-- Add this above the model status indicators in the chat-header div -->
|
||||
<div class="loading-progress">
|
||||
<progress id="modelLoadingProgress" max="100" value="0">0%</progress>
|
||||
</div>
|
||||
|
||||
<!-- Add this model status panel -->
|
||||
<div class="model-status">
|
||||
<div id="csmStatus" class="model-indicator loading" title="Loading CSM model...">CSM</div>
|
||||
<div id="asrStatus" class="model-indicator loading" title="Loading ASR model...">ASR</div>
|
||||
<div id="llmStatus" class="model-indicator loading" title="Loading LLM model...">LLM</div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="conversation" class="conversation"></div>
|
||||
</div>
|
||||
|
||||
<div class="control-panel">
|
||||
<div>
|
||||
<h3>Controls</h3>
|
||||
<p>Click the button below to start and stop recording.</p>
|
||||
<div class="button-row">
|
||||
<button id="streamButton">
|
||||
<i class="fas fa-microphone"></i>
|
||||
Start Conversation
|
||||
</button>
|
||||
<button id="clearButton">
|
||||
<i class="fas fa-trash"></i>
|
||||
Clear
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Audio visualizer and volume meter -->
|
||||
<div class="visualizer-container">
|
||||
<canvas id="audioVisualizer"></canvas>
|
||||
<div id="visualizerLabel">Start speaking to see audio visualization</div>
|
||||
</div>
|
||||
<div class="volume-meter">
|
||||
<div id="volumeLevel"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="settings-panel">
|
||||
<h3>Settings</h3>
|
||||
<div class="settings-toggles">
|
||||
<div class="toggle-switch">
|
||||
<input type="checkbox" id="autoPlayResponses" checked>
|
||||
<label for="autoPlayResponses">Autoplay Responses</label>
|
||||
</div>
|
||||
<div class="toggle-switch">
|
||||
<input type="checkbox" id="showVisualizer" checked>
|
||||
<label for="showVisualizer">Show Audio Visualizer</label>
|
||||
</div>
|
||||
<div>
|
||||
<label for="speakerSelect">Speaker Voice:</label>
|
||||
<select id="speakerSelect">
|
||||
<option value="0">Speaker 0 (You)</option>
|
||||
<option value="1">Speaker 1 (AI)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label for="thresholdSlider">Silence Threshold: <span id="thresholdValue">0.010</span></label>
|
||||
<input type="range" id="thresholdSlider" min="0.001" max="0.05" step="0.001" value="0.01">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<p>Powered by CSM 1B & Llama 3.2 | Whisper for speech recognition</p>
|
||||
</footer>
|
||||
|
||||
<!-- Load external JavaScript file -->
|
||||
<script src="voice-chat.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,13 +0,0 @@
|
||||
flask==2.2.5
|
||||
flask-socketio==5.3.6
|
||||
flask-cors==4.0.0
|
||||
torch==2.4.0
|
||||
torchaudio==2.4.0
|
||||
tokenizers==0.21.0
|
||||
transformers==4.49.0
|
||||
librosa==0.10.1
|
||||
huggingface_hub==0.28.1
|
||||
moshi==0.2.2
|
||||
torchtune==0.4.0
|
||||
torchao==0.9.0
|
||||
silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
|
||||
@@ -1,117 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Disable Triton compilation
|
||||
os.environ["NO_TORCH_COMPILE"] = "1"
|
||||
|
||||
# Default prompts are available at https://hf.co/sesame/csm-1b
|
||||
prompt_filepath_conversational_a = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_a.wav"
|
||||
)
|
||||
prompt_filepath_conversational_b = hf_hub_download(
|
||||
repo_id="sesame/csm-1b",
|
||||
filename="prompts/conversational_b.wav"
|
||||
)
|
||||
|
||||
SPEAKER_PROMPTS = {
|
||||
"conversational_a": {
|
||||
"text": (
|
||||
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
||||
"start really early I'd be like okay I'm gonna start revising now and then like "
|
||||
"you're revising for ages and then I just like start losing steam I didn't do that "
|
||||
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
||||
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
||||
"sort of start the day with this not like a panic but like a"
|
||||
),
|
||||
"audio": prompt_filepath_conversational_a
|
||||
},
|
||||
"conversational_b": {
|
||||
"text": (
|
||||
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
||||
"into the park, it just like, everything looks like a computer game and they have all "
|
||||
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
||||
"will have like a question block. And if you like, you know, punch it, a coin will "
|
||||
"come out. So like everyone, when they come into the park, they get like this little "
|
||||
"bracelet and then you can go punching question blocks around."
|
||||
),
|
||||
"audio": prompt_filepath_conversational_b
|
||||
}
|
||||
}
|
||||
|
||||
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
# Resample is lazy so we can always call it
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
||||
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
||||
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
||||
|
||||
def main():
|
||||
# Select the best available device, skipping MPS due to float64 limitations
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
generator = load_csm_1b(device)
|
||||
|
||||
# Prepare prompts
|
||||
prompt_a = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_a"]["text"],
|
||||
0,
|
||||
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
prompt_b = prepare_prompt(
|
||||
SPEAKER_PROMPTS["conversational_b"]["text"],
|
||||
1,
|
||||
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
||||
generator.sample_rate
|
||||
)
|
||||
|
||||
# Generate conversation
|
||||
conversation = [
|
||||
{"text": "Hey how are you doing?", "speaker_id": 0},
|
||||
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
||||
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
||||
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
||||
]
|
||||
|
||||
# Generate each utterance
|
||||
generated_segments = []
|
||||
prompt_segments = [prompt_a, prompt_b]
|
||||
|
||||
for utterance in conversation:
|
||||
print(f"Generating: {utterance['text']}")
|
||||
audio_tensor = generator.generate(
|
||||
text=utterance['text'],
|
||||
speaker=utterance['speaker_id'],
|
||||
context=prompt_segments + generated_segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
||||
|
||||
# Concatenate all generations
|
||||
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
||||
torchaudio.save(
|
||||
"full_conversation.wav",
|
||||
all_audio.unsqueeze(0).cpu(),
|
||||
generator.sample_rate
|
||||
)
|
||||
print("Successfully generated full_conversation.wav")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,639 +1,19 @@
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
import queue
|
||||
import tempfile
|
||||
import gc
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
"""
|
||||
CSM Voice Chat Server
|
||||
A voice chat application that uses CSM 1B for voice synthesis,
|
||||
WhisperX for speech recognition, and Llama 3.2 for language generation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify, send_from_directory
|
||||
from flask_socketio import SocketIO, emit
|
||||
from flask_cors import CORS
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
# Start the Flask application
|
||||
from api.app import app, socketio
|
||||
|
||||
# Import WhisperX for better transcription
|
||||
import whisperx
|
||||
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Add these imports at the top
|
||||
import psutil
|
||||
import gc
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__, static_folder='.')
|
||||
CORS(app)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=120)
|
||||
|
||||
# Configure device
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
logger.info(f"Using device: {DEVICE}")
|
||||
|
||||
# Global variables
|
||||
active_conversations = {}
|
||||
user_queues = {}
|
||||
processing_threads = {}
|
||||
|
||||
# Load models
|
||||
@dataclass
|
||||
class AppModels:
|
||||
generator = None
|
||||
tokenizer = None
|
||||
llm = None
|
||||
whisperx_model = None
|
||||
whisperx_align_model = None
|
||||
whisperx_align_metadata = None
|
||||
diarize_model = None
|
||||
last_language = None
|
||||
|
||||
# Initialize the models object
|
||||
models = AppModels()
|
||||
|
||||
def load_models():
|
||||
"""Load all required models"""
|
||||
global models
|
||||
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 0})
|
||||
|
||||
# CSM 1B loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 10, 'message': 'Loading CSM voice model'})
|
||||
models.generator = load_csm_1b(device=DEVICE)
|
||||
logger.info("CSM 1B model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 33})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading CSM 1B model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'csm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# WhisperX loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 40, 'message': 'Loading speech recognition model'})
|
||||
# Use WhisperX for better transcription with timestamps
|
||||
import whisperx
|
||||
|
||||
# Use compute_type based on device
|
||||
compute_type = "float16" if DEVICE == "cuda" else "float32"
|
||||
|
||||
# Load the WhisperX model (smaller model for faster processing)
|
||||
models.whisperx_model = whisperx.load_model("small", DEVICE, compute_type=compute_type)
|
||||
|
||||
logger.info("WhisperX model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 66})
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error loading WhisperX model: {str(e)}\n{error_details}")
|
||||
socketio.emit('model_status', {'model': 'asr', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Llama loading
|
||||
try:
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 70, 'message': 'Loading language model'})
|
||||
models.llm = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
device_map=DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
models.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
|
||||
# Configure all special tokens
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
models.tokenizer.padding_side = "left" # For causal language modeling
|
||||
|
||||
# Inform the model about the pad token
|
||||
if hasattr(models.llm.config, "pad_token_id") and models.llm.config.pad_token_id is None:
|
||||
models.llm.config.pad_token_id = models.tokenizer.pad_token_id
|
||||
|
||||
logger.info("Llama 3.2 model loaded successfully")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'loaded'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loading', 'progress': 100, 'message': 'All models loaded successfully'})
|
||||
socketio.emit('model_status', {'model': 'overall', 'status': 'loaded'})
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading Llama 3.2 model: {str(e)}")
|
||||
socketio.emit('model_status', {'model': 'llm', 'status': 'error', 'message': str(e)})
|
||||
|
||||
# Load models in a background thread
|
||||
threading.Thread(target=load_models, daemon=True).start()
|
||||
|
||||
# Conversation data structure
|
||||
class Conversation:
|
||||
def __init__(self, session_id):
|
||||
self.session_id = session_id
|
||||
self.segments: List[Segment] = []
|
||||
self.current_speaker = 0
|
||||
self.ai_speaker_id = 1 # Add this property
|
||||
self.last_activity = time.time()
|
||||
self.is_processing = False
|
||||
|
||||
def add_segment(self, text, speaker, audio):
|
||||
segment = Segment(text=text, speaker=speaker, audio=audio)
|
||||
self.segments.append(segment)
|
||||
self.last_activity = time.time()
|
||||
return segment
|
||||
|
||||
def get_context(self, max_segments=10):
|
||||
"""Return the most recent segments for context"""
|
||||
return self.segments[-max_segments:] if self.segments else []
|
||||
|
||||
# Routes
|
||||
@app.route('/')
|
||||
def index():
|
||||
return send_from_directory('.', 'index.html')
|
||||
|
||||
@app.route('/voice-chat.js')
|
||||
def voice_chat_js():
|
||||
return send_from_directory('.', 'voice-chat.js')
|
||||
|
||||
@app.route('/api/health')
|
||||
def health_check():
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"models_loaded": models.generator is not None and models.llm is not None
|
||||
})
|
||||
|
||||
# Fix the system_status function:
|
||||
|
||||
@app.route('/api/status')
|
||||
def system_status():
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"device": DEVICE,
|
||||
"models": {
|
||||
"generator": models.generator is not None,
|
||||
"asr": models.whisperx_model is not None, # Use the correct model name
|
||||
"llm": models.llm is not None
|
||||
}
|
||||
})
|
||||
|
||||
# Add a new endpoint to check system resources
|
||||
@app.route('/api/system_resources')
|
||||
def system_resources():
|
||||
# Get CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
|
||||
# Get memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
memory_used_gb = memory.used / (1024 ** 3)
|
||||
memory_total_gb = memory.total / (1024 ** 3)
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Get GPU memory if available
|
||||
gpu_memory = {}
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_memory[f"gpu_{i}"] = {
|
||||
"allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||
"reserved": torch.cuda.memory_reserved(i) / (1024 ** 3),
|
||||
"max_allocated": torch.cuda.max_memory_allocated(i) / (1024 ** 3)
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory": {
|
||||
"used_gb": memory_used_gb,
|
||||
"total_gb": memory_total_gb,
|
||||
"percent": memory_percent
|
||||
},
|
||||
"gpu_memory": gpu_memory,
|
||||
"active_sessions": len(active_conversations)
|
||||
})
|
||||
|
||||
# Socket event handlers
|
||||
@socketio.on('connect')
|
||||
def handle_connect(auth=None):
|
||||
session_id = request.sid
|
||||
logger.info(f"Client connected: {session_id}")
|
||||
|
||||
# Initialize conversation data
|
||||
if session_id not in active_conversations:
|
||||
active_conversations[session_id] = Conversation(session_id)
|
||||
user_queues[session_id] = queue.Queue()
|
||||
processing_threads[session_id] = threading.Thread(
|
||||
target=process_audio_queue,
|
||||
args=(session_id, user_queues[session_id]),
|
||||
daemon=True
|
||||
)
|
||||
processing_threads[session_id].start()
|
||||
|
||||
emit('connection_status', {'status': 'connected'})
|
||||
|
||||
@socketio.on('disconnect')
|
||||
def handle_disconnect(reason=None):
|
||||
session_id = request.sid
|
||||
logger.info(f"Client disconnected: {session_id}. Reason: {reason}")
|
||||
|
||||
# Cleanup
|
||||
if session_id in active_conversations:
|
||||
# Mark for deletion rather than immediately removing
|
||||
# as the processing thread might still be accessing it
|
||||
active_conversations[session_id].is_processing = False
|
||||
user_queues[session_id].put(None) # Signal thread to terminate
|
||||
|
||||
@socketio.on('start_stream')
|
||||
def handle_start_stream():
|
||||
session_id = request.sid
|
||||
logger.info(f"Starting stream for client: {session_id}")
|
||||
emit('streaming_status', {'status': 'active'})
|
||||
|
||||
@socketio.on('stop_stream')
|
||||
def handle_stop_stream():
|
||||
session_id = request.sid
|
||||
logger.info(f"Stopping stream for client: {session_id}")
|
||||
emit('streaming_status', {'status': 'inactive'})
|
||||
|
||||
@socketio.on('clear_context')
|
||||
def handle_clear_context():
|
||||
session_id = request.sid
|
||||
logger.info(f"Clearing context for client: {session_id}")
|
||||
|
||||
if session_id in active_conversations:
|
||||
active_conversations[session_id].segments = []
|
||||
emit('context_updated', {'status': 'cleared'})
|
||||
|
||||
@socketio.on('audio_chunk')
|
||||
def handle_audio_chunk(data):
|
||||
session_id = request.sid
|
||||
audio_data = data.get('audio', '')
|
||||
speaker_id = int(data.get('speaker', 0))
|
||||
|
||||
if not audio_data or not session_id in active_conversations:
|
||||
return
|
||||
|
||||
# Update the current speaker
|
||||
active_conversations[session_id].current_speaker = speaker_id
|
||||
|
||||
# Queue audio for processing
|
||||
user_queues[session_id].put({
|
||||
'audio': audio_data,
|
||||
'speaker': speaker_id
|
||||
})
|
||||
|
||||
emit('processing_status', {'status': 'transcribing'})
|
||||
|
||||
def process_audio_queue(session_id, q):
|
||||
"""Background thread to process audio chunks for a session"""
|
||||
logger.info(f"Started processing thread for session: {session_id}")
|
||||
|
||||
try:
|
||||
while session_id in active_conversations:
|
||||
try:
|
||||
# Get the next audio chunk with a timeout
|
||||
data = q.get(timeout=120)
|
||||
if data is None: # Termination signal
|
||||
break
|
||||
|
||||
# Process the audio and generate a response
|
||||
process_audio_and_respond(session_id, data)
|
||||
|
||||
except queue.Empty:
|
||||
# Timeout, check if session is still valid
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio for {session_id}: {str(e)}")
|
||||
# Create an app context for the socket emit
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': str(e)}, room=session_id)
|
||||
finally:
|
||||
logger.info(f"Ending processing thread for session: {session_id}")
|
||||
# Clean up when thread is done
|
||||
with app.app_context():
|
||||
if session_id in active_conversations:
|
||||
del active_conversations[session_id]
|
||||
if session_id in user_queues:
|
||||
del user_queues[session_id]
|
||||
|
||||
def process_audio_and_respond(session_id, data):
|
||||
"""Process audio data and generate a response using WhisperX"""
|
||||
if models.generator is None or models.whisperx_model is None or models.llm is None:
|
||||
logger.warning("Models not yet loaded!")
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'Models still loading, please wait'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Processing audio for session {session_id}")
|
||||
conversation = active_conversations[session_id]
|
||||
|
||||
try:
|
||||
# Set processing flag
|
||||
conversation.is_processing = True
|
||||
|
||||
# Process base64 audio data
|
||||
audio_data = data['audio']
|
||||
speaker_id = data['speaker']
|
||||
logger.info(f"Received audio from speaker {speaker_id}")
|
||||
|
||||
# Convert from base64 to WAV
|
||||
try:
|
||||
audio_bytes = base64.b64decode(audio_data.split(',')[1])
|
||||
logger.info(f"Decoded audio bytes: {len(audio_bytes)} bytes")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 audio: {str(e)}")
|
||||
raise
|
||||
|
||||
# Save to temporary file for processing
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
||||
temp_file.write(audio_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Notify client that transcription is starting
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'transcribing'}, room=session_id)
|
||||
|
||||
# Load audio using WhisperX
|
||||
import whisperx
|
||||
audio = whisperx.load_audio(temp_path)
|
||||
|
||||
# Check audio length and add a warning for short clips
|
||||
audio_length = len(audio) / 16000 # assuming 16kHz sample rate
|
||||
if audio_length < 1.0:
|
||||
logger.warning(f"Audio is very short ({audio_length:.2f}s), may affect transcription quality")
|
||||
|
||||
# Transcribe using WhisperX
|
||||
batch_size = 16 # adjust based on your GPU memory
|
||||
logger.info("Running WhisperX transcription...")
|
||||
|
||||
# Handle the warning about audio being shorter than 30s by suppressing it
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="audio is shorter than 30s")
|
||||
result = models.whisperx_model.transcribe(audio, batch_size=batch_size)
|
||||
|
||||
# Get the detected language
|
||||
language_code = result["language"]
|
||||
logger.info(f"Detected language: {language_code}")
|
||||
|
||||
# Check if alignment model needs to be loaded or updated
|
||||
if models.whisperx_align_model is None or language_code != models.last_language:
|
||||
# Clean up old models if they exist
|
||||
if models.whisperx_align_model is not None:
|
||||
del models.whisperx_align_model
|
||||
del models.whisperx_align_metadata
|
||||
if DEVICE == "cuda":
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load new alignment model for the detected language
|
||||
logger.info(f"Loading alignment model for language: {language_code}")
|
||||
models.whisperx_align_model, models.whisperx_align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code, device=DEVICE
|
||||
)
|
||||
models.last_language = language_code
|
||||
|
||||
# Align the transcript to get word-level timestamps
|
||||
if result["segments"] and len(result["segments"]) > 0:
|
||||
logger.info("Aligning transcript...")
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
models.whisperx_align_model,
|
||||
models.whisperx_align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False
|
||||
)
|
||||
|
||||
# Process the segments for better output
|
||||
for segment in result["segments"]:
|
||||
# Round timestamps for better display
|
||||
segment["start"] = round(segment["start"], 2)
|
||||
segment["end"] = round(segment["end"], 2)
|
||||
# Add a confidence score if not present
|
||||
if "confidence" not in segment:
|
||||
segment["confidence"] = 1.0 # Default confidence
|
||||
|
||||
# Extract the full text from all segments
|
||||
user_text = ' '.join([segment['text'] for segment in result['segments']])
|
||||
|
||||
# If no text was recognized, don't process further
|
||||
if not user_text or len(user_text.strip()) == 0:
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': 'No speech detected'}, room=session_id)
|
||||
return
|
||||
|
||||
logger.info(f"Transcription: {user_text}")
|
||||
|
||||
# Load audio for CSM input
|
||||
waveform, sample_rate = torchaudio.load(temp_path)
|
||||
|
||||
# Normalize to mono if needed
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||
|
||||
# Resample to the CSM sample rate if needed
|
||||
if sample_rate != models.generator.sample_rate:
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform,
|
||||
orig_freq=sample_rate,
|
||||
new_freq=models.generator.sample_rate
|
||||
)
|
||||
|
||||
# Add the user's message to conversation history
|
||||
user_segment = conversation.add_segment(
|
||||
text=user_text,
|
||||
speaker=speaker_id,
|
||||
audio=waveform.squeeze()
|
||||
)
|
||||
|
||||
# Send transcription to client with detailed segments
|
||||
with app.app_context():
|
||||
socketio.emit('transcription', {
|
||||
'text': user_text,
|
||||
'speaker': speaker_id,
|
||||
'segments': result['segments'] # Include the detailed segments with timestamps
|
||||
}, room=session_id)
|
||||
|
||||
# Generate AI response using Llama
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'generating'}, room=session_id)
|
||||
|
||||
# Create prompt from conversation history
|
||||
conversation_history = ""
|
||||
for segment in conversation.segments[-5:]: # Last 5 segments for context
|
||||
role = "User" if segment.speaker == 0 else "Assistant"
|
||||
conversation_history += f"{role}: {segment.text}\n"
|
||||
|
||||
# Add final prompt
|
||||
prompt = f"{conversation_history}Assistant: "
|
||||
|
||||
# Generate response with Llama
|
||||
try:
|
||||
# Ensure pad token is set
|
||||
if models.tokenizer.pad_token is None:
|
||||
models.tokenizer.pad_token = models.tokenizer.eos_token
|
||||
|
||||
input_tokens = models.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = input_tokens.input_ids.to(DEVICE)
|
||||
attention_mask = input_tokens.attention_mask.to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = models.llm.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=models.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the response
|
||||
response_text = models.tokenizer.decode(
|
||||
generated_ids[0][input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
).strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
response_text = "I'm sorry, I encountered an error while processing your request."
|
||||
|
||||
# Synthesize speech
|
||||
with app.app_context():
|
||||
socketio.emit('processing_status', {'status': 'synthesizing'}, room=session_id)
|
||||
|
||||
# Start sending the audio response
|
||||
socketio.emit('audio_response_start', {
|
||||
'text': response_text,
|
||||
'total_chunks': 1,
|
||||
'chunk_index': 0
|
||||
}, room=session_id)
|
||||
|
||||
# Define AI speaker ID
|
||||
ai_speaker_id = conversation.ai_speaker_id
|
||||
|
||||
# Generate audio
|
||||
audio_tensor = models.generator.generate(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
context=conversation.get_context(),
|
||||
max_audio_length_ms=10_000,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
# Add AI response to conversation history
|
||||
ai_segment = conversation.add_segment(
|
||||
text=response_text,
|
||||
speaker=ai_speaker_id,
|
||||
audio=audio_tensor
|
||||
)
|
||||
|
||||
# Convert audio to WAV format
|
||||
with io.BytesIO() as wav_io:
|
||||
torchaudio.save(
|
||||
wav_io,
|
||||
audio_tensor.unsqueeze(0).cpu(),
|
||||
models.generator.sample_rate,
|
||||
format="wav"
|
||||
)
|
||||
wav_io.seek(0)
|
||||
wav_data = wav_io.read()
|
||||
|
||||
# Convert WAV data to base64
|
||||
audio_base64 = f"data:audio/wav;base64,{base64.b64encode(wav_data).decode('utf-8')}"
|
||||
|
||||
# Send audio chunk to client
|
||||
with app.app_context():
|
||||
socketio.emit('audio_response_chunk', {
|
||||
'chunk': audio_base64,
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'is_last': True
|
||||
}, room=session_id)
|
||||
|
||||
# Signal completion
|
||||
socketio.emit('audio_response_complete', {
|
||||
'text': response_text
|
||||
}, room=session_id)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
with app.app_context():
|
||||
socketio.emit('error', {'message': f'Error: {str(e)}'}, room=session_id)
|
||||
finally:
|
||||
# Reset processing flag
|
||||
conversation.is_processing = False
|
||||
|
||||
# Error handler
|
||||
@socketio.on_error()
|
||||
def error_handler(e):
|
||||
logger.error(f"SocketIO error: {str(e)}")
|
||||
|
||||
# Periodic cleanup of inactive sessions
|
||||
def cleanup_inactive_sessions():
|
||||
"""Remove sessions that have been inactive for too long"""
|
||||
current_time = time.time()
|
||||
inactive_timeout = 3600 # 1 hour
|
||||
|
||||
for session_id in list(active_conversations.keys()):
|
||||
conversation = active_conversations[session_id]
|
||||
if (current_time - conversation.last_activity > inactive_timeout and
|
||||
not conversation.is_processing):
|
||||
|
||||
logger.info(f"Cleaning up inactive session: {session_id}")
|
||||
|
||||
# Signal processing thread to terminate
|
||||
if session_id in user_queues:
|
||||
user_queues[session_id].put(None)
|
||||
|
||||
# Remove from active conversations
|
||||
del active_conversations[session_id]
|
||||
|
||||
# Start cleanup thread
|
||||
def start_cleanup_thread():
|
||||
while True:
|
||||
try:
|
||||
cleanup_inactive_sessions()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup thread: {str(e)}")
|
||||
time.sleep(300) # Run every 5 minutes
|
||||
|
||||
cleanup_thread = threading.Thread(target=start_cleanup_thread, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
# Start the server
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
|
||||
port = int(os.environ.get('PORT', 5000))
|
||||
debug_mode = os.environ.get('DEBUG', 'False').lower() == 'true'
|
||||
logger.info(f"Starting server on port {port} (debug={debug_mode})")
|
||||
|
||||
print(f"Starting server on port {port} (debug={debug_mode})")
|
||||
print("Visit http://localhost:5000 to access the application")
|
||||
|
||||
socketio.run(app, host='0.0.0.0', port=port, debug=debug_mode, allow_unsafe_werkzeug=True)
|
||||
@@ -1,13 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
|
||||
# Read requirements from requirements.txt
|
||||
with open('requirements.txt') as f:
|
||||
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
||||
|
||||
setup(
|
||||
name='csm',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
12
React/src/app/auth/session/route.ts
Normal file
12
React/src/app/auth/session/route.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { NextResponse } from "next/server";
|
||||
import { auth0 } from "../../../lib/auth0";
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
const session = await auth0.getSession();
|
||||
return NextResponse.json({ session });
|
||||
} catch (error) {
|
||||
console.error("Error getting session:", error);
|
||||
return NextResponse.json({ session: null }, { status: 500 });
|
||||
}
|
||||
}
|
||||
@@ -1,67 +1,94 @@
|
||||
import { useState } from "react";
|
||||
import { auth0 } from "../lib/auth0";
|
||||
|
||||
|
||||
export default async function Home() {
|
||||
"use client";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export default function Home() {
|
||||
const [contacts, setContacts] = useState<string[]>([]);
|
||||
const [codeword, setCodeword] = useState("");
|
||||
const [session, setSession] = useState<any>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const router = useRouter();
|
||||
|
||||
const session = await auth0.getSession();
|
||||
|
||||
console.log("Session:", session?.user);
|
||||
useEffect(() => {
|
||||
// Fetch session data from an API route
|
||||
fetch("/auth/session")
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
setSession(data.session);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Failed to fetch session:", error);
|
||||
setLoading(false);
|
||||
});
|
||||
}, []);
|
||||
|
||||
function saveToDB() {
|
||||
//e.preventDefault();
|
||||
alert("Saving contacts...");
|
||||
// const contactInputs = document.querySelectorAll(".text-input") as NodeListOf<HTMLInputElement>;
|
||||
// const contactValues = Array.from(contactInputs).map(input => input.value);
|
||||
// console.log("Contact values:", contactValues);
|
||||
// // save codeword and contacts to database
|
||||
// fetch("/api/databaseStorage", {
|
||||
// method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// },
|
||||
// body: JSON.stringify({
|
||||
// email: session?.user?.email || "",
|
||||
// codeword: (document.getElementById("codeword") as HTMLInputElement)?.value,
|
||||
// contacts: contactValues,
|
||||
// }),
|
||||
// })
|
||||
// .then((response) => {
|
||||
// if (response.ok) {
|
||||
// alert("Contacts saved successfully!");
|
||||
// } else {
|
||||
// alert("Error saving contacts.");
|
||||
// }
|
||||
// })
|
||||
// .catch((error) => {
|
||||
// console.error("Error:", error);
|
||||
// alert("Error saving contacts.");
|
||||
// });
|
||||
const contactInputs = document.querySelectorAll(
|
||||
".text-input"
|
||||
) as NodeListOf<HTMLInputElement>;
|
||||
const contactValues = Array.from(contactInputs).map((input) => input.value);
|
||||
|
||||
fetch("/api/databaseStorage", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
email: session?.user?.email || "",
|
||||
codeword: codeword,
|
||||
contacts: contactValues,
|
||||
}),
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.ok) {
|
||||
alert("Contacts saved successfully!");
|
||||
} else {
|
||||
alert("Error saving contacts.");
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error("Error:", error);
|
||||
alert("Error saving contacts.");
|
||||
});
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
// If no session, show sign-up and login buttons
|
||||
if (!session) {
|
||||
|
||||
return (
|
||||
<div className="space-y-7 bg-indigo-800 items-center justify-items-center min-h-screen p-8 pb-20 gap-16 sm:p-20 font-[family-name:var(--font-geist-sans)]">
|
||||
<main className="space-x-2 flex flex-row gap-[32px] row-start-2 items-center sm:items-start">
|
||||
<a href="/auth/login?screen_hint=signup">
|
||||
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-300">Sign up</button>
|
||||
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-300">
|
||||
Sign up
|
||||
</button>
|
||||
</a>
|
||||
<a href="/auth/login">
|
||||
<button className = "box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-400">Log in</button>
|
||||
<button className="box-content w-32 border-2 h-16 text-2xl bg-violet-900 text-green-400">
|
||||
Log in
|
||||
</button>
|
||||
</a>
|
||||
</main>
|
||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">Fauxcall</h1>
|
||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">Set emergency contacts</h2>
|
||||
<p>If you stop speaking or say the codeword, these contacts will be notified</p>
|
||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">
|
||||
Fauxcall
|
||||
</h1>
|
||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">
|
||||
Set emergency contacts
|
||||
</h2>
|
||||
<p>
|
||||
If you stop speaking or say the codeword, these contacts will be
|
||||
notified
|
||||
</p>
|
||||
{/* form for setting codeword */}
|
||||
<form className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
||||
<form
|
||||
className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||
onSubmit={(e) => e.preventDefault()}
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={codeword}
|
||||
@@ -71,10 +98,16 @@ export default async function Home() {
|
||||
/>
|
||||
<button
|
||||
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
||||
type="submit">Set codeword</button>
|
||||
type="submit"
|
||||
>
|
||||
Set codeword
|
||||
</button>
|
||||
</form>
|
||||
{/* form for adding contacts */}
|
||||
<form className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
||||
<form
|
||||
className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||
onSubmit={(e) => e.preventDefault()}
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={contacts}
|
||||
@@ -97,7 +130,12 @@ export default async function Home() {
|
||||
className="border border-gray-300 rounded-md p-2"
|
||||
/>
|
||||
<button type="button">Add</button>
|
||||
<button className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2" type="submit">Set contacts</button>
|
||||
<button
|
||||
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2"
|
||||
type="submit"
|
||||
>
|
||||
Set contacts
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
@@ -108,11 +146,21 @@ export default async function Home() {
|
||||
<main className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start">
|
||||
<h1>Welcome, {session.user.name}!</h1>
|
||||
|
||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">Fauxcall</h1>
|
||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">Set emergency contacts</h2>
|
||||
<p>If you stop speaking or say the codeword, these contacts will be notified</p>
|
||||
<h1 className="space-y-3 text-6xl text-lime-500 subpixel-antialiased font-stretch-semi-expanded font-serif">
|
||||
Fauxcall
|
||||
</h1>
|
||||
<h2 className="space-y-3 text-6x1 text-red-700 antialiased font-mono">
|
||||
Set emergency contacts
|
||||
</h2>
|
||||
<p>
|
||||
If you stop speaking or say the codeword, these contacts will be
|
||||
notified
|
||||
</p>
|
||||
{/* form for setting codeword */}
|
||||
<form className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
||||
<form
|
||||
className="flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||
onSubmit={(e) => e.preventDefault()}
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={codeword}
|
||||
@@ -122,10 +170,17 @@ export default async function Home() {
|
||||
/>
|
||||
<button
|
||||
className="bg-blue-500 text-white font-semibold font-lg rounded-md p-2"
|
||||
type="submit">Set codeword</button>
|
||||
type="submit"
|
||||
>
|
||||
Set codeword
|
||||
</button>
|
||||
</form>
|
||||
{/* form for adding contacts */}
|
||||
<form id="Contacts" className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start" onSubmit={(e) => e.preventDefault()}>
|
||||
<form
|
||||
id="Contacts"
|
||||
className="space-y-5 flex flex-col gap-[32px] row-start-2 items-center sm:items-start"
|
||||
onSubmit={(e) => e.preventDefault()}
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={contacts}
|
||||
@@ -154,24 +209,35 @@ export default async function Home() {
|
||||
placeholder="Write down an emergency contact"
|
||||
className="text-input border border-gray-300 rounded-md p-2"
|
||||
/>
|
||||
<button onClick={() => {
|
||||
<button
|
||||
onClick={() => {
|
||||
alert("Adding contact...");
|
||||
let elem = document.getElementsByClassName("text-input")[0] as HTMLElement;
|
||||
let elem = document.getElementsByClassName(
|
||||
"text-input"
|
||||
)[0] as HTMLElement;
|
||||
console.log("Element:", elem);
|
||||
let d = elem.cloneNode(true) as HTMLElement;
|
||||
document.getElementById("Contacts")?.appendChild(d);
|
||||
}}
|
||||
className="bg-emerald-500 text-fuchsia-300"
|
||||
type="button">Add</button>
|
||||
type="button"
|
||||
>
|
||||
Add
|
||||
</button>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={saveToDB}
|
||||
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2">Save</button>
|
||||
className="bg-slate-500 text-yellow-300 text-stretch-50% font-lg rounded-md p-2"
|
||||
>
|
||||
Save
|
||||
</button>
|
||||
</form>
|
||||
<div>
|
||||
<a href="/call">
|
||||
<button className="bg-zinc-700 text-lime-300 font-semibold font-lg rounded-md p-2">Call</button>
|
||||
<button className="bg-zinc-700 text-lime-300 font-semibold font-lg rounded-md p-2">
|
||||
Call
|
||||
</button>
|
||||
</a>
|
||||
</div>
|
||||
<p>
|
||||
@@ -182,6 +248,4 @@ export default async function Home() {
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user