Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -22,7 +22,7 @@ class FastSAM(Model):
|
|||
"""
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
||||
if str(model) == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
|
||||
|
|
@ -30,4 +30,5 @@ class FastSAM(Model):
|
|||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes FastSAMPredictor class by inheriting from DetectionPredictor and setting task to 'segment'."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses the predictions, applies non-max suppression, scales the boxes, and returns the results."""
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from ultralytics.utils import TQDM
|
|||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, source, results, device='cuda') -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
|
@ -30,6 +31,7 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
"""Segments the given image according to the provided bounding box coordinates."""
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
|
@ -45,6 +47,9 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
||||
area.
|
||||
"""
|
||||
annotations = []
|
||||
n = len(result.masks.data) if result.masks is not None else 0
|
||||
for i in range(n):
|
||||
|
|
@ -61,6 +66,9 @@ class FastSAMPrompt:
|
|||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
||||
contours.
|
||||
"""
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
|
|
@ -195,6 +203,7 @@ class FastSAMPrompt:
|
|||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
|
|
@ -206,6 +215,7 @@ class FastSAMPrompt:
|
|||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
||||
|
|
@ -229,6 +239,7 @@ class FastSAMPrompt:
|
|||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||
if self.results[0].masks is not None:
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
if os.path.isdir(self.source):
|
||||
|
|
@ -261,7 +272,8 @@ class FastSAMPrompt:
|
|||
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
||||
return self.results
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理
|
||||
def point_prompt(self, points, pointlabel): # numpy
|
||||
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
||||
if self.results[0].masks is not None:
|
||||
if os.path.isdir(self.source):
|
||||
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
||||
|
|
@ -284,6 +296,7 @@ class FastSAMPrompt:
|
|||
return self.results
|
||||
|
||||
def text_prompt(self, text):
|
||||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
|
|
@ -296,4 +309,5 @@ class FastSAMPrompt:
|
|||
return self.results
|
||||
|
||||
def everything_prompt(self):
|
||||
"""Returns the processed results from the previous methods in the class."""
|
||||
return self.results
|
||||
|
|
|
|||
|
|
@ -25,12 +25,13 @@ from .val import NASValidator
|
|||
class NAS(Model):
|
||||
|
||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
||||
super().__init__(model, task='detect')
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str, task: str):
|
||||
# Load or create new NAS model
|
||||
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
||||
import super_gradients
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
|
|
@ -58,4 +59,5 @@ class NAS(Model):
|
|||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
||||
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
RT-DETR model interface
|
||||
"""
|
||||
"""RT-DETR model interface."""
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
|
||||
|
|
@ -11,17 +9,17 @@ from .val import RTDETRValidator
|
|||
|
||||
|
||||
class RTDETR(Model):
|
||||
"""
|
||||
RTDETR model interface.
|
||||
"""
|
||||
"""RTDETR model interface."""
|
||||
|
||||
def __init__(self, model='rtdetr-l.pt') -> None:
|
||||
"""Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'."""
|
||||
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
||||
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
|
||||
super().__init__(model=model, task='detect')
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model."""
|
||||
return {
|
||||
'detect': {
|
||||
'predictor': RTDETRPredictor,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ class RTDETRPredictor(BasePredictor):
|
|||
return results
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-transform input image before inference.
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
|
|
|||
|
|
@ -37,7 +37,8 @@ class RTDETRTrainer(DetectionTrainer):
|
|||
return model
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build RTDETR Dataset
|
||||
"""
|
||||
Build RTDETR Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ __all__ = 'RTDETRValidator', # tuple or list
|
|||
class RTDETRDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
|
||||
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
|
||||
|
||||
# NOTE: add stretch version load_image for rtdetr mosaic
|
||||
|
|
|
|||
|
|
@ -32,9 +32,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
Computes the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
|
||||
and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
|
|
@ -60,7 +61,11 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|||
|
||||
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
|
||||
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
||||
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes.
|
||||
|
||||
Each layer has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
|
@ -145,8 +150,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
Calculates boxes in XYXY format around masks.
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
SAM model interface
|
||||
"""
|
||||
"""SAM model interface."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -13,16 +11,16 @@ from .predict import Predictor
|
|||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM model interface.
|
||||
"""
|
||||
"""SAM model interface."""
|
||||
|
||||
def __init__(self, model='sam_b.pt') -> None:
|
||||
"""Initializes the SAM model instance with the specified pre-trained model file."""
|
||||
if model and Path(model).suffix not in ('.pt', '.pth'):
|
||||
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
||||
super().__init__(model=model, task='segment')
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""Loads the provided weights into the SAM model."""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
|
|
@ -48,4 +46,5 @@ class SAM(Model):
|
|||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
|
||||
return {'segment': {'predictor': Predictor}}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,11 @@ class MaskDecoder(nn.Module):
|
|||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
"""
|
||||
Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
|
@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
|
|
@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
|
|||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
"""Gets the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
|
|
@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
|
|||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Returns the device of the first point embedding's weight tensor."""
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
|
|
@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
|
|
@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -351,6 +352,7 @@ class Block(nn.Module):
|
|||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
|
|
@ -404,6 +406,7 @@ class Attention(nn.Module):
|
|||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
|
|
@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
|||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
|
|
@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ from ultralytics.utils.instance import to_2tuple
|
|||
class Conv2d_BN(torch.nn.Sequential):
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
|
|
@ -34,6 +37,9 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
class PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
|
|
@ -48,12 +54,16 @@ class PatchEmbed(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
|
|
@ -73,6 +83,7 @@ class MBConv(nn.Module):
|
|||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
|
|
@ -87,6 +98,9 @@ class MBConv(nn.Module):
|
|||
class PatchMerging(nn.Module):
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -99,6 +113,7 @@ class PatchMerging(nn.Module):
|
|||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
|
|
@ -149,6 +164,7 @@ class ConvLayer(nn.Module):
|
|||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
|
@ -157,6 +173,7 @@ class ConvLayer(nn.Module):
|
|||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
|
@ -167,6 +184,7 @@ class Mlp(nn.Module):
|
|||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
|
|
@ -216,6 +234,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
|
|
@ -298,6 +317,9 @@ class TinyViTBlock(nn.Module):
|
|||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
|
@ -337,6 +359,9 @@ class TinyViTBlock(nn.Module):
|
|||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
||||
|
||||
|
|
@ -402,23 +427,28 @@ class BasicLayer(nn.Module):
|
|||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
|
|
@ -518,6 +548,7 @@ class TinyViT(nn.Module):
|
|||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# layers -> blocks (depth)
|
||||
|
|
@ -525,6 +556,7 @@ class TinyViT(nn.Module):
|
|||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
|
|
@ -544,12 +576,14 @@ class TinyViT(nn.Module):
|
|||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, 'lr_scale'), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
|
|
@ -561,11 +595,12 @@ class TinyViT(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
return {'attention_biases'}
|
||||
|
||||
def forward_features(self, x):
|
||||
# x: (N, C, H, W)
|
||||
x = self.patch_embed(x)
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
|
|
@ -579,4 +614,5 @@ class TinyViT(nn.Module):
|
|||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
return self.forward_features(x)
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ class TwoWayTransformer(nn.Module):
|
|||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
|
|
@ -171,8 +170,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .build import build_sam
|
|||
class Predictor(BasePredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
|
|
@ -34,7 +35,8 @@ class Predictor(BasePredictor):
|
|||
self.segment_all = False
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
"""
|
||||
Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
|
|
@ -189,7 +191,8 @@ class Predictor(BasePredictor):
|
|||
stability_score_thresh=0.95,
|
||||
stability_score_offset=0.95,
|
||||
crop_nms_thresh=0.7):
|
||||
"""Segment the whole image.
|
||||
"""
|
||||
Segment the whole image.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
||||
|
|
@ -360,14 +363,15 @@ class Predictor(BasePredictor):
|
|||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
"""Resets the image and its features to None."""
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates. Requires open-cv as a dependency.
|
||||
Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
|
||||
Requires open-cv as a dependency.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks, (N, H, W).
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class DETRLoss(nn.Module):
|
|||
self.device = None
|
||||
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
||||
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
|
||||
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f'loss_class{postfix}'
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
|
|
@ -68,6 +69,9 @@ class DETRLoss(nn.Module):
|
|||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
||||
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
||||
boxes.
|
||||
"""
|
||||
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f'loss_bbox{postfix}'
|
||||
name_giou = f'loss_giou{postfix}'
|
||||
|
|
@ -125,7 +129,7 @@ class DETRLoss(nn.Module):
|
|||
postfix='',
|
||||
masks=None,
|
||||
gt_mask=None):
|
||||
"""Get auxiliary losses"""
|
||||
"""Get auxiliary losses."""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
|
|
@ -166,12 +170,14 @@ class DETRLoss(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _get_index(match_indices):
|
||||
"""Returns batch indices, source indices, and destination indices from provided match indices."""
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
|
||||
pred_assigned = torch.cat([
|
||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
||||
|
|
@ -190,7 +196,7 @@ class DETRLoss(nn.Module):
|
|||
gt_mask=None,
|
||||
postfix='',
|
||||
match_indices=None):
|
||||
"""Get losses"""
|
||||
"""Get losses."""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(pred_bboxes,
|
||||
pred_scores,
|
||||
|
|
@ -250,22 +256,43 @@ class DETRLoss(nn.Module):
|
|||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
"""
|
||||
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
|
||||
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
||||
an additional denoising training loss when provided with denoising metadata.
|
||||
"""
|
||||
|
||||
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
||||
"""
|
||||
Forward pass to compute the detection loss.
|
||||
|
||||
Args:
|
||||
preds (tuple): Predicted bounding boxes and scores.
|
||||
batch (dict): Batch data containing ground truth information.
|
||||
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
|
||||
dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
|
||||
dn_meta (dict, optional): Metadata for denoising. Default is None.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing the total loss and, if applicable, the denoising loss.
|
||||
"""
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
# Check for denoising metadata to compute denoising training loss
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||
|
||||
# Denoising match indices
|
||||
# Get the match indices for denoising
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||
|
||||
# Compute denoising training loss
|
||||
# Compute the denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
# If no denoising metadata is provided, set denoising loss to zero
|
||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
|
@ -276,12 +303,12 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||
Get the match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||
dn_num_group (int): The number of groups of denoising.
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
|
||||
dn_num_group (int): Number of denoising groups.
|
||||
gt_groups (List[int]): List of integers representing the number of ground truths for each image.
|
||||
|
||||
Returns:
|
||||
dn_match_indices (List(tuple)): Matched indices.
|
||||
(List[tuple]): List of tuples containing matched indices for denoising.
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in
|
||||
an end-to-end fashion.
|
||||
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
|
||||
end-to-end fashion.
|
||||
|
||||
HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
|
||||
|
|
@ -32,6 +32,9 @@ class HungarianMatcher(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||
"""Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
|
||||
gamma factors.
|
||||
"""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||||
|
|
@ -45,8 +48,8 @@ class HungarianMatcher(nn.Module):
|
|||
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||||
"""
|
||||
Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
|
||||
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching
|
||||
between predictions and ground truth based on these costs.
|
||||
(classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
|
||||
predictions and ground truth based on these costs.
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
|
||||
|
|
@ -153,9 +156,9 @@ def get_cdn_group(batch,
|
|||
box_noise_scale=1.0,
|
||||
training=False):
|
||||
"""
|
||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with
|
||||
positive and negative samples from the ground truths (gt). It applies noise to the class labels and bounding
|
||||
box coordinates, and returns the modified labels, bounding boxes, attention mask and meta information.
|
||||
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
|
||||
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
||||
and returns the modified labels, bounding boxes, attention mask and meta information.
|
||||
|
||||
Args:
|
||||
batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
|
||||
|
|
@ -191,12 +194,12 @@ def get_cdn_group(batch,
|
|||
gt_bbox = batch['bboxes'] # bs*num, 4
|
||||
b_idx = batch['batch_idx']
|
||||
|
||||
# each group has positive and negative queries.
|
||||
# Each group has positive and negative queries.
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# positive and negative mask
|
||||
# Positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
|
|
@ -220,10 +223,9 @@ def get_cdn_group(batch,
|
|||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = inverse_sigmoid(dn_bbox)
|
||||
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
||||
|
||||
# total denoising queries
|
||||
num_dn = int(max_nums * 2 * num_group)
|
||||
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
||||
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
|
|
@ -256,9 +258,3 @@ def get_cdn_group(batch,
|
|||
|
||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||||
class_embed.device), dn_meta
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6):
|
||||
"""Inverse sigmoid function."""
|
||||
x = x.clip(min=0., max=1.)
|
||||
return torch.log(x / (1 - x + eps) + eps)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class ClassificationPredictor(BasePredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
|
|
@ -113,8 +114,9 @@ class ClassificationTrainer(BaseTrainer):
|
|||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
Returns a loss dict with labelled training loss items tensor.
|
||||
|
||||
Not needed for classification but necessary for segmentation & detection
|
||||
"""
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ class ClassificationValidator(BaseValidator):
|
|||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||
# self.args.box *= 3 / nl # scale to layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
|
|
@ -80,8 +80,9 @@ class DetectionTrainer(BaseTrainer):
|
|||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
Returns a loss dict with labelled training loss items tensor.
|
||||
|
||||
Not needed for classification but necessary for segmentation & detection
|
||||
"""
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
|
|
|
|||
|
|
@ -6,13 +6,11 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel,
|
|||
|
||||
|
||||
class YOLO(Model):
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
"""
|
||||
"""YOLO (You Only Look Once) object detection model."""
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Map head to model, trainer, validator, and predictor classes"""
|
||||
"""Map head to model, trainer, validator, and predictor classes."""
|
||||
return {
|
||||
'classify': {
|
||||
'model': ClassificationModel,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class PosePredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'pose'
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Applies non-max suppression and processes detections for each image in an input batch."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue