Code Refactor ruff check --fix --extend-select I (#13672)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c8514a6754
commit
6227d8f8a1
6 changed files with 30 additions and 24 deletions
|
|
@ -293,8 +293,12 @@ class DetectionModel(BaseModel):
|
|||
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
|
||||
def _forward(x):
|
||||
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
|
||||
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
||||
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue