
极市导读
本文对torch中的jit模块进行了详细的解读,主要介绍了jit的两种到处方式的使用例子、IR的形式、导出IR的两种方式的源码解读以及对IR优化的简单介绍。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
前言
-
一种新的计算图中间表示 (Intermediate Representation),之后简称为 IR. -
从 Python 代码导出IR的两种方法,即 trace 与 script. -
IR 优化以及 IR 的解释器(翻译为具体的运算 op).
-
jit 的简单介绍以及两种导出方式的使用例子 -
jit 中 IR 的形式 -
导出 IR 的两种方式,trace 与 script 的源码解读 -
IR 优化的简单介绍
1 jit 的简单介绍以及使用例子
JIT 简介
trace
import torchvision.models as modelsresnet = torch.jit.trace(models.resnet18(), torch.rand(1,3,224,224))output=resnet(torch.ones(1,3,224,224))print(output)output=resnet(torch.ones(1,3,224,224))resnet.save('resnet.pt')
graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):%1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)%1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)%1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)%1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)....%1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)%1202 : int = prim::Constant[value=1]()%1203 : int = prim::Constant[value=-1]()%input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203)%1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)return (%1557)
torch.jit.trace,参数为你需要导出的 model,以及合法输入 input,其大概原理恰如其名,便是跟踪模型 inference 过程,将模型对输入进行的操作逐一记录下来,并对应到 IR 的操作,从而得到原本模型 forward 的 IR。
if x > 2.0:r = torch.tensor(1.0)else:r = torch.tensor(2.0)return rftrace = torch.jit.trace(test, (torch.ones(1)))y = torch.ones(1) * 5print(ftrace(y))# results: tensor(2.)# 因为输入只走了的分支elsescript
@torch.jit.scriptdef foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rprint(foo.graph)print(foo(torch.Tensor([0]), torch.Tensor([1])))print(foo(torch.Tensor([1]), torch.Tensor([0])))graph(%x.1 : Tensor,%y.1 : Tensor):%3 : Tensor = aten::max(%x.1)%5 : Tensor = aten::max(%y.1)# 可以看到确实捕捉到了控制语句,%6 : Tensor = aten::gt(%3, %5)%7 : bool = aten::Bool(%6)%r : Tensor = prim::If(%7)block0():-> (%x.1)block1():-> (%y.1)return (%r)tensor([1.])tensor([1.])
torch.jit.script,其转换方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch 代码,通过语法分析解析你的逻辑为一棵语法树,然后转换为中间表示 IR。
两者结合
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()# torch.jit.trace produces a ScriptModule's conv1 and conv2self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))def forward(self, input):input = F.relu(self.conv1(input))input = F.relu(self.conv2(input))return inputscripted_module = torch.jit.script(MyModule())
2 带 control-flow (if-else, for-loop) 的,上 scripting
3 碰上 scripting 不能 handle 的语法,要么重写,要么把 tracing 和 scripting 合起来用(比如说只在有 control-flow 的代码用 scripting,其他用 tracing)
如何扩展
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);}更多可以参考官方教程
2 IR (torchscript)的基本表示
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with,Value,Type等
# %x.1 valuegraph(%x.1 : Tensor,%y.1 : Tensor):# aten::max 就是一个Node# Tensor: Type-TensorType%3 : Tensor = aten::max(%x.1)%5 : Tensor = aten::max(%y.1)%6 : Tensor = aten::gt(%3, %5)%7 : bool = aten::Bool(%6)%r : Tensor = prim::If(%7)# Blocksblock0():-> (%x.1)block1():-> (%y.1)return (%r)
3 导出 IR 的两种方式,trace 与 script
trace 实现
func,example_inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5,strict=True,_force_outplace=False,_module_class=None,_compilation_unit=_python_cu,):# 发现是nn.Module instacene forward, 追踪forwardif isinstance(func, torch.nn.Module):return trace_module(func,{"forward": example_inputs},None,check_trace,wrap_check_inputs(check_inputs),check_tolerance,strict,_force_outplace,_module_class,)# 传进来的是某个module instance的forwardif (hasattr(func, "__self__")and isinstance(func.__self__, torch.nn.Module)and func.__name__ == "forward"):return trace_module(func.__self__,{"forward": example_inputs},None,check_trace,wrap_check_inputs(check_inputs),check_tolerance,strict,_force_outplace,_module_class,)# 一个查找变量名的接口var_lookup_fn = _create_interpreter_name_lookup_fn(0)# C++ 入口traced = torch._C._create_function_from_trace(name, func, example_inputs, var_lookup_fn, strict, _force_outplace)# 检查traced 与 原func是否有差异if check_trace:if check_inputs is not None:_check_trace(check_inputs,func,traced,check_tolerance,strict,_force_outplace,False,_module_class,)else:_check_trace([example_inputs],func,traced,check_tolerance,strict,_force_outplace,False,_module_class,)return traced
traced = torch._C._create_function_from_trace(name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
std::pair<std::shared_ptr<TracingState>, Stack> trace(Stack inputs,const std::function<Stack(Stack)>& traced_fn,std::function<std::string(const Variable&)> var_name_lookup_fn,bool strict,bool force_outplace,Module* self) {try {auto state = std::make_shared<TracingState>();# setTracingState 将state 这个实例set下来,在之后计算节点get出来insert计算过程setTracingState(state);#state这个数据结构会在forward过程中存储trace到的计算过程if (self) {Value* self_value = state->graph->insertInput(0, "self")->setType(self->_ivalue()->type());gatherParametersAndBuffers(state, self_value, *self, {"__module"});}for (IValue& input : inputs) {input = addInput(state, input, input.type(), state->graph->addInput());}auto graph = state->graph;# 将python中的变量名解析函数绑定下来getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);getTracingState()->strict = strict;getTracingState()->force_outplace = force_outplace;# 开始forward,在计算发生时,会把计算记录到state中auto out_stack = traced_fn(inputs);// Exit a trace, treating 'out_stack' as the outputs of the trace. These// are the variables whose values will be computed upon subsequent// invocations of the trace.size_t i = 0;for (auto& output : out_stack) {// NB: The stack is in "reverse" order, so when we pass the diagnostic// number we need to flip it based on size.state->graph->registerOutput(state->getOutput(output, out_stack.size() - i));i++;}setTracingState(nullptr);if (getInlineEverythingMode()) {Inline(*graph);}FixupTraceScopeBlocks(graph, self);NormalizeOps(graph);return {state, out_stack};} catch (...) {tracer::abandon();throw;}}
Operator createOperatorFromC10_withTracingHandledHere(const c10::OperatorHandle& op) {return Operator(op, [op](Stack& stack) {const auto input_size = op.schema().arguments().size();const auto output_size = op.schema().returns().size();Node* node = nullptr;std::shared_ptr<jit::tracer::TracingState> tracer_state;// trace the input before unwrapping, otherwise we may lose// the input informationif (jit::tracer::isTracing()) {# 获取 tracer_statetracer_state = jit::tracer::getTracingState();auto symbol = Symbol::fromQualString(op.schema().name());const auto& graph = tracer::getTracingState()->graph;node = graph->create(symbol, 0);tracer::recordSourceLocation(node);const auto& args = op.schema().arguments();int i = 0;# 记录argsfor (auto iter = stack.end() - input_size; iter != stack.end();++iter, ++i) {// TODO we need to refactor graph APIs (e.g., addInputs)// appropriately; after that, we can get rid of the giant if-else// block we will clean this tech debt together in the following PRsauto type = args[i].type();if (type->kind() == TypeKind::OptionalType) {if (iter->isNone()) {Value* none = graph->insertNode(graph->createNone())->output();node->addInput(none);continue;} else {type = type->expect<OptionalType>()->getElementType();}}if (type->isSubtypeOf(TensorType::get())) {AT_ASSERT(iter->isTensor());tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());} else if (type->kind() == TypeKind::FloatType) {AT_ASSERT(iter->isDouble());tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());} else if (type->kind() == TypeKind::IntType) {AT_ASSERT(iter->isInt());tracer::addInputs(node, args[i].name().c_str(), iter->toInt());} else if (type->kind() == TypeKind::BoolType) {AT_ASSERT(iter->isBool());tracer::addInputs(node, args[i].name().c_str(), iter->toBool());} else if (type->kind() == TypeKind::StringType) {AT_ASSERT(iter->isString());tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());} else if (type->kind() == TypeKind::NumberType) {tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());} else if (type->kind() == TypeKind::ListType) {const auto& elem_type = type->expect<ListType>()->getElementType();if (elem_type->isSubtypeOf(TensorType::get())) {AT_ASSERT(iter->isTensorList());auto list = iter->toTensorVector();tracer::addInputs(node, args[i].name().c_str(), list);} else if (elem_type->kind() == TypeKind::FloatType) {AT_ASSERT(iter->isDoubleList());// NB: now, tracer doesn't support tracing double list. We add// special handling here, since in our case, we assume that all the// doubles in the list are constantsauto value = iter->toDoubleVector();std::vector<Value*> info(value.size());for (size_t value_index = 0; value_index < value.size();++value_index) {info[value_index] = graph->insertConstant(value[value_index]);tracer::recordSourceLocation(info[value_index]->node());}node->addInput(graph->insertNode(graph->createList(jit::FloatType::get(), info))->output());} else if (elem_type->kind() == TypeKind::IntType) {AT_ASSERT(iter->isIntList());tracer::addInputs(node, args[i].name().c_str(), iter->toIntVector());} else if (elem_type->kind() == TypeKind::BoolType) {AT_ASSERT(iter->isBoolList());tracer::addInputs(node, args[i].name().c_str(), iter->toBoolList().vec());} else {throw std::runtime_error("unsupported input list type: " + elem_type->str());}} else if (iter->isObject()) {tracer::addInputs(node, args[i].name().c_str(), iter->toObject());} else {throw std::runtime_error("unsupported input type: " + type->str());}}# node嵌入graphgraph->insertNode(node);jit::tracer::setTracingState(nullptr);}
script
def script(obj, optimize=None, _frames_up=0, _rcb=None):# fucntion 分支if hasattr(obj, "__script_if_tracing_wrapper"):obj = obj.__original_fn_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)# 检查重载_check_directly_compile_overloaded(obj)# 是否之前被script过了maybe_already_compiled_fn = _try_get_jit_cached_function(obj)if maybe_already_compiled_fn:return maybe_already_compiled_fn# 得到ast语法树ast = get_jit_def(obj, obj.__name__)if _rcb is None:_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)#c++ 入口,根据ast得到irfn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))# Forward docstrings= obj.__doc__# cache起来fn)return fn
def get_jit_def(fn, def_name, self_name=None):# 得到源代码的一些信息file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())sourcelines = normalize_source_lines(sourcelines)source = dedent_src ''.join(sourcelines)# dedent_src 为包含了要script函数的字符串dedent_src = dedent(source)# 调用python ast包将字符串解析为Python的astpy_ast = ast.parse(dedent_src)# 得到python类型注释type_line = torch.jit.annotations.get_type_line(source)#ctx中包含了函数所有原信息ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)fn_def = py_ast.body[0]# build_def将python 的ast 转化为torchjit 使用的ast格式return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
import astfunc_def= \"""def test(a):a = a + 2return a + 1"""results = ast.parse(func_def)
Binop具体为一个Add,left 是Name类型,id为 `a,right是Num,也就是2,这个Binop即解析的a = a + 2。
build_def是如何将 Python 的 ast 转化为自己需要的 ast 的。
buid_def
def build_def(ctx, py_def, type_line, def_name, self_name=None):....return Def(Ident(r, def_name),decl,build_stmts(ctx, body))
ctx 包含 source code 所有信息, body 是 Python ast 解析结果,那么build_stmts中应该包含我们想要的答案。
a+2为例看会怎么转换,这部分可见 frontend.py
StmtBuilder
from torch._C._jit_tree_views import (ClassDef, Ident, Stmt, Decl, Def, Var,EmptyTypeAnnotation, Param, ExprStmt, Assign,Delete, Return, Raise, Assert, AugAssign, While,For, If, Pass, Break, Continue, Apply, Dots, Select,TrueLiteral, FalseLiteral, NoneLiteral, Starred,ListLiteral, TupleLiteral, DictLiteral, Const,StringLiteral, ListComp, Attribute, BinOp, UnaryOp,SliceExpr, Subscript, TernaryIf, With, WithItem, Property,DictComp,)# jit中定义的ast基本结构def build_stmts(ctx, stmts):#发现其调用了`build_stmt`stmts = [build_stmt(ctx, s) for s in stmts]return list(filter(None, stmts))#`build_stmt` 是一个StmtBuilder()的instancebuild_stmt = StmtBuilder()build_expr = ExprBuilder()class Builder(object):def __call__(self, ctx, node):# 可见会根据解析出的ast的类型返回相应的build方法,从截图可以看到`a+2`是一个`Assign`类型# 因此会调用build_Assignmethod = getattr(self, 'build_' + node.__class__.__name__, None)if method is None:raise UnsupportedNodeError(ctx, node)return method(ctx, node)class StmtBuilder(Builder):def build_Assign(ctx, stmt):# 截图可以看到stmt.value是一个Binop# build_expr是ExprBuilder的INSTANCE,其会调用`build_BinOp`rhs = build_expr(ctx, stmt.value)lhs = [build_expr(ctx, x) for x in stmt.targets]return Assign(lhs, rhs)def build_Expr(ctx, stmt):# Binopvalue = stmt.valueif value.__class__.__name__ == 'Str':# If a statement is a string literal expression,# then it is a docstring. Just ignore it.return Noneelse:return ExprStmt(build_expr(ctx, value))class ExprBuilder(Builder):binop_map = {ast.Add: '+',ast.Sub: '-',ast.Mult: '*',ast.Div: '/',ast.Pow: '**',ast.Mod: '%',ast.FloorDiv: '//',ast.BitAnd: '&',ast.BitXor: '^',ast.BitOr: '|',ast.LShift: '<<',ast.RShift: '>>',}def build_BinOp(ctx, expr):#expr.left是个`Name`调用build_Namelhs = build_expr(ctx, expr.left)rhs = build_expr(ctx, expr.right)op = type(expr.op)# 转化为约定的代表运算类型的string 符号op_token = ExprBuilder.binop_map.get(op)return BinOp(op_token, lhs, rhs)
(def(ident test)(decl(list(param(ident a)(option)(option)(False)))(option))(list(assign(list (variable (ident a)))(option(+(variable (ident a))(const 2)))(option))(return(+(variable (ident a))(const 1)))))
static StrongFunctionPtr script_compile_function(const c10::QualifiedName& name,const Def& def,const FunctionDefaults& defaults,const ResolutionCallback& rcb) {auto cu = get_python_cu();#看来是get_python_cu这个类中的define函数完成的auto defined_functions = cu->define(QualifiedName(name.prefix()),/*properties=*/{},/*propResolvers=*/{},{def},{pythonResolver(rcb)},nullptr,true);TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);auto& defined = defined_functions[0];defined->setSchema(getSchemaWithNameAndDefaults(def.range(), defined->getSchema(), def.name().name(), defaults));StrongFunctionPtr ret(std::move(cu), defined);didFinishEmitFunction(ret);return ret;}# 发现只是wapper了下CompilationUnitinline std::shared_ptr<CompilationUnit> get_python_cu() {return py::module::import("torch.jit._state").attr("_python_cu").cast<std::shared_ptr<CompilationUnit>>();}#关于compilation_unit#/torch/csrc/jit/api/compilation_unit.h// for historic reasons, these are defined in ir_emitter.cpp// Returns the list of Functions just defined.std::vector<Function*> define(const c10::optional<c10::QualifiedName>& prefix,const std::vector<Property>& properties,const std::vector<ResolverPtr>& propResolvers,const std::vector<Def>& definitions,const std::vector<ResolverPtr>&defResolvers, /* determines how we handle freevariables in each definition*/// if non-null, the first argument to each def, is bound to this valueconst Self* self,// see [name mangling]bool shouldMangle = false);#实现在torch/csrc/jit/frontend/ir_emitter.cppstd::unique_ptr<Function> CompilationUnit::define(const c10::optional<QualifiedName>& prefix,const Def& def,const ResolverPtr& resolver,const Self* self,const std::unordered_map<std::string, Function*>& function_table,bool shouldMangle) const {auto _resolver = resolver;.....auto creator = [def, _resolver, self](Function& method) {....##核心代码to_irto_ir(def, _resolver, self, method);};auto fn = torch::make_unique<GraphFunction>(std::move(name), std::make_shared<Graph>(), creator);return fn;}
struct to_ir ,其输入中有 def,也就是 ast,_resolver 是 Python 中传过来的解析名字的函数,我们可以在内部找到关键部分
to_ir(const Def& def,ResolverPtr resolver_,const Self* self,Function& method) // method being constructed: method(method),graph(method.graph()),resolver(std::move(resolver_)),typeParser_(resolver),environment_stack(nullptr) {AT_ASSERT(resolver);pushFrame(graph->block(), /*starts_def=*/true);#emitDef 中会调用emitStatementsmethod.setSchema(emitDef(def, self, graph->block()));ConvertToSSA(graph);CanonicalizeModifiedLoops(graph);NormalizeOps(graph);runCleanupPasses(graph);}private:#在to_ir 的private中我们可以看到Graph Function这些我们之前介绍的IR的组成部分Function& method;std::shared_ptr<Graph> graph;ResolverPtr resolver;std::unordered_map<int64_t, Value*> integral_constants;#emitDef 中会调用emitStatementsFunctionSchema emitDef(const Def& def, const Self* self, Block* block) {......// bodyauto stmts_list = def.statements();emitStatements(stmts_list.begin(), stmts_list.end());........}void emitStatements(List<Stmt>::const_iterator begin,List<Stmt>::const_iterator end) {for (; begin != end; ++begin) {auto stmt = *begin;ErrorReport::CallStack::update_pending_range(stmt.range());switch (stmt.kind()) {case TK_IF:emitIf(If(stmt));break;case TK_WHILE:emitWhile(While(stmt));break;case TK_FOR:emitFor(For(stmt));break;case TK_ASSIGN:emitAssignment(Assign(stmt));.................break;default:throw ErrorReport(stmt)<< "Unrecognized statement kind " << kindToString(stmt.kind());}// Found an exit statement in this block. The remaining statements aren't// reachable so we don't emit them.if (exit_blocks.count(environment_stack->block()))return;}}我们可以看到根据stmt.kind(),会进入而各种emit里面,其中一定可以找到graph->insertNode(graph->create(.....));类似的操作,对应我们建立IR graph
-
在 module 原有的 init 结束后随即开始完整的 script forward 函数,替换涉及到的所有函数为 script 后的函数 -
如何正常访问原有的属性
class MyModule(torch.jit.ScriptModule):@torch.jit.script_methoddef f(self.x):return x * x@torch.jit.script_methoddef forward(self, x):return x + self.f(x)关于script_methoddef script_method(fn):_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")#暂时没有script,只是返回包含ast的nametuplereturn ScriptMethodStub(_rcb, ast, fn)ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))1 移除所有script_method属性被(@script_method修饰的方法),确保访问到的是script function2 修改module的_init_,确保module的self.param或者self.module初始化后立即编译所有的script_method,从而生成的instance的forward已经被替换class ScriptMeta(type):def __init__(cls, name, bases, attrs): # noqa: B902# cls ScriptMeta的instance,是一个类如ScriptModulecls._methods: Dict[str, Any] = {}cls._constants_set = set(getattr(cls, "__constants__", ()))for base in reversed(bases):# 还记得吗trace的module也是有一个_methods的属性for k, v in getattr(base, "_methods", {}).items():cls._methods[k] = vbase_constants = getattr(base, "_constants_set", set())cls._constants_set = cls._constants_set.union(base_constants)# 找到现在所有被@script_method修饰的方法,放到_method,并删除原有attr# init后之后统一scriptfor k, v in sorted(attrs.items()):if isinstance(v, ScriptMethodStub):delattr(cls, k)cls._methods[v.original_method.__name__] = voriginal_init = getattr(cls, "__init__", lambda self: None)# 此处实现了init结束后,调用create_script_module进行script@functools.wraps(original_init)def init_then_script(self, *args, **kwargs):# 此处的self为instancenum_methods = len(cls._methods)original_init(self, *args, **kwargs)added_methods_in_init = len(cls._methods) > num_methodsif type(self) == cls:# 选取需要script的methoddef make_stubs(module):cls = type(module)if hasattr(cls, "_methods"):return [v for k, v in sorted(cls._methods.items())]else:# infer_methods_to_compile 是一个选取要script函数的函数return infer_methods_to_compile(module)# 讲所有script_method一块编译为_actual_script_module属性self.__dict__["_actual_script_module"] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)# Delete the Python attributes that now shadow the ScriptModule# ones, so that __getattr__ and __setattr__ will properly find# the scripted versions.concrete_type = self._actual_script_module._concrete_typefor name in concrete_type.get_attributes():delattr(self, name)for name, _ in concrete_type.get_modules():delattr(self, name)for name in ("_parameters", "_buffers", "_modules"):delattr(self, name)cls.__init__ = init_then_script # type: ignorereturn super(ScriptMeta, cls).__init__(name, bases, attrs)class _CachedForward(object):def __get__(self, obj, cls):return self.__getattr__("forward") # type: ignoreclass ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignoredef __init__(self):super(ScriptModule, self).__init__()forward = _CachedForward()# 想访问module的attr,返回_actual_script_module的attrdef __getattr__(self, attr):if "_actual_script_module" not in self.__dict__:return super(ScriptModule, self).__getattr__(attr)return getattr(self._actual_script_module, attr)def __setattr__(self, attr, value):if "_actual_script_module" not in self.__dict__:# Unwrap torch.jit.Attribute into a regular setattr + recording# the provided type in __annotations__.## This ensures that if we use the attr again in `__init__`, it# will look like the actual value, not an instance of Attribute.if isinstance(value, Attribute):if "__annotations__" not in self.__class__.__dict__:self.__class__.__annotations__ = {}self.__annotations__[attr] = value.typevalue = value.valuereturn super(ScriptModule, self).__setattr__(attr, value)setattr(self._actual_script_module, attr, value)...
关于 getattribute vs getattr
4 IR优化的简单介绍
def test(x):# Dead code Eliminationfor i in range(1000):y = x + 1for i in range(100):#peephole optimizationx = x.t()x = x.t()return x.sum()opt_test = torch.jit.script(test)s = time()inputs = torch.ones(4,4).cuda()s = time()for i in range(10000):test(inputs)print(time()-s)# 95ss = time()for i in range(10000):opt_test(inputs)print(time()-s)# 0.13sprint(opt_test.graph)print(opt_test.graph_for(inputs))95.138237953186040.13010907173156738graph(%x.1 : Tensor):%22 : None = prim::Constant()%13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4%10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19%x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4block0(%i : int, %x.10 : Tensor):%x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12%x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12-> (%13, %x.7)%23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11return (%23)graph(%x.1 : Tensor):%1 : None = prim::Constant()%2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11return (%2)
关于 IR 计算图优化
文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里
GraphExecutor& get_executor() {return function_->get_executor();}
GraphExecutor::GraphExecutor(const std::shared_ptr<Graph>& graph,std::string function_name): pImpl(IsNewExecutorEnabled()? dynamic_cast<GraphExecutorImplBase*>(new ProfilingGraphExecutorImpl(graph,std::move(function_name))): dynamic_cast<GraphExecutorImplBase*>(new GraphExecutorImpl(graph, std::move(function_name)))) {}std::shared_ptr<Graph> GraphExecutor::graph() const {return pImpl->graph;}const ExecutionPlan& GraphExecutor::getPlanFor(Stack& inputs,size_t remaining_bailout_depth) {return pImpl->getPlanFor(inputs, remaining_bailout_depth);}std::shared_ptr<GraphExecutorImplBase> pImpl;.....关于GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cppconst ExecutionPlan& getOrCompile(const Stack& stack) {.....auto plan = compileSpec(spec);}}# compileSpec 会返回一个planExecutionPlan compileSpec(const ArgumentSpec& spec) {auto opt_graph = graph->copy();GRAPH_DUMP("Optimizing the following function:", opt_graph);arg_spec_creator_.specializeTypes(*opt_graph, spec);// Phase 0. Inline functions, then clean up any artifacts that the inliner// left in that may inhibit optimization.....runRequiredPasses(opt_graph);GRAPH_DEBUG("After runRequiredPasses, before ConstantPropagation\n", *opt_graph);// Phase 2. Propagate detailed information about the spec through the// graph (enabled more specializations in later passes).// Shape propagation sometimes depends on certain arguments being// constants, and constant propagation doesn't need shape// information anyway, so it's better to run it first.ConstantPropagation(opt_graph);GRAPH_DEBUG("After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);PropagateInputShapes(opt_graph);GRAPH_DEBUG("After PropagateInputShapes, before PropagateRequiresGrad\n",*opt_graph);PropagateRequiresGrad(opt_graph);GRAPH_DEBUG("After PropagateRequiresGrad, before runOptimization\n", *opt_graph);// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites// that we can still execute using autograd).runOptimization(opt_graph);.....各种优化return ExecutionPlan(opt_graph, function_name_);}
参考:
公众号后台回复“速查表”获取
21张速查表(神经网络、线性代数、可视化等)打包下载~
算法竞赛:算法offer直通车、50万总奖池!高通人工智能创新应用大赛等你来战!
技术干货:超简单正则表达式入门教程|22 款神经网络设计和可视化的工具大汇总
极视角动态:芜湖市湾沚区联手极视角打造核酸检测便民服务系统上线!|青岛市委常委、组织部部长于玉一行莅临极视角调研
点击阅读原文进入CV社区
获取更多技术干货

