2020. 3. 5. 19:47ㆍResearch
PyTorch IR
Reference Link: https://github.com/pytorch/pytorch/wiki/PyTorch-IR
PyTorch는 SSA 기반의 IR을 사용한다.
In compiler design, static single assignment form (often abbreviated as SSA form or simply SSA) is a property of an intermediate representation (IR), which requires that each variable is assigned exactly once, and every variable is defined before it is used.
- Graph는 일반적으로 프로그램 representation에 있어서 프로그램 전체를 감싸는 가장 바깥의 컨테이너에 해당한다.
- Block은 함수에 해당하는 컴포넌트이다. input들과 output들에 대한 리스트, 그리고 Node들에 대한 topological order list를 포함한다. 모든 Graph는 마치 main 함수와 같이 가장 top-level의 Block을 가지고 있다.
- Node는 function call을 표현하는 컴포넌트이다. (여러 인자와 리턴 값을 포함한다.)
- 프로그램의 각 중간 값들(intermediate values)은 Value로써 표현된다.
- root Block들은 input 값들에 대한 리스트를 가지고 있고, 모든 Node는 그것들을 입력으로 받아서 결과를 반환한다. 각 Value는 그에 해당하는 Type을 가지고 있다.
Ownership model
위에서 설명한 구조에서 각 구성 요소들은 보통 pointer를 통해 서로에게 전달된다. 모든 Node들과 Value들은 Graph에 의해서 소유 되며(owned), 서로 다른 Graph 간에는 공유될 수 없다. Type은 Value에 의해서만 참조될 수 있고, 그것들은 shared_ptr로 wrapping 되어있다.
각 Block은 doubly-linked list인 node_list_를 통해서 순서가 표현된다. 이 순서는 실제로 JIT interpreter를 통해서 bytecode로 컴파일 되었을 때 각 operation들이 실행되는 순서에 해당한다. 프로그래머는 반드시 자신이 만든 Node가 node_list_의 어딘가에 나타나도록 하여야하고, 그 리스트는 유효한 topological order를 가져야 한다.
Examples
def f(a, b):
c = a + b
d = c * c
e = torch.tanh(d * c)
return d + (e + e)
위와 같은 파이썬 프로그램이 있을 때, 만약 이를 IR로 변환을 한다면 다음과 같은 Graph 형태로 표현된다.
실제로 Python 프로그램을 IR로 변환하는 작업이 어떻게 이루어지는지는 뒤에서 설명한다.
graph(%0 : Double(2)
%1 : Double(2)) {
%2 : int = prim::Constant[value=1]() # default 값
%3 : Double(2) = aten::add(%0, %1, %2)
%4 : Double(2) = aten::mul(%3, %3)
%5 : Double(2) = aten::mul(%4, %3)
%6 : Double(2) = aten::tanh(%5)
%7 : Double(2) = aten::add(%6, %6, %2)
%8 : Double(2) = aten::add(%5, %7, %2)
return (%8);
}
aten::add가 표현하는 정확한 계산식은 x+ay이기 때문에 3개의 인자를 받는다. 그런데 만약 a의 값을 입력하지 않았다면 default 값으로 1이 들어간다. %2 : int = prim::Constant[value=1]()은 그 default 값을 나타낸다.
위의 graph 구성 요소를 살펴보면 다음과 같다.
- graph는 Graph에 해당한다.
- %x는 Value에 해당한다.
- %x : Double(2)는 Value %x에 대한 type annotation이다.
- %x : T1, %y : T2 = namespace::name(%z, %w)는 Node에 해당한다. namespace::name은 해당 Node의 operator를 나타내고, %z와 %w를 입력으로 받아서 T1 타입의 %x와 T2 타입의 %y를 결과값으로 리턴한다.
노드들은 attribute을 통해서 추가적인 정보를 표현한다. int = prim::Constant[value=1]()를 보면 해당 노드가 호출되었을 때 value라는 attribute을 리턴하는 것을 알 수 있다. 다음은 attach 가능한 타입들의 리스트이다.
- int64_t
- double
- Tensor
- Graph (operator fusion을 위해서 subgraph를 slicing 하는 경우 등에 사용)
- std::string
- lists of above (단, nesting 되면 안 됨)
Supported Types
- int, float, bool - 스칼라 값
- Dynamic - 어떠한 static info도 없는 tensor
- Float(*, *) - partial static info가 있는 tensor. 다음의 정보를 포함해야한다.
- 타입 정보(Byte, Short, Int, Long, Half, Float, Double)
- dimension들의 개수
- data가 할당된 device
- grad를 사용할지 말지에 대한 boolean flag
- Float(1, 3, 224, 224) - 모든 static info가 주어진 tensor (거의 사용되지 않음)
Control Flow
Block은 Node 사이에 끼어넣어질 수 있으며, 이는 control flow를 구현하는데 활용된다. 이 경우엔 마치 lambda expression이 인자로 전달되었다고 보면된다.
현재 사용되는 control flow combinator는 두 가지가 있다.
1. prim::If
%y_1, ..., %y_r = prim::If(%condition)
block0() { # TRUE BRANCH, never takes arguments, has to return r outputs
%t_1, ..., %t_k = some::node(%a_value_from_outer_block)
-> (%t_1, ..., %t_r)
}
block1() { # FALSE BRANCH, never takes arguments, has to return r outputs
%f_1, ..., %f_m = some::node(%a_value_from_outer_block)
-> (%f_1, ..., %f_r)
}
- 조건문을 구현하는데 사용된다.
- %y_1, ..., %y_r에 해당하는 값들은 런타임에서의 %condition에 따라서 %t_1, ..., %t_k 또는 %f_1, ..., %f_m 중 하나가 된다.
예시를 하나 살펴보면 다음과 같다.
def f(a, b, c):
d = a + b
if c:
e = d + d
else:
e = b + d
return e
graph(%a : Dynamic
%b : Dynamic
%c : Dynamic) {
%2 : int = prim::Constant[value=1]()
%3 : Dynamic = aten::add(%a, %b, %2)
%5 : Dynamic = prim::If(%c)
block0() {
%6 : int = prim::Constant[value=1]()
%7 : Dynamic = aten::add(%3, %3, %6)
-> (%7)
}
block1() {
%8 : int = prim::Constant[value=1]()
%9 : Dynamic = aten::add(%b, %3, %8)
-> (%9)
}
return (%5);
}
2. prim::Loop
while과 for를 커버한다.
%y_1, ..., %y_r = prim::Loop(%max_trip_count, %initial_condition, %x_1, ..., %x_r)
block0(%i, %a_1, ..., %a_r) {
%b_1, ..., %b_m = some::node(%a_value_from_outer_block, %a_1)
%iter_condition = some::other_node(%a_2)
-> (%iter_condition, %b_1, ..., %b_r)
}
이를 Python-like pseudo code로 나타내면 다음과 같다.
y_1, ..., y_r = x_1, ..., x_r
condition = initial_condition
i = 0while condition and i < max_trip_count:
a_1, ..., a_r = y_1, ..., y_r
############################################################# Actual body of the loop
b_1, ..., b_m = some::node(a_value_from_outside_of_the_loop, a_1)
iter_condition = some::node(a_2)
############################################################
y_1, ..., y_r = b_1, ..., b_r
condition = iter_condition
i += 1
Note that translations of for loops simply pass in a constant true for both %initial_condition and %iter_condition, while for while loops %max_trip_count is set to the largest value of int64_t, and %i is unused. Those patterns are recognized by our interpreter and optimized accordingly (e.g. while loops don't maintain the loop counter).
보다 구체적인 실제 예시를 들자면, 다음과 같은 프로그램이 주어졌을 때
def f(x):
z = x
for i in range(x.size(0)):
z = z * z
return z
다음과 같은 IR로 변환 가능하다.
graph(%z.1 : Dynamic) {
%3 : bool = prim::Constant[value=1]()
%1 : int = prim::Constant[value=0]()
%2 : int = aten::size(%z.1, %1)
%z : Dynamic = prim::Loop(%2, %3, %z.1)
block0(%i : int, %5 : Dynamic) {
%z.2 : Dynamic = aten::mul(%5, %5)
-> (%3, %z.2)
}
return (%z);
}
Function Calls
Graph에서 다른 Graph를 호출하는 경우, 모든 함수들은 caller에게 callee의 body가 inlining 되듯이 표현된다. 재귀 호출은 아직 지원되지 않는다.
Node Overloading
PyTorch IR은 함수 오버로딩을 지원한다. 예를 들어, aten::add의 경우에 다음과 같이 몇 가지 함수들이 오버로딩 되어 있음을 알 수 있다.
- aten::add(Tensor self, Tensor other) -> Tensor
- aten::add(Tensor self, Scalar other) -> Tensor
- aten::add(int self, int other) -> int
- aten::add(float self, float other) -> float
All of the strings above can actually be parsed into FunctionSchema objects, which hold all this infomation in a machine-readable way. A Node can be queried for its schema using the schema() method (it will check the argument types, and will try to match one of the options for its kind()).
Note that the chosen overload is not shown in any way in the textual output. If you're unsure which function does a node resolve to, you might need to check the type annotations of its input values.
JIT Interpreter Bytecode
Graph는 조작하기 쉬운 데이터 표현을 위해서 사용되는 자료구조일 뿐 이를 직접적으로 bytecode로 해석(변환)하기 위한 의도로 사용되는 건 아니다. Graph는 우선 std::function 리스트와 레지스터 사용과 관련된 metadata를 가지고 있는 Code 객체로 변환된다. 그 이후에 Code 객체는 InterpreterState 객체를 통해서 실행된다.
JIT interpreter는 스택 기반의 VM(지역 변수를 저장하기 위한 레지스터들을 가지고 있음)으로, 인자를 모든 명령들로 넘기기 위한 단일 스택을 가지고 있다. 앞서 언급한 Code 객체의 metadata는 레지스터와 스택 간에 어떻게 load/store를 할지를 구성하기 위한 정보를 담고 있다.
Stack은 std::vector<IValue>로 표현되는데, IValue는 JIT이 받을 수 있는 모든 종류의 value를 표현하는 type이다.
Important Files
모든 중요한 파일들은 torch/csrc/jit에 있다.
- ir.h - implementation of Graph, Block, Node, Value
- type.h - implementation of Type
- interpreter.cpp - JIT interpreter (Code, InterpreterImpl)
- ivalue.h - implementation of IValue
- stack.h - implementation of Stack
- graph_executor.cpp - a runner for graphs that will specialize them to different argument configurations
- tracer.h - tracer for PyTorch code (generates straight line Graphs from any code)
- operator.cpp - infrastructure for overload resolution and custom operator registration
- script/ - compiler from TorchScript (think Python AST) to Graphs
- passes/*.cpp - optimization passes
- fusers/**/* - CUDA and CPU codegens for pointwise subgraphs
- autodiff.cpp - symbolic AD for Graphs
- symbolic_variable.h - a helper to make Graph building easier
IR Construction
PyTorch IR을 만들기 위한 세 가지 방법이 존재한다.
Tracing
This means that you run arbitrary Python/C++ code using PyTorch operators, and we record a straight line trace (control flow gets unrolled and inlined). Good for simple models, bad if you really have data dependent control flow (and it's not only used for metaprogramming). The relevant entry point for this is torch.jit.trace.
TorchScript
This method implements a simple Python-like language (it's in fact a subset of Python that conforms to its semantics) and a compiler from it to the IR. Great if you need to retain control flow, but a bit annoying if you need more advanced language features.
Manual construction
This doesn't really happen anywhere outside of the optimization passes, and is probably not recommended. SymbolicVariable is a helper that overloads many Tensor operators and makes them insert Nodes into its Graph instead of doing actual compute.
Graph Manipulation
PyTorch에서는 편리하게 IR을 만들 수 있도록 여러 API를 제공한다.
As mentioend previously, the IR is really optimized to be easy to manipulate and change. TO help with that there are numerous methods on Graphs, Nodes and Values, and we maintain a lot of extra metadata that allows to quickly check certain conditions (e.g. looking up all use sites of a single Value takes constant time, because we have this information cached). Here's a list of the most relevant methods you can find (think of ArrayRef as of an std::vector, Symbol is an interned string)
Graph
- ArrayRef<Value*> inputs()
- ArrayRef<Value*> outputs()
- graph_node_list nodes()
- Value* addInput()
- Value* insertInput(size_t offset)
- Value* eraseInput(size_t offset)
- size_t registerOutput(Value *output);
- void eraseOutput(size_t offset)
- Value* insert(Symbol opname, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs = {});
- This is the most convenient method of adding more nodes to the Graph. An example call looks like this:
- graph->insert(aten::add, {some_value, 3}) (note how C++ values will get inserted into the Graph as constants automatically).
- Block* block() (returns the top-level block)
- void lint() (throws if the graph violates invariants like having the node list being a valid topological order of data dependencies)
- void dump() (prints the graph to stdout -- useful for debugging)
Value
- const TypePtr& type()
- Node* node() (producer of this Value)
- size_t offset() (offset into the output list of the node())
- void replaceAllUsesWith(Value* other)
- const use_list& uses() (use_list is std::vector<Use>, where Use is a struct containing a Node* and offset into its input list)
- Graph* owningGraph()
Node
- Symbol kind()
- ArrayRef<Value*> inputs()
- ArrayRef<Value*> outputs()
- Value* namedInput(Symbol name) (lets you look up inputs by their names instead of depdending on the positional order)
- bool is_constant(Symbol name) (return true if input name is a constant)
- optional<IValue> get(Symbol name) (if is_constant(name), returns an IValue containing its value)
- optional<T> get(Symbol name) (same as above but returns T instead of IValue)
- Value* addInput(Value* value)
- Value* insertInput(size_t offset, Value* value)
- Value* replaceInput(size_t offset, Value* newValue)
- Value* replaceInputWith(Value* from, Value* to)
- Value* addOutput()
- Value* insertOutput(size_t offset)
- void eraseOutput(size_t offset)
- ArrayRef<Block*> blocks()
- Block* addBlock()
- void eraseBlock(size_t offset)
- void destroy() (This is dangerous! All references to Values produced by this node, and to the node itself become invalid!)
- void dump() (Debug print to stdout)
- Block* owningBlock()
- Graph* owningGraph()