ultralytics 8.1.39 add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
18036908d4
commit
e9187c1296
34 changed files with 2166 additions and 100 deletions
|
|
@ -35,7 +35,7 @@ class FastSAMPrompt:
|
|||
except ImportError:
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
||||
|
||||
from .model import YOLO, YOLOWorld
|
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
|
|||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.nt_per_class = None
|
||||
self.is_coco = False
|
||||
self.is_lvis = False
|
||||
self.class_map = None
|
||||
self.args.task = "detect"
|
||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
|
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
|
|||
"""Initialize evaluation metrics for YOLO."""
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
|
||||
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.metrics.names = self.names
|
||||
|
|
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
|
|||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"category_id": self.class_map[int(p[5])],
|
||||
"category_id": self.class_map[int(p[5])]
|
||||
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
|
||||
"bbox": [round(x, 3) for x in b],
|
||||
"score": round(p[4], 5),
|
||||
}
|
||||
|
|
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
|
|||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
pkg = "pycocotools" if self.is_coco else "lvis"
|
||||
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements("pycocotools>=2.0.6")
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
for x in pred_json, anno_json:
|
||||
assert x.is_file(), f"{x} file not found"
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, "bbox")
|
||||
else:
|
||||
from lvis import LVIS, LVISEval
|
||||
|
||||
anno = LVIS(str(anno_json)) # init annotations api
|
||||
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = LVISEval(anno, pred, "bbox")
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
if self.is_lvis:
|
||||
eval.print_results() # explicitly call print_results
|
||||
# update mAP50-95 and mAP50
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
||||
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||
LOGGER.warning(f"{pkg} unable to run: {e}")
|
||||
return stats
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class YOLOWorld(Model):
|
|||
"model": WorldModel,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.world.WorldTrainer,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
5
ultralytics/models/yolo/world/__init__.py
Normal file
5
ultralytics/models/yolo/world/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .train import WorldTrainer
|
||||
|
||||
__all__ = ["WorldTrainer"]
|
||||
91
ultralytics/models/yolo/world/train.py
Normal file
91
ultralytics/models/yolo/world/train.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import WorldModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
from ultralytics.data import build_yolo_dataset
|
||||
from ultralytics.utils.torch_utils import de_parallel
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
import itertools
|
||||
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Callback."""
|
||||
if RANK in (-1, 0):
|
||||
# NOTE: for evaluation
|
||||
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||
device = next(trainer.model.parameters()).device
|
||||
text_model, _ = clip.load("ViT-B/32", device=device)
|
||||
for p in text_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
trainer.text_model = text_model
|
||||
|
||||
|
||||
class WorldTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class to fine-tune a world model on a close-set dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.yolo.world import WorldModel
|
||||
|
||||
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
|
||||
trainer = WorldTrainer(overrides=args)
|
||||
trainer.train()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a WorldTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return WorldModel initialized with specified config and weights."""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = WorldModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=3,
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
||||
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
|
||||
batch = super().preprocess_batch(batch)
|
||||
|
||||
# NOTE: add text features
|
||||
texts = list(itertools.chain(*batch["texts"]))
|
||||
text_token = clip.tokenize(texts).to(batch["img"].device)
|
||||
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
||||
return batch
|
||||
108
ultralytics/models/yolo/world/train_world.py
Normal file
108
ultralytics/models/yolo/world/train_world.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.world import WorldTrainer
|
||||
from ultralytics.utils.torch_utils import de_parallel
|
||||
from ultralytics.utils import DEFAULT_CFG
|
||||
|
||||
|
||||
class WorldTrainerFromScratch(WorldTrainer):
|
||||
"""
|
||||
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||
from ultralytics import YOLOWorld
|
||||
|
||||
data = dict(
|
||||
train=dict(
|
||||
yolo_data=["Objects365.yaml"],
|
||||
grounding_data=[
|
||||
dict(
|
||||
img_path="../datasets/flickr30k/images",
|
||||
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
||||
),
|
||||
dict(
|
||||
img_path="../datasets/GQA/images",
|
||||
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
|
||||
),
|
||||
],
|
||||
),
|
||||
val=dict(yolo_data=["lvis.yaml"]),
|
||||
)
|
||||
|
||||
model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||
model.train(data=data, trainer=WorldTrainerFromScratch)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a WorldTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (List[str] | str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
if mode == "train":
|
||||
dataset = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
||||
for im_path in img_path
|
||||
]
|
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
|
||||
else:
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
final_data = dict()
|
||||
data_yaml = self.args.data
|
||||
assert data_yaml.get("train", False) # object365.yaml
|
||||
assert data_yaml.get("val", False) # lvis.yaml
|
||||
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
||||
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
||||
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
||||
for d in data["val"]:
|
||||
if d.get("minival") is None: # for lvis dataset
|
||||
continue
|
||||
d["minival"] = str(d["path"] / d["minival"])
|
||||
for s in ["train", "val"]:
|
||||
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
||||
# save grounding data if there's one
|
||||
grounding_data = data_yaml[s].get("grounding_data")
|
||||
if grounding_data is None:
|
||||
continue
|
||||
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
|
||||
for g in grounding_data:
|
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||
final_data[s] += grounding_data
|
||||
# NOTE: to make training work properly, set `nc` and `names`
|
||||
final_data["nc"] = data["val"][0]["nc"]
|
||||
final_data["names"] = data["val"][0]["names"]
|
||||
self.data = final_data
|
||||
return final_data["train"], final_data["val"][0]
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""DO NOT plot labels."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO-World model."""
|
||||
val = self.args.data["val"]["yolo_data"][0]
|
||||
self.validator.args.data = val
|
||||
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
||||
return super().final_eval()
|
||||
Loading…
Add table
Add a link
Reference in a new issue