从头开始阅读PyTorch代码 -- Operators篇

这篇是阅读PyTorch源代码整理的笔记,方便以后翻阅。这里主要是想知道PyTorch的operators的定义都是怎么组织的,以及如果要添加新的operator的话,该怎么做。

由于pytorch开发非常活跃,代码也在不断地更改,本文内容跟读者实际看到的最新的代码肯定也有所区别。如果读者想要查看作者写文时候的代码,可以在pytorch仓库中:

1
git checkout 14cbd9adb8efafbd51444fdd88d6a34e6438b1c5

__init__.pysetup.py

比较不错的着手点是torch这个模块的__init__.py跟安装用的setup.py。 把两个文件都浏览一遍有个大体的概念。然后在__init__.py里面搜__all__,通过观察这些operators是怎么被添加到__all__里面去的,就能知道我们用的所有的那些个operators是怎么来的了。

__all__发现的关键的一段是:

1
2
3
4
5
from torch._C import *
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]

这段干的事情就是把torch._C中定义的各种东西按需加入__all__里面去,所以要想知道哪些operators是怎么来的,还是需要去看torch._C。从名字一看就知道,torch._C这个东西,是PyTorch的用C/C++之类的语言写的那一部分。这就涉及到这一部分是怎么构建的了,这个要从setup.py里面翻找,发现的相关的代码段如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
main_sources = [
"torch/csrc/PtrWrapper.cpp",
#......
]
#.....
extensions = []
packages = find_packages(exclude=('tools', 'tools.*', 'caffe2', 'caffe2.*', 'caffe', 'caffe.*'))
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c++',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')],
)
extensions.append(C)

基本上就可以断定torch._C这个东西就是从torch/csrc/这个目录里面的一大堆文件编译出来的了。torch/csrc/里面文件一大堆,从哪里着手是个问题。因为torch._C是python的一个模块,那么肯定得有地方通过python的C-binding创建这个模块才是,这就是个不错的着手点。要找到这个模块是从哪里创建的,这时候就要祭出grep大法了,在torch/csrc/这个目录里面运行这样一条命令:

1
grep 'torch._C' -r .

就可以得到所有有关的代码行了,发现的最像的一行是:

1
./Module.cpp: ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));

这一行一看就是在初始化这个模块,而且文件名叫做Module.cpp也很符合,那下一步就从这里开始好了。__init__.pysetup.py也可以退出我们的历史舞台了。

另外要注意,在施展grep大法之前,一定要先把PyTorch给编译一遍,因为PyTorch的很多代码是编译的时候根据其他文件生成的,带着生成的文件一起查找比较好。

Module.cpp跟autograd

把这个文件从头到尾浏览一遍,基本上就可以断定初始化模块是在initModule里面完成的了,这个函数里面初始化了一大堆东西,主要还是找找具体的哪一行是负责初始化那堆operators的。注意到initModule里面有好多类似

1
2
3
#ifdef WITH_CUDA
// do something ....
#endif

这种,这种就可以直接跳过不读了,因为不管你是否启用了这一堆的feature,那些个operators都是存在的,那就当你没启用这堆好了。还有一堆东西,看名字就不相关,就直接不用搭理了。所以到最后基本上筛选出来的看起来可能是的也只有:

1
2
3
4
5
THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
//....
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));

那就先从TorchMethods看起,这个东西就定义在Module.cpp里面,看一眼就会知道跟operators没啥关系。torch::autograd::python_functions()这个东西,使用grep大法,可以发现他的定义位于torch/csrc/autograd/init.cpp,也不是啥想要的,继续看THPVariable_initModule,使用grep大法,就会发现这个函数是在torch/csrc/autograd/python_variable.cpp里面定义的,翻看一下这个函数的定义,注意到下面这行:

1
THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);

整个函数定义里面,也只有这一行最像是定义operators的了,于是继续深挖,继续使用grep大法:

1
grep variable_methods -r .

就会发现这个变量是定义在torch/csrc/下的autograd/generated/python_variable_methods.cpp里面的,打开看看,发现如下的内容:

1
2
3
4
PyMethodDef variable_methods[] = {
{"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
//......
}

这就是一个长长的列表,列举了所有的的operators,每个operator对应一个THPVariable_开头的函数定义在同一个文件里面,而这个文件的开头,则说明了这个文件是从tools/autograd/templates/python_variable_methods.cpp这个模板生成而来。

浏览一下所有的THPVariable_开头的函数,就会发现所有的不同的这些个函数都大同小异,基本上核心部分只有下面的内容:

1
wrap(dispatch_xxxxx(...));

其中dispatch_xxxxx应该就是xxxxx这个operator的核心实现部分。继续动用grep大法挖,只需要随便挑选一个operator搜索就行了,例如:

1
grep dispatch_acos -r .

搜了就会发现这些个dispatch_开头的函数,是定义在同目录以下的python_variable_methods_dispatch.h文件里面的。翻开这个文件,浏览一下这些dispatch函数的定义,都大同小异,下面摘录其中一个:

1
2
3
4
5
6
inline Tensor dispatch_add(Tensor & self, Scalar alpha, const Tensor & other) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.add(other, alpha);
}

从代码发现,这些个operator,实际上是Tensor这个类的成员函数,所以我们就知道,下一步应该挖的,就是Tensor这个类了。除此以外,还有一个很重要的东西就是搞明白代码生成的原理,这样就能知道代码生成器是怎样找到这些operators的定义,进而生成这些函数的了。

Tensor这个类的出处在python_variable_methods_dispatch.h文件的头部可以找到:

1
using at::Tensor;

由此可见Tensor是ATen里面定义的,由此看来autograd也基本要退出我们的历史舞台了,轮到ATen登场了。

ATen

要学习ATen其实非常简单,在aten目录里面乱扒乱翻一通,挨个文件夹都点开瞅两眼,把所有的README.md都读一遍,就会发现,实际上ATen的算符是怎么定义的,实际上,已经在aten/src/ATen/native/README.md文件中,进行了非常详细的说明。

综合各个README.md的信息,并简单总结一下,就是:PyTorch的算符,都是定义在ATen里面的,而ATen里面的算符的实现,一部分是从老的Lua Torch继承而来,这一部分的代码,位于aten/src/TH*这些个目录里面,这些都是历史遗留的遗产,继承过来直接用,并不是PyTorch最终想要的operator的实现方式。最终“好”的实现方式,是在aten/src/ATen/native/目录里面。很多算符,也已经在这个目录下,被重新实现了一遍。这些老的算符的列表,是在aten/src/ATen/Declarations.cwrap中定义的。而新的算符的列表的定义,是在aten/src/ATen/native/native_functions.yaml中。本文只去探讨新的算符的实现方式。

到了这里,想要继续扒一下新的算符实现的,就需要去扒一下native_functions.yaml这个文件是怎么被读取的了。在PyTorch的根目录下,继续用grep大法,搜索关键字native_functions.yaml,得到的结果中,一条看起来很像是我们想要的结果的是:

1
./aten/src/ATen/gen.py: native_files = filter_by_extension(options.files, 'native_functions.yaml')

打开gen.py这个文件查看,就会发现下面比较有趣的代码:

1
2
3
4
5
TEMPLATE_PATH = options.source_path + "/templates"
# ......
TENSOR_DERIVED_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorDerived.h")
TENSOR_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Tensor.h")
TENSOR_METHODS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorMethods.h")

以及:

1
2
3
4
5
6
7
8
9
10
11
12
13
def generate_outputs():
cwrap_files = filter_by_extension(options.files, '.cwrap')
nn_files = filter_by_extension(options.files, 'nn.yaml', '.h')
native_files = filter_by_extension(options.files, 'native_functions.yaml')
declarations = [d
for file in cwrap_files
for d in cwrap_parser.parse(file)]
declarations += nn_parse.run(nn_files)
declarations += native_parse.run(native_files)
declarations = preprocess_declarations.run(declarations)
# ......

继续在这附近翻找上下文,就会发现,ATen的代码生成,是通过gen.py等的Python脚本,解析之前说过的那若干个列表文件,然后根据aten/src/ATen/templates/目录下的文件生成的。这些文件,都是模板,不长,全都浏览一遍就是了。阅读的过程,就会发现非常多的重要信息,比如在Tensor.h中有:

1
2
3
4
5
6
7
8
9
namespace at {
// ......
struct Tensor : public detail::TensorBase {
// ......
${tensor_method_declarations}
// ......
};
// ......
} // namespace at

以及TensorMethods.h中的:

1
2
3
4
5
6
7
8
9
10
namespace at {
inline Tensor & Tensor::operator=(Tensor const & rhs) && {
return copy_(rhs);
}
// ......
// all static inline to allow for inlining of the non-dynamic part of dispatch
${tensor_method_definitions}
// ......
} // namespace at

基本上,上面的看完,就已经知道,Tensor类是怎么定义的了。最后一步,就是看一下tensor_method_definitions是怎么填充的了。

继续动用grep大法扒,在PyTorch的根目录用grep搜索tensor_method_definitions,会得到下面有意思的结果:

1
./aten/src/ATen/function_wrapper.py: top_env['tensor_method_definitions'].append(

看了这个,直接跳到function_wrapper.py文件去扒,发现下面一段:

1
2
3
4
5
6
if is_method:
top_env['tensor_method_declarations'].append(
TENSOR_METHOD_DECLARATION.substitute(env))
top_env['tensor_method_definitions'].append(
TENSOR_METHOD_DEFINITION.substitute(env))
method_of.append('Tensor')

而上面代码中的TENSOR_METHOD_DECLARATIONTENSOR_METHOD_DEFINITION,就定义在同一个文件中:

1
2
3
4
5
6
7
8
9
10
# add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
""")
# add non-virtual declaration to Tensor.cpp
TENSOR_METHOD_DEFINITION = CodeTemplate("""\
inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
return type().${api_name}(${method_actuals});
}
""")

至此,基本上,整个ATen的代码生成,基本上算是扒完了,具体做事情的时候,再去返回这些涉及到的文件查看即可。

本文完结