This is my note for reading PyTorch’s JIT source. We begin by looking at torch.jit.script and torch.jit.script_method to find the frontend that compiles the Python code into PyTorch’s tree views, and the backend that compiles tree views to graph. We also read the structure of the internal representation of PyTorch’s graph. Finally we go to graph executor to look at how the computation graph is further compiled into instructions and how the action of these instructions are defined and executed.

# Starting point: script and script_method

In PyTorch, a Python function can be just-in-time compiled by doing something like:

the torch.jit.script is a decorator of your function f. If you are unfamiliar with Python’s decorator, please refer to this article.

It is also possible to create a module with its method JIT compiled by doing something like:

## Scripting a function

We will start by looking at torch.jit.script. To read torch.jit.script, we begin by looking at torch/jit/__init__.py. To quickly locate script, search def script in your editor, and you will immediately find it:

In the beginning, createResolutionCallback is called. This function is defined in the same file. The source code tells us that it just returns a function that maps names to its values in the scope of the caller of script, this would be used later in C++ to read values from Python.

The get_jit_ast in next line is imported from torch.jit.frontend. From the name of this function and its owning module, we can tell that this is the frontend of PyTorch’s JIT compiler that compiles the source code of the scripted function into abstract syntax tree(AST).

The next line uses _jit_script_compile to compiles the AST obtained in the previous step into computation graph. By searching _jit_script_compile, we find something that reads: torch._C._jit_script_compile, which tells us that _jit_script_compile is implemented in C++.

The next couple lines basically create a ScriptModule whose forward method is the compiled graph.

## Scripting a module

We start by looking at torch.jit.script_method:

This is similar to script, but instead of creating and returning a module and put the compiled function into its forward method, it simply use a named tuple to store the resolution callback, AST and the original function.

This can not be the end of the story because a named tuple can never be called to do the computation. So there must be some magic somewhere that replace the named tuples with something that actually do the job. For readers familiar with Python’s class meta-programming, it’s not hard to imagine how the magic happens. For those not familiar with class meta-programming, I would refer to the book Fluent Python. I will explain a bit of detail on that:

In Python, everything is an object, and a class itself is not an exception. Classes in Python are objects of a special type of classes called meta-class. During import time, when Python see the following code:

It will execute the body of the class definition, that is: compile the return x * x, create an function object with that compiled code, pass this function object to torch.jit.script_method, and set the returned named tuple as f. Then do the same thing for forward. After that, Python will have a map of attribute names and values of the class to be constructed. This map will then be passed to the meta-class of MyModule to actually construct MyModule as an instance of that meta-class.

To know in detail how this is achieved in PyTorch, we should take a look at ScriptMeta and ScriptModule. These two classes are lengthy, so I will not copy their full code here, but to use pseudocode to show what is done:

In the above pseudocode, _create_methods, _has_method, and _get_method are inherited from torch._C.ScriptModule. So a natural question to ask is then: what does torch._C.ScriptModule do? Before answering this question, let’s first take a look at the frontend.

# The frontend

A good starting point of the frontend is the get_jit_ast we just saw. This function is defined at torch/jit/frontend.py. The code is:

The first 4 lines of function body just use the standard tools provided by Python, dedent, inspect, and ast, to construct the Python AST, and do some check to make sure the thing being compiled is “a single top-level function”.

The following line type_line = torch.jit.annotations.get_type_line(source) is interesting. After looking at torch/jit/annotations.py, we can see that PyTorch’s JIT allows the user to specify the type of arguments and return value by writing something like # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor].

In the next line ctx = SourceContext(source, _uses_true_division(fn)), the _uses_true_division is defined in the same file to handle the different behavior of / in Python2 with or without from __future__ import division (see PEP 238 for the difference). The SourceContext is also defined in the same file. It is a subclass of SourceRangeFactory with additional field to store if the division is true division. The SourceRangeFactory is imported by from torch._C._jit_tree_views import *. After reading its definition at torch/csrc/jit/script/python_tree_views.cpp, we can see that this is basically a class designed to store the range of source code, e.g. where in the source code a token is located.

The core is the build_def in the last line, so we move on:

Reading through this, we can see that what basically this does is to convert the Python’s AST into the internal representation. Names like Decl, Def, Ident are all imported by from torch._C._jit_tree_views import *. In the last line, we can see that the function body is constructed by build_stmts, so let’s go further to read build_stmts:

This is a very simple function: call build_stmt for each item and filter out those not needed. But what is build_stmt? It is defined as: build_stmt = StmtBuilder(). The definition of StmtBuilder looks like:

We can see that, this is a class with many static methods that define what to do for different types of Python AST. I will not go deep into how each type is handled. Since at this point, the readers should be able to catch all the details on how each type of nodes in Python AST are dealt with by themselves. So We will stop our frontend reading right here.

# ScriptModule and ScriptMethod

To find where ScriptModule in C++ is defined, run grep 'ScriptModule' -r torch/csrc/ and you will locate it at torch/csrc/jit/script/init.cpp:

We can see that ScriptModule is basically a binding for the C++ class Module. By skim through the list of methods defined here, we can see that it has methods for adding, getting, and checking existence of methods, parameters, submodules, buffers, etc. The class for methods is Method, which binds to Python as ScriptMethod. Methods in modules are created by defineMethodsInModule and invoked by invokeScriptMethodFromPython. defineMethodsInModule is a bit complicated, and we will postpone its reading to the backend compiler part of this article. But invokeScriptMethodFromPython is very simple. Searching with grep, we can easily find its definition in torch/csrc/jit/pybind_utils.h:

We can easily tell that it just create a stack from the input parameters, invoke Method::run to consume elements on the stack as input and leave the output of graph on the stack, and finally convert elements on the stack into Python objects.

Now let’s move on to Module and Method. It’s easy to guess from the name that these classes are defined at torch/csrc/jit/script/module.{h,cpp}. Read through these two files, we would see that Module is just a container of things: it just uses ordered dictionary to store methods, parameters and submodules, and provide methods to access or run them.

What Method does is more interesting. One important thing that the designer of Method must worry about is, since methods have access to not only its arguments, but also other class members of the same object, there must be a mechanism for such kind of access. We will see how this is handled very soon. From its constructor, we can see that a method can be created either from the graph and initial class members directly, or from a method creator. The method creator is invoked lazily, i.e. it is not invoked inside the constructor, but wait until someone call ensure_defined. The following member functions of Method defines how an object of Method is run:

By looking at the types of names appearing in the above code, we can see that: graph is an object of Graph, and the virtual machine that execute the graph is an object of GraphExecutor. GraphExecutor operate on data type IValue, and its stack is a vector of that data type. To run a method, one need to first push the arguments onto the stack, and invoke Method::run, which will further push other member inputs onto the stack, and invoke GraphExecutor::run to run the graph. The graph executor will leave its output on the stack.

At this point, we still don’t know how things like Graph and GraphExecutor works, but before looking deep into that, let’s pause a little bit to take a look at the backend compiler.

# From Python AST to PyTorch IR: part 1

Now let’s move on to read _jit_script_compile. To find where it is located, simply run the command grep _jit_script_compile -r .. We will find something like:

So, torch/csrc/jit/script/init.cpp would be a good start point. The complete definition of _jit_script_compile is:

So, let’s move on to compileFunction. Using grep to search, we would find its definition in torch/csrc/jit/script/compiler.cpp:

We see the defineMethodsInModule that we saw before on the definition of Python bindings for Module. Move on to defineMethodsInModule, on the same file:

Less important parts of the code is omitted. From above, we can find that the core of compiling an AST into a compute graph is done at to_ir. Skimming through to_ir we find that it is a struct of ~1000 lines of code, with member functions that handles different cases of Python AST. Without knowing PyTorch’s IR, it’s not easy to understand what to_ir does. So let’s pause a little bit to take a look at PyTorch IR and come back later.

# The PyTorch IR

A good starting point is the class Graph, located at torch/csrc/jit/ir.h. Skimming through this file, as well as to_ir, we keep seeing things like aten::mul, prim::Constant. What are they? They seems to be very relevant, actually they seems to be the node in the graph. By doing some grep search, we find a good document of them at torch/csrc/jit/interned_strings.h:

This very well explains what those things are: they are instances of Symbol to represent operators. Knowing this level of detail about these things is enough for us, so let’s go back to IR.

The beginning of file torch/csrc/jit/ir.h very well explains what things are:

Reading through the whole file, we can summarize how it works:

A Graph object owns all Nodes, Values, and Blocks. The internal structure is not maintained by the Graph object, but inside Nodes, Values, and Blocks.

Each Node keeps pointers to its input, and output Values. It also maintains pointers to siblings in a doubly-linked list of Nodes. This doubly-linked list is a topological sort of the Nodes in the Graph. Each Node has a NodeKind as an object of Symbol. Nodes also maintains a pointer to the Block owning this Node, as well as pointers to subblocks.

Each Value must be an output of some Node, and it has a Node pointer pointing to the Node that outputs this Value. It also has a Use list storing where this Value is used as input.

Each Block maintains pointers to its input and output Nodes, as well as the Node owning this Block.

# From Python AST to PyTorch IR: part 2

With the knowledge of IR, let’s go back to read the backend compiler.

In the code in torch/csrc/jit/script/compiler.cpp, we have been seeing SugaredValue many times. What SugaredValue does is explained in torch/csrc/jit/script/compiler.h:

From the comments above, together with what we see when skimming through the code, we can see that, SugaredValue is a super class of different types of values. These values might be first-class values like tensors or integers, or ScriptModule such as self, or Python modules like torch, or some builtin functions like print. Different types of values are handled by different subclasses: SimpleValue for first class values, BuiltinFunction for operators like aten::relu, BuiltinModule for something like torch, NoneValue for None, PrintValue for print, CastValue for types like int, float, etc. These subclasses listed above are all defined in torch/csrc/jit/script/compiler.{cpp, h}.

Now let’s move on to read the constructor of the struct to_ir. It basically:

1. Read the information of parameters from the Python AST, and set them up in graph.
2. Call emitStatements to emit IR for function body.
3. Set up output values for the graph based on the return statement in the end of function body (compiling functions that has a return statement on somewhere other than the end is not supported).

In step 1, there is a little bit of trouble that for functions that is a method of some module, the first parameter is always the reference to the object owing this method (aka. the so called “self”). So it requires a little bit of special case when checking against schema. Also, we need to add the identifier for the first parameter to the symbol table (here the symbol table is Environment::value_table, an object of ValueTable). The input to the graph is not only those appears explicitly in the argument list, but also those members access inside the function body. Recall that when we read the code of Method::run, there is a step that push members onto the stack. This issue is not handled here, and we will see how it is handled later.

In step 2, things started to get complicated. In emitStatements, code emitting are dispatched to different specialized private methods of the struct by its type:

There are so many specialized emits, I will not go over these in detail one by one. I will only go deep into emitSugaredExpr as an example here. emitSugaredExpr is defined as follows:

What it does is basically: for cases that guaranteed to produce a SimpleValue, we just call emitSimpleExpr to emit the code, otherwise it must be one of the following three format: foo, foo.bar, foo(bar). For the foo case, we just lookup foo in the symbol table, for the foo.bar case, we first emit foo and lookup its attribute bar. For the foo(bar) case, depending on whether foo is an identifier or an expression, invoke emitApplyIdent or emitApplyExpr correspondingly to do code emitting.

The self argument of the method is handled a bit differently: there is a subclass of SugaredValue called ModuleValue defined in torch/csrc/jit/script/init.cpp, in its override method attr, we see:

Where the get_or_add_parameter defined in torch/csrc/jit/script/module.h reads:

That tells us: adding members as parameters of the graph actually happens at code emitting of self.bar, where the attr of ModuleValue called.

# The Graph Executor

Now we have seen how the compilation is done and what does PyTorch JIT’s IR looks like, the thing left is how the IR are executed. From above we already know that the executor is obtained by invoking Method::get_executor and run by invoking GraphExecutor::run. Let’s first take a look at Method::get_executor:

We know that a graph executor is created from a graph, and does optimization if asked. It’s not hard to guess from name that GraphExecutor is defined in torch/csrc/jit/graph_executor.{h, cpp}.

The constructor and run tells us that GraphExecutor is just a wrapper of GraphExecutorImpl:

So let’s move on to GraphExecutorImpl:

We see that the graph is compiled at the first time it runs to get an execution plan. The run method of execution plan is called to run the graph. Compilation of graph to execution plan is done by getOrCompile or getOrCompileFallback depending on if optimization is enabled. These two methods are copied below:

These code explain itself well: if optimization is turned off, then we only run required passes and cache the result. Otherwise, depending on the characteristic of inputs (ArgumentSpec), we run full optimization and cache the generated plan for each different ArgumentSpec. The plan is created by the constructor of ExecutionPlan.

It worth a look at what passes are called:

I will not go deep into these passes here, interested readers can read them at torch/csrc/jit/passes/.

Now it’s time to look at ExecutionPlan:

It just convert the graph into an object of Code, and the running is done by InterpreterState.

# Compiling to Interpreter Instructions

Code and InterpreterState are defined in torch/csrc/jit/interpreter.{h,cpp}. These two classes are just a wrapper of its implementations:

CodeImpl is a long struct, but quite logical. A selected list of fields it has is listed below:

Its constructor is:

Clearly we can see what it does is: 1. preprocess the graph, and then 2. emit instructions for interpreter.

The preprocessing of graph is very well explained in the beginning of file:

as well as in its definition

The insertNodesFromBlock emits instructions. It is also very self-explained:

Since the nodes are topologically sorted, we just need to iterate the linked list and generate code for each node.

# The Virtual Machine

InterpreterStateImpl is the virtual machine that executes instructions.

There is nothing special, just mimicking the behavior of processors. We can easily tell from the above code that the actions is defined at Instruction::callback and branching is implemented as returning a non-zero value from that callback function. Some of the callbacks are defined inside CodeImpl, such as:

while others are defined by its node kind:

where getOperation is defined in torch/csrc/jit/operator.{h, cpp}. Further reading through these two files, we can see that operations are registered by calling registerOperator, which is done through calling RegisterOperators. Using grep RegisterOperators -r torch/csrc/, we can locate the definition of all operations:

At this point, we are done with getting the whole big picture of PyTorch’s JIT. It’s time to stop here, and interested readers can read the code by themselves for more details.