Code Refactor ruff check --fix --extend-select I (#13672)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-17 11:17:52 +02:00 committed by GitHub
parent c8514a6754
commit 6227d8f8a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 30 additions and 24 deletions

View file

@ -293,8 +293,12 @@ class DetectionModel(BaseModel):
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
s = 256 # 2x min stride
m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
def _forward(x):
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.bias_init() # only run once
else: