Complete Refactor

This commit is contained in:
2025-03-30 01:28:07 -04:00
parent 158afc78c7
commit 46be33b10a
28 changed files with 723 additions and 3081 deletions

View File

@@ -0,0 +1,28 @@
from dataclasses import dataclass
import torch
@dataclass
class AudioModel:
model: torch.nn.Module
sample_rate: int
def __post_init__(self):
self.model.eval()
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
processed_audio = self.model(audio_tensor)
return processed_audio
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
if self.sample_rate != target_sample_rate:
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
return resampled_audio
return audio_tensor
def save_model(self, path: str):
torch.save(self.model.state_dict(), path)
def load_model(self, path: str):
self.model.load_state_dict(torch.load(path))
self.model.eval()