欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

【TVM源码学习笔记】3.onnxnode转tvmrelayir

时间:2023-05-29

在GraphProto.from_onnx中,遍历onnx模型的各个节点,将节点转换为tvm ir:

for node in graph.node:# 获取算子类型和属性op_name = node.op_typeattr = self._parse_attr(node.attribute)# Create and populate input list.# 创建一个(算子)输入实例inputs = onnx_input()# 获取当前(onnx)节点的所有输入(name)for i in node.input:if i != "":inputs.append(self._nodes[self._renames.get(i, i)])else:# 有些输入没使用?inputs.append(None)i_name = self._parse_value_proto(node)# 获取onnx节点的输出,为string类型,是输出的namenode_output = self._fix_outputs(op_name, node.output)# 记录onnx节点的属性attr["tvm_custom"] = {}attr["tvm_custom"]["name"] = i_nameattr["tvm_custom"]["num_outputs"] = len(node_output)# 将onnx 算子节点转换为对应的tvm表示、 op = self._convert_operator(op_name, inputs, attr, opset)

还是以tvm源码里面的mnist onnx模型为例,我们可以在这段代码里打印当前处理的onnx节点name、输出以及转换后的op类型和数据等:

################################################onnx op node Times212_reshape1output: ['Parameter193_reshape1']convert to tvm op: free_var %Parameter193: Tensor[(16, 4, 4, 10), float32];reshape(%Parameter193, newshape=[256, 10])#####################################################################################################onnx op node Convolution28output: ['Convolution28_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5])#####################################################################################################onnx op node Plus30output: ['Plus30_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];add(%1, %Parameter6)#####################################################################################################onnx op node ReLU32output: ['ReLU32_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);nn.relu(%2)#####################################################################################################onnx op node Pooling66output: ['Pooling66_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0])#####################################################################################################onnx op node Convolution110output: ['Convolution110_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5])#####################################################################################################onnx op node Plus112output: ['Plus112_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];add(%6, %Parameter88)#####################################################################################################onnx op node ReLU114output: ['ReLU114_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];%7 = add(%6, %Parameter88);nn.relu(%7)#####################################################################################################onnx op node Pooling160output: ['Pooling160_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];%7 = add(%6, %Parameter88);%8 = nn.relu(%7);nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0])#####################################################################################################onnx op node Times212_reshape0output: ['Pooling160_Output_0_reshape0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];%7 = add(%6, %Parameter88);%8 = nn.relu(%7);%9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]);reshape(%9, newshape=[1, 256])#####################################################################################################onnx op node Times212output: ['Times212_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];%7 = add(%6, %Parameter88);%8 = nn.relu(%7);%9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]);free_var %Parameter193: Tensor[(16, 4, 4, 10), float32];%10 = reshape(%Parameter193, newshape=[256, 10]);%11 = reshape(%9, newshape=[1, 256]);%12 = transpose(%10, axes=[1, 0]);nn.dense(%11, %12, units=None, out_dtype="float32")#####################################################################################################onnx op node Plus214output: ['Plus214_Output_0']convert to tvm op: free_var %Input3: Tensor[(1, 1, 28, 28), float32];%0 = nn.pad(%Input3, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter5: Tensor[(8, 1, 5, 5), float32];%1 = nn.conv2d(%0, %Parameter5, padding=[0, 0, 0, 0], channels=8, kernel_size=[5, 5]);free_var %Parameter6: Tensor[(8, 1, 1), float32];%2 = add(%1, %Parameter6);%3 = nn.relu(%2);%4 = nn.max_pool2d(%3, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]);%5 = nn.pad(%4, 0f, pad_width=[[0, 0], [0, 0], [2, 2], [2, 2]]);free_var %Parameter87: Tensor[(16, 8, 5, 5), float32];%6 = nn.conv2d(%5, %Parameter87, padding=[0, 0, 0, 0], channels=16, kernel_size=[5, 5]);free_var %Parameter88: Tensor[(16, 1, 1), float32];%7 = add(%6, %Parameter88);%8 = nn.relu(%7);%9 = nn.max_pool2d(%8, pool_size=[3, 3], strides=[3, 3], padding=[0, 0, 0, 0]);free_var %Parameter193: Tensor[(16, 4, 4, 10), float32];%10 = reshape(%Parameter193, newshape=[256, 10]);%11 = reshape(%9, newshape=[1, 256]);%12 = transpose(%10, axes=[1, 0]);%13 = nn.dense(%11, %12, units=None, out_dtype="float32");free_var %Parameter194: Tensor[(1, 10), float32];add(%13, %Parameter194)#####################################################

我们可以看到返回的op数据不仅仅转换了当前onnx节点,还有将当前节点的输入节点也叠加进来了。打印inputs会看到,这个叠加的部分源自inputs参数。最终的模型的输出节点表示,是整个网络的计算过程。完成这些转换的_convert_operator函数代码如下:

def _convert_operator(self, op_name, inputs, attrs, opset): """Convert onNX operator into a Relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters ---------- op_name : str Operator name, such as Convolution, FullyConnected inputs : list of tvm.relay.function.Function List of inputs. 算子的输入,类型为tvm.relay.function.Function attrs : dict Dict of operator attributes opset : int Opset version 算子的版本号 Returns ------- sym : tvm.relay.function.Function Converted relay function """ # 获取onnx算子与tvm的映射表 convert_map = _get_convert_map(opset) # 如果当前onnx算子在_identity_list表中 if op_name in _identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) # 如果算子在映射表中 elif op_name in convert_map: #执行转换 sym = convert_map[op_name](inputs, attrs, self._params) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) # 返回转换结果 return sym

从这个函数里面可以看到,onnx算子在tvm中的映射有两个来源,一个是_get_convert_map返回的映射表,一个是_identity_list,如果这两个表中都没有,那么当前onnx算子就是不支持的。

搜索代码可以看到_identity_list在输入为onnx模型、caffe模型和tensorflow模型时都为空,只有在输入为mxnet时不为空,主要是一些数学计算函数,见mxnet.py:

# Note: due to attribute conversion constraint# ops in the identity set must be attribute free_identity_list = [ "abs", "log", "exp", "erf", "sqrt", "floor", "ceil", "round", "trunc", "sign", "sigmoid", "negative", "reshape_like", "zeros_like", "ones_like", "cos", "cosh", "sin", "sinh", "tan", "tanh", "where",]

猜测这个表里面的算子是在mxnet和tvm之间不用变换,直接可以使用的。对onnx模型而言_identity_list为空,所以所有算子的映射都来源于,_get_convert_map返回的映射表:

# _convert_map defines maps of name to converter functor(callable)# for 1 to 1 mapping, use Renamer if nothing but name is different# use AttrCvt if attributes need to be converted# for 1 to N mapping(composed), use custom callable functions# for N to 1 mapping, currently not supported(?)def _get_convert_map(opset): return { # defs/experimental "Identity": Renamer("copy"), "Affine": Affine.get_converter(opset), "BitShift": BitShift.get_converter(opset), "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), "Constant": Constant.get_converter(opset), "ConstantOfShape": ConstantOfShape.get_converter(opset), # 'GivenTensorFill' "FC": AttrCvt("dense", ignores=["axis", "axis_w"]), "Scale": Scale.get_converter(opset), ... }

从注释看,映射包括三种场景:

1、onnx算子和tvm算子仅仅只是名字不一致,这种是1对1映射的,调用tvm算子的Renamer即可。如果参数需要变换,就调用AttrCvt转换参数;

2、一个onnx算子由多个tvm算子组成,这种需要调用算子的get_converter函数;

3、多个onnx算子合成一个tvm算子,当前不支持。

我们以卷积算子为例,在_convert_operator中转换时调用了Conv.get_converter(opset)(inputs, attrs, self._params)。这里Conv类继承自OnnxOpConverter, 而get_converter是OnnxOpConverter的方法,定义如下:

class onnxOpConverter(object): """A helper class for holding onnx op converters.""" @classmethod def get_converter(cls, opset): """Get converter matches given opset. Parameters ---------- opset: int opset from model. Returns ------- converter, which should be `_impl_vx`、Number x is the biggest number smaller than or equal to opset belongs to all support versions. """ # dir(cls)得到的是类的属性,包括特殊成员变量, 普通成员变量和方法 # 这里是在这些属性名中查找有没有包含字符串_impl_v的.找到了就将_impl_v去掉, # 剩下部分转换为int类型、而各算子的_impl_vx属性是算子变换方法,x为整数表示版本 # 所以这里versions得到的是算子支持的所有版本号 versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d] # opset为当前传入的版本号,将这个参数版本号加入到版本号表中,并从小到大排序 versions = sorted(versions + [opset]) # max语句得到opset在版本号表中的下标,然后减1就是比opset的前一个.因为versions是排序过的, # 所以这个元素大于等于opset.所以这里得到的是仅次于(小于等于)opset的版本号 version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1] # 将这个得到的版本号和_impl_v拼接得到一个方法名.如果算子类有该方法,就返回该方法的句柄.否则就报错版本不支持 if hasattr(cls, "_impl_v{}".format(version)): return getattr(cls, "_impl_v{}".format(version)) raise NotImplementedError( "opset version {} of {} not implemented".format(version, cls.__name__) )

我们看下Conv类支持的_impl_vx方法:

class Conv(OnnxOpConverter): """Operator converter for Conv.""" @classmethod def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. # 从传入的inputs参数中获取输入和卷积核数据,并推导各自的形状 data = inputs[0] kernel = inputs[1] input_shape = infer_shape(data) ndim = len(input_shape) kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] # 如果onnx卷积属性中没有给出卷积核的形状,就使用inputs里面推导出来的形状 if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] # 如果onnx卷积算子设置了auto_pad属性 if "auto_pad" in attr: # 对用的tvm卷积算子也使用onnx设置的auto_pad属性值 attr["auto_pad"] = attr["auto_pad"].decode("utf-8") # 根据auto_pad属性值对数据进行填充处理 if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"): # Warning: Convolution does not yet support dynamic shapes, # one will need to run dynamic_to_static on this model after import # 对输入数据进行填充,得到填充后的数据 data = autopad( data, attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": attr["pads"] = [0 for i in range(ndim - 2)] elif attr["auto_pad"] == "NOTSET": pass else: msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"])) attr.pop("auto_pad") attr["channels"] = kernel_shapes[0][0] out = AttrCvt( # 返回的op_name是一个函数,在AttrCvt.__call__方法中调用该函数,根据当前attr中kernel_shape # 属性得到对应的TVM conv1d/conv2d/conv3d算子接口;然后算子接收([data, kernel], attr, params) # 参数, 返回转换后的TVM表示out op_name=dimension_picker("conv"), transforms={ "kernel_shape": "kernel_size", "dilations": ("dilation", 1), "pads": ("padding", 0), "group": ("groups", 1), }, custom_check=dimension_constraint(), )([data, kernel], attr, params) use_bias = len(inputs) == 3 # 如果输入中有偏置参数,则在表达式中添加偏置运算 if use_bias: out = _op.nn.bias_add(out, inputs[2]) return out

在_impl_v1中对卷积的输入数据,卷积核参数,以及填充做了初步的处理,然后创建一个AttrCvt实例。传入的参数op_name是一个函数,在AttrCvt.__call__方法中会调用该方法,参数为当前卷积的attr。根据attr中的kernel_shape参数,判断当前是1d/2d/3d卷积,得到对应的tvm算子名称conv1d/conv2d/conv3d;传入的transforms参数,用作AttrCvt.__call__中对当前attr和权重参数转换,会转换为tvm的卷积需要的参数形式;custom_check参数用于检查参数,这里对于卷积来说,是检查当前卷积维度是否合法(1d/2d/3d)。

AttrCvt.__call__方法大致流程是对参数进行检查,转换,然后调用get_relay_op得到算子对应的tvm接口函数,将当前算子的输入和变换后的参数输入接口,得到onnx node对应的tvm relay ir。

AttrCvt是一个公共类,不仅仅针对onnx模型,AttrCvt.__call__的详细处理流程比较复杂,目前也没搞太明白。get_relay_op接口代码:

def get_relay_op(op_name): """Get the callable function from Relay based on operator name. Parameters ---------- op_name : str The Relay operator name. """ if "." in op_name: # explicit hierarchical modules op = _op try: for opn in op_name.split("."): op = getattr(op, opn) except AttributeError: op = None else: # try search op in various modules for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib): op = getattr(candidate, op_name, None) if op is not None: break if not op: raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name)) return op

这个函数中处理两种格式的op_name,一种是以点做分隔符,比如tvm.relay.op.nn.conv2d; else分支处理的是以‘/’为分隔符的路径形式,例如tvm/relay/op/nn/conv2d。两个分支都是去掉路径,得到最后的算子接口名称(如conv2d)。python/tvm/relay/op/下是tvm relay算子的实现文件存放目录,该目录下有nn、image、vision、contrib等目录,分别存放各类算子。在mnist-8.onnx模型解析时,传入的op_name直接是算子接口名,没有分隔符,所以这里直接返回了。

从打印的tvm relay ir看,这里的conv2d是nn.conv2d,代码如下:

def conv2d( data, weight, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, channels=None, kernel_size=None, data_layout="NCHW", kernel_layout="OIHW", out_layout="", out_dtype="",): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(strides, int): strides = (strides, strides) if isinstance(dilation, int): dilation = (dilation, dilation) # TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged # convert 2-way padding to 4-way padding padding = get_pad_tuple2d(padding) return _make.conv2d( data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype, )

这里_make是在python/tvm/relay/op/nn/__init__.py中导入的同目录下的_make.py:

import tvm._ffitvm._ffi._init_api("relay.op.nn._make", __name__)

_init_api的定义在python/tvm/_ffi/registry.py中

def _init_api(namespace, target_module_name=None): """Initialize api for a given module name namespace : str The namespace of the source registry target_module_name : str The target module name if different from namespace """ target_module_name = target_module_name if target_module_name else namespace if namespace.startswith("tvm."): _init_api_prefix(target_module_name, namespace[4:]) else: _init_api_prefix(target_module_name, namespace)

这里传入的参数namespace为relay.op.nn._make, target_module_name参数为_make.py的__name__属性,即_make.py的路径tvm.relay.op.nn._make。这样传入_init_api_prefix的参数将是 tvm.relay.op.nn._make和relay.op.nn._make。

def _init_api_prefix(module_name, prefix): module = sys.modules[module_name] for name in list_global_func_names(): if not name.startswith(prefix): continue fname = name[len(prefix) + 1 :] target_module = module if fname.find(".") != -1: continue f = get_global_func(name) ff = _get_api(f) ff.__name__ = fname ff.__doc__ = "TVM PackedFunc %s、" % fname setattr(target_module, ff.__name__, ff)

 module = sys.modules[module_name]获取的是tvm.relay.op.nn._make模块的句柄,然后调用list_global_func_names()获取当前所有全局函数,这里的全局函数都是定义在C++中并注册到python端供调用的。然后将所有以prefix(relay.op.nn._make)开头的函数打包,调用setattr设置为tvm.relay.op.nn._make模块的属性。所以前面nn.py中conv2d调用的_make.conv2d其实是这里设置的C++接口在python端的映射接口。

这里仅以卷积算子为例分析。tvm/relay/op下各类算子目录下都有_make.py文件,会给各模块设置对应算子的C++接口映射属性。这样走到各python端算子接口调用时,最终调用到C++端的对应实现。

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。