ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
|
|||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
from .dataset import YOLODataset
|
||||
from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
|
||||
from .utils import PIN_MEMORY
|
||||
|
||||
|
||||
|
|
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
|
||||
"""Build YOLO Dataset."""
|
||||
return YOLODataset(
|
||||
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
|
||||
return dataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
|
|
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
|
|||
)
|
||||
|
||||
|
||||
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
|
||||
"""Build YOLO Dataset."""
|
||||
return GroundingDataset(
|
||||
img_path=img_path,
|
||||
json_file=json_file,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
task=cfg.task,
|
||||
classes=cfg.classes,
|
||||
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
||||
batch = min(batch, len(dataset))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue