Refactor Python code (#13448)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
6a234f3639
commit
1b26838def
22 changed files with 81 additions and 101 deletions
|
|
@ -12,11 +12,11 @@ clap = { version = "4.2.4", features = ["derive"] }
|
|||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
ndarray = { version = "0.15.6" }
|
||||
ort = {version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"]}
|
||||
ort = { version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"] }
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
anyhow = { version = "1.0.75"}
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
rand = { version ="0.8.5" }
|
||||
rand = { version = "0.8.5" }
|
||||
chrono = { version = "0.4.30" }
|
||||
half = { version = "2.3.1" }
|
||||
dirs = { version = "5.0.1" }
|
||||
|
|
|
|||
|
|
@ -1114,10 +1114,7 @@ class RandomLoadText:
|
|||
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
|
||||
|
||||
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
|
||||
neg_labels = []
|
||||
for i in range(num_classes):
|
||||
if i not in pos_labels:
|
||||
neg_labels.append(i)
|
||||
neg_labels = [i for i in range(num_classes) if i not in pos_labels]
|
||||
neg_labels = random.sample(neg_labels, k=neg_samples)
|
||||
|
||||
sampled_labels = pos_labels + neg_labels
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from ultralytics.data.loaders import (
|
|||
autocast_list,
|
||||
)
|
||||
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
|
||||
from ultralytics.utils import LINUX, RANK, colorstr
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -209,9 +209,10 @@ class Exporter:
|
|||
if self.args.optimize:
|
||||
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
||||
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
||||
if edgetpu and not LINUX:
|
||||
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
|
||||
elif edgetpu and self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
|
||||
if edgetpu:
|
||||
if not LINUX:
|
||||
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
|
||||
elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
|
||||
LOGGER.warning("WARNING ⚠️ Edge TPU export requires batch size 1, setting batch=1.")
|
||||
self.args.batch = 1
|
||||
if isinstance(model, WorldModel):
|
||||
|
|
|
|||
|
|
@ -742,7 +742,6 @@ class Model(nn.Module):
|
|||
|
||||
if hasattr(self.model, "names"):
|
||||
return check_class_names(self.model.names)
|
||||
else:
|
||||
if not self.predictor: # export formats will not have predictor defined until predict() is called
|
||||
self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=False)
|
||||
|
|
|
|||
|
|
@ -319,13 +319,13 @@ class BasePredictor:
|
|||
frame = self.dataset.count
|
||||
else:
|
||||
match = re.search(r"frame (\d+)/", s[i])
|
||||
frame = int(match.group(1)) if match else None # 0 if frame undetermined
|
||||
frame = int(match[1]) if match else None # 0 if frame undetermined
|
||||
|
||||
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
|
||||
string += "%gx%g " % im.shape[2:]
|
||||
result = self.results[i]
|
||||
result.save_dir = self.save_dir.__str__() # used in other locations
|
||||
string += result.verbose() + f"{result.speed['inference']:.1f}ms"
|
||||
string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
|
||||
|
||||
# Add predictions to image
|
||||
if self.args.save or self.args.show:
|
||||
|
|
|
|||
|
|
@ -368,5 +368,5 @@ class HUBTrainingSession:
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
for _ in response.iter_content(chunk_size=1024):
|
||||
pass # Do nothing with data chunks
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class FastSAMPrompt:
|
|||
def __init__(self, source, results, device="cuda") -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
if isinstance(source, (str, Path)) and os.path.isdir(source):
|
||||
raise ValueError(f"FastSAM only accepts image paths and PIL Image sources, not directories.")
|
||||
raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class NASValidator(DetectionValidator):
|
|||
ultimately producing the final detections.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
|
||||
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
||||
|
||||
Example:
|
||||
|
|
|
|||
|
|
@ -300,22 +300,22 @@ class DetectionValidator(BaseValidator):
|
|||
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
val = COCOeval(anno, pred, "bbox")
|
||||
else:
|
||||
from lvis import LVIS, LVISEval
|
||||
|
||||
anno = LVIS(str(anno_json)) # init annotations api
|
||||
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = LVISEval(anno, pred, "bbox")
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
val = LVISEval(anno, pred, "bbox")
|
||||
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
val.evaluate()
|
||||
val.accumulate()
|
||||
val.summarize()
|
||||
if self.is_lvis:
|
||||
eval.print_results() # explicitly call print_results
|
||||
val.print_results() # explicitly call print_results
|
||||
# update mAP50-95 and mAP50
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
||||
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
||||
val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{pkg} unable to run: {e}")
|
||||
|
|
|
|||
|
|
@ -54,7 +54,8 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
if mode == "train":
|
||||
if mode != "train":
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
dataset = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
|
|
@ -62,8 +63,6 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
for im_path in img_path
|
||||
]
|
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||
else:
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
|
|
@ -71,7 +70,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
final_data = dict()
|
||||
final_data = {}
|
||||
data_yaml = self.args.data
|
||||
assert data_yaml.get("train", False) # object365.yaml
|
||||
assert data_yaml.get("val", False) # lvis.yaml
|
||||
|
|
@ -88,7 +87,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
grounding_data = data_yaml[s].get("grounding_data")
|
||||
if grounding_data is None:
|
||||
continue
|
||||
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
||||
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
|
||||
for g in grounding_data:
|
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||
final_data[s] += grounding_data
|
||||
|
|
|
|||
|
|
@ -320,10 +320,8 @@ class AutoBackend(nn.Module):
|
|||
with open(w, "rb") as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file
|
||||
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
|
||||
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
||||
except StopIteration:
|
||||
pass # no metadata file found
|
||||
|
||||
# TFLite or TFLite Edge TPU
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
|
|
|
|||
|
|
@ -666,8 +666,7 @@ class CBLinear(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
"""Forward pass through CBLinear layer."""
|
||||
outs = self.conv(x).split(self.c2s, dim=1)
|
||||
return outs
|
||||
return self.conv(x).split(self.c2s, dim=1)
|
||||
|
||||
|
||||
class CBFuse(nn.Module):
|
||||
|
|
@ -682,5 +681,4 @@ class CBFuse(nn.Module):
|
|||
"""Forward pass through CBFuse layer."""
|
||||
target_size = xs[-1].shape[2:]
|
||||
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
|
||||
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
||||
return out
|
||||
return torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class AIGym:
|
|||
self.stage[ind] = "up"
|
||||
self.count[ind] += 1
|
||||
|
||||
elif self.pose_type == "pushup" or self.pose_type == "squat":
|
||||
elif self.pose_type in {"pushup", "squat"}:
|
||||
if self.angle[ind] > self.poseup_angle:
|
||||
self.stage[ind] = "up"
|
||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class ObjectCounter:
|
|||
if self.draw_tracks:
|
||||
self.annotator.draw_centroid_and_tracks(
|
||||
track_line,
|
||||
color=self.track_color if self.track_color else colors(int(track_id), True),
|
||||
color=self.track_color or colors(int(track_id), True),
|
||||
track_thickness=self.track_thickness,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -73,8 +73,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|||
idx = tracks[:, -1].astype(int)
|
||||
predictor.results[i] = predictor.results[i][idx]
|
||||
|
||||
update_args = dict()
|
||||
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
|
||||
update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
|
||||
predictor.results[i].update(**update_args)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class GMC:
|
|||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
self.downscale = max(1, downscale)
|
||||
|
||||
if self.method == "orb":
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
|
|
|
|||
|
|
@ -208,9 +208,10 @@ class RF100Benchmark:
|
|||
|
||||
return self.ds_names, self.ds_cfg_list
|
||||
|
||||
def fix_yaml(self, path):
|
||||
@staticmethod
|
||||
def fix_yaml(path):
|
||||
"""
|
||||
Function to fix yaml train and val path.
|
||||
Function to fix YAML train and val path.
|
||||
|
||||
Args:
|
||||
path (str): YAML file path.
|
||||
|
|
@ -245,32 +246,19 @@ class RF100Benchmark:
|
|||
entries = line.split(" ")
|
||||
entries = list(filter(lambda val: val != "", entries))
|
||||
entries = [e.strip("\n") for e in entries]
|
||||
start_class = False
|
||||
for e in entries:
|
||||
if e == "all":
|
||||
if "(AP)" not in entries:
|
||||
if "(AR)" not in entries:
|
||||
# parse all
|
||||
eval = {}
|
||||
eval["class"] = entries[0]
|
||||
eval["images"] = entries[1]
|
||||
eval["targets"] = entries[2]
|
||||
eval["precision"] = entries[3]
|
||||
eval["recall"] = entries[4]
|
||||
eval["map50"] = entries[5]
|
||||
eval["map95"] = entries[6]
|
||||
eval_lines.append(eval)
|
||||
|
||||
if e in class_names:
|
||||
eval = {}
|
||||
eval["class"] = entries[0]
|
||||
eval["images"] = entries[1]
|
||||
eval["targets"] = entries[2]
|
||||
eval["precision"] = entries[3]
|
||||
eval["recall"] = entries[4]
|
||||
eval["map50"] = entries[5]
|
||||
eval["map95"] = entries[6]
|
||||
eval_lines.append(eval)
|
||||
eval_lines.extend(
|
||||
{
|
||||
"class": entries[0],
|
||||
"images": entries[1],
|
||||
"targets": entries[2],
|
||||
"precision": entries[3],
|
||||
"recall": entries[4],
|
||||
"map50": entries[5],
|
||||
"map95": entries[6],
|
||||
}
|
||||
for e in entries
|
||||
if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
|
||||
)
|
||||
map_val = 0.0
|
||||
if len(eval_lines) > 1:
|
||||
print("There's more dicts")
|
||||
|
|
|
|||
|
|
@ -103,7 +103,8 @@ def on_fit_epoch_end(trainer):
|
|||
|
||||
def on_train_end(trainer):
|
||||
"""Log model artifacts at the end of the training."""
|
||||
if mlflow:
|
||||
if not mlflow:
|
||||
return
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
|
|
@ -116,8 +117,7 @@ def on_train_end(trainer):
|
|||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n"
|
||||
f"{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
|
|||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
This function crafts a custom metric visualization that mimics the behavior of wandb's default precision-recall
|
||||
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
|
||||
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
||||
different classes.
|
||||
|
||||
|
|
|
|||
|
|
@ -434,10 +434,9 @@ def check_torchvision():
|
|||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
print(
|
||||
f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
|
|
|
|||
|
|
@ -493,7 +493,7 @@ class Annotator:
|
|||
angle = 360 - angle
|
||||
return angle
|
||||
|
||||
def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2, conf_thres=0.25):
|
||||
def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
|
||||
"""
|
||||
Draw specific keypoints for gym steps counting.
|
||||
|
||||
|
|
@ -503,6 +503,8 @@ class Annotator:
|
|||
shape (tuple): imgsz for model inference
|
||||
radius (int): Keypoint radius value
|
||||
"""
|
||||
if indices is None:
|
||||
indices = [2, 5, 7]
|
||||
for i, k in enumerate(keypoints):
|
||||
if i in indices:
|
||||
x_coord, y_coord = k[0], k[1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue