Start Multi-OS CI (#172)
This commit is contained in:
parent
202f7bffa3
commit
f80ff923e7
5 changed files with 26 additions and 25 deletions
|
|
@ -72,7 +72,8 @@ class ClassificationTrainer(BaseTrainer):
|
|||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size if mode == "train" else (batch_size * 2),
|
||||
augment=mode == "train",
|
||||
rank=rank)
|
||||
rank=rank,
|
||||
workers=self.args.workers)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,10 @@ class ClassificationValidator(BaseValidator):
|
|||
return self.metrics.results_dict
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
workers=self.args.workers)
|
||||
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue