Code Refactor ruff check --fix --extend-select I (#13672)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c8514a6754
commit
6227d8f8a1
6 changed files with 30 additions and 24 deletions
18
.github/workflows/publish.yml
vendored
18
.github/workflows/publish.yml
vendored
|
|
@ -88,7 +88,7 @@ jobs:
|
||||||
GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
|
GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
|
||||||
CURRENT_TAG = os.getenv('CURRENT_TAG')
|
CURRENT_TAG = os.getenv('CURRENT_TAG')
|
||||||
PREVIOUS_TAG = os.getenv('PREVIOUS_TAG')
|
PREVIOUS_TAG = os.getenv('PREVIOUS_TAG')
|
||||||
|
|
||||||
# Check for required environment variables
|
# Check for required environment variables
|
||||||
if not all([OPENAI_AZURE_API_KEY, OPENAI_AZURE_ENDPOINT, OPENAI_AZURE_API_VERSION, GITHUB_TOKEN, CURRENT_TAG, PREVIOUS_TAG]):
|
if not all([OPENAI_AZURE_API_KEY, OPENAI_AZURE_ENDPOINT, OPENAI_AZURE_API_VERSION, GITHUB_TOKEN, CURRENT_TAG, PREVIOUS_TAG]):
|
||||||
print(OPENAI_AZURE_API_KEY)
|
print(OPENAI_AZURE_API_KEY)
|
||||||
|
|
@ -98,24 +98,24 @@ jobs:
|
||||||
print(CURRENT_TAG)
|
print(CURRENT_TAG)
|
||||||
print(PREVIOUS_TAG)
|
print(PREVIOUS_TAG)
|
||||||
raise ValueError("One or more required environment variables are missing.")
|
raise ValueError("One or more required environment variables are missing.")
|
||||||
|
|
||||||
latest_tag = f"v{CURRENT_TAG}"
|
latest_tag = f"v{CURRENT_TAG}"
|
||||||
previous_tag = f"v{PREVIOUS_TAG}"
|
previous_tag = f"v{PREVIOUS_TAG}"
|
||||||
repo = 'ultralytics/ultralytics'
|
repo = 'ultralytics/ultralytics'
|
||||||
headers = {"Authorization": f"token {GITHUB_TOKEN}", "Accept": "application/vnd.github.v3.diff"}
|
headers = {"Authorization": f"token {GITHUB_TOKEN}", "Accept": "application/vnd.github.v3.diff"}
|
||||||
|
|
||||||
# Get the diff between the tags
|
# Get the diff between the tags
|
||||||
url = f"https://api.github.com/repos/{repo}/compare/{previous_tag}...{latest_tag}"
|
url = f"https://api.github.com/repos/{repo}/compare/{previous_tag}...{latest_tag}"
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
diff = response.text if response.status_code == 200 else f"Failed to get diff: {response.content}"
|
diff = response.text if response.status_code == 200 else f"Failed to get diff: {response.content}"
|
||||||
|
|
||||||
# Set up OpenAI client
|
# Set up OpenAI client
|
||||||
client = openai.AzureOpenAI(
|
client = openai.AzureOpenAI(
|
||||||
api_key=OPENAI_AZURE_API_KEY,
|
api_key=OPENAI_AZURE_API_KEY,
|
||||||
api_version=OPENAI_AZURE_API_VERSION,
|
api_version=OPENAI_AZURE_API_VERSION,
|
||||||
azure_endpoint=OPENAI_AZURE_ENDPOINT
|
azure_endpoint=OPENAI_AZURE_ENDPOINT
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare messages for OpenAI completion
|
# Prepare messages for OpenAI completion
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
|
@ -131,17 +131,17 @@ jobs:
|
||||||
f"\n\nHere's the release diff:\n\n{diff[:96000]}",
|
f"\n\nHere's the release diff:\n\n{diff[:96000]}",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion = client.chat.completions.create(model="gpt-4o-2024-05-13", messages=messages)
|
completion = client.chat.completions.create(model="gpt-4o-2024-05-13", messages=messages)
|
||||||
summary = completion.choices[0].message.content.strip()
|
summary = completion.choices[0].message.content.strip()
|
||||||
except openai.error.OpenAIError as e:
|
except openai.error.OpenAIError as e:
|
||||||
print(f"Failed to generate summary: {e}")
|
print(f"Failed to generate summary: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Get the latest commit message
|
# Get the latest commit message
|
||||||
commit_message = subprocess.run(['git', 'log', '-1', '--pretty=%B'], check=True, text=True, capture_output=True).stdout.split("\n")[0].strip()
|
commit_message = subprocess.run(['git', 'log', '-1', '--pretty=%B'], check=True, text=True, capture_output=True).stdout.split("\n")[0].strip()
|
||||||
|
|
||||||
# Prepare release data
|
# Prepare release data
|
||||||
release = {
|
release = {
|
||||||
'tag_name': latest_tag,
|
'tag_name': latest_tag,
|
||||||
|
|
@ -150,7 +150,7 @@ jobs:
|
||||||
'draft': False,
|
'draft': False,
|
||||||
'prerelease': False
|
'prerelease': False
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the release on GitHub
|
# Create the release on GitHub
|
||||||
release_url = f"https://api.github.com/repos/{repo}/releases"
|
release_url = f"https://api.github.com/repos/{repo}/releases"
|
||||||
release_response = requests.post(release_url, headers=headers, data=json.dumps(release))
|
release_response = requests.post(release_url, headers=headers, data=json.dumps(release))
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,10 @@ keywords: MLflow, Ultralytics YOLO, logging, metrics, parameters, model artifact
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.utils.callbacks.mlflow.sanitize_dict
|
||||||
|
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## ::: ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end
|
## ::: ultralytics.utils.callbacks.mlflow.on_pretrain_routine_end
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -221,8 +221,7 @@ names:
|
||||||
204: cape
|
204: cape
|
||||||
205: cappuccino/coffee cappuccino
|
205: cappuccino/coffee cappuccino
|
||||||
206: car/car automobile/auto/auto automobile/automobile
|
206: car/car automobile/auto/auto automobile/automobile
|
||||||
207: railcar/railcar part of a train/railway car/railway car part of a train/railroad
|
207: railcar/railcar part of a train/railway car/railway car part of a train/railroad car/railroad car part of a train
|
||||||
car/railroad car part of a train
|
|
||||||
208: elevator car
|
208: elevator car
|
||||||
209: car battery/automobile battery
|
209: car battery/automobile battery
|
||||||
210: identity card
|
210: identity card
|
||||||
|
|
@ -241,8 +240,7 @@ names:
|
||||||
223: cast/plaster cast/plaster bandage
|
223: cast/plaster cast/plaster bandage
|
||||||
224: cat
|
224: cat
|
||||||
225: cauliflower
|
225: cauliflower
|
||||||
226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper
|
226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper spice
|
||||||
spice
|
|
||||||
227: CD player
|
227: CD player
|
||||||
228: celery
|
228: celery
|
||||||
229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone
|
229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone
|
||||||
|
|
@ -258,8 +256,7 @@ names:
|
||||||
239: chessboard
|
239: chessboard
|
||||||
240: chicken/chicken animal
|
240: chicken/chicken animal
|
||||||
241: chickpea/garbanzo
|
241: chickpea/garbanzo
|
||||||
242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly
|
242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly vegetable/chile/chile vegetable
|
||||||
vegetable/chile/chile vegetable
|
|
||||||
243: chime/gong
|
243: chime/gong
|
||||||
244: chinaware
|
244: chinaware
|
||||||
245: crisp/crisp potato chip/potato chip
|
245: crisp/crisp potato chip/potato chip
|
||||||
|
|
@ -1061,8 +1058,7 @@ names:
|
||||||
1041: sweater
|
1041: sweater
|
||||||
1042: sweatshirt
|
1042: sweatshirt
|
||||||
1043: sweet potato
|
1043: sweet potato
|
||||||
1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing
|
1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing trunks
|
||||||
trunks
|
|
||||||
1045: sword
|
1045: sword
|
||||||
1046: syringe
|
1046: syringe
|
||||||
1047: Tabasco sauce
|
1047: Tabasco sauce
|
||||||
|
|
|
||||||
|
|
@ -259,7 +259,7 @@ def layout():
|
||||||
|
|
||||||
with col2:
|
with col2:
|
||||||
similarity_form(selected_imgs)
|
similarity_form(selected_imgs)
|
||||||
display_labels = st.checkbox("Labels", value=False, key="display_labels")
|
st.checkbox("Labels", value=False, key="display_labels")
|
||||||
utralytics_explorer_docs_callback()
|
utralytics_explorer_docs_callback()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -293,8 +293,12 @@ class DetectionModel(BaseModel):
|
||||||
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
||||||
s = 256 # 2x min stride
|
s = 256 # 2x min stride
|
||||||
m.inplace = self.inplace
|
m.inplace = self.inplace
|
||||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
|
||||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
def _forward(x):
|
||||||
|
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
|
||||||
|
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
||||||
|
|
||||||
|
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
||||||
self.stride = m.stride
|
self.stride = m.stride
|
||||||
m.bias_init() # only run once
|
m.bias_init() # only run once
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,13 @@ try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
PREFIX = colorstr("MLflow: ")
|
PREFIX = colorstr("MLflow: ")
|
||||||
SANITIZE = lambda x: {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
|
||||||
|
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
mlflow = None
|
mlflow = None
|
||||||
|
|
||||||
|
def sanitize_dict(x):
|
||||||
|
"""Sanitize dictionary keys by removing parentheses and converting values to floats."""
|
||||||
|
return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
||||||
|
|
||||||
def on_pretrain_routine_end(trainer):
|
def on_pretrain_routine_end(trainer):
|
||||||
"""
|
"""
|
||||||
|
|
@ -88,8 +90,8 @@ def on_train_epoch_end(trainer):
|
||||||
if mlflow:
|
if mlflow:
|
||||||
mlflow.log_metrics(
|
mlflow.log_metrics(
|
||||||
metrics={
|
metrics={
|
||||||
**SANITIZE(trainer.lr),
|
**sanitize_dict(trainer.lr),
|
||||||
**SANITIZE(trainer.label_loss_items(trainer.tloss, prefix="train")),
|
**sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
|
||||||
},
|
},
|
||||||
step=trainer.epoch,
|
step=trainer.epoch,
|
||||||
)
|
)
|
||||||
|
|
@ -98,7 +100,7 @@ def on_train_epoch_end(trainer):
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
"""Log training metrics at the end of each fit epoch to MLflow."""
|
"""Log training metrics at the end of each fit epoch to MLflow."""
|
||||||
if mlflow:
|
if mlflow:
|
||||||
mlflow.log_metrics(metrics=SANITIZE(trainer.metrics), step=trainer.epoch)
|
mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
|
||||||
|
|
||||||
|
|
||||||
def on_train_end(trainer):
|
def on_train_end(trainer):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue