Refactor Python code (#13448)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
6a234f3639
commit
1b26838def
22 changed files with 81 additions and 101 deletions
|
|
@ -208,9 +208,10 @@ class RF100Benchmark:
|
|||
|
||||
return self.ds_names, self.ds_cfg_list
|
||||
|
||||
def fix_yaml(self, path):
|
||||
@staticmethod
|
||||
def fix_yaml(path):
|
||||
"""
|
||||
Function to fix yaml train and val path.
|
||||
Function to fix YAML train and val path.
|
||||
|
||||
Args:
|
||||
path (str): YAML file path.
|
||||
|
|
@ -245,32 +246,19 @@ class RF100Benchmark:
|
|||
entries = line.split(" ")
|
||||
entries = list(filter(lambda val: val != "", entries))
|
||||
entries = [e.strip("\n") for e in entries]
|
||||
start_class = False
|
||||
for e in entries:
|
||||
if e == "all":
|
||||
if "(AP)" not in entries:
|
||||
if "(AR)" not in entries:
|
||||
# parse all
|
||||
eval = {}
|
||||
eval["class"] = entries[0]
|
||||
eval["images"] = entries[1]
|
||||
eval["targets"] = entries[2]
|
||||
eval["precision"] = entries[3]
|
||||
eval["recall"] = entries[4]
|
||||
eval["map50"] = entries[5]
|
||||
eval["map95"] = entries[6]
|
||||
eval_lines.append(eval)
|
||||
|
||||
if e in class_names:
|
||||
eval = {}
|
||||
eval["class"] = entries[0]
|
||||
eval["images"] = entries[1]
|
||||
eval["targets"] = entries[2]
|
||||
eval["precision"] = entries[3]
|
||||
eval["recall"] = entries[4]
|
||||
eval["map50"] = entries[5]
|
||||
eval["map95"] = entries[6]
|
||||
eval_lines.append(eval)
|
||||
eval_lines.extend(
|
||||
{
|
||||
"class": entries[0],
|
||||
"images": entries[1],
|
||||
"targets": entries[2],
|
||||
"precision": entries[3],
|
||||
"recall": entries[4],
|
||||
"map50": entries[5],
|
||||
"map95": entries[6],
|
||||
}
|
||||
for e in entries
|
||||
if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
|
||||
)
|
||||
map_val = 0.0
|
||||
if len(eval_lines) > 1:
|
||||
print("There's more dicts")
|
||||
|
|
|
|||
|
|
@ -103,22 +103,22 @@ def on_fit_epoch_end(trainer):
|
|||
|
||||
def on_train_end(trainer):
|
||||
"""Log model artifacts at the end of the training."""
|
||||
if mlflow:
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
mlflow.end_run()
|
||||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
if not mlflow:
|
||||
return
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
mlflow.end_run()
|
||||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
callbacks = (
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
|
||||
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
|
||||
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
||||
different classes.
|
||||
|
||||
|
|
|
|||
|
|
@ -434,10 +434,9 @@ def check_torchvision():
|
|||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
print(
|
||||
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
|
|
|
|||
|
|
@ -493,7 +493,7 @@ class Annotator:
|
|||
angle = 360 - angle
|
||||
return angle
|
||||
|
||||
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25):
|
||||
def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
|
||||
"""
|
||||
Draw specific keypoints for gym steps counting.
|
||||
|
||||
|
|
@ -503,6 +503,8 @@ class Annotator:
|
|||
shape (tuple): imgsz for model inference
|
||||
radius (int): Keypoint radius value
|
||||
"""
|
||||
if indices is None:
|
||||
indices = [2, 5, 7]
|
||||
for i, k in enumerate(keypoints):
|
||||
if i in indices:
|
||||
x_coord, y_coord = k[0], k[1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue