Minor Results.to_sql cleanup (#19081)

Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2025-02-07 11:19:58 +08:00 committed by GitHub
parent c526a652ab
commit 116813d3c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 20 deletions

View file

@ -967,7 +967,7 @@ class Results(SimpleClass):
# Convert results to a list of dictionaries
data = self.summary(normalize=normalize, decimals=decimals)
if not data:
if len(data) == 0:
LOGGER.warning("⚠️ No results to save to SQL. Results dict is empty")
return
@ -977,31 +977,20 @@ class Results(SimpleClass):
# Create table if it doesn't exist
columns = (
"id INTEGER PRIMARY KEY AUTOINCREMENT, class_name TEXT, confidence REAL, "
"box TEXT, masks TEXT, kpts TEXT, obb TEXT"
"id INTEGER PRIMARY KEY AUTOINCREMENT, class_name TEXT, confidence REAL, box TEXT, masks TEXT, kpts TEXT"
)
cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})")
# Insert data into the table
for i, item in enumerate(data):
detect, obb = None, None # necessary to reinit these variables inside for loop to avoid duplication
class_name = item.get("name")
box = item.get("box", {})
# Serialize the box as JSON for 'detect' and 'obb' based on key presence
if all(key in box for key in ["x1", "y1", "x2", "y2"]) and not any(key in box for key in ["x3", "x4"]):
detect = json.dumps(box)
if all(key in box for key in ["x1", "y1", "x2", "y2", "x3", "x4"]):
obb = json.dumps(box)
for item in data:
cursor.execute(
f"INSERT INTO {table_name} (class_name, confidence, box, masks, kpts, obb) VALUES (?, ?, ?, ?, ?, ?)",
f"INSERT INTO {table_name} (class_name, confidence, box, masks, kpts) VALUES (?, ?, ?, ?, ?)",
(
class_name,
item.get("name"),
item.get("confidence"),
detect,
json.dumps(item.get("segments", {}).get("x", [])),
json.dumps(item.get("keypoints", {}).get("x", [])),
obb,
json.dumps(item.get("box", {})),
json.dumps(item.get("segments", {})),
json.dumps(item.get("keypoints", {})),
),
)

View file

@ -131,7 +131,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
floor (int): Minimum allowed value for image size.
Returns:
(List[int]): Updated image size.
(List[int] | int): Updated image size.
"""
# Convert stride to integer if it is a tensor
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)