PyTorch JIT Source Code Read Note

This is my note for reading PyTorch’s JIT source. We begin by looking at torch.jit.script and torch.jit.script_method to find the frontend that compiles the Python code into PyTorch’s tree views, and the backend that compiles tree views to graph. We also read the structure of the internal representation of PyTorch’s graph. Finally we go to graph executor to look at how the computation graph is further compiled into instructions and how the action of these instructions are defined and executed.

PyTorch is under very active development. So the PyTorch’s source code at the time the reader reading this article won’t be the same as when I wrote this article. To get the same source code as in this article, the readers could run the following command:

1
git checkout 76ab26cc3eff1d7ba822d8db93723f5c9598eead

Starting point: script and script_method

In PyTorch, a Python function can be just-in-time compiled by doing something like:

1
2
3
@torch.jit.script
def f(x):
return x + x

the torch.jit.script is a decorator of your function f. If you are unfamiliar with Python’s decorator, please refer to this article.

It is also possible to create a module with its method JIT compiled by doing something like:

1
2
3
4
5
6
7
8
9
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def f(self.x):
return x * x
@torch.jit.script_method
def forward(self, x):
return x + self.f(x)

Scripting a function

We will start by looking at torch.jit.script. To read torch.jit.script, we begin by looking at torch/jit/__init__.py. To quickly locate script, search def script in your editor, and you will immediately find it:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def script(fn, optimize=True, _frames_up=0):
if not _enabled:
return fn
rcb = createResolutionCallback(_frames_up + 1)
ast = get_jit_ast(fn, is_method=False)
graph = _jit_script_compile(ast, rcb)
mod = ScriptModule()
mod._create_method_from_graph('forward', graph)
# TODO: refactor everything so we're not 1) creating a ScriptModule
# 2) Throwing everything away except for the graph 3) Creating a new
# ScriptModule and dumping that graph in 4) Re-populating the schema
# because it was lost doing the previous
mod.__getattr__('forward').forward_schema(ast, False)
# Forward docstrings
mod.__doc__ = fn.__doc__
return mod

In the beginning, createResolutionCallback is called. This function is defined in the same file. The source code tells us that it just returns a function that maps names to its values in the scope of the caller of script, this would be used later in C++ to read values from Python.

The get_jit_ast in next line is imported from torch.jit.frontend. From the name of this function and its owning module, we can tell that this is the frontend of PyTorch’s JIT compiler that compiles the source code of the scripted function into abstract syntax tree(AST).

The next line uses _jit_script_compile to compiles the AST obtained in the previous step into computation graph. By searching _jit_script_compile, we find something that reads: torch._C._jit_script_compile, which tells us that _jit_script_compile is implemented in C++.

The next couple lines basically create a ScriptModule whose forward method is the compiled graph.

Scripting a module

We start by looking at torch.jit.script_method:

1
2
3
4
5
6
7
8
9
10
ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
def script_method(fn):
if not _enabled:
return fn
# ...
rcb = createResolutionCallback(frames_up=2)
ast = get_jit_ast(fn, is_method=True)
return ScriptMethodStub(rcb, ast, fn)

This is similar to script, but instead of creating and returning a module and put the compiled function into its forward method, it simply use a named tuple to store the resolution callback, AST and the original function.

This can not be the end of the story because a named tuple can never be called to do the computation. So there must be some magic somewhere that replace the named tuples with something that actually do the job. For readers familiar with Python’s class meta-programming, it’s not hard to imagine how the magic happens. For those not familiar with class meta-programming, I would refer to the book Fluent Python. I will explain a bit of detail on that:

In Python, everything is an object, and a class itself is not an exception. Classes in Python are objects of a special type of classes called meta-class. During import time, when Python see the following code:

1
2
3
4
5
6
7
8
9
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def f(self.x):
return x * x
@torch.jit.script_method
def forward(self, x):
return x + self.f(x)

It will execute the body of the class definition, that is: compile the return x * x, create an function object with that compiled code, pass this function object to torch.jit.script_method, and set the returned named tuple as f. Then do the same thing for forward. After that, Python will have a map of attribute names and values of the class to be constructed. This map will then be passed to the meta-class of MyModule to actually construct MyModule as an instance of that meta-class.

To know in detail how this is achieved in PyTorch, we should take a look at ScriptMeta and ScriptModule. These two classes are lengthy, so I will not copy their full code here, but to use pseudocode to show what is done:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ScriptMeta(type(torch._C.ScriptModule)):
def __init__(cls, name, bases, attrs):
# delete all ScriptMethodStub
@functools.wraps(original_init)
def init_then_register(self, *args, **kwargs):
# invoke the original __init__
self._create_methods(defs, rcbs)
cls.__init__ = init_then_register
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class ScriptModule(with_metaclass(ScriptMeta, torch._C.ScriptModule, Module)):
# ......
def __getattr__(self, attr):
if self._has_method(attr):
# ......
return self._get_method(attr)
# .....
return Module.__getattr__(self, attr)

In the above pseudocode, _create_methods, _has_method, and _get_method are inherited from torch._C.ScriptModule. So a natural question to ask is then: what does torch._C.ScriptModule do? Before answering this question, let’s first take a look at the frontend.

The frontend

A good starting point of the frontend is the get_jit_ast we just saw. This function is defined at torch/jit/frontend.py. The code is:

1
2
3
4
5
6
7
8
def get_jit_ast(fn, is_method):
source = dedent(inspect.getsource(fn))
py_ast = ast.parse(source)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError("expected a single top-level function")
type_line = torch.jit.annotations.get_type_line(source)
ctx = SourceContext(source, _uses_true_division(fn))
return build_def(ctx, py_ast.body[0], type_line, is_method)

The first 4 lines of function body just use the standard tools provided by Python, dedent, inspect, and ast, to construct the Python AST, and do some check to make sure the thing being compiled is “a single top-level function”.

The following line type_line = torch.jit.annotations.get_type_line(source) is interesting. After looking at torch/jit/annotations.py, we can see that PyTorch’s JIT allows the user to specify the type of arguments and return value by writing something like # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor].

In the next line ctx = SourceContext(source, _uses_true_division(fn)), the _uses_true_division is defined in the same file to handle the different behavior of / in Python2 with or without from __future__ import division (see PEP 238 for the difference). The SourceContext is also defined in the same file. It is a subclass of SourceRangeFactory with additional field to store if the division is true division. The SourceRangeFactory is imported by from torch._C._jit_tree_views import *. After reading its definition at torch/csrc/jit/script/python_tree_views.cpp, we can see that this is basically a class designed to store the range of source code, e.g. where in the source code a token is located.

The core is the build_def in the last line, so we move on:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def build_def(ctx, py_def, type_line, is_method):
returns = []
ret_body = []
body = py_def.body
r = ctx.make_range(py_def.lineno, py_def.col_offset,
py_def.col_offset + len("def"))
param_list = build_param_list(ctx, py_def.args)
return_type = None
if getattr(py_def, 'returns', None) is not None:
return_type = build_expr(ctx, py_def.returns)
decl = Decl(r, param_list, return_type)
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
return Def(Ident(r, py_def.name),
decl,
build_stmts(ctx, body))

Reading through this, we can see that what basically this does is to convert the Python’s AST into the internal representation. Names like Decl, Def, Ident are all imported by from torch._C._jit_tree_views import *. In the last line, we can see that the function body is constructed by build_stmts, so let’s go further to read build_stmts:

1
2
3
def build_stmts(ctx, stmts):
stmts = [build_stmt(ctx, s) for s in stmts]
return list(filter(None, stmts))

This is a very simple function: call build_stmt for each item and filter out those not needed. But what is build_stmt? It is defined as: build_stmt = StmtBuilder(). The definition of StmtBuilder looks like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class StmtBuilder(Builder):
# ...
@staticmethod
def build_Expr(ctx, stmt):
value = stmt.value
if value.__class__.__name__ == 'Str':
# If a statement is a string literal expression,
# then it is a docstring. Just ignore it.
return None
else:
return ExprStmt([build_expr(ctx, value)])
# ...
@staticmethod
def get_assign_lhs_expr(ctx, expr):
# ...
# ...
@staticmethod
def build_Assign(ctx, stmt):
#...
# ......

We can see that, this is a class with many static methods that define what to do for different types of Python AST. I will not go deep into how each type is handled. Since at this point, the readers should be able to catch all the details on how each type of nodes in Python AST are dealt with by themselves. So We will stop our frontend reading right here.

ScriptModule and ScriptMethod

To find where ScriptModule in C++ is defined, run grep 'ScriptModule' -r torch/csrc/ and you will locate it at torch/csrc/jit/script/init.cpp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// torch.jit.ScriptModule is a subclass of this C++ object.
// Methods here are prefixed with _ since they should not be
// public.
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
.def(py::init<>())
.def("save", &Module::save)
.def("_set_optimized", &Module::set_optimized)
.def(
"_define",
[](std::shared_ptr<Module> m,
const std::string& script,
ResolutionCallback rcb, bool has_self) {
auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
return defineMethodsInModule(*m, script, pythonResolver(rcb), self);
})
.def("_create_methods", [](std::shared_ptr<Module> m, const std::vector<Def>& defs, const std::vector<ResolutionCallback>& rcbs) {
std::vector<Resolver> resolvers;
for(auto & callback : rcbs) {
resolvers.push_back(pythonResolver(callback));
}
defineMethodsInModule(
*m,
defs,
resolvers,
std::make_shared<ModuleValue>(m));
})
.def("_get_method",
[](Module& self, const std::string& name) -> const Method& {
return self.get_method(name);
}, py::return_value_policy::reference_internal)
//.def more ...
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
.def("graph", [&](Method& self) {
return self.graph();
})
.def("__call__", invokeScriptMethodFromPython)
//.def more ...

We can see that ScriptModule is basically a binding for the C++ class Module. By skim through the list of methods defined here, we can see that it has methods for adding, getting, and checking existence of methods, parameters, submodules, buffers, etc. The class for methods is Method, which binds to Python as ScriptMethod. Methods in modules are created by defineMethodsInModule and invoked by invokeScriptMethodFromPython. defineMethodsInModule is a bit complicated, and we will postpone its reading to the backend compiler part of this article. But invokeScriptMethodFromPython is very simple. Searching with grep, we can easily find its definition in torch/csrc/jit/pybind_utils.h:

1
2
3
4
5
6
7
8
9
10
inline py::object invokeScriptMethodFromPython(
script::Method& method,
py::args args, py::kwargs kwargs) {
auto stack = createStackForSchema(method.getSchema(), std::move(args), std::move(kwargs));
{
AutoNoGIL no_gil_guard;
method.run(stack);
}
return createPyObjectForStack(std::move(stack));
}

We can easily tell that it just create a stack from the input parameters, invoke Method::run to consume elements on the stack as input and leave the output of graph on the stack, and finally convert elements on the stack into Python objects.

Now let’s move on to Module and Method. It’s easy to guess from the name that these classes are defined at torch/csrc/jit/script/module.{h,cpp}. Read through these two files, we would see that Module is just a container of things: it just uses ordered dictionary to store methods, parameters and submodules, and provide methods to access or run them.

What Method does is more interesting. One important thing that the designer of Method must worry about is, since methods have access to not only its arguments, but also other class members of the same object, there must be a mechanism for such kind of access. We will see how this is handled very soon. From its constructor, we can see that a method can be created either from the graph and initial class members directly, or from a method creator. The method creator is invoked lazily, i.e. it is not invoked inside the constructor, but wait until someone call ensure_defined. The following member functions of Method defines how an object of Method is run:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void run(Stack & stack) {
for(at::Tensor* tp : member_inputs) {
stack.push_back(*tp);
}
get_executor().run(stack);
}
IValue operator()(std::vector<IValue> stack) {
checkInputsAgainstSchema(stack);
run(stack);
if (stack.size() != 1) {
return Tuple::create(std::move(stack));
}
return stack.front();
}

By looking at the types of names appearing in the above code, we can see that: graph is an object of Graph, and the virtual machine that execute the graph is an object of GraphExecutor. GraphExecutor operate on data type IValue, and its stack is a vector of that data type. To run a method, one need to first push the arguments onto the stack, and invoke Method::run, which will further push other member inputs onto the stack, and invoke GraphExecutor::run to run the graph. The graph executor will leave its output on the stack.

At this point, we still don’t know how things like Graph and GraphExecutor works, but before looking deep into that, let’s pause a little bit to take a look at the backend compiler.

From Python AST to PyTorch IR: part 1

Now let’s move on to read _jit_script_compile. To find where it is located, simply run the command grep _jit_script_compile -r .. We will find something like:

1
./torch/csrc/jit/script/init.cpp: m.def("_jit_script_compile", [](const Def &def, ResolutionCallback rcb) {

So, torch/csrc/jit/script/init.cpp would be a good start point. The complete definition of _jit_script_compile is:

1
2
3
m.def("_jit_script_compile", [](const Def &def, ResolutionCallback rcb) {
return compileFunction(def, PythonResolver(rcb));
});

So, let’s move on to compileFunction. Using grep to search, we would find its definition in torch/csrc/jit/script/compiler.cpp:

1
2
3
4
5
std::shared_ptr<Graph> compileFunction(Def def, const Resolver& resolver) {
Module m;
defineMethodsInModule(m, {def}, {resolver}, nullptr);
return m.get_method(def.name().name()).graph();
}

We see the defineMethodsInModule that we saw before on the definition of Python bindings for Module. Move on to defineMethodsInModule, on the same file:

1
2
3
4
5
6
7
8
9
10
11
12
void defineMethodsInModule(Module & m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, SugaredValuePtr self) {
// ......
for(Def def : definitions) {
// ......
auto creator = [def, &table, resolver, self](Method& method) {
to_ir(def, table, resolver, self, method);
};
Method& method = m.create_method(name, creator);
// ......
}
// ......
}

Less important parts of the code is omitted. From above, we can find that the core of compiling an AST into a compute graph is done at to_ir. Skimming through to_ir we find that it is a struct of ~1000 lines of code, with member functions that handles different cases of Python AST. Without knowing PyTorch’s IR, it’s not easy to understand what to_ir does. So let’s pause a little bit to take a look at PyTorch IR and come back later.

The PyTorch IR

A good starting point is the class Graph, located at torch/csrc/jit/ir.h. Skimming through this file, as well as to_ir, we keep seeing things like aten::mul, prim::Constant. What are they? They seems to be very relevant, actually they seems to be the node in the graph. By doing some grep search, we find a good document of them at torch/csrc/jit/interned_strings.h:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// 'prim' symbols are synthetic operators that occur only in the IR
// and don't have corresponding implementations in ATen.
// 'onnx' symbols correspond to ONNX operators. Their semantics
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
// The particular version we are targeting is specified by '_onnx_opset_version'
// in torch.onnx.symbolic
//
// In general, most ONNX operators won't get an entry here, because they
// are handled from the Python end. However, you may occasionally need
// to intern an ONNX symbol here so that you can conveniently write an
// optimization on ONNX operations.
// 'attr' symbols are attribute keys. They are shared between both ONNX and ATen
// operators (you disambiguate their meaning by looking at the operator itself).
// In general, you only need to define attribute keys that are used by
// onnx or prim; ATen attributes are automatically generated in FORALL_ATTR_BASE_SYMBOLS.
// Note [Symbol allocation]
// ~~~~~~~~~~~~~~~~~~~~~~~~
//
// 1. Symbol namespace is split up into namespaces.
//
// 2. The intended access pattern for built-in symbols is onnx::MatMul
// in the torch::jit namespace (this is a Symbol).
//
// Built-in constant definition strategy:
// - Enum is the most convenient way to generate a contiguous sequence
// of numbers for an identifier.
// - However, an enum gives you a fresh type. We want onnx::MatMul to
// be type Symbol, not some random enum type!
// - Therefore, after using enums to generate the sequence of integers,
// we then declare constexpr Symbols to get everything the actual Symbol
// type we want. Symbols must be constexpr to be valid to be "case"ed on.
using unique_t = uint32_t;
static const std::string domain_prefix = "org.PyTorch.";
// A Symbol is like an interned string, but with a little extra
// structure; it is namespaced via SymbolNamespace and the resulting
// intern pointers support efficient namespace testing.
struct TORCH_API Symbol {
// more code omitted ......

This very well explains what those things are: they are instances of Symbol to represent operators. Knowing this level of detail about these things is enough for us, so let’s go back to IR.

The beginning of file torch/csrc/jit/ir.h very well explains what things are:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside it.
// All references inside the graph are raw pointers.
// Destroying the Graph will invalidate any pointers to nodes in the graph.
struct Graph;
// Node is the base class of the IR graph. It represents one computation
// and dependencies on a list of Values. The "prim-ops", so to speak.
struct Node;
// A Value represents an input or output to node that is either a
// Tensor or an opaque Handle object, as determined by type().
struct Value;
// ......
// A list of nodes, with inputs and outputs
struct Block;
// Each use is represented by this type, see Node::uses()
// 'user' is the consumer of the value, offset is the index into
// 'user's input this where the produces will be found.
struct Use {
Use(Node * user, size_t offset)
: user(user), offset(offset) {}
Node * user;
size_t offset;
};
// ......
// Scope is a node of a trie that represents the tree of nested scopes.
// Individual scopes are pushed and popped from Graph, which holds a
// pointer to the current scope. Each Node in Graph holds a pointer
// to the scope that was current when the node was created.
// The trie never needs to shrink, it only grows until it is disposed
// of when Graph is deallocated. Hence, pointers to scopes held by nodes
// will always be valid as long as Graph is alive.
struct Scope {

Reading through the whole file, we can summarize how it works:

A Graph object owns all Nodes, Values, and Blocks. The internal structure is not maintained by the Graph object, but inside Nodes, Values, and Blocks.

Each Node keeps pointers to its input, and output Values. It also maintains pointers to siblings in a doubly-linked list of Nodes. This doubly-linked list is a topological sort of the Nodes in the Graph. Each Node has a NodeKind as an object of Symbol. Nodes also maintains a pointer to the Block owning this Node, as well as pointers to subblocks.

Each Value must be an output of some Node, and it has a Node pointer pointing to the Node that outputs this Value. It also has a Use list storing where this Value is used as input.

Each Block maintains pointers to its input and output Nodes, as well as the Node owning this Block.

From Python AST to PyTorch IR: part 2

With the knowledge of IR, let’s go back to read the backend compiler.

In the code in torch/csrc/jit/script/compiler.cpp, we have been seeing SugaredValue many times. What SugaredValue does is explained in torch/csrc/jit/script/compiler.h:

1
2
3
4
5
6
7
8
9
// The AST can contain nodes like `self`, `self.b` or `Python_fn` that
// are not first-class values in the graph representation, but instead
// will be desugared based on how they are used in the AST.
// SugaredValue is used to temporarily represent these values in a way
// that separates their behavior from the AST -> IR converter itself.
// This allows us to keep dependencies on Python minimal.
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {

From the comments above, together with what we see when skimming through the code, we can see that, SugaredValue is a super class of different types of values. These values might be first-class values like tensors or integers, or ScriptModule such as self, or Python modules like torch, or some builtin functions like print. Different types of values are handled by different subclasses: SimpleValue for first class values, BuiltinFunction for operators like aten::relu, BuiltinModule for something like torch, NoneValue for None, PrintValue for print, CastValue for types like int, float, etc. These subclasses listed above are all defined in torch/csrc/jit/script/compiler.{cpp, h}.

Now let’s move on to read the constructor of the struct to_ir. It basically:

  1. Read the information of parameters from the Python AST, and set them up in graph.
  2. Call emitStatements to emit IR for function body.
  3. Set up output values for the graph based on the return statement in the end of function body (compiling functions that has a return statement on somewhere other than the end is not supported).

In step 1, there is a little bit of trouble that for functions that is a method of some module, the first parameter is always the reference to the object owing this method (aka. the so called “self”). So it requires a little bit of special case when checking against schema. Also, we need to add the identifier for the first parameter to the symbol table (here the symbol table is Environment::value_table, an object of ValueTable). The input to the graph is not only those appears explicitly in the argument list, but also those members access inside the function body. Recall that when we read the code of Method::run, there is a step that push members onto the stack. This issue is not handled here, and we will see how it is handled later.

In step 2, things started to get complicated. In emitStatements, code emitting are dispatched to different specialized private methods of the struct by its type:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
break;
case TK_GLOBAL:
for (auto ident : Global(stmt).names()) {
const auto& name = Ident(ident).name();
environment_stack->setVar(ident.range(), name, graph->addInput(name));
}
break;
case TK_EXPR_STMT: {
auto exprs = ExprStmt(stmt).exprs();
for (const auto& expr : exprs) {
emitSugaredExpr(expr, 0);
}
}
break;
case TK_RETURN:
throw ErrorReport(stmt) << "return statements can appear only at the end "
<< "of the function body";
break;
}
}
}

There are so many specialized emits, I will not go over these in detail one by one. I will only go deep into emitSugaredExpr as an example here. emitSugaredExpr is defined as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// any expression that can produce a SugaredValue is handled here
// expressions that only return a single Value* are handled in emitSimpleExpr
std::shared_ptr<SugaredValue> emitSugaredExpr(Expr tree, size_t n_binders) {
switch(tree.kind()) {
case TK_VAR:
return environment_stack->getSugaredVar(Var(tree).name());
case '.': {
auto select = Select(tree);
auto sv = emitSugaredExpr(select.value(), 1);
return sv->attr(select.range(), method, select.selector().name());
}
case TK_APPLY: {
auto apply = Apply(tree);
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = fmap(apply.attributes(), [&](const Attribute& attr) {
return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
});
// the apply is directly an identifier 'foo'
if(apply.callee().kind() == TK_VAR) {
return emitApplyIdent(Var(apply.callee()).name(), inputs, attributes, n_binders);
}
return emitApplyExpr(apply.callee(), inputs, attributes, n_binders);
} break;
default:
return std::make_shared<SimpleValue>(emitSimpleExpr(tree));
}
}

What it does is basically: for cases that guaranteed to produce a SimpleValue, we just call emitSimpleExpr to emit the code, otherwise it must be one of the following three format: foo, foo.bar, foo(bar). For the foo case, we just lookup foo in the symbol table, for the foo.bar case, we first emit foo and lookup its attribute bar. For the foo(bar) case, depending on whether foo is an identifier or an expression, invoke emitApplyIdent or emitApplyExpr correspondingly to do code emitting.

The self argument of the method is handled a bit differently: there is a subclass of SugaredValue called ModuleValue defined in torch/csrc/jit/script/init.cpp, in its override method attr, we see:

1
2
3
if(NamedParameter* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
}

Where the get_or_add_parameter defined in torch/csrc/jit/script/module.h reads:

1
2
3
4
5
6
7
8
9
10
Value * get_or_add_parameter(at::Tensor* slot) {
auto it = member_input_index.find(slot);
if(it != member_input_index.end()) {
return graph()->inputs().at(it->second);
}
// add it as a new parameter
member_inputs.push_back(slot);
member_input_index[slot] = graph()->inputs().size();
return graph()->addInput();
}

That tells us: adding members as parameters of the graph actually happens at code emitting of self.bar, where the attr of ModuleValue called.

The Graph Executor

Now we have seen how the compilation is done and what does PyTorch JIT’s IR looks like, the thing left is how the IR are executed. From above we already know that the executor is obtained by invoking Method::get_executor and run by invoking GraphExecutor::run. Let’s first take a look at Method::get_executor:

1
2
3
4
5
6
GraphExecutor& get_executor() {
std::call_once(executor_init, [&]{
executor = GraphExecutor(graph(), optimize);
});
return executor;
}

We know that a graph executor is created from a graph, and does optimization if asked. It’s not hard to guess from name that GraphExecutor is defined in torch/csrc/jit/graph_executor.{h, cpp}.

The constructor and run tells us that GraphExecutor is just a wrapper of GraphExecutorImpl:

1
2
3
4
5
6
GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
void GraphExecutor::run(Stack & inputs) {
return pImpl->run(inputs);
}

So let’s move on to GraphExecutorImpl:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
: graph(prepareGraph(graph))
, optimize(optimize)
, num_inputs(this->graph->inputs().size())
, num_flat_inputs(countFlatInputs(graph))
, num_outputs(this->graph->outputs().size()) {}
// entry point where execution begins
void run(Stack & stack) {
AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size());
if(tracer::isTracing()) {
return runTraced(stack);
}
auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback();
return execution_plan.run(stack);
}

We see that the graph is compiled at the first time it runs to get an execution plan. The run method of execution plan is called to run the graph. Compilation of graph to execution plan is done by getOrCompile or getOrCompileFallback depending on if optimization is enabled. These two methods are copied below:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
const ExecutionPlan & getOrCompileFallback() {
std::lock_guard<std::mutex> lock(compile_mutex);
if(!fallback) {
auto graph_ = graph->copy();
runRequiredPasses(graph_);
fallback = ExecutionPlan(graph_);
}
return fallback;
}
const ExecutionPlan & getOrCompile(const Stack& stack) {
// outside lock guard, to minimize the time holding the lock on the fast path
// ArgumentSpec even computes its hashCode here.
ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs), num_flat_inputs);
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
if (it != plan_cache.end())
return it->second;
auto plan = compileSpec(spec);
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
return r.first->second;
}
}

These code explain itself well: if optimization is turned off, then we only run required passes and cache the result. Otherwise, depending on the characteristic of inputs (ArgumentSpec), we run full optimization and cache the generated plan for each different ArgumentSpec. The plan is created by the constructor of ExecutionPlan.

It worth a look at what passes are called:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
ExecutionPlan compileSpec(const ArgumentSpec & spec) {
auto opt_graph = graph->copy();
setInputTypes(*opt_graph, spec);
// Phase 1. Specialize to input definedness (this is very important for
// gradient graphs), and run required passes to bring the graph
// to an executable form.
runRequiredPasses(opt_graph);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape information
// anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
PropagateInputShapes(*opt_graph);
PropagateRequiresGrad(opt_graph);
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
// we can still execute using autograd).
runOptimization(opt_graph, spec);
// Phase 4. If this graph will be differentiated, we need to slice out the
// symbolically differentiable subgraphs for further optimizations.
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// (or the whole grpah if we know we won't need its derivative).
if (needsGradient(opt_graph)) {
auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph);
for (Node * dnode : diff_nodes) {
auto diff_graph = std::move(dnode->g(attr::Subgraph));
Gradient gradient = differentiate(diff_graph);
runNondiffOptimization(gradient.f);
packGradient(gradient, dnode);
}
InlineAutodiffSubgraphs(opt_graph);
} else {
runNondiffOptimization(opt_graph);
}
// Make sure there are no leftovers from any passes.
EliminateDeadCode(opt_graph);
return ExecutionPlan(opt_graph);
}
void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
UnrollLoops(graph);
PeepholeOptimize(graph);
CheckInplace(graph);
BatchMM(graph);
}
void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
FuseGraph(graph);
}
// ......
void runRequiredPasses(const std::shared_ptr<Graph>& g) {
specializeUndef(*g);
LowerGradOf(*g);
// implicit inserted expand nodes are not necessarily always valid
// when used inside script methods that might have unstable shapes
// we remove the implicitly created ones, and have shape analysis
// add valid expand nodes when the shapes are stable
RemoveExpands(g);
CanonicalizeOps(g);
EliminateDeadCode(g);
}

I will not go deep into these passes here, interested readers can read them at torch/csrc/jit/passes/.

Now it’s time to look at ExecutionPlan:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph)
: code(graph)
, graph(std::move(graph)) {}
void run(Stack& stack) const {
return InterpreterState(code).runOneStage(stack);
}
operator bool() const {
return static_cast<bool>(graph);
}
ExecutionPlanState getDebugState() {
ExecutionPlanState state;
state.code = &code;
state.graph = graph.get();
return state;
}
Code code;
std::shared_ptr<Graph> graph;
};

It just convert the graph into an object of Code, and the running is done by InterpreterState.

Compiling to Interpreter Instructions

Code and InterpreterState are defined in torch/csrc/jit/interpreter.{h,cpp}. These two classes are just a wrapper of its implementations:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Code::Code(std::shared_ptr<Graph>& graph)
: pImpl(new CodeImpl(graph)) {}
Code::~Code() = default;
const std::vector<GraphExecutor*>& Code::grad_executors() {
return pImpl->grad_executors();
}
InterpreterState::InterpreterState(const Code & code)
: pImpl(new InterpreterStateImpl(code)) {}
InterpreterState::~InterpreterState() = default;
void InterpreterState::runOneStage(Stack & stack) {
return pImpl->runOneStage(stack);
}

CodeImpl is a long struct, but quite logical. A selected list of fields it has is listed below:

1
2
PreprocessGraph preprocess;
std::vector<Instruction> instructions;

Its constructor is:

1
2
3
4
5
6
CodeImpl(std::shared_ptr<Graph>& graph_)
: preprocess(*graph_) {
graph = preprocess.graph;
// std::cout << "into code graph:\n" << *graph << "\n";
insertNodesFromBlock(graph->block());
}

Clearly we can see what it does is: 1. preprocess the graph, and then 2. emit instructions for interpreter.

The preprocessing of graph is very well explained in the beginning of file:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Before we translate to intepreter instructions, we do
// some preprocessing of the graph to turn it into a form that is closer
// to what the instructions will look like.
// In particular we:
// * (TODO) desugar Loop trip counts into c = 0, c += 1 instructions in the loop
// * flatten stages so that each stage starts with a load from the stack
// and ends with a store to the stack
// *. computes move_flags (see Outputs), and inserts
// * Drop nodes are inserted for any node that is unused to create a dummy use
// that will cause the interpreter to free the node.
// A drop node is just a node with no outputs that just pops its inputs off the stack,
// to ensure the interpreter release references to nodes that are never used.
// Drop nodes are also inserted when the last use of a node is in some conditionally
// run control flow (e.g. one side of an If) and the interpreter must free
// the node only after the control flow has reconverged
// Outputs are:
// * graph - the post processed copy of g
// * move_flags[n] - a list of booleans, one for each input,
// indicating whether this is the last use of the value. The interpreter
// should generate a move rather than a copy in this case.
// * stage_input_types: the type annotations on the inputs to each stage
// these can be removed once the the backward tracer is no longer used

as well as in its definition

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
struct PreprocessGraph {
PreprocessGraph(Graph & g)
: graph(g.copy()) {
desugarTripCounts(graph->block());
stage_input_types = flattenStages(*graph);
dropUnused(graph->block());
// fill in move_flags by scanning blocks;
move_flags = findLastUses(*graph);
//TODO: desugar Loop trip counts, for now we drop trip counts
}
// Outputs of the preprocessing:
std::shared_ptr<Graph> graph;
// for each input, should we move rather than copy the inputs
std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
std::vector<std::vector<TypePtr>> stage_input_types;
};

The insertNodesFromBlock emits instructions. It is also very self-explained:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
void insertNodesFromBlock(Block* block) {
for(auto node : block->nodes()) {
const auto & source_location = node->getSourceLocation();
switch(node->kind()) {
case prim::If: {
// x = if c:
// <then_block>
// -> (vt)
// else:
// <else_block>
// -> (vf)
// turns into:
// JumpNZ c, then
// <else_block>
// x = vf
// Jump end
// then:
// <then_block>
// x = vt
// end:
// prim::Placeholder instructions are replaced with branch instructions
// when the branch target locations are known
auto cond_branch = insertInstruction(prim::Placeholder, source_location, node->inputs(), moveFlags(node), {});
auto then_block = node->blocks()[0];
auto else_block = node->blocks()[1];
insertNodesFromBlock(else_block);
insertAssign(source_location,else_block->outputs(), moveFlags(else_block), node->outputs());
auto jump = insertInstruction(prim::Placeholder, source_location, {}, {}, {});
auto then_block_start = instructions.size();
insertNodesFromBlock(then_block);
insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs());
createJump(jump, instructions.size());
createJumpNZ(cond_branch, then_block_start);
} break;
case prim::Loop: {
// omitted ......
} break;
default: {
insertInstruction(node);
} break;
}
// each stage ends with a load instruction
// we record where these instructions occur, and use them to
// exit the interpreter
if(node->kind() == prim::Load) {
stage_end.push_back(instructions.size());
}
}
}

Since the nodes are topologically sorted, we just need to iterate the linked list and generate code for each node.

The Virtual Machine

InterpreterStateImpl is the virtual machine that executes instructions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
InterpreterStateImpl(const Code & code)
: function(code.pImpl),
int_data(function->int_data.data()),
bool_data(function->bool_data),
registers(function->register_size) {
}
void runOneStage(Stack & stack) {
// std::cout << "running stage: " << current_stage << " of " << function->stage_end.size() << "\n";
// std::cout << *function->graph << "\n";
// function->dump(std::cout);
size_t pc = current_pc;
size_t last = function->stage_end[current_stage];
auto & instructions = function->instructions;
while(pc < last) {
// std::cout << "executing " << pc << ": ";
// function->dumpInstruction(std::cout, pc);
// std::cout << "\n";
try {
auto & inst = instructions[pc];
loadTensorsFromRegisters(inst.inputs, stack);
size_t new_pc = pc + 1 + inst.callback(stack);
for(int i = inst.outputs.size - 1; i >= 0; i--) {
int reg = get(inst.outputs,i);
registers[reg] = pop(stack);
// std::cout << "pop reg[" << reg << "];\n" << registers[reg].pImpl << "\n";
}
pc = new_pc;
} catch(std::exception & e) {
if(!instructions[pc].debug_location)
throw; // rethrow original exception
// throw a new exception with enhanced debugging information
instructions[pc].debug_location->wrapAndRethrowException(e, "operation failed in interpreter");
}
}
current_pc = pc;
current_stage++;
}

There is nothing special, just mimicking the behavior of processors. We can easily tell from the above code that the actions is defined at Instruction::callback and branching is implemented as returning a non-zero value from that callback function. Some of the callbacks are defined inside CodeImpl, such as:

1
2
3
4
5
6
7
8
9
10
11
// jump when input is not 0
void createJumpNZ(int from_inst, int to_inst) {
auto & inst = instructions[from_inst];
JIT_ASSERT(inst.debug_name == prim::Placeholder);
auto offset = relativeJump(from_inst, to_inst);
inst.callback = [offset](Stack & stack) {
auto t = pop(stack).toInt();
return (t != 0) ? offset : 0;
};
inst.debug_name = prim::JumpNZ;
}

while others are defined by its node kind:

1
2
3
4
5
size_t insertInstruction(Node * n) {
auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs());
instructions[inst].callback = getOperation(n);
return inst;
}

where getOperation is defined in torch/csrc/jit/operator.{h, cpp}. Further reading through these two files, we can see that operations are registered by calling registerOperator, which is done through calling RegisterOperators. Using grep RegisterOperators -r torch/csrc/, we can locate the definition of all operations:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
torch/csrc/jit/generated/register_aten_ops.cpp:RegisterOperators reg({
torch/csrc/jit/fusers/common/fusion_handle_impl.cpp:RegisterOperators reg_fused_operators({
torch/csrc/jit/custom_operator.h:/// so in the global scope when a `RegisterOperators` object is assigned to a
torch/csrc/jit/custom_operator.h:struct TORCH_API RegisterOperators {
torch/csrc/jit/custom_operator.h: RegisterOperators() = default;
torch/csrc/jit/custom_operator.h: RegisterOperators(std::vector<Operator> operators) {
torch/csrc/jit/custom_operator.h: RegisterOperators(const std::string& name, Implementation&& implementation) {
torch/csrc/jit/custom_operator.h: RegisterOperators& op(
torch/csrc/jit/Python_interpreter.cpp:RegisterOperators reg({
torch/csrc/jit/register_special_ops.cpp:RegisterOperators reg({
torch/csrc/jit/graph_executor.cpp:RegisterOperators reg_graph_executor_ops({
torch/csrc/jit/constants.cpp:RegisterOperators reg({
torch/csrc/jit/register_prim_ops.cpp:RegisterOperators reg({
torch/csrc/jit/register_prim_ops.cpp:RegisterOperators reg2({
torch/csrc/jit/test_jit.cpp: RegisterOperators reg({createOperator(
torch/csrc/jit/test_jit.cpp: RegisterOperators reg({createOperator(
torch/csrc/jit/test_jit.cpp: RegisterOperators reg({createOperator(
torch/csrc/jit/test_jit.cpp: RegisterOperators reg(

At this point, we are done with getting the whole big picture of PyTorch’s JIT. It’s time to stop here, and interested readers can read the code by themselves for more details.