50 lines
1.2 KiB
Python
50 lines
1.2 KiB
Python
import os
|
|
import torch
|
|
import torchaudio
|
|
from huggingface_hub import hf_hub_download
|
|
from generator import load_csm_1b, Segment
|
|
from dataclasses import dataclass
|
|
|
|
if torch.backends.mps.is_available():
|
|
device = "mps"
|
|
elif torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
|
|
generator = load_csm_1b(device=device)
|
|
|
|
speakers = [0, 1, 0, 0]
|
|
transcripts = [
|
|
"Hey how are you doing.",
|
|
"Pretty good, pretty good.",
|
|
"I'm great.",
|
|
"So happy to be speaking to you.",
|
|
]
|
|
audio_paths = [
|
|
"utterance_0.wav",
|
|
"utterance_1.wav",
|
|
"utterance_2.wav",
|
|
"utterance_3.wav",
|
|
]
|
|
|
|
def load_audio(audio_path):
|
|
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
|
audio_tensor = torchaudio.functional.resample(
|
|
audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate
|
|
)
|
|
return audio_tensor
|
|
|
|
segments = [
|
|
Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path))
|
|
for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths)
|
|
]
|
|
|
|
audio = generator.generate(
|
|
text="Me too, this is some cool stuff huh?",
|
|
speaker=1,
|
|
context=segments,
|
|
max_audio_length_ms=10_000,
|
|
)
|
|
|
|
torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) |