Simplify Results() class (#4579)
This commit is contained in:
parent
e9f596430f
commit
2db35afad5
4 changed files with 30 additions and 39 deletions
|
|
@ -217,7 +217,7 @@ def test_all_model_yamls():
|
|||
for m in (ROOT / 'cfg' / 'models').rglob('*.yaml'):
|
||||
if 'rtdetr' in m.name:
|
||||
if TORCH_1_9: # torch<=1.8 issue - TypeError: __init__() got an unexpected keyword argument 'batch_first'
|
||||
RTDETR(m.name)(SOURCE, imgsz=640)
|
||||
RTDETR(m.name)(SOURCE, imgsz=640) # must be 640
|
||||
else:
|
||||
YOLO(m.name)
|
||||
|
||||
|
|
@ -225,8 +225,8 @@ def test_all_model_yamls():
|
|||
def test_workflow():
|
||||
model = YOLO(MODEL)
|
||||
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||
model.val()
|
||||
model.predict(SOURCE)
|
||||
model.val(imgsz=32)
|
||||
model.predict(SOURCE, imgsz=32)
|
||||
model.export(format='onnx') # export a model to ONNX format
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def test_predict_callback_and_setup():
|
|||
|
||||
dataset = load_inference_source(source=SOURCE)
|
||||
bs = dataset.bs # noqa access predictor properties
|
||||
results = model.predict(dataset, stream=True) # source already setup
|
||||
results = model.predict(dataset, stream=True, imgsz=160) # source already setup
|
||||
for r, im0, bs in results:
|
||||
print('test_callback', im0.shape)
|
||||
print('test_callback', bs)
|
||||
|
|
@ -254,7 +254,7 @@ def test_predict_callback_and_setup():
|
|||
def test_results():
|
||||
for m in 'yolov8n-pose.pt', 'yolov8n-seg.pt', 'yolov8n.pt', 'yolov8n-cls.pt':
|
||||
model = YOLO(m)
|
||||
results = model([SOURCE, SOURCE])
|
||||
results = model([SOURCE, SOURCE], imgsz=160)
|
||||
for r in results:
|
||||
r = r.cpu().numpy()
|
||||
r = r.to(device='cpu', dtype=torch.float32)
|
||||
|
|
@ -263,10 +263,7 @@ def test_results():
|
|||
r.tojson(normalize=True)
|
||||
r.plot(pil=True)
|
||||
r.plot(conf=True, boxes=True)
|
||||
print(r)
|
||||
print(r.path)
|
||||
for k in r.keys:
|
||||
print(getattr(r, k))
|
||||
print(r, len(r), r.path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue