Fix Windows non-UTF source filenames (#4524)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
a7419617a6
commit
1db9afc2e5
16 changed files with 129 additions and 95 deletions
|
|
@ -852,9 +852,10 @@ ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter'
|
|||
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||
set_sentry()
|
||||
|
||||
# Apply monkey patches if the script is being run from within the parent directory of the script's location
|
||||
from .patches import imread, imshow, imwrite
|
||||
# Apply monkey patches
|
||||
from .patches import imread, imshow, imwrite, torch_save
|
||||
|
||||
# torch.save = torch_save
|
||||
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
|
||||
torch.save = torch_save
|
||||
if WINDOWS:
|
||||
# Apply cv2 patches for non-ASCII and non-UTF characters in image paths
|
||||
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class ProfileModels:
|
|||
if path.is_dir():
|
||||
extensions = ['*.pt', '*.onnx', '*.yaml']
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in ('.pt', '.yaml', '.yml'): # add non-existing
|
||||
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
|
@ -262,7 +262,7 @@ class ProfileModels:
|
|||
data = clipped_data
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str):
|
||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-7):
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
|
|
@ -279,7 +279,7 @@ class ProfileModels:
|
|||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs * 50)
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
|
|
@ -290,7 +290,7 @@ class ProfileModels:
|
|||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str):
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-7):
|
||||
check_requirements('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
@ -330,7 +330,7 @@ class ProfileModels:
|
|||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs)
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
|
|
|
|||
|
|
@ -101,7 +101,11 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
|
|||
zip_file = directory.with_suffix('.zip')
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, 'w', compression) as f:
|
||||
for file in tqdm(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
|
||||
for file in tqdm(files_to_zip,
|
||||
desc=f'Zipping {directory} to {zip_file}...',
|
||||
unit='file',
|
||||
disable=not progress,
|
||||
bar_format=TQDM_BAR_FORMAT):
|
||||
f.write(file, file.relative_to(directory))
|
||||
|
||||
return zip_file # return path to zip file
|
||||
|
|
@ -159,7 +163,11 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
|
|||
LOGGER.info(f'Skipping {file} unzip (already unzipped)')
|
||||
return path
|
||||
|
||||
for f in tqdm(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
|
||||
for f in tqdm(files,
|
||||
desc=f'Unzipping {file} to {Path(path).resolve()}...',
|
||||
unit='file',
|
||||
disable=not progress,
|
||||
bar_format=TQDM_BAR_FORMAT):
|
||||
zipObj.extract(f, path=extract_path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
|
|
|||
|
|
@ -13,20 +13,45 @@ import torch
|
|||
_imshow = cv2.imshow # copy to avoid recursion errors
|
||||
|
||||
|
||||
def imread(filename, flags=cv2.IMREAD_COLOR):
|
||||
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
|
||||
"""Read an image from a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to read.
|
||||
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The read image.
|
||||
"""
|
||||
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
|
||||
|
||||
|
||||
def imwrite(filename, img):
|
||||
def imwrite(filename: str, img: np.ndarray, params=None):
|
||||
"""Write an image to a file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to write.
|
||||
img (np.ndarray): Image to write.
|
||||
params (list of ints, optional): Additional parameters. See OpenCV documentation.
|
||||
|
||||
Returns:
|
||||
(bool): True if the file was written, False otherwise.
|
||||
"""
|
||||
try:
|
||||
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
|
||||
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def imshow(path, im):
|
||||
_imshow(path.encode('unicode_escape').decode(), im)
|
||||
def imshow(winname: str, mat: np.ndarray):
|
||||
"""Displays an image in the specified window.
|
||||
|
||||
Args:
|
||||
winname (str): Name of the window.
|
||||
mat (np.ndarray): Image to be shown.
|
||||
"""
|
||||
_imshow(winname.encode('unicode_escape').decode(), mat)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
|
|
@ -34,12 +59,17 @@ _torch_save = torch.save # copy to avoid recursion errors
|
|||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
|
||||
|
||||
Args:
|
||||
*args (tuple): Positional arguments to pass to torch.save.
|
||||
**kwargs (dict): Keyword arguments to pass to torch.save.
|
||||
"""
|
||||
try:
|
||||
import dill as pickle
|
||||
import dill as pickle # noqa
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
if 'pickle_module' not in kwargs:
|
||||
kwargs['pickle_module'] = pickle
|
||||
kwargs['pickle_module'] = pickle # noqa
|
||||
return _torch_save(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue