Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -97,8 +97,7 @@ class AIFI(TransformerEncoderLayer):
|
|||
out_w = grid_w.flatten()[..., None] @ omega[None]
|
||||
out_h = grid_h.flatten()[..., None] @ omega[None]
|
||||
|
||||
return torch.concat([torch.sin(out_w), torch.cos(out_w),
|
||||
torch.sin(out_h), torch.cos(out_h)], axis=1)[None, :, :]
|
||||
return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
|
|
@ -170,9 +169,11 @@ class MLP(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""
|
||||
LayerNorm2d module from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
|
||||
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, eps=1e-6):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class DetectionModel(BaseModel):
|
|||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.yaml['nc'] = nc # override YAML value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.inplace = self.yaml.get('inplace', True)
|
||||
|
|
@ -329,7 +329,7 @@ class ClassificationModel(BaseModel):
|
|||
ch=3,
|
||||
nc=None,
|
||||
cutoff=10,
|
||||
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
|
||||
verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
|
|
@ -357,7 +357,7 @@ class ClassificationModel(BaseModel):
|
|||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.yaml['nc'] = nc # override YAML value
|
||||
elif not nc and not self.yaml.get('nc', None):
|
||||
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue