mirror of
https://github.com/SirBlobby/Hoya26.git
synced 2026-02-03 19:24:34 -05:00
187 lines
6.1 KiB
Python
187 lines
6.1 KiB
Python
import cv2
|
|
import numpy as np
|
|
import os
|
|
from typing import List, Dict, Optional
|
|
|
|
from ..config import (
|
|
ULTRALYTICS_AVAILABLE,
|
|
YOLO26_MODELS,
|
|
COLORS,
|
|
DEFAULT_CONF_THRESHOLD,
|
|
DEFAULT_IOU_THRESHOLD,
|
|
)
|
|
|
|
if ULTRALYTICS_AVAILABLE:
|
|
from ultralytics import YOLO
|
|
|
|
class YOLO26Detector:
|
|
def __init__(self,
|
|
model_size: str = "nano",
|
|
model_path: Optional[str] = None,
|
|
conf_threshold: float = DEFAULT_CONF_THRESHOLD,
|
|
iou_threshold: float = DEFAULT_IOU_THRESHOLD,
|
|
device: str = "auto"):
|
|
self.conf_threshold = conf_threshold
|
|
self.iou_threshold = iou_threshold
|
|
self.device = device
|
|
self.model = None
|
|
|
|
if not ULTRALYTICS_AVAILABLE:
|
|
raise RuntimeError("Ultralytics not installed. Run: pip install ultralytics")
|
|
|
|
if model_path and os.path.exists(model_path):
|
|
model_name = model_path
|
|
elif model_size in YOLO26_MODELS:
|
|
model_name = YOLO26_MODELS[model_size]
|
|
else:
|
|
print(f"Unknown model size '{model_size}', defaulting to 'nano'")
|
|
model_name = YOLO26_MODELS["nano"]
|
|
|
|
print(f"Loading YOLO26 model: {model_name}")
|
|
self.model = YOLO(model_name)
|
|
print(f"YOLO26 model loaded successfully!")
|
|
print(f"Classes: {len(self.model.names)} | Device: {device}")
|
|
|
|
def detect(self,
|
|
frame: np.ndarray,
|
|
conf_threshold: Optional[float] = None,
|
|
classes: Optional[List[int]] = None) -> List[Dict]:
|
|
if self.model is None:
|
|
return []
|
|
|
|
conf = conf_threshold if conf_threshold is not None else self.conf_threshold
|
|
|
|
results = self.model(
|
|
frame,
|
|
conf=conf,
|
|
iou=self.iou_threshold,
|
|
device=self.device if self.device != "auto" else None,
|
|
classes=classes,
|
|
verbose=False
|
|
)
|
|
|
|
detections = []
|
|
for result in results:
|
|
boxes = result.boxes
|
|
if boxes is None:
|
|
continue
|
|
|
|
for i in range(len(boxes)):
|
|
xyxy = boxes.xyxy[i].cpu().numpy()
|
|
x1, y1, x2, y2 = map(int, xyxy)
|
|
|
|
conf_val = float(boxes.conf[i].cpu().numpy())
|
|
class_id = int(boxes.cls[i].cpu().numpy())
|
|
label = self.model.names[class_id]
|
|
|
|
detections.append({
|
|
"bbox": (x1, y1, x2, y2),
|
|
"label": label,
|
|
"confidence": conf_val,
|
|
"class_id": class_id
|
|
})
|
|
|
|
return detections
|
|
|
|
def detect_and_track(self,
|
|
frame: np.ndarray,
|
|
conf_threshold: Optional[float] = None,
|
|
tracker: str = "bytetrack.yaml") -> List[Dict]:
|
|
if self.model is None:
|
|
return []
|
|
|
|
conf = conf_threshold if conf_threshold is not None else self.conf_threshold
|
|
|
|
results = self.model.track(
|
|
frame,
|
|
conf=conf,
|
|
iou=self.iou_threshold,
|
|
device=self.device if self.device != "auto" else None,
|
|
tracker=tracker,
|
|
persist=True,
|
|
verbose=False
|
|
)
|
|
|
|
detections = []
|
|
for result in results:
|
|
boxes = result.boxes
|
|
if boxes is None:
|
|
continue
|
|
|
|
for i in range(len(boxes)):
|
|
xyxy = boxes.xyxy[i].cpu().numpy()
|
|
x1, y1, x2, y2 = map(int, xyxy)
|
|
|
|
conf_val = float(boxes.conf[i].cpu().numpy())
|
|
class_id = int(boxes.cls[i].cpu().numpy())
|
|
label = self.model.names[class_id]
|
|
|
|
track_id = None
|
|
if boxes.id is not None:
|
|
track_id = int(boxes.id[i].cpu().numpy())
|
|
|
|
detections.append({
|
|
"bbox": (x1, y1, x2, y2),
|
|
"label": label,
|
|
"confidence": conf_val,
|
|
"class_id": class_id,
|
|
"track_id": track_id
|
|
})
|
|
|
|
return detections
|
|
|
|
def draw_detections(self,
|
|
frame: np.ndarray,
|
|
detections: List[Dict],
|
|
show_labels: bool = True,
|
|
show_conf: bool = True) -> np.ndarray:
|
|
result = frame.copy()
|
|
|
|
for det in detections:
|
|
x1, y1, x2, y2 = det["bbox"]
|
|
label = det["label"]
|
|
conf = det["confidence"]
|
|
track_id = det.get("track_id")
|
|
|
|
if conf > 0.7:
|
|
color = COLORS["high_conf"]
|
|
elif conf > 0.5:
|
|
color = COLORS["medium_conf"]
|
|
else:
|
|
color = COLORS["low_conf"]
|
|
|
|
cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
|
|
|
|
if show_labels:
|
|
label_parts = [label]
|
|
if track_id is not None:
|
|
label_parts.append(f"ID:{track_id}")
|
|
if show_conf:
|
|
label_parts.append(f"{conf:.2f}")
|
|
label_text = " | ".join(label_parts)
|
|
|
|
(text_w, text_h), baseline = cv2.getTextSize(
|
|
label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
)
|
|
cv2.rectangle(
|
|
result,
|
|
(x1, y1 - text_h - 8),
|
|
(x1 + text_w + 4, y1),
|
|
color,
|
|
-1
|
|
)
|
|
cv2.putText(
|
|
result,
|
|
label_text,
|
|
(x1 + 2, y1 - 4),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(0, 0, 0),
|
|
1
|
|
)
|
|
|
|
return result
|
|
|
|
def get_class_names(self) -> Dict[int, str]:
|
|
return self.model.names if self.model else {}
|