这篇是阅读PyTorch源代码整理的笔记,方便以后翻阅。这里主要是想知道PyTorch的operators的定义都是怎么组织的,以及如果要添加新的operator的话,该怎么做。
由于pytorch开发非常活跃,代码也在不断地更改,本文内容跟读者实际看到的最新的代码肯定也有所区别。如果读者想要查看作者写文时候的代码,可以在pytorch仓库中:
git checkout 14cbd9adb8efafbd51444fdd88d6a34e6438b1c5
__init__.py
跟setup.py
比较不错的着手点是torch
这个模块的__init__.py
跟安装用的setup.py
。 把两个文件都浏览一遍有个大体的概念。然后在__init__.py
里面搜__all__
,通过观察这些operators是怎么被添加到__all__
里面去的,就能知道我们用的所有的那些个operators是怎么来的了。
搜__all__
发现的关键的一段是:
from torch._C import *
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
这段干的事情就是把torch._C
中定义的各种东西按需加入__all__
里面去,所以要想知道哪些operators是怎么来的,还是需要去看torch._C
。从名字一看就知道,torch._C
这个东西,是PyTorch的用C/C++之类的语言写的那一部分。这就涉及到这一部分是怎么构建的了,这个要从setup.py
里面翻找,发现的相关的代码段如下:
main_sources = [
"torch/csrc/PtrWrapper.cpp",
#......
]
#.....
extensions = []
packages = find_packages(exclude=('tools', 'tools.*', 'caffe2', 'caffe2.*', 'caffe', 'caffe.*'))
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c++',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + [make_relative_rpath('lib')],
)
extensions.append(C)
基本上就可以断定torch._C
这个东西就是从torch/csrc/
这个目录里面的一大堆文件编译出来的了。torch/csrc/
里面文件一大堆,从哪里着手是个问题。因为torch._C
是python的一个模块,那么肯定得有地方通过python的C-binding创建这个模块才是,这就是个不错的着手点。要找到这个模块是从哪里创建的,这时候就要祭出grep大法了,在torch/csrc/
这个目录里面运行这样一条命令:
grep 'torch._C' -r .
就可以得到所有有关的代码行了,发现的最像的一行是:
./Module.cpp: ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));
这一行一看就是在初始化这个模块,而且文件名叫做Module.cpp
也很符合,那下一步就从这里开始好了。__init__.py
跟setup.py
也可以退出我们的历史舞台了。
另外要注意,在施展grep大法之前,一定要先把PyTorch给编译一遍,因为PyTorch的很多代码是编译的时候根据其他文件生成的,带着生成的文件一起查找比较好。
Module.cpp
跟autograd
把这个文件从头到尾浏览一遍,基本上就可以断定初始化模块是在initModule
里面完成的了,这个函数里面初始化了一大堆东西,主要还是找找具体的哪一行是负责初始化那堆operators的。注意到initModule
里面有好多类似
#ifdef WITH_CUDA
// do something ....
#endif
这种,这种就可以直接跳过不读了,因为不管你是否启用了这一堆的feature,那些个operators都是存在的,那就当你没启用这堆好了。还有一堆东西,看名字就不相关,就直接不用搭理了。所以到最后基本上筛选出来的看起来可能是的也只有:
THPUtils_addPyMethodDefs(methods, TorchMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
//....
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
那就先从TorchMethods
看起,这个东西就定义在Module.cpp
里面,看一眼就会知道跟operators没啥关系。torch::autograd::python_functions()
这个东西,使用grep大法,可以发现他的定义位于torch/csrc/autograd/init.cpp
,也不是啥想要的,继续看THPVariable_initModule
,使用grep大法,就会发现这个函数是在torch/csrc/autograd/python_variable.cpp
里面定义的,翻看一下这个函数的定义,注意到下面这行:
THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);
整个函数定义里面,也只有这一行最像是定义operators的了,于是继续深挖,继续使用grep大法:
grep variable_methods -r .
就会发现这个变量是定义在torch/csrc/
下的autograd/generated/python_variable_methods.cpp
里面的,打开看看,发现如下的内容:
PyMethodDef variable_methods[] = {
{"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL},
//......
}
这就是一个长长的列表,列举了所有的的operators,每个operator对应一个THPVariable_
开头的函数定义在同一个文件里面,而这个文件的开头,则说明了这个文件是从tools/autograd/templates/python_variable_methods.cpp
这个模板生成而来。
浏览一下所有的THPVariable_
开头的函数,就会发现所有的不同的这些个函数都大同小异,基本上核心部分只有下面的内容:
wrap(dispatch_xxxxx(...));
其中dispatch_xxxxx
应该就是xxxxx
这个operator的核心实现部分。继续动用grep大法挖,只需要随便挑选一个operator搜索就行了,例如:
grep dispatch_acos -r .
搜了就会发现这些个dispatch_
开头的函数,是定义在同目录以下的python_variable_methods_dispatch.h
文件里面的。翻开这个文件,浏览一下这些dispatch函数的定义,都大同小异,下面摘录其中一个:
inline Tensor dispatch_add(Tensor & self, Scalar alpha, const Tensor & other) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.add(other, alpha);
}
从代码发现,这些个operator,实际上是Tensor
这个类的成员函数,所以我们就知道,下一步应该挖的,就是Tensor
这个类了。除此以外,还有一个很重要的东西就是搞明白代码生成的原理,这样就能知道代码生成器是怎样找到这些operators的定义,进而生成这些函数的了。
Tensor这个类的出处在python_variable_methods_dispatch.h
文件的头部可以找到:
using at::Tensor;
由此可见Tensor是ATen里面定义的,由此看来autograd也基本要退出我们的历史舞台了,轮到ATen登场了。
ATen
要学习ATen其实非常简单,在aten
目录里面乱扒乱翻一通,挨个文件夹都点开瞅两眼,把所有的README.md
都读一遍,就会发现,实际上ATen的算符是怎么定义的,实际上,已经在aten/src/ATen/native/README.md
文件中,进行了非常详细的说明。
综合各个README.md
的信息,并简单总结一下,就是:PyTorch的算符,都是定义在ATen里面的,而ATen里面的算符的实现,一部分是从老的Lua Torch继承而来,这一部分的代码,位于aten/src/TH*
这些个目录里面,这些都是历史遗留的遗产,继承过来直接用,并不是PyTorch最终想要的operator的实现方式。最终“好”的实现方式,是在aten/src/ATen/native/
目录里面。很多算符,也已经在这个目录下,被重新实现了一遍。这些老的算符的列表,是在aten/src/ATen/Declarations.cwrap
中定义的。而新的算符的列表的定义,是在aten/src/ATen/native/native_functions.yaml
中。本文只去探讨新的算符的实现方式。
到了这里,想要继续扒一下新的算符实现的,就需要去扒一下native_functions.yaml
这个文件是怎么被读取的了。在PyTorch的根目录下,继续用grep大法,搜索关键字native_functions.yaml
,得到的结果中,一条看起来很像是我们想要的结果的是:
./aten/src/ATen/gen.py: native_files = filter_by_extension(options.files, 'native_functions.yaml')
打开gen.py
这个文件查看,就会发现下面比较有趣的代码:
TEMPLATE_PATH = options.source_path + "/templates"
# ......
TENSOR_DERIVED_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorDerived.h")
TENSOR_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Tensor.h")
TENSOR_METHODS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorMethods.h")
以及:
def generate_outputs():
cwrap_files = filter_by_extension(options.files, '.cwrap')
nn_files = filter_by_extension(options.files, 'nn.yaml', '.h')
native_files = filter_by_extension(options.files, 'native_functions.yaml')
declarations = [d
for file in cwrap_files
for d in cwrap_parser.parse(file)]
declarations += nn_parse.run(nn_files)
declarations += native_parse.run(native_files)
declarations = preprocess_declarations.run(declarations)
# ......
继续在这附近翻找上下文,就会发现,ATen的代码生成,是通过gen.py
等的Python脚本,解析之前说过的那若干个列表文件,然后根据aten/src/ATen/templates/
目录下的文件生成的。这些文件,都是模板,不长,全都浏览一遍就是了。阅读的过程,就会发现非常多的重要信息,比如在Tensor.h
中有:
namespace at {
// ......
struct Tensor : public detail::TensorBase {
// ......
${tensor_method_declarations}
// ......
};
// ......
} // namespace at
以及TensorMethods.h
中的:
namespace at {
inline Tensor & Tensor::operator=(Tensor const & rhs) && {
return copy_(rhs);
}
// ......
// all static inline to allow for inlining of the non-dynamic part of dispatch
${tensor_method_definitions}
// ......
} // namespace at
基本上,上面的看完,就已经知道,Tensor类是怎么定义的了。最后一步,就是看一下tensor_method_definitions
是怎么填充的了。
继续动用grep大法扒,在PyTorch的根目录用grep搜索tensor_method_definitions
,会得到下面有意思的结果:
./aten/src/ATen/function_wrapper.py: top_env['tensor_method_definitions'].append(
看了这个,直接跳到function_wrapper.py
文件去扒,发现下面一段:
if is_method:
top_env['tensor_method_declarations'].append(
TENSOR_METHOD_DECLARATION.substitute(env))
top_env['tensor_method_definitions'].append(
TENSOR_METHOD_DEFINITION.substitute(env))
method_of.append('Tensor')
而上面代码中的TENSOR_METHOD_DECLARATION
跟TENSOR_METHOD_DEFINITION
,就定义在同一个文件中:
# add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
""")
# add non-virtual declaration to Tensor.cpp
TENSOR_METHOD_DEFINITION = CodeTemplate("""\
inline ${return_type} Tensor::${api_name}(${method_formals})${const_mark} {
return type().${api_name}(${method_actuals});
}
""")
至此,基本上,整个ATen的代码生成,基本上算是扒完了,具体做事情的时候,再去返回这些涉及到的文件查看即可。
本文完结