Apply Ruff 0.9.0 (#18622)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
cc1e77138c
commit
3902e740cf
22 changed files with 69 additions and 65 deletions
|
|
@ -74,7 +74,7 @@ MODELS = {TASK2MODEL[task] for task in TASKS}
|
|||
|
||||
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
||||
SOLUTIONS_HELP_MSG = f"""
|
||||
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview:
|
||||
Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview:
|
||||
|
||||
yolo solutions SOLUTION ARGS
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ SOLUTIONS_HELP_MSG = f"""
|
|||
yolo streamlit-predict
|
||||
"""
|
||||
CLI_HELP_MSG = f"""
|
||||
Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||
Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
|
|
@ -359,11 +359,11 @@ def check_cfg(cfg, hard=True):
|
|||
)
|
||||
cfg[k] = v = float(v)
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. '{k}' must be an int (i.e. '{k}=8')"
|
||||
)
|
||||
cfg[k] = int(v)
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
|
|
|
|||
|
|
@ -271,9 +271,9 @@ class Compose:
|
|||
"""
|
||||
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
|
||||
if isinstance(index, list):
|
||||
assert isinstance(
|
||||
value, list
|
||||
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
|
||||
assert isinstance(value, list), (
|
||||
f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
|
||||
)
|
||||
if isinstance(index, int):
|
||||
index, value = [index], [value]
|
||||
for i, v in zip(index, value):
|
||||
|
|
|
|||
|
|
@ -242,7 +242,9 @@ def convert_coco(
|
|||
from ultralytics.data.converter import convert_coco
|
||||
|
||||
convert_coco("../datasets/coco/annotations/", use_segments=True, use_keypoints=False, cls91to80=False)
|
||||
convert_coco("../datasets/lvis/annotations/", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
|
||||
convert_coco(
|
||||
"../datasets/lvis/annotations/", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True
|
||||
)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
|
@ -270,7 +272,7 @@ def convert_coco(
|
|||
data = json.load(f)
|
||||
|
||||
# Create image dict
|
||||
images = {f'{x["id"]:d}': x for x in data["images"]}
|
||||
images = {f"{x['id']:d}": x for x in data["images"]}
|
||||
# Create image-annotations dict
|
||||
imgToAnns = defaultdict(list)
|
||||
for ann in data["annotations"]:
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ class GroundingDataset(YOLODataset):
|
|||
LOGGER.info("Loading annotation file...")
|
||||
with open(self.json_file) as f:
|
||||
annotations = json.load(f)
|
||||
images = {f'{x["id"]:d}': x for x in annotations["images"]}
|
||||
images = {f"{x['id']:d}": x for x in annotations["images"]}
|
||||
img_to_anns = defaultdict(list)
|
||||
for ann in annotations["annotations"]:
|
||||
img_to_anns[ann["image_id"]].append(ann)
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@ def check_cls_dataset(dataset, split=""):
|
|||
|
||||
# Print to console
|
||||
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
||||
prefix = f'{colorstr(f"{k}:")} {v}...'
|
||||
prefix = f"{colorstr(f'{k}:')} {v}..."
|
||||
if v is None:
|
||||
LOGGER.info(prefix)
|
||||
else:
|
||||
|
|
@ -519,7 +519,7 @@ class HUBDatasetStats:
|
|||
except Exception as e:
|
||||
raise Exception("error/HUB/dataset_stats/init") from e
|
||||
|
||||
self.hub_dir = Path(f'{data["path"]}-hub')
|
||||
self.hub_dir = Path(f"{data['path']}-hub")
|
||||
self.im_dir = self.hub_dir / "images"
|
||||
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
|
@ -531,7 +531,7 @@ class HUBDatasetStats:
|
|||
return False, None, path
|
||||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), (
|
||||
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
|
||||
f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
|
||||
)
|
||||
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ class Exporter:
|
|||
)
|
||||
self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
|
||||
data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
|
||||
description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
|
||||
description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
|
||||
self.metadata = {
|
||||
"description": description,
|
||||
"author": "Ultralytics",
|
||||
|
|
@ -377,7 +377,7 @@ class Exporter:
|
|||
|
||||
LOGGER.info(
|
||||
f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
|
||||
f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
|
||||
)
|
||||
|
||||
# Exports
|
||||
|
|
@ -427,11 +427,11 @@ class Exporter:
|
|||
predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
|
||||
q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nExport complete ({time.time() - t:.1f}s)"
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
|
||||
f'\nVisualize: https://netron.app'
|
||||
f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
|
||||
f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
|
||||
f"\nVisualize: https://netron.app"
|
||||
)
|
||||
|
||||
self.run_callbacks("on_export_end")
|
||||
|
|
@ -680,16 +680,16 @@ class Exporter:
|
|||
shutil.rmtree(unzip_dir) # delete unzip dir
|
||||
|
||||
ncnn_args = [
|
||||
f'ncnnparam={f / "model.ncnn.param"}',
|
||||
f'ncnnbin={f / "model.ncnn.bin"}',
|
||||
f'ncnnpy={f / "model_ncnn.py"}',
|
||||
f"ncnnparam={f / 'model.ncnn.param'}",
|
||||
f"ncnnbin={f / 'model.ncnn.bin'}",
|
||||
f"ncnnpy={f / 'model_ncnn.py'}",
|
||||
]
|
||||
|
||||
pnnx_args = [
|
||||
f'pnnxparam={f / "model.pnnx.param"}',
|
||||
f'pnnxbin={f / "model.pnnx.bin"}',
|
||||
f'pnnxpy={f / "model_pnnx.py"}',
|
||||
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
||||
f"pnnxparam={f / 'model.pnnx.param'}",
|
||||
f"pnnxbin={f / 'model.pnnx.bin'}",
|
||||
f"pnnxpy={f / 'model_pnnx.py'}",
|
||||
f"pnnxonnx={f / 'model.pnnx.onnx'}",
|
||||
]
|
||||
|
||||
cmd = [
|
||||
|
|
@ -1139,7 +1139,9 @@ class Exporter:
|
|||
def export_imx(self, prefix=colorstr("IMX:")):
|
||||
"""YOLO IMX export."""
|
||||
gptq = False
|
||||
assert LINUX, "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
|
||||
assert LINUX, (
|
||||
"export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
|
||||
)
|
||||
if getattr(self.model, "end2end", False):
|
||||
raise ValueError("IMX export is not supported for end2end models.")
|
||||
if "C2f" not in self.model.__str__():
|
||||
|
|
|
|||
|
|
@ -367,7 +367,7 @@ class BasePredictor:
|
|||
# Save videos and streams
|
||||
if self.dataset.mode in {"stream", "video"}:
|
||||
fps = self.dataset.fps if self.dataset.mode == "video" else 30
|
||||
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
||||
frames_path = f"{save_path.split('.', 1)[0]}_frames/"
|
||||
if save_path not in self.vid_writer: # new video
|
||||
if self.args.save_frames:
|
||||
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class BaseTrainer:
|
|||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
|
||||
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
@ -329,10 +329,10 @@ class BaseTrainer:
|
|||
self.train_time_start = time.time()
|
||||
self.run_callbacks("on_train_start")
|
||||
LOGGER.info(
|
||||
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n"
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||
)
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
|
|
@ -814,6 +814,6 @@ class BaseTrainer:
|
|||
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
|
||||
)
|
||||
return optimizer
|
||||
|
|
|
|||
|
|
@ -224,12 +224,12 @@ class Tuner:
|
|||
|
||||
# Save and print tune results
|
||||
header = (
|
||||
f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
||||
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
|
||||
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
||||
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
|
||||
f'{self.prefix}Best fitness model is {best_save_dir}\n'
|
||||
f'{self.prefix}Best fitness hyperparameters are printed below.\n'
|
||||
f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n"
|
||||
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
|
||||
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
|
||||
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
|
||||
f"{self.prefix}Best fitness model is {best_save_dir}\n"
|
||||
f"{self.prefix}Best fitness hyperparameters are printed below.\n"
|
||||
)
|
||||
LOGGER.info("\n" + header)
|
||||
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||
|
|
|
|||
|
|
@ -87,9 +87,9 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert len(labels) == len(
|
||||
points
|
||||
), f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
||||
assert len(labels) == len(points), (
|
||||
f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
||||
)
|
||||
point_idx = (
|
||||
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
||||
if labels.sum() == 0 # all negative points
|
||||
|
|
|
|||
|
|
@ -479,9 +479,9 @@ class ImageEncoder(nn.Module):
|
|||
self.trunk = trunk
|
||||
self.neck = neck
|
||||
self.scalp = scalp
|
||||
assert (
|
||||
self.trunk.channel_list == self.neck.backbone_channel_list
|
||||
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
||||
assert self.trunk.channel_list == self.neck.backbone_channel_list, (
|
||||
f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
||||
)
|
||||
|
||||
def forward(self, sample: torch.Tensor):
|
||||
"""Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
||||
|
|
|
|||
|
|
@ -279,9 +279,9 @@ class Predictor(BasePredictor):
|
|||
if labels is None:
|
||||
labels = np.ones(points.shape[:-1])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert (
|
||||
points.shape[-2] == labels.shape[-1]
|
||||
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
|
||||
assert points.shape[-2] == labels.shape[-1], (
|
||||
f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
|
||||
)
|
||||
points *= r
|
||||
if points.ndim == 2:
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
|
|
@ -552,9 +552,9 @@ class Predictor(BasePredictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM models only support square image size, but got {self.imgsz}."
|
||||
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
||||
f"SAM models only support square image size, but got {self.imgsz}."
|
||||
)
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
|
|
@ -795,9 +795,9 @@ class SAM2Predictor(Predictor):
|
|||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features from the SAM image encoder for subsequent processing."""
|
||||
assert (
|
||||
isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1]
|
||||
), f"SAM 2 models only support square image size, but got {self.imgsz}."
|
||||
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
||||
f"SAM 2 models only support square image size, but got {self.imgsz}."
|
||||
)
|
||||
self.model.set_imgsz(self.imgsz)
|
||||
self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
|
||||
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class DetectionValidator(BaseValidator):
|
|||
predn,
|
||||
self.args.save_conf,
|
||||
pbatch["ori_shape"],
|
||||
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
|
||||
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
||||
)
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class OBBValidator(DetectionValidator):
|
|||
classname = self.names[d["category_id"] - 1].replace(" ", "-")
|
||||
p = d["poly"]
|
||||
|
||||
with open(f'{pred_txt / f"Task1_{classname}"}.txt', "a") as f:
|
||||
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
# Save merged results, this could result slightly lower map than using official merging script,
|
||||
# because of the probiou calculation.
|
||||
|
|
@ -197,7 +197,7 @@ class OBBValidator(DetectionValidator):
|
|||
p = [round(i, 3) for i in x[:-2]] # poly
|
||||
score = round(x[-2], 3)
|
||||
|
||||
with open(f'{pred_merged_txt / f"Task1_{classname}"}.txt', "a") as f:
|
||||
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
|
||||
return stats
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class PoseValidator(DetectionValidator):
|
|||
pred_kpts,
|
||||
self.args.save_conf,
|
||||
pbatch["ori_shape"],
|
||||
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
|
||||
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
||||
)
|
||||
|
||||
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
pred_masks,
|
||||
self.args.save_conf,
|
||||
pbatch["ori_shape"],
|
||||
self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
|
||||
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
||||
)
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class SecurityAlarm(BaseSolution):
|
|||
message["Subject"] = "Security Alert"
|
||||
|
||||
# Add the text message body
|
||||
message_body = f"Ultralytics ALERT!!! " f"{records} objects have been detected!!"
|
||||
message_body = f"Ultralytics ALERT!!! {records} objects have been detected!!"
|
||||
message.attach(MIMEText(message_body))
|
||||
|
||||
# Attach the image
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def on_train_end(trainer):
|
|||
final=True,
|
||||
)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n" f"{PREFIX}View model at {session.model_url} 🚀")
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def on_pretrain_routine_end(trainer):
|
|||
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
mlflow.log_params(dict(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n" f"{PREFIX}WARNING ⚠️ Not tracking this run")
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n{PREFIX}WARNING ⚠️ Not tracking this run")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
|
|
|
|||
|
|
@ -707,7 +707,7 @@ def check_amp(model):
|
|||
LOGGER.info(f"{prefix}checks passed ✅")
|
||||
except ConnectionError:
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks skipped ⚠️. " f"Offline and unable to download YOLO11n for AMP checks. {warning_msg}"
|
||||
f"{prefix}checks skipped ⚠️. Offline and unable to download YOLO11n for AMP checks. {warning_msg}"
|
||||
)
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
trainer.args.model = "{getattr(trainer.hub_session, 'model_url', trainer.args.model)}"
|
||||
trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
|
||||
results = trainer.train()
|
||||
"""
|
||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -432,7 +432,7 @@ class ConfusionMatrix:
|
|||
ax.set_xlabel("True")
|
||||
ax.set_ylabel("Predicted")
|
||||
ax.set_title(title)
|
||||
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
|
||||
plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
|
||||
fig.savefig(plot_fname, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue