Refactor Python code (#13448)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
6a234f3639
commit
1b26838def
22 changed files with 81 additions and 101 deletions
|
|
@ -25,7 +25,7 @@ class FastSAMPrompt:
|
|||
def __init__(self, source, results, device="cuda") -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
if isinstance(source, (str, Path)) and os.path.isdir(source):
|
||||
raise ValueError(f"FastSAM only accepts image paths and PIL Image sources, not directories.")
|
||||
raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class NASValidator(DetectionValidator):
|
|||
ultimately producing the final detections.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
|
||||
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
||||
|
||||
Example:
|
||||
|
|
|
|||
|
|
@ -300,22 +300,22 @@ class DetectionValidator(BaseValidator):
|
|||
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
val = COCOeval(anno, pred, "bbox")
|
||||
else:
|
||||
from lvis import LVIS, LVISEval
|
||||
|
||||
anno = LVIS(str(anno_json)) # init annotations api
|
||||
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = LVISEval(anno, pred, "bbox")
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
val = LVISEval(anno, pred, "bbox")
|
||||
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
val.evaluate()
|
||||
val.accumulate()
|
||||
val.summarize()
|
||||
if self.is_lvis:
|
||||
eval.print_results() # explicitly call print_results
|
||||
val.print_results() # explicitly call print_results
|
||||
# update mAP50-95 and mAP50
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
||||
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
||||
val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{pkg} unable to run: {e}")
|
||||
|
|
|
|||
|
|
@ -54,16 +54,15 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
if mode == "train":
|
||||
dataset = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
||||
for im_path in img_path
|
||||
]
|
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||
else:
|
||||
if mode != "train":
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
dataset = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
||||
for im_path in img_path
|
||||
]
|
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
|
|
@ -71,7 +70,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
final_data = dict()
|
||||
final_data = {}
|
||||
data_yaml = self.args.data
|
||||
assert data_yaml.get("train", False) # object365.yaml
|
||||
assert data_yaml.get("val", False) # lvis.yaml
|
||||
|
|
@ -88,7 +87,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|||
grounding_data = data_yaml[s].get("grounding_data")
|
||||
if grounding_data is None:
|
||||
continue
|
||||
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
||||
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
|
||||
for g in grounding_data:
|
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||
final_data[s] += grounding_data
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue