Skip to content

Object Detection Confusion Matrix - Evaluation API #140

@SkalskiP

Description

@SkalskiP

Description

Let's start our work on Evaluation API with sv.ConfusionMatrix. Here is the expected API:

@dataclass
class ConfusionMatrix:
    matrix: np.ndarray
    classes: List[str]
    conf_threshold: float
    iou_threshold: float

    @classmethod
    def from_detections(
        cls,
        predictions: List[sv.Detections],
        target: List[sv.Detections],
        classes: List[str],
        conf_threshold: float = 0.3,
        iou_threshold: float = 0.5
    ) -> ConfusionMatrix:
        pass

   @classmethod
    def benchmark(
        cls,
        dataset: sv.DetectionDataset,
        callback: Callable[[np.ndarray], sv.Detections],
        conf_threshold: float = 0.3,
        iou_threshold: float = 0.5
    ) -> ConfusionMatrix:
        pass

    def plot(self, target_path: str) -> None:
        pass

Usage example

from_detections

>>> import supervision as sv
>>> from ultralytics import YOLO

>>> model = YOLO('yolov8s.pt')

>>> dataset = sv.DetectionDataset.from_yolo(
...     images_directory_path='...',
...     annotations_directory_path='...',
...     data_yaml_path='...'
... )

>>> predictions, target = [], []

>>> for _, image, labels in dataset:
...    result = model(image)[0]
...    detections = sv.Detections.from_yolov8(result)
...    predictions.append(detections)
...    target.append(labels)

>>> matrix = ConfusionMatrix.from_detections(
...     predictions=predictions,
...     target=target,
...     classes=dataset.classes
... )
>>> matrix.plot('...')

benchmark

>>> import supervision as sv
>>> from ultralytics import YOLO

>>> model = YOLO('yolov8s.pt')

>>> dataset = sv.DetectionDataset.from_yolo(
...     images_directory_path='...',
...     annotations_directory_path='...',
...     data_yaml_path='...'
... )

>>> def wrapper(image: np.ndarray) -> sv.Detections:
...     result = model(image)[0]
...     return sv.Detections.from_yolov8(result)

>>> matrix = ConfusionMatrix.benchmark(
...     dataset=dataset,
...     callback=wrapper
... )
>>> matrix.plot('...')

Additional

  • Code should live in supervision.metrics.detection
  • We can use code in onemetric pip package as starting point
  • We shouldn't need any external dependencies. Only numpy and matplotlib.

Metadata

Metadata

Labels

API:evaluationModel evaluation APIenhancementNew feature or requestversion: 0.12.0Feature to be added in `0.12.0` release

Type

No type

Projects

Status

Done

Relationships

None yet

Development

No branches or pull requests

Issue actions