diff --git a/Backend/server.py b/Backend/server.py index 92d0f1f..e8ed1ae 100644 --- a/Backend/server.py +++ b/Backend/server.py @@ -52,24 +52,31 @@ manager = ConnectionManager() # Helper function to convert audio data async def decode_audio_data(audio_data: str) -> torch.Tensor: """Decode base64 audio data to a torch tensor""" - # Decode base64 audio data - binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data) - - # Load audio from binary data - buf = BytesIO(binary_data) - audio_tensor, sample_rate = torchaudio.load(buf) - - # Resample if needed - if sample_rate != generator.sample_rate: - audio_tensor = torchaudio.functional.resample( - audio_tensor.squeeze(0), - orig_freq=sample_rate, - new_freq=generator.sample_rate - ) - else: - audio_tensor = audio_tensor.squeeze(0) + try: + # Decode base64 audio data + binary_data = base64.b64decode(audio_data.split(',')[1] if ',' in audio_data else audio_data) - return audio_tensor + # Save to a temporary WAV file first + temp_file = BytesIO(binary_data) + + # Load audio from binary data, explicitly specifying the format + audio_tensor, sample_rate = torchaudio.load(temp_file, format="wav") + + # Resample if needed + if sample_rate != generator.sample_rate: + audio_tensor = torchaudio.functional.resample( + audio_tensor.squeeze(0), + orig_freq=sample_rate, + new_freq=generator.sample_rate + ) + else: + audio_tensor = audio_tensor.squeeze(0) + + return audio_tensor + except Exception as e: + print(f"Error decoding audio: {str(e)}") + # Return a small silent audio segment as fallback + return torch.zeros(generator.sample_rate // 2) # 0.5 seconds of silence async def encode_audio_data(audio_tensor: torch.Tensor) -> str: @@ -95,43 +102,57 @@ async def websocket_endpoint(websocket: WebSocket): action = request.get("action") if action == "generate": - text = request.get("text", "") - speaker_id = request.get("speaker", 0) - - # Generate audio response - print(f"Generating audio for: '{text}' with speaker {speaker_id}") - audio_tensor = generator.generate( - text=text, - speaker=speaker_id, - context=context_segments, - max_audio_length_ms=10_000, - ) - - # Add to conversation context - context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) - - # Convert audio to base64 and send back to client - audio_base64 = await encode_audio_data(audio_tensor) - await websocket.send_json({ - "type": "audio_response", - "audio": audio_base64 - }) + try: + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + + # Generate audio response + print(f"Generating audio for: '{text}' with speaker {speaker_id}") + audio_tensor = generator.generate( + text=text, + speaker=speaker_id, + context=context_segments, + max_audio_length_ms=10_000, + ) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + # Convert audio to base64 and send back to client + audio_base64 = await encode_audio_data(audio_tensor) + await websocket.send_json({ + "type": "audio_response", + "audio": audio_base64 + }) + except Exception as e: + print(f"Error generating audio: {str(e)}") + await websocket.send_json({ + "type": "error", + "message": f"Error generating audio: {str(e)}" + }) elif action == "add_to_context": - text = request.get("text", "") - speaker_id = request.get("speaker", 0) - audio_data = request.get("audio", "") - - # Convert received audio to tensor - audio_tensor = await decode_audio_data(audio_data) - - # Add to conversation context - context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) - - await websocket.send_json({ - "type": "context_updated", - "message": "Audio added to context" - }) + try: + text = request.get("text", "") + speaker_id = request.get("speaker", 0) + audio_data = request.get("audio", "") + + # Convert received audio to tensor + audio_tensor = await decode_audio_data(audio_data) + + # Add to conversation context + context_segments.append(Segment(text=text, speaker=speaker_id, audio=audio_tensor)) + + await websocket.send_json({ + "type": "context_updated", + "message": "Audio added to context" + }) + except Exception as e: + print(f"Error adding to context: {str(e)}") + await websocket.send_json({ + "type": "error", + "message": f"Error processing audio: {str(e)}" + }) elif action == "clear_context": context_segments = []