Complete Refactor
This commit is contained in:
28
Backend/src/models/audio_model.py
Normal file
28
Backend/src/models/audio_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class AudioModel:
|
||||
model: torch.nn.Module
|
||||
sample_rate: int
|
||||
|
||||
def __post_init__(self):
|
||||
self.model.eval()
|
||||
|
||||
def process_audio(self, audio_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
processed_audio = self.model(audio_tensor)
|
||||
return processed_audio
|
||||
|
||||
def resample_audio(self, audio_tensor: torch.Tensor, target_sample_rate: int) -> torch.Tensor:
|
||||
if self.sample_rate != target_sample_rate:
|
||||
resampled_audio = torchaudio.functional.resample(audio_tensor, orig_freq=self.sample_rate, new_freq=target_sample_rate)
|
||||
return resampled_audio
|
||||
return audio_tensor
|
||||
|
||||
def save_model(self, path: str):
|
||||
torch.save(self.model.state_dict(), path)
|
||||
|
||||
def load_model(self, path: str):
|
||||
self.model.load_state_dict(torch.load(path))
|
||||
self.model.eval()
|
||||
Reference in New Issue
Block a user