diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index f20165b96..273351ad0 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -48,7 +48,7 @@ def is_openvino_gpu_npu_available() -> bool: """ available_devices = get_openvino_available_devices() # Check for GPU, NPU, or other acceleration devices (excluding CPU) - acceleration_devices = ['GPU', 'MYRIAD', 'NPU', 'GNA', 'HDDL'] + acceleration_devices = ["GPU", "MYRIAD", "NPU", "GNA", "HDDL"] return any(device in available_devices for device in acceleration_devices) @@ -354,20 +354,10 @@ def get_optimized_runner( if rknn_path: return RKNNModelRunner(rknn_path) - providers, options = get_ort_providers(device == "CPU", device, **kwargs) - - if device == "CPU": - return ONNXModelRunner( - ort.InferenceSession( - model_path, - providers=providers, - provider_options=options, - ) - ) - if is_openvino_gpu_npu_available(): return OpenVINOModelRunner(model_path, device, **kwargs) + providers, options = get_ort_providers(device == "CPU", device, **kwargs) ortSession = ort.InferenceSession( model_path, providers=providers, diff --git a/frigate/util/model.py b/frigate/util/model.py index c64287660..b988cae1a 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -338,14 +338,16 @@ def get_ort_providers( else: continue elif provider == "OpenVINOExecutionProvider": - os.makedirs(os.path.join(MODEL_CACHE_DIR, "openvino/ort"), exist_ok=True) - providers.append(provider) - options.append( - { - "cache_dir": os.path.join(MODEL_CACHE_DIR, "openvino/ort"), - "device_type": device, - } - ) + # OpenVINO is used directly + if device == "OpenVINO": + os.makedirs(os.path.join(MODEL_CACHE_DIR, "openvino/ort"), exist_ok=True) + providers.append(provider) + options.append( + { + "cache_dir": os.path.join(MODEL_CACHE_DIR, "openvino/ort"), + "device_type": device, + } + ) elif provider == "MIGraphXExecutionProvider": # MIGraphX uses more CPU than ROCM, while also being the same speed if device == "MIGraphX":