Fix Windows non-UTF source filenames (#4524)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-24 03:22:26 +02:00 committed by GitHub
parent a7419617a6
commit 1db9afc2e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 129 additions and 95 deletions

View file

@ -13,20 +13,45 @@ import torch
_imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename, flags=cv2.IMREAD_COLOR):
def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
"""Read an image from a file.
Args:
filename (str): Path to the file to read.
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
Returns:
(np.ndarray): The read image.
"""
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
def imwrite(filename, img):
def imwrite(filename: str, img: np.ndarray, params=None):
"""Write an image to a file.
Args:
filename (str): Path to the file to write.
img (np.ndarray): Image to write.
params (list of ints, optional): Additional parameters. See OpenCV documentation.
Returns:
(bool): True if the file was written, False otherwise.
"""
try:
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
return True
except Exception:
return False
def imshow(path, im):
_imshow(path.encode('unicode_escape').decode(), im)
def imshow(winname: str, mat: np.ndarray):
"""Displays an image in the specified window.
Args:
winname (str): Name of the window.
mat (np.ndarray): Image to be shown.
"""
_imshow(winname.encode('unicode_escape').decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
@ -34,12 +59,17 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this.
Args:
*args (tuple): Positional arguments to pass to torch.save.
**kwargs (dict): Keyword arguments to pass to torch.save.
"""
try:
import dill as pickle
import dill as pickle # noqa
except ImportError:
import pickle
if 'pickle_module' not in kwargs:
kwargs['pickle_module'] = pickle
kwargs['pickle_module'] = pickle # noqa
return _torch_save(*args, **kwargs)