AI | ML/AI 개발 | CUDA

[Python] TFLite to ONNX 변환

깜태 2021. 12. 15. 09:42
728x90

1. pip install tf2onnx

2. python -m tf2onnx.convert --opset 13 --tflite {tflite_file} --output {onnx_file}

3. pip install onnxruntime ( gpu 사용 시, onnxruntime-gpu)

4. onnx 모델 실행

 

결과 예시

 

Onnx 모델 실행 코드 (Python)

import onnxruntime
import numpy as np
import cv2

def load_onnx_model(path, use_gpu):
    ort_session = onnxruntime.InferenceSession(path)
    if use_gpu:
        ort_session.set_providers(['CUDAExecutionProvider'])
    else:
        ort_session.set_providers(['CPUExecutionProvider'])
    return ort_session


def load_img(path):
    img = cv2.imread(path)
    img = cv2.resize(img, (64, 64))
    img = np.transpose(img, axes=(1, 0, 2))
    img = np.expand_dims(img, axis=0).astype(np.float32)/255.0
    return img


if __name__ == "__main__":
    test_img = load_img(path)
    onnx_model = load_onnx_model(onnx_path, use_gpu=True)
    # Output 결과 node 이름 확인이 필요한 경우
    # for output in onnx_model.get_outputs():
    #    print(output.name)
    # 특정 Output만 보고 싶은 경우 예시
    # ort_outs = onnx_model.run(['output_iris'], {'input_1': img})

    ort_outs = onnx_model.run(None, {'input_1': img})
    print(ort_outs)

 

 

 

참고 : https://github.com/onnx/tensorflow-onnx#getting-started

728x90