Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -160,7 +160,7 @@ class BaseModel(nn.Module):
weights (str): The weights to load into the model.
"""
# Force all tasks to implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!")
raise NotImplementedError('This function needs to be implemented by derived classes!')
class DetectionModel(BaseModel):
@ -249,7 +249,7 @@ class SegmentationModel(DetectionModel):
super().__init__(cfg, ch, nc, verbose)
def _forward_augment(self, x):
raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!")
raise NotImplementedError('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')
class ClassificationModel(BaseModel):
@ -292,7 +292,7 @@ class ClassificationModel(BaseModel):
self.info()
def load(self, weights):
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
csd = model.float().state_dict()
csd = intersect_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
@ -341,10 +341,10 @@ def torch_safe_load(weight):
return torch.load(file, map_location='cpu') # load
except ModuleNotFoundError as e:
if e.name == 'omegaconf': # e.name is missing module name
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
f"\nRecommend fixes are to train a new model using updated ultralytics package or to "
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
if e.name != 'models':
check_requirements(e.name) # install missing module
return torch.load(file, map_location='cpu') # load
@ -489,13 +489,13 @@ def guess_model_task(model):
def cfg2task(cfg):
# Guess from YAML dictionary
m = cfg["head"][-1][-2].lower() # output module name
if m in ["classify", "classifier", "cls", "fc"]:
return "classify"
if m in ["detect"]:
return "detect"
if m in ["segment"]:
return "segment"
m = cfg['head'][-1][-2].lower() # output module name
if m in ['classify', 'classifier', 'cls', 'fc']:
return 'classify'
if m in ['detect']:
return 'detect'
if m in ['segment']:
return 'segment'
# Guess from model cfg
if isinstance(model, dict):
@ -513,22 +513,22 @@ def guess_model_task(model):
for m in model.modules():
if isinstance(m, Detect):
return "detect"
return 'detect'
elif isinstance(m, Segment):
return "segment"
return 'segment'
elif isinstance(m, Classify):
return "classify"
return 'classify'
# Guess from model filename
if isinstance(model, (str, Path)):
model = Path(model).stem
if '-seg' in model:
return "segment"
return 'segment'
elif '-cls' in model:
return "classify"
return 'classify'
else:
return "detect"
return 'detect'
# Unable to determine task from model
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")