Demo Update 11

This commit is contained in:
2025-03-30 01:30:14 -04:00
parent f7a8ebf770
commit a0ee0685dc

View File

@@ -1,23 +1,51 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List, Optional
import torch
@dataclass
class Segment:
speaker: int
text: str
# (num_samples,), sample_rate = 24_000
audio: Optional[torch.Tensor] = None
def __post_init__(self):
# Ensure audio is a tensor if provided
if self.audio is not None and not isinstance(self.audio, torch.Tensor):
self.audio = torch.tensor(self.audio, dtype=torch.float32)
@dataclass @dataclass
class Conversation: class Conversation:
context: List[str] = field(default_factory=list) context: List[str] = field(default_factory=list)
segments: List[Segment] = field(default_factory=list)
current_speaker: Optional[int] = None current_speaker: Optional[int] = None
def add_message(self, message: str, speaker: int): def add_message(self, message: str, speaker: int):
self.context.append(f"Speaker {speaker}: {message}") self.context.append(f"Speaker {speaker}: {message}")
self.current_speaker = speaker self.current_speaker = speaker
def add_segment(self, segment: Segment):
self.segments.append(segment)
self.context.append(f"Speaker {segment.speaker}: {segment.text}")
self.current_speaker = segment.speaker
def get_context(self) -> List[str]: def get_context(self) -> List[str]:
return self.context return self.context
def get_segments(self) -> List[Segment]:
return self.segments
def clear_context(self): def clear_context(self):
self.context.clear() self.context.clear()
self.segments.clear()
self.current_speaker = None self.current_speaker = None
def get_last_message(self) -> Optional[str]: def get_last_message(self) -> Optional[str]:
if self.context: if self.context:
return self.context[-1] return self.context[-1]
return None
def get_last_segment(self) -> Optional[Segment]:
if self.segments:
return self.segments[-1]
return None return None