Code Refactor ruff check --fix --extend-select I (#13672)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-17 11:17:52 +02:00 committed by GitHub
parent c8514a6754
commit 6227d8f8a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 30 additions and 24 deletions

View file

@ -11,6 +11,10 @@ keywords: MLflow, Ultralytics YOLO, logging, metrics, parameters, model artifact
<br><br> <br><br>
## ::: ultralytics.utils.callbacks.mlflow.sanitize_dict
<br><br>
## ::: ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end ## ::: ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end
<br><br> <br><br>

View file

@ -221,8 +221,7 @@ names:
204: cape 204: cape
205: cappuccino/coffee cappuccino 205: cappuccino/coffee cappuccino
206: car/car automobile/auto/auto automobile/automobile 206: car/car automobile/auto/auto automobile/automobile
207: railcar/railcar part of a train/railway car/railway car part of a train/railroad 207: railcar/railcar part of a train/railway car/railway car part of a train/railroad car/railroad car part of a train
car/railroad car part of a train
208: elevator car 208: elevator car
209: car battery/automobile battery 209: car battery/automobile battery
210: identity card 210: identity card
@ -241,8 +240,7 @@ names:
223: cast/plaster cast/plaster bandage 223: cast/plaster cast/plaster bandage
224: cat 224: cat
225: cauliflower 225: cauliflower
226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper 226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper spice
spice
227: CD player 227: CD player
228: celery 228: celery
229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone 229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone
@ -258,8 +256,7 @@ names:
239: chessboard 239: chessboard
240: chicken/chicken animal 240: chicken/chicken animal
241: chickpea/garbanzo 241: chickpea/garbanzo
242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly 242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly vegetable/chile/chile vegetable
vegetable/chile/chile vegetable
243: chime/gong 243: chime/gong
244: chinaware 244: chinaware
245: crisp/crisp potato chip/potato chip 245: crisp/crisp potato chip/potato chip
@ -1061,8 +1058,7 @@ names:
1041: sweater 1041: sweater
1042: sweatshirt 1042: sweatshirt
1043: sweet potato 1043: sweet potato
1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing 1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing trunks
trunks
1045: sword 1045: sword
1046: syringe 1046: syringe
1047: Tabasco sauce 1047: Tabasco sauce

View file

@ -259,7 +259,7 @@ def layout():
with col2: with col2:
similarity_form(selected_imgs) similarity_form(selected_imgs)
display_labels = st.checkbox("Labels", value=False, key="display_labels") st.checkbox("Labels", value=False, key="display_labels")
utralytics_explorer_docs_callback() utralytics_explorer_docs_callback()

View file

@ -293,8 +293,12 @@ class DetectionModel(BaseModel):
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
s = 256 # 2x min stride s = 256 # 2x min stride
m.inplace = self.inplace m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward def _forward(x):
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride self.stride = m.stride
m.bias_init() # only run once m.bias_init() # only run once
else: else:

View file

@ -34,11 +34,13 @@ try:
from pathlib import Path from pathlib import Path
PREFIX = colorstr("MLflow: ") PREFIX = colorstr("MLflow: ")
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
except (ImportError, AssertionError): except (ImportError, AssertionError):
mlflow = None mlflow = None
def sanitize_dict(x):
"""Sanitize dictionary keys by removing parentheses and converting values to floats."""
return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
def on_pretrain_routine_end(trainer): def on_pretrain_routine_end(trainer):
""" """
@ -88,8 +90,8 @@ def on_train_epoch_end(trainer):
if mlflow: if mlflow:
mlflow.log_metrics( mlflow.log_metrics(
metrics={ metrics={
**SANITIZE(trainer.lr), **sanitize_dict(trainer.lr),
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")), **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
}, },
step=trainer.epoch, step=trainer.epoch,
) )
@ -98,7 +100,7 @@ def on_train_epoch_end(trainer):
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Log training metrics at the end of each fit epoch to MLflow.""" """Log training metrics at the end of each fit epoch to MLflow."""
if mlflow: if mlflow:
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch) mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer): def on_train_end(trainer):