ultralytics 8.1.40 search in Python sets {} for speed (#9450)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-04-01 00:16:52 +02:00 committed by GitHub
parent 30484d5925
commit ea527507fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 97 additions and 93 deletions

View file

@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0):
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0):
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]: