Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -267,10 +267,11 @@ class PositionEmbeddingRandom(nn.Module):
|
|||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer(
|
||||
'positional_encoding_gaussian_matrix',
|
||||
scale * torch.randn((2, num_pos_feats)),
|
||||
)
|
||||
self.register_buffer('positional_encoding_gaussian_matrix', scale * torch.randn((2, num_pos_feats)))
|
||||
|
||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
|
|
|
|||
|
|
@ -20,12 +20,14 @@ class Sam(nn.Module):
|
|||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
|
||||
def __init__(self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = None,
|
||||
pixel_std: List[float] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
||||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
|
|
@ -37,10 +39,6 @@ class Sam(nn.Module):
|
|||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
if pixel_mean is None:
|
||||
pixel_mean = [123.675, 116.28, 103.53]
|
||||
if pixel_std is None:
|
||||
pixel_std = [58.395, 57.12, 57.375]
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
|
|
|
|||
|
|
@ -30,40 +30,6 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module('bn', bn)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
|
||||
w.size(0),
|
||||
w.shape[2:],
|
||||
stride=self.c.stride,
|
||||
padding=self.c.padding,
|
||||
dilation=self.c.dilation,
|
||||
groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
# NOTE: This module and timm package is needed only for training.
|
||||
# from ultralytics.utils.checks import check_requirements
|
||||
# check_requirements('timm')
|
||||
# from timm.models.layers import DropPath as TimmDropPath
|
||||
# from timm.models.layers import trunc_normal_
|
||||
# class DropPath(TimmDropPath):
|
||||
#
|
||||
# def __init__(self, drop_prob=None):
|
||||
# super().__init__(drop_prob=drop_prob)
|
||||
# self.drop_prob = drop_prob
|
||||
#
|
||||
# def __repr__(self):
|
||||
# msg = super().__repr__()
|
||||
# msg += f'(drop_prob={self.drop_prob})'
|
||||
# return msg
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
|
||||
|
|
|
|||
|
|
@ -153,8 +153,7 @@ class Predictor(BasePredictor):
|
|||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes *= r
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
|
||||
masks = masks[:, None, :, :]
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
|
|
@ -257,9 +256,7 @@ class Predictor(BasePredictor):
|
|||
pred_bbox = batched_mask_to_box(pred_mask).float()
|
||||
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
|
||||
if not torch.all(keep_mask):
|
||||
pred_bbox = pred_bbox[keep_mask]
|
||||
pred_mask = pred_mask[keep_mask]
|
||||
pred_score = pred_score[keep_mask]
|
||||
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
|
||||
|
||||
crop_masks.append(pred_mask)
|
||||
crop_bboxes.append(pred_bbox)
|
||||
|
|
@ -288,9 +285,7 @@ class Predictor(BasePredictor):
|
|||
if len(crop_regions) > 1:
|
||||
scores = 1 / region_areas
|
||||
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
|
||||
pred_masks = pred_masks[keep]
|
||||
pred_bboxes = pred_bboxes[keep]
|
||||
pred_scores = pred_scores[keep]
|
||||
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
|
||||
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue