Add docstrings and improve comments (#11229)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ccfc1cf925
commit
d5458f27cd
16 changed files with 34 additions and 17 deletions
|
|
@ -30,7 +30,7 @@ from ultralytics import YOLO
|
|||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
# Retrieve the batch data
|
||||
"""Handle prediction batch end by combining results with corresponding frames; modifies predictor results."""
|
||||
_, image, _, _ = predictor.batch
|
||||
|
||||
# Ensure that image is a list
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from ultralytics.models.yolo.detect import DetectionTrainer
|
|||
|
||||
class CustomTrainer(DetectionTrainer):
|
||||
def get_model(self, cfg, weights):
|
||||
"""Loads a custom detection model given configuration and weight files."""
|
||||
...
|
||||
|
||||
|
||||
|
|
@ -65,16 +66,19 @@ from ultralytics.nn.tasks import DetectionModel
|
|||
|
||||
class MyCustomModel(DetectionModel):
|
||||
def init_criterion(self):
|
||||
"""Initializes the loss function and adds a callback for uploading the model to Google Drive every 10 epochs."""
|
||||
...
|
||||
|
||||
|
||||
class CustomTrainer(DetectionTrainer):
|
||||
def get_model(self, cfg, weights):
|
||||
"""Returns a customized detection model instance configured with specified config and weights."""
|
||||
return MyCustomModel(...)
|
||||
|
||||
|
||||
# callback to upload model weights
|
||||
def log_model(trainer):
|
||||
"""Logs the path of the last model weight used by the trainer."""
|
||||
last_weight_path = trainer.last
|
||||
print(last_weight_path)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue