ultralytics 8.1.30 add advanced HUB train arguments (#9110)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
a62cdab53a
commit
8617fcf32d
3 changed files with 18 additions and 19 deletions
|
|
@ -170,10 +170,19 @@ class HUBTrainingSession:
|
|||
|
||||
return api_key, model_id, filename
|
||||
|
||||
def _set_train_args(self, **kwargs):
|
||||
"""Initializes training arguments and creates a model entry on the Ultralytics HUB."""
|
||||
def _set_train_args(self):
|
||||
"""
|
||||
Initializes training arguments and creates a model entry on the Ultralytics HUB.
|
||||
|
||||
This method sets up training arguments based on the model's state and updates them with any additional
|
||||
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
|
||||
or requires specific file setup.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
||||
issues with the provided training arguments.
|
||||
"""
|
||||
if self.model.is_trained():
|
||||
# Model is already trained
|
||||
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
||||
|
||||
if self.model.is_resumable():
|
||||
|
|
@ -182,26 +191,16 @@ class HUBTrainingSession:
|
|||
self.model_file = self.model.get_weights_url("last")
|
||||
else:
|
||||
# Model has no saved weights
|
||||
def get_train_args(config):
|
||||
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
|
||||
return {
|
||||
"batch": config["batchSize"],
|
||||
"epochs": config["epochs"],
|
||||
"imgsz": config["imageSize"],
|
||||
"patience": config["patience"],
|
||||
"device": config["device"],
|
||||
"cache": config["cache"],
|
||||
"data": self.model.get_dataset_url(),
|
||||
}
|
||||
self.train_args = self.model.data.get("train_args") # new response
|
||||
|
||||
self.train_args = get_train_args(self.model.data.get("config"))
|
||||
# Set the model file as either a *.pt or *.yaml file
|
||||
self.model_file = (
|
||||
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
||||
)
|
||||
|
||||
if not self.train_args.get("data"):
|
||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
|
||||
if "data" not in self.train_args:
|
||||
# RF bug - datasets are sometimes not exported
|
||||
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
|
||||
|
||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||
self.model_id = self.model.id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue