Backend Server Code

This commit is contained in:
2025-03-29 20:36:18 -04:00
parent 15a04c9d54
commit 5da627097d
11 changed files with 1307 additions and 0 deletions

156
Backend/server.py Normal file
View File

@@ -0,0 +1,156 @@
import os
import base64
import json
import asyncio
import torch
import torchaudio
import numpy as np
from io import BytesIO
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from generator import load_csm_1b, Segment
import uvicorn
# Select device
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device: {device}")
# Initialize the model
generator = load_csm_1b(device=device)
app = FastAPI()
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins in development
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Connection manager to handle multiple clients
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
manager = ConnectionManager()
# Helper function to convert audio data
async def decode_audio_data(audio_data: str) -> torch.Tensor:
"""Decode base64 audio data to a torch tensor"""
# Decode base64 audio data
binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data)
# Load audio from binary data
buf = BytesIO(binary_data)
audio_tensor, sample_rate = torchaudio.load(buf)
# Resample if needed
if sample_rate != generator.sample_rate:
audio_tensor = torchaudio.functional.resample(
audio_tensor.squeeze(0),
orig_freq=sample_rate,
new_freq=generator.sample_rate
)
else:
audio_tensor = audio_tensor.squeeze(0)
return audio_tensor
async def encode_audio_data(audio_tensor: torch.Tensor) -> str:
"""Encode torch tensor audio to base64 string"""
buf = BytesIO()
torchaudio.save(buf, audio_tensor.unsqueeze(0).cpu(), generator.sample_rate, format="wav")
buf.seek(0)
audio_base64 = base64.b64encode(buf.read()).decode('utf-8')
return f"data:audio/wav;base64,{audio_base64}"
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
context_segments = [] # Store conversation context
try:
while True:
# Receive JSON data from client
data = await websocket.receive_text()
request = json.loads(data)
action = request.get("action")
if action == "generate":
text = request.get("text", "")
speaker_id = request.get("speaker", 0)
# Generate audio response
print(f"Generating audio for: '{text}' with speaker {speaker_id}")
audio_tensor = generator.generate(
text=text,
speaker=speaker_id,
context=context_segments,
max_audio_length_ms=10_000,
)
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
# Convert audio to base64 and send back to client
audio_base64 = await encode_audio_data(audio_tensor)
await websocket.send_json({
"type": "audio_response",
"audio": audio_base64
})
elif action == "add_to_context":
text = request.get("text", "")
speaker_id = request.get("speaker", 0)
audio_data = request.get("audio", "")
# Convert received audio to tensor
audio_tensor = await decode_audio_data(audio_data)
# Add to conversation context
context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor))
await websocket.send_json({
"type": "context_updated",
"message": "Audio added to context"
})
elif action == "clear_context":
context_segments = []
await websocket.send_json({
"type": "context_updated",
"message": "Context cleared"
})
except WebSocketDisconnect:
manager.disconnect(websocket)
print("Client disconnected")
except Exception as e:
print(f"Error: {str(e)}")
await websocket.send_json({
"type": "error",
"message": str(e)
})
manager.disconnect(websocket)
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)