Fix Windows non-UTF source filenames (#4524)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-24 03:22:26 +02:00 committed by GitHub
parent a7419617a6
commit 1db9afc2e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 129 additions and 95 deletions

View file

@ -240,7 +240,7 @@ class ProfileModels:
if path.is_dir():
extensions = ['*.pt', '*.onnx', '*.yaml']
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in ('.pt', '.yaml', '.yml'): # add non-existing
elif path.suffix in {'.pt', '.yaml', '.yml'}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
@ -262,7 +262,7 @@ class ProfileModels:
data = clipped_data
return data
def profile_tensorrt_model(self, engine_file: str):
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-7):
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
@ -279,7 +279,7 @@ class ProfileModels:
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs * 50)
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
# Timed runs
run_times = []
@ -290,7 +290,7 @@ class ProfileModels:
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str):
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-7):
check_requirements('onnxruntime')
import onnxruntime as ort
@ -330,7 +330,7 @@ class ProfileModels:
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs)
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
# Timed runs
run_times = []