Improve tests coverage and speed (#4340)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-13 22:24:01 +02:00 committed by GitHub
parent d704507217
commit 9f6d48d3cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 183 additions and 347 deletions

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
@ -8,6 +9,8 @@ import numpy as np
import torch
from PIL import Image
from ultralytics.utils import LOGGER
class FastSAMPrompt:
@ -15,8 +18,8 @@ class FastSAMPrompt:
# self.img_path = img_path
self.device = device
self.results = results
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
self.img_path = str(img_path)
self.ori_img = cv2.imread(self.img_path)
# Import and assign clip
try:
@ -111,7 +114,7 @@ class FastSAMPrompt:
original_w = image.shape[1]
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
fig = plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
@ -174,21 +177,11 @@ class FastSAMPrompt:
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = output
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
plt.axis('off')
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
fig.savefig(save_path)
LOGGER.info(f'Saved to {save_path.absolute()}')
# CPU post process
def fast_show_mask(