PyTorch를 ONNX로 export하기

2020. 2. 8. 18:24Develop

ONNX란?

ONNX(Open Neural Network Exchange)는 그 이름에서 살펴 볼 수 있듯이, Tensorflow, PyTorch와 같은 서로 다른 DNN 프레임워크 환경에서 만들어진 모델들을 서로 호환되게 사용할 수 있도록 만들어진 공유 플랫폼이다. Tensorflow에서 어떤 모델을 만들고 이를 ONNX 그래프로 export를 하면, 이후에 PyTorch와 같은 다른 프레임워크에서도 그 모델을 import 하여 사용할 수 있다. ONNX의 장점을 정리하자면 다음 두 가지를 뽑을 수 있다.

 

1. Framework Interoperability

위에서 언급했다시피 특정 환경에서 생성된 모델을 다른 환경으로 import하여 자유롭게 사용을 할 수 있다는 것은 ONNX의 최대 강점이다. 예컨대, Tensorflow에서 빠르게 모델을 학습 시킨 뒤에 이를 모바일로 옮겨서 사용을 하는 등 여러가지 방식으로 활용 가능하다.

 

2. Shared Optimization

HW vendor(가속기와 같은 HW 제조업체)의 입장에서 ONNX와 같은 프레임워크 간 공유되는 포맷이 존재하면, 하드웨어 설계시 ONNX representation을 기준으로 최적화를 하면 되기 때문에 효율적이다. (하지만 현재 상황을 보았을 때 ONNX가 얼마나 범용적으로 활용될 수 있을지는 잘 모르겠다...)

 

아주 쉽게 생각하자면, 마치 JSON 포맷이 정보 표현을 위해서 여러 개발자들 사이에서 합의되어 사용되듯, ONNX라는 합의된 DNN 모델 포맷이 존재한다고 생각하면 된다. ONNX 사용과 관련하여 보다 자세한 튜토리얼이 필요하다면 다음 페이지를 참고하면 된다. (링크: https://github.com/onnx/tutorials)

 

 

onnx/tutorials

Tutorials for creating and using ONNX models. Contribute to onnx/tutorials development by creating an account on GitHub.

github.com

 

 

PyTorch 모델을 ONNX로 export 하기

공식 문서 보다 좋은 레퍼런스는 없다. 다소 길 수 있지만 자세하고 정확한 정보는 공식 문서를 읽는 습관을 들이자.

 

 

위의 그림은 PyTorch 모델을 ONNX 그래프로 export하는 전체 과정을 도식화한 것이다.

 

  1. PyTorch 모델과 example input을 인자로 하여 torch.onnx.export 함수를 호출하면, PyTorch의 JIT 컴파일러인 TorchScript를 통해서 trace 혹은 script를 생성한다. (Trace와 Script는 그 생성 방식과 representation에 차이가 있는데 밑에서 좀 더 설명을 하도록 하겠다.) Trace나 script는 PyTorch의 nn.Module을 상속하는 모델의 forward 함수에서 실행되는 코드들에 대한 IR (Intermediate Representation)을 담고 있다. 쉽게 설명하면 forward propagation 시에 호출되는 함수 및 연산들에 대한 최적화된 그래프가 만들어진다.
  2. 생성된 trace / script (PyTorch IR)는 ONNX exporter를 통해서 ONNX IR로 변환되고 여기에서 한 번더 graph optimization이 이루어진다.
  3. 최종적으로 생성된 ONNX 그래프는 .onnx 포맷으로 저장된다.

 

TorchScript란?
PyTorch의 just-in-time 컴파일러인 TorchScript는 기존의 imperative한 PyTorch의 실행 방식 대신, 모델이나 함수의 소스코드를 TorchScript compiler를 통해 TorchScript 코드로 컴파일하는 기능을 제공한다. 이를 통해 Tensorflow의 symblolic graph execution 방식과 같이 여러 optimization을 적용할 수 있고, serialized 된 모델을 Python dependency가 없는 다른 환경에서도 활용할 수 있는 이점이 있다.

(TorchScript와 관련한 포스팅은 이후에 시간이 된다면...)

 

 

Tracing vs Scripting

ONNX exporter는 trace-base와 script-base로 나뉘는데 이 둘 모두 IR graph를 만든다는 점에서 동일하지만 정보를 기록하는 방식에 차이가 있다.

 

1. Trace

Pytorch에서 model(input) 과 같이 nn.Module을 상속하는 모델에게 input을 넘겨주면 해당 모델 클래스의 forward 함수가 실행되면서 forward propagation 한 번이 수행된다. 이 때 forward 함수 내부에서는 conv2d, batch_norm, dropout과 같은 여러 Tensor 연산들을 호출할 수도 있고, 유저가 정의한 함수 혹은 기본 python 코드들이 실행될 수도 있다. 이렇게 forward를 한 번 수행하는 동안 execution path에 존재했던 모든 연산들은 IR로 기록이 된다.

 

여기에서 문제가 될 수 있는 부분은 forward 함수 내부에 dynamic control flow가 존재한다면, trace를 생성하기 위해 한 번 forward가 호출되었을 때 거쳐가지 않은 control path는 추적이 되지 않는다는 점이다. 만약 model에 if문이나 loop와 같은 control flow가 존재한다면 loop unrolling을 통해서 branch가 static하게 고정된다.

 

또한 trace 시엔 forward 함수 호출을 위한 example input이 필요한데, 이 input의 shape에 따라서 graph가 정적으로 고정이 되어 버리기 때문에 이후 tracing된 모델을 활용해서 training, inferencing을 한다면 example input과 정확히 동일한 형태의 input을 넣어야만 한다.

 

2. Script

앞에서 설명한  Trace의 단점은 Python의 dynamic feature를 살리지 못하는 것이라 하였다. Script를 사용하면 이 문제를 해결할 수 있다! 마치 C, C++, Java와 같은 언어들에서 전체 code를 컴파일하여 사용을 하듯, scripting 시에는 forward propagation 시에 실행될 전체 코드에 대해서 컴파일을 하고 TorchScript code인 ScriptModule 인스턴스를 생성한다. 전체 코드를 보고 컴파일을 하기 때문에 당연히 dynamic control flow를 살릴 수 있고, Trace의 경우처럼 example input이 필요하지도 않다.

 

 

Trace와 Script의 장단점을 정리하면 다음과 같다.

 

  Tracing Scripting
장점 Script와 비교했을 때 type 추정으로 인한 문제, Python primitive와의 호환성 문제가 적다. C, Java와 같이 전체 코드를 보고 컴파일을 하기 때문에 dynamic control flow를 살릴 수 있다.
단점 forward를 한 번 수행하는 동안 거쳐간 execution path에 대해서 그래프가 statically fix 되기 때문에 dynamic control flow를 살리지 못한다.  지원하지 않는 파이썬 코드들이 상당히 많고, type 추정, 중간에 attribute이 변하는 경우 등에 문제가 생긴다.

 

 

Usage Example

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=True, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True)

 

torch.onnx.export 함수는 PyTorch 모델을 .onnx 포맷으로 export 해주는 핵심 함수이다. 다음 예시에서는 torchvision의 AlexNet을 import 한 뒤, 이를 .onnx 포맷으로 export하고 있다. torch.onnx.export에서는 기본적으로 scripting이 아닌 tracing을 사용하기 때문에 example input을 넣어줘야한다.

 

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
dummy_output = model(dummy_input)

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, example_outputs=dummy_output)

 

만약에 자신이 export 하고 싶은 모델이 dynamic control flow를 포함하고 있고, 이를 살리고 싶다면 해당 모듈의 해당 함수에 @torch.jit.script 데코레이터(decorator)를 추가해주면 된다. 그러면 해당 함수에 대해서는 scripting을 하고, 나머지 부분에 대해서는 tracing을 하여 IR을 생성한다. (이와 같이 trace와 script를 섞어서 사용하는 방식은 TorchScript에서도 동일하게 적용 가능하다.)

 

# tracing과 scripting을 함께 사용하기

# loop 함수에 대해서는 for문이 unrolling 되지 않고 dynamic feature가 보존된다.
@torch.jit.script
def loop(x, y):
    for i in range(int(y)):
        x = x + i
    return x

class LoopModel2(torch.nn.Module):
    def forward(self, x, y):
        return loop(x, y)

model = LoopModel2()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True,
                  input_names=['input_data', 'loop_range'])

 

위의 코드를 통해서 생성된 ONNX graph는 다음과 같다.

 

graph(%input_data : Long(2, 3),
      %loop_range : Long()):
  %2 : Long() = onnx::Constant[value={1}](), scope: LoopModel2/loop
  %3 : Tensor = onnx::Cast[to=9](%2)
  %4 : Long(2, 3) = onnx::Loop(%loop_range, %3, %input_data), scope: LoopModel2/loop # custom_loop.py:240:5
    block0(%i.1 : Long(), %cond : bool, %x.6 : Long(2, 3)):
      %8 : Long(2, 3) = onnx::Add(%x.6, %i.1), scope: LoopModel2/loop # custom_loop.py:241:13
      %9 : Tensor = onnx::Cast[to=9](%2)
      -> (%9, %8)
  return (%4)

 

만약 python 에서 export 된 onnx 모델을 import 하여 사용하고 싶다면 onnx.load 함수를 이용하면 된다.

 

import onnx

# ONNX model을 load
model = onnx.load("alexnet.onnx")

# IR이 제대로 구성되었는지 체크
onnx.checker.check_model(model)

# 위와 같이 읽어서 알아볼 수 있는 형태의 그래프 포멧으로 출력 
onnx.helper.printable_graph(model.graph)

 

 

Limitations

  • PyTorch의 JIT compiler가 아직 완벽하지 않다보니 파이썬으로 구현한 모델에 대해서 완벽하게 support 하지 못한다. 현재까지는 tuple, list, Variable 만이 input / output으로 지원되는 상황이며, dictionary나 string은 일부만 지원되는 상황이다. (dynamic loop up 불가능)
  • 또한 PyTorch와 ONNX의 backend 구현에 차이가 있다보니 모델 구조에 따라서 학습 성능에 문제가 있을 수 있다.