Demo Update 11
This commit is contained in:
@@ -1,23 +1,51 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
|
speaker: int
|
||||||
|
text: str
|
||||||
|
# (num_samples,), sample_rate = 24_000
|
||||||
|
audio: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Ensure audio is a tensor if provided
|
||||||
|
if self.audio is not None and not isinstance(self.audio, torch.Tensor):
|
||||||
|
self.audio = torch.tensor(self.audio, dtype=torch.float32)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
context: List[str] = field(default_factory=list)
|
context: List[str] = field(default_factory=list)
|
||||||
|
segments: List[Segment] = field(default_factory=list)
|
||||||
current_speaker: Optional[int] = None
|
current_speaker: Optional[int] = None
|
||||||
|
|
||||||
def add_message(self, message: str, speaker: int):
|
def add_message(self, message: str, speaker: int):
|
||||||
self.context.append(f"Speaker {speaker}: {message}")
|
self.context.append(f"Speaker {speaker}: {message}")
|
||||||
self.current_speaker = speaker
|
self.current_speaker = speaker
|
||||||
|
|
||||||
|
def add_segment(self, segment: Segment):
|
||||||
|
self.segments.append(segment)
|
||||||
|
self.context.append(f"Speaker {segment.speaker}: {segment.text}")
|
||||||
|
self.current_speaker = segment.speaker
|
||||||
|
|
||||||
def get_context(self) -> List[str]:
|
def get_context(self) -> List[str]:
|
||||||
return self.context
|
return self.context
|
||||||
|
|
||||||
|
def get_segments(self) -> List[Segment]:
|
||||||
|
return self.segments
|
||||||
|
|
||||||
def clear_context(self):
|
def clear_context(self):
|
||||||
self.context.clear()
|
self.context.clear()
|
||||||
|
self.segments.clear()
|
||||||
self.current_speaker = None
|
self.current_speaker = None
|
||||||
|
|
||||||
def get_last_message(self) -> Optional[str]:
|
def get_last_message(self) -> Optional[str]:
|
||||||
if self.context:
|
if self.context:
|
||||||
return self.context[-1]
|
return self.context[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_segment(self) -> Optional[Segment]:
|
||||||
|
if self.segments:
|
||||||
|
return self.segments[-1]
|
||||||
return None
|
return None
|
||||||
Reference in New Issue
Block a user