Add benchmarking for RF100 datasets (#10190)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Muhammad Rizwan Munawar 2024-04-26 20:04:25 +05:00 committed by GitHub
parent 2f3e17d23e
commit 5323ee0d58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 180 additions and 0 deletions

View file

@ -25,18 +25,23 @@ NCNN | `ncnn` | yolov8n_ncnn_model/
"""
import glob
import os
import platform
import re
import shutil
import time
from pathlib import Path
import numpy as np
import torch.cuda
import yaml
from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
from ultralytics.utils.downloads import safe_download
from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import select_device
@ -152,6 +157,133 @@ def benchmark(
return df
class RF100Benchmark:
def __init__(self):
"""Function for initialization of RF100Benchmark."""
self.ds_names = []
self.ds_cfg_list = []
self.rf = None
self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
def set_key(self, api_key):
"""
Set Roboflow API key for processing.
Args:
api_key (str): The API key.
"""
check_requirements("roboflow")
from roboflow import Roboflow
self.rf = Roboflow(api_key=api_key)
def parse_dataset(self, ds_link_txt="datasets_links.txt"):
"""
Parse dataset links and downloads datasets.
Args:
ds_link_txt (str): Path to dataset_links file.
"""
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
os.chdir("rf-100")
os.mkdir("ultralytics-benchmarks")
safe_download("https://ultralytics.com/assets/datasets_links.txt")
with open(ds_link_txt, "r") as file:
for line in file:
try:
_, url, workspace, project, version = re.split("/+", line.strip())
self.ds_names.append(project)
proj_version = f"{project}-{version}"
if not Path(proj_version).exists():
self.rf.workspace(workspace).project(project).version(version).download("yolov8")
else:
print("Dataset already downloaded.")
self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
except Exception:
continue
return self.ds_names, self.ds_cfg_list
def fix_yaml(self, path):
"""
Function to fix yaml train and val path.
Args:
path (str): YAML file path.
"""
with open(path, "r") as file:
yaml_data = yaml.safe_load(file)
yaml_data["train"] = "train/images"
yaml_data["val"] = "valid/images"
with open(path, "w") as file:
yaml.safe_dump(yaml_data, file)
def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind):
"""
Model evaluation on validation results.
Args:
yaml_path (str): YAML file path.
val_log_file (str): val_log_file path.
eval_log_file (str): eval_log_file path.
list_ind (int): Index for current dataset.
"""
skip_symbols = ["🚀", "⚠️", "💡", ""]
with open(yaml_path) as stream:
class_names = yaml.safe_load(stream)["names"]
with open(val_log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
eval_lines = []
for line in lines:
if any(symbol in line for symbol in skip_symbols):
continue
entries = line.split(" ")
entries = list(filter(lambda val: val != "", entries))
entries = [e.strip("\n") for e in entries]
start_class = False
for e in entries:
if e == "all":
if "(AP)" not in entries:
if "(AR)" not in entries:
# parse all
eval = {}
eval["class"] = entries[0]
eval["images"] = entries[1]
eval["targets"] = entries[2]
eval["precision"] = entries[3]
eval["recall"] = entries[4]
eval["map50"] = entries[5]
eval["map95"] = entries[6]
eval_lines.append(eval)
if e in class_names:
eval = {}
eval["class"] = entries[0]
eval["images"] = entries[1]
eval["targets"] = entries[2]
eval["precision"] = entries[3]
eval["recall"] = entries[4]
eval["map50"] = entries[5]
eval["map95"] = entries[6]
eval_lines.append(eval)
map_val = 0.0
if len(eval_lines) > 1:
print("There's more dicts")
for lst in eval_lines:
if lst["class"] == "all":
map_val = lst["map50"]
else:
print("There's only one dict res")
map_val = [res["map50"] for res in eval_lines][0]
with open(eval_log_file, "a") as f:
f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
class ProfileModels:
"""
ProfileModels class for profiling different models on ONNX and TensorRT.