Backend Server Code
This commit is contained in:
50
Backend/test.py
Normal file
50
Backend/test.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from generator import load_csm_1b, Segment
|
||||
from dataclasses import dataclass
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
generator = load_csm_1b(device=device)
|
||||
|
||||
speakers = [0, 1, 0, 0]
|
||||
transcripts = [
|
||||
"Hey how are you doing.",
|
||||
"Pretty good, pretty good.",
|
||||
"I'm great.",
|
||||
"So happy to be speaking to you.",
|
||||
]
|
||||
audio_paths = [
|
||||
"utterance_0.wav",
|
||||
"utterance_1.wav",
|
||||
"utterance_2.wav",
|
||||
"utterance_3.wav",
|
||||
]
|
||||
|
||||
def load_audio(audio_path):
|
||||
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
||||
audio_tensor = torchaudio.functional.resample(
|
||||
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
||||
)
|
||||
return audio_tensor
|
||||
|
||||
segments = [
|
||||
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
||||
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
||||
]
|
||||
|
||||
audio = generator.generate(
|
||||
text="Me too, this is some cool stuff huh?",
|
||||
speaker=1,
|
||||
context=segments,
|
||||
max_audio_length_ms=10_000,
|
||||
)
|
||||
|
||||
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)
|
||||
Reference in New Issue
Block a user