[Docs]: Link buttons, add autobackend, BaseModel and ops (#130)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
af6e3c536b
commit
8996c5c6cf
10 changed files with 562 additions and 96 deletions
|
|
@ -17,11 +17,36 @@ from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_wei
|
|||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
# YOLOv5 base model
|
||||
'''
|
||||
The BaseModel class is a base class for all the models in the Ultralytics YOLO family.
|
||||
'''
|
||||
|
||||
def forward(self, x, profile=False, visualize=False):
|
||||
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
||||
"""
|
||||
> `forward` is a wrapper for `_forward_once` that runs the model on a single scale
|
||||
|
||||
Args:
|
||||
x: the input image
|
||||
profile: whether to profile the model. Defaults to False
|
||||
visualize: if True, will return the intermediate feature maps. Defaults to False
|
||||
|
||||
Returns:
|
||||
The output of the network.
|
||||
"""
|
||||
return self._forward_once(x, profile, visualize)
|
||||
|
||||
def _forward_once(self, x, profile=False, visualize=False):
|
||||
"""
|
||||
> Forward pass of the network
|
||||
|
||||
Args:
|
||||
x: input to the model
|
||||
profile: if True, the time taken for each layer will be printed. Defaults to False
|
||||
visualize: If True, it will save the feature maps of the model. Defaults to False
|
||||
|
||||
Returns:
|
||||
The last layer of the model.
|
||||
"""
|
||||
y, dt = [], [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
|
|
@ -36,6 +61,15 @@ class BaseModel(nn.Module):
|
|||
return x
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""
|
||||
It takes a model, an input, and a list of times, and it profiles the model on the input, appending
|
||||
the time to the list
|
||||
|
||||
Args:
|
||||
m: the model
|
||||
x: the input image
|
||||
dt: list of time taken for each layer
|
||||
"""
|
||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
|
|
@ -48,7 +82,13 @@ class BaseModel(nn.Module):
|
|||
if c:
|
||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||
|
||||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
||||
def fuse(self):
|
||||
"""
|
||||
> It takes a model and fuses the Conv2d() and BatchNorm2d() layers into a single layer
|
||||
|
||||
Returns:
|
||||
The model is being returned.
|
||||
"""
|
||||
LOGGER.info('Fusing layers... ')
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||
|
|
@ -58,11 +98,27 @@ class BaseModel(nn.Module):
|
|||
self.info()
|
||||
return self
|
||||
|
||||
def info(self, verbose=False, imgsz=640): # print model information
|
||||
def info(self, verbose=False, imgsz=640):
|
||||
"""
|
||||
Prints model information
|
||||
|
||||
Args:
|
||||
verbose: if True, prints out the model information. Defaults to False
|
||||
imgsz: the size of the image that the model will be trained on. Defaults to 640
|
||||
"""
|
||||
model_info(self, verbose, imgsz)
|
||||
|
||||
def _apply(self, fn):
|
||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||
"""
|
||||
`_apply()` is a function that applies a function to all the tensors in the model that are not
|
||||
parameters or registered buffers
|
||||
|
||||
Args:
|
||||
fn: the function to apply to the model
|
||||
|
||||
Returns:
|
||||
A model that is a Detect() object.
|
||||
"""
|
||||
self = super()._apply(fn)
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
|
|
@ -72,6 +128,12 @@ class BaseModel(nn.Module):
|
|||
return self
|
||||
|
||||
def load(self, weights):
|
||||
"""
|
||||
> This function loads the weights of the model from a file
|
||||
|
||||
Args:
|
||||
weights: The weights to load into the model.
|
||||
"""
|
||||
# Force all tasks to implement this function
|
||||
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue