Refactor Python code (#13448)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-09 02:32:17 +02:00 committed by GitHub
parent 6a234f3639
commit 1b26838def
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 81 additions and 101 deletions

View file

@ -208,9 +208,10 @@ class RF100Benchmark:
return self.ds_names, self.ds_cfg_list
def fix_yaml(self, path):
@staticmethod
def fix_yaml(path):
"""
Function to fix yaml train and val path.
Function to fix YAML train and val path.
Args:
path (str): YAML file path.
@ -245,32 +246,19 @@ class RF100Benchmark:
entries = line.split(" ")
entries = list(filter(lambda val: val != "", entries))
entries = [e.strip("\n") for e in entries]
start_class = False
for e in entries:
if e == "all":
if "(AP)" not in entries:
if "(AR)" not in entries:
# parse all
eval = {}
eval["class"] = entries[0]
eval["images"] = entries[1]
eval["targets"] = entries[2]
eval["precision"] = entries[3]
eval["recall"] = entries[4]
eval["map50"] = entries[5]
eval["map95"] = entries[6]
eval_lines.append(eval)
if e in class_names:
eval = {}
eval["class"] = entries[0]
eval["images"] = entries[1]
eval["targets"] = entries[2]
eval["precision"] = entries[3]
eval["recall"] = entries[4]
eval["map50"] = entries[5]
eval["map95"] = entries[6]
eval_lines.append(eval)
eval_lines.extend(
{
"class": entries[0],
"images": entries[1],
"targets": entries[2],
"precision": entries[3],
"recall": entries[4],
"map50": entries[5],
"map95": entries[6],
}
for e in entries
if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
)
map_val = 0.0
if len(eval_lines) > 1:
print("There's more dicts")

View file

@ -103,22 +103,22 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer):
"""Log model artifacts at the end of the training."""
if mlflow:
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
if not mlflow:
return
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
f"{PREFIX}disable with 'yolo settings mlflow=False'"
)
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = (

View file

@ -19,7 +19,7 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.

View file

@ -434,10 +434,9 @@ def check_torchvision():
# Extract only the major and minor versions
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
if all(v_torchvision != v for v in compatible_versions):
print(
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"

View file

@ -493,7 +493,7 @@ class Annotator:
angle = 360 - angle
return angle
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25):
def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
"""
Draw specific keypoints for gym steps counting.
@ -503,6 +503,8 @@ class Annotator:
shape (tuple): imgsz for model inference
radius (int): Keypoint radius value
"""
if indices is None:
indices = [2, 5, 7]
for i, k in enumerate(keypoints):
if i in indices:
x_coord, y_coord = k[0], k[1]