Files
HooHacks-12/Backend/src/models/audio_model.py
2025-03-30 01:28:07 -04:00

28 lines
921 B
Python

from dataclasses import dataclass
import torch
@dataclass
class AudioModel:
model: torch.nn.Module
sample_rate: int
def __post_init__(self):
self.model.eval()
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
processed_audio = self.model(audio_tensor)
return processed_audio
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
if self.sample_rate != target_sample_rate:
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
return resampled_audio
return audio_tensor
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
def load_model(self, path: str):
self.model.load_state_dict(torch.load(path))
self.model.eval()