README and Docs updates with A100 TensorRT times (#270)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-01-11 21:54:41 +01:00 committed by GitHub
parent 216cf2ddb6
commit e18ae9d8e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 250 additions and 241 deletions

View file

@ -17,35 +17,36 @@ from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_wei
class BaseModel(nn.Module):
'''
The BaseModel class is a base class for all the models in the Ultralytics YOLO family.
'''
"""
The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
"""
def forward(self, x, profile=False, visualize=False):
"""
> `forward` is a wrapper for `_forward_once` that runs the model on a single scale
Forward pass of the model on a single scale.
Wrapper for `_forward_once` method.
Args:
x: the input image
profile: whether to profile the model. Defaults to False
visualize: if True, will return the intermediate feature maps. Defaults to False
x (torch.tensor): The input image tensor
profile (bool): Whether to profile the model, defaults to False
visualize (bool): Whether to return the intermediate feature maps, defaults to False
Returns:
The output of the network.
(torch.tensor): The output of the network.
"""
return self._forward_once(x, profile, visualize)
def _forward_once(self, x, profile=False, visualize=False):
"""
> Forward pass of the network
Perform a forward pass through the network.
Args:
x: input to the model
profile: if True, the time taken for each layer will be printed. Defaults to False
visualize: If True, it will save the feature maps of the model. Defaults to False
x (torch.tensor): The input tensor to the model
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False
Returns:
The last layer of the model.
(torch.tensor): The last output of the model.
"""
y, dt = [], [] # outputs
for m in self.model:
@ -62,13 +63,15 @@ class BaseModel(nn.Module):
def _profile_one_layer(self, m, x, dt):
"""
It takes a model, an input, and a list of times, and it profiles the model on the input, appending
the time to the list
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list.
Args:
m: the model
x: the input image
dt: list of time taken for each layer
m (nn.Module): The layer to be profiled.
x (torch.Tensor): The input data to the layer.
dt (list): A list to store the computation time of the layer.
Returns:
None
"""
c = m == self.model[-1] # is final layer, copy input as inplace fix
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
@ -84,10 +87,10 @@ class BaseModel(nn.Module):
def fuse(self):
"""
> It takes a model and fuses the Conv2d() and BatchNorm2d() layers into a single layer
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency.
Returns:
The model is being returned.
(nn.Module): The fused model is returned.
"""
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
@ -103,8 +106,8 @@ class BaseModel(nn.Module):
Prints model information
Args:
verbose: if True, prints out the model information. Defaults to False
imgsz: the size of the image that the model will be trained on. Defaults to 640
verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
"""
model_info(self, verbose, imgsz)
@ -129,10 +132,10 @@ class BaseModel(nn.Module):
def load(self, weights):
"""
> This function loads the weights of the model from a file
This function loads the weights of the model from a file
Args:
weights: The weights to load into the model.
weights (str): The weights to load into the model.
"""
# Force all tasks to implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!")