General refactoring and improvements (#373)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-01-15 14:44:25 +01:00 committed by GitHub
parent ac628c0d3e
commit 583eac0e80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 304 additions and 309 deletions

View file

@ -57,7 +57,7 @@ class BaseModel(nn.Module):
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if visualize:
pass
LOGGER.info('visualize feature not yet supported')
# TODO: feature_visualization(x, m.type, m.i, save_dir=visualize)
return x
@ -106,8 +106,8 @@ class BaseModel(nn.Module):
Prints model information
Args:
verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
"""
model_info(self, verbose, imgsz)
@ -117,10 +117,10 @@ class BaseModel(nn.Module):
parameters or registered buffers
Args:
fn: the function to apply to the model
fn: the function to apply to the model
Returns:
A model that is a Detect() object.
A model that is a Detect() object.
"""
self = super()._apply(fn)
m = self.model[-1] # Detect()
@ -135,7 +135,7 @@ class BaseModel(nn.Module):
This function loads the weights of the model from a file
Args:
weights (str): The weights to load into the model.
weights (str): The weights to load into the model.
"""
# Force all tasks to implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!")