mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-04 06:33:45 +00:00
ran ruff
This commit is contained in:
parent
b2d83d0a9c
commit
515597345a
@ -18,13 +18,18 @@ except ModuleNotFoundError:
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum, InputTensorEnum
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
ModelTypeEnum,
|
||||
InputTensorEnum,
|
||||
)
|
||||
from frigate.util.model import post_process_yolo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTOR_KEY = "memryx"
|
||||
|
||||
|
||||
# Configuration class for model settings
|
||||
class ModelConfig(BaseModel):
|
||||
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
||||
@ -37,7 +42,7 @@ class MemryXDetectorConfig(BaseDetectorConfig):
|
||||
|
||||
|
||||
class MemryXDetector(DetectionApi):
|
||||
type_key = DETECTOR_KEY # Set the type key
|
||||
type_key = DETECTOR_KEY # Set the type key
|
||||
supported_models = [
|
||||
ModelTypeEnum.ssd,
|
||||
ModelTypeEnum.yolonas,
|
||||
@ -54,7 +59,9 @@ class MemryXDetector(DetectionApi):
|
||||
if "model_type" in getattr(model_cfg, "__fields_set__", set()):
|
||||
detector_config.model.model_type = model_cfg.model_type
|
||||
else:
|
||||
logger.info("model_type not set in config — defaulting to yolonas for MemryX.")
|
||||
logger.info(
|
||||
"model_type not set in config — defaulting to yolonas for MemryX."
|
||||
)
|
||||
detector_config.model.model_type = ModelTypeEnum.yolonas
|
||||
|
||||
self.capture_queue = Queue(maxsize=10)
|
||||
@ -65,13 +72,13 @@ class MemryXDetector(DetectionApi):
|
||||
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
||||
self.memx_post_model = None # Path to .post file
|
||||
self.expected_post_model = None
|
||||
|
||||
|
||||
self.memx_device_path = detector_config.device # Device path
|
||||
# Parse the device string to split PCIe:<index>
|
||||
device_str = self.memx_device_path
|
||||
self.device_id = []
|
||||
self.device_id.append(int(device_str.split(":")[1]))
|
||||
|
||||
|
||||
self.memx_model_height = detector_config.model.height
|
||||
self.memx_model_width = detector_config.model.width
|
||||
self.memx_model_type = detector_config.model.model_type
|
||||
@ -80,41 +87,51 @@ class MemryXDetector(DetectionApi):
|
||||
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
model_mapping = {
|
||||
(640, 640): ("https://developer.memryx.com/example_files/2p0_frigate/yolov9_640.zip", "yolov9_640"),
|
||||
(320, 320): ("https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip", "yolov9_320")
|
||||
(640, 640): (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_640.zip",
|
||||
"yolov9_640",
|
||||
),
|
||||
(320, 320): (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip",
|
||||
"yolov9_320",
|
||||
),
|
||||
}
|
||||
self.model_url, self.model_folder = model_mapping.get(
|
||||
(self.memx_model_height, self.memx_model_width),
|
||||
("https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip", "yolov9_320")
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"YOLO_v9_small_onnx.dfp"
|
||||
(
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip",
|
||||
"yolov9_320",
|
||||
),
|
||||
)
|
||||
self.expected_dfp_model = "YOLO_v9_small_onnx.dfp"
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||
model_mapping = {
|
||||
(640, 640): ("https://developer.memryx.com/example_files/2p0_frigate/yolonas_640.zip", "yolonas_640"),
|
||||
(320, 320): ("https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip", "yolonas_320")
|
||||
(640, 640): (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_640.zip",
|
||||
"yolonas_640",
|
||||
),
|
||||
(320, 320): (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip",
|
||||
"yolonas_320",
|
||||
),
|
||||
}
|
||||
self.model_url, self.model_folder = model_mapping.get(
|
||||
(self.memx_model_height, self.memx_model_width),
|
||||
("https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip", "yolonas_320")
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"yolo_nas_s.dfp"
|
||||
)
|
||||
self.expected_post_model = (
|
||||
"yolo_nas_s_post.onnx"
|
||||
(
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip",
|
||||
"yolonas_320",
|
||||
),
|
||||
)
|
||||
self.expected_dfp_model = "yolo_nas_s.dfp"
|
||||
self.expected_post_model = "yolo_nas_s_post.onnx"
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||
self.model_folder = "yolox"
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/yolox.zip"
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"YOLOX_640_640_3_onnx.dfp"
|
||||
)
|
||||
self.expected_dfp_model = "YOLOX_640_640_3_onnx.dfp"
|
||||
self.set_strides_grids()
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||
@ -122,12 +139,8 @@ class MemryXDetector(DetectionApi):
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/2p0_frigate/ssd.zip"
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
|
||||
)
|
||||
self.expected_post_model = (
|
||||
"SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||
)
|
||||
self.expected_dfp_model = "SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
|
||||
self.expected_post_model = "SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||
|
||||
self.check_and_prepare_model()
|
||||
logger.info(
|
||||
@ -143,9 +156,9 @@ class MemryXDetector(DetectionApi):
|
||||
self.accl = AsyncAccl(
|
||||
self.memx_model_path,
|
||||
device_ids=self.device_id, # AsyncAccl device ids
|
||||
local_mode=True
|
||||
local_mode=True,
|
||||
)
|
||||
|
||||
|
||||
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
||||
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
||||
if self.memx_post_model:
|
||||
@ -161,19 +174,13 @@ class MemryXDetector(DetectionApi):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MemryX model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_yolo_constants(self):
|
||||
base = f"{self.cache_dir}/{self.model_folder}"
|
||||
# constants for yolov9 post-processing
|
||||
self.const_A = np.load(
|
||||
f"{base}/_model_22_Constant_9_output_0.npy"
|
||||
)
|
||||
self.const_B = np.load(
|
||||
f"{base}/_model_22_Constant_10_output_0.npy"
|
||||
)
|
||||
self.const_C = np.load(
|
||||
f"{base}/_model_22_Constant_12_output_0.npy"
|
||||
)
|
||||
self.const_A = np.load(f"{base}/_model_22_Constant_9_output_0.npy")
|
||||
self.const_B = np.load(f"{base}/_model_22_Constant_10_output_0.npy")
|
||||
self.const_C = np.load(f"{base}/_model_22_Constant_12_output_0.npy")
|
||||
|
||||
def check_and_prepare_model(self):
|
||||
"""Check if models exist; if not, download and extract them."""
|
||||
@ -182,7 +189,11 @@ class MemryXDetector(DetectionApi):
|
||||
|
||||
model_subdir = os.path.join(self.cache_dir, self.model_folder)
|
||||
dfp_path = os.path.join(model_subdir, self.expected_dfp_model)
|
||||
post_path = os.path.join(model_subdir, self.expected_post_model) if self.expected_post_model else None
|
||||
post_path = (
|
||||
os.path.join(model_subdir, self.expected_post_model)
|
||||
if self.expected_post_model
|
||||
else None
|
||||
)
|
||||
|
||||
dfp_exists = os.path.exists(dfp_path)
|
||||
post_exists = os.path.exists(post_path) if post_path else True
|
||||
@ -210,7 +221,11 @@ class MemryXDetector(DetectionApi):
|
||||
|
||||
# Re-assign model paths after extraction
|
||||
self.memx_model_path = os.path.join(model_subdir, self.expected_dfp_model)
|
||||
self.memx_post_model = os.path.join(model_subdir, self.expected_post_model) if self.expected_post_model else None
|
||||
self.memx_post_model = (
|
||||
os.path.join(model_subdir, self.expected_post_model)
|
||||
if self.expected_post_model
|
||||
else None
|
||||
)
|
||||
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
self.load_yolo_constants()
|
||||
@ -232,7 +247,9 @@ class MemryXDetector(DetectionApi):
|
||||
if self.memx_model_type == ModelTypeEnum.yolonas:
|
||||
if tensor_input.ndim == 4 and tensor_input.shape[1:] == (320, 320, 3):
|
||||
logger.debug("Transposing tensor from NHWC to NCHW for YOLO-NAS")
|
||||
tensor_input = np.transpose(tensor_input, (0, 3, 1, 2)) # (1, H, W, C) → (1, C, H, W)
|
||||
tensor_input = np.transpose(
|
||||
tensor_input, (0, 3, 1, 2)
|
||||
) # (1, H, W, C) → (1, C, H, W)
|
||||
tensor_input = tensor_input.astype(np.float32)
|
||||
tensor_input /= 255
|
||||
|
||||
@ -390,7 +407,6 @@ class MemryXDetector(DetectionApi):
|
||||
return reshaped
|
||||
|
||||
def post_process_yolox(self, output):
|
||||
|
||||
output_785 = output[0] # 785
|
||||
output_794 = output[1] # 794
|
||||
output_795 = output[2] # 795
|
||||
@ -528,7 +544,6 @@ class MemryXDetector(DetectionApi):
|
||||
def process_output(self, *outputs):
|
||||
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
|
||||
conv_out1 = outputs[0]
|
||||
conv_out2 = outputs[1]
|
||||
conv_out3 = outputs[2]
|
||||
|
||||
@ -97,10 +97,10 @@ class AsyncLocalObjectDetector(BaseLocalDetector):
|
||||
def async_send_input(self, tensor_input: np.ndarray, connection_id):
|
||||
tensor_input = self._transform_input(tensor_input)
|
||||
return self.detect_api.send_input(connection_id, tensor_input)
|
||||
|
||||
|
||||
def async_receive_output(self):
|
||||
return self.detect_api.receive_output()
|
||||
|
||||
|
||||
|
||||
def prepare_detector(name, out_events):
|
||||
threading.current_thread().name = f"detector:{name}"
|
||||
@ -136,10 +136,7 @@ def run_detector(
|
||||
start: Value,
|
||||
detector_config: BaseDetectorConfig,
|
||||
):
|
||||
|
||||
stop_event, frame_manager, outputs, logger = prepare_detector(
|
||||
name, out_events
|
||||
)
|
||||
stop_event, frame_manager, outputs, logger = prepare_detector(name, out_events)
|
||||
|
||||
object_detector = LocalObjectDetector(detector_config=detector_config)
|
||||
|
||||
@ -179,10 +176,7 @@ def async_run_detector(
|
||||
start: Value,
|
||||
detector_config: BaseDetectorConfig,
|
||||
):
|
||||
|
||||
stop_event, frame_manager, outputs, logger = prepare_detector(
|
||||
name, out_events
|
||||
)
|
||||
stop_event, frame_manager, outputs, logger = prepare_detector(name, out_events)
|
||||
|
||||
object_detector = AsyncLocalObjectDetector(detector_config=detector_config)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user