Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -20,11 +20,11 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "segment"
overrides['task'] = 'segment'
super().__init__(cfg, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
@ -43,13 +43,13 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni):
images = batch["img"]
masks = batch["masks"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg")
images = batch['img']
masks = batch['masks']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self):
plot_results(file=self.csv, segment=True) # save results.png
@ -80,15 +80,15 @@ class SegLoss(Loss):
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
masks = batch["masks"].to(self.device).float()
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
@ -135,13 +135,13 @@ class SegLoss(Loss):
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
# Mask loss for one image
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def train(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-seg.pt"
data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
@ -153,5 +153,5 @@ def train(cfg=DEFAULT_CFG, use_python=False):
trainer.train()
if __name__ == "__main__":
if __name__ == '__main__':
train()