#!/usr/bin/python # # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Script to dump TensorFlow weights in TRT v1 and v2 dump format. # The V1 format is for TensorRT 4.0. The V2 format is for TensorRT 4.0 and later. import sys import struct import argparse try: import tensorflow as tf from tensorflow.python import pywrap_tensorflow except ImportError as err: sys.stderr.write("""Error: Failed to import module ({})""".format(err)) sys.exit() parser = argparse.ArgumentParser(description="TensorFlow Weight Dumper") parser.add_argument( "-m", "--model", required=True, help="The checkpoint file basename, example basename(model.ckpt-766908.data-00000-of-00001) -> model.ckpt-766908", ) parser.add_argument("-o", "--output", required=True, help="The weight file to dump all the weights to.") parser.add_argument("-1", "--wtsv1", required=False, default=False, type=bool, help="Dump the weights in the wts v1.") opt = parser.parse_args() if opt.wtsv1: print("Outputting the trained weights in TensorRT's wts v1 format. This format is documented as:") print("Line 0: ") print("Line 1-Num: [buffer name] [buffer type] [buffer size] ") else: print("Outputting the trained weights in TensorRT's wts v2 format. This format is documented as:") print("Line 0: ") print("Line 1-Num: [buffer name] [buffer type] [(buffer shape{e.g. (1, 2, 3)}] ") inputbase = opt.model outputbase = opt.output def float_to_hex(f): return hex(struct.unpack("