import os import torch import torchaudio from huggingface_hub import hf_hub_download from generator import load_csm_1b, Segment from dataclasses import dataclass if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" generator = load_csm_1b(device=device) speakers = [0, 1, 0, 0] transcripts = [ "Hey how are you doing.", "Pretty good, pretty good.", "I'm great.", "So happy to be speaking to you.", ] audio_paths = [ "utterance_0.wav", "utterance_1.wav", "utterance_2.wav", "utterance_3.wav", ] def load_audio(audio_path): audio_tensor, sample_rate = torchaudio.load(audio_path) audio_tensor = torchaudio.functional.resample( audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate ) return audio_tensor segments = [ Segment(text=transcript, speaker=speaker, audio=load_audio(audio_path)) for transcript, speaker, audio_path in zip(transcripts, speakers, audio_paths) ] audio = generator.generate( text="Me too, this is some cool stuff huh?", speaker=1, context=segments, max_audio_length_ms=10_000, ) torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate)