28 lines
921 B
Python
28 lines
921 B
Python
from dataclasses import dataclass
|
|
import torch
|
|
|
|
@dataclass
|
|
class AudioModel:
|
|
model: torch.nn.Module
|
|
sample_rate: int
|
|
|
|
def __post_init__(self):
|
|
self.model.eval()
|
|
|
|
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
|
|
with torch.no_grad():
|
|
processed_audio = self.model(audio_tensor)
|
|
return processed_audio
|
|
|
|
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
|
|
if self.sample_rate != target_sample_rate:
|
|
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
|
|
return resampled_audio
|
|
return audio_tensor
|
|
|
|
def save_model(self, path: str):
|
|
torch.save(self.model.state_dict(), path)
|
|
|
|
def load_model(self, path: str):
|
|
self.model.load_state_dict(torch.load(path))
|
|
self.model.eval() |