Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -25,12 +25,13 @@ from .val import NASValidator
|
|||
class NAS(Model):
|
||||
|
||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
||||
super().__init__(model, task='detect')
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str, task: str):
|
||||
# Load or create new NAS model
|
||||
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
||||
import super_gradients
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
|
|
@ -58,4 +59,5 @@ class NAS(Model):
|
|||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
||||
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue