Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -127,11 +127,11 @@ class AutoBackend(nn.Module):
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
if network.get_parameters()[0].get_layout().empty:
network.get_parameters()[0].set_layout(Layout("NCHW"))
network.get_parameters()[0].set_layout(Layout('NCHW'))
batch_dim = get_batch(network)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
import tensorflow as tf
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
@ -198,7 +198,7 @@ class AutoBackend(nn.Module):
gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
@ -220,9 +220,9 @@ class AutoBackend(nn.Module):
output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
with zipfile.ZipFile(w, 'r') as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
@ -251,8 +251,8 @@ class AutoBackend(nn.Module):
else:
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
raise TypeError(f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/tasks/detection/#export for help."
f"\n\n{EXPORT_FORMATS_TABLE}")
'See https://docs.ultralytics.com/tasks/detection/#export for help.'
f'\n\n{EXPORT_FORMATS_TABLE}')
# Load external metadata YAML
if xml or saved_model or paddle:
@ -410,5 +410,5 @@ class AutoBackend(nn.Module):
url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf]
types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
return types + [triton]