Improve tests coverage and speed (#4340)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d704507217
commit
9f6d48d3cf
10 changed files with 183 additions and 347 deletions
|
|
@ -1,6 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -8,6 +9,8 @@ import numpy as np
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
|
|
@ -15,8 +18,8 @@ class FastSAMPrompt:
|
|||
# self.img_path = img_path
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
self.img_path = str(img_path)
|
||||
self.ori_img = cv2.imread(self.img_path)
|
||||
|
||||
# Import and assign clip
|
||||
try:
|
||||
|
|
@ -111,7 +114,7 @@ class FastSAMPrompt:
|
|||
original_w = image.shape[1]
|
||||
# for macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
fig = plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
|
|
@ -174,21 +177,11 @@ class FastSAMPrompt:
|
|||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
save_path = output
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
save_path = Path(output) / result_name
|
||||
save_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
plt.axis('off')
|
||||
fig = plt.gcf()
|
||||
plt.draw()
|
||||
|
||||
try:
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
except AttributeError:
|
||||
fig.canvas.draw()
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
cols, rows = fig.canvas.get_width_height()
|
||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
||||
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
||||
fig.savefig(save_path)
|
||||
LOGGER.info(f'Saved to {save_path.absolute()}')
|
||||
|
||||
# CPU post process
|
||||
def fast_show_mask(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue