General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia 2023-01-07 19:25:48 +05:30 committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 60 deletions

View file

@ -18,10 +18,15 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
overrides["task"] = "segment"
super().__init__(config, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
model.load(weights)
return model
def get_validator(self):