-
Couldn't load subscription status.
- Fork 5.9k
Mxnet Graph构建分析
这篇文章详细比较了Symbolic和Imperative两种神经网络编程风格的优缺点:
-
Symbolic风格效率更高。 原因在于构建完计算图之后可以做优化,比如inplace内存管理,operator合并等。而动态网络,因为中间结果可能随时会被用到,所以所有中间状态都需要保存下来。
- Symbol
- Graph
- Node
- NodeEntry
Symbol对外提供的操作自己的接口,基本都是在操作vector<NodeEntry> outputs,Graph也是操作同样的数据结构,所以他们通过这个outputs进行转换,后面会分析。
/*! \brief output entries contained in the symbol */ std::vector<NodeEntry> outputs;例如下面是创建一个VariableNode:
Symbol Symbol::CreateVariable(const std::string& name) { Symbol s; s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0}); return s; }Graph是内部逻辑表示,用户构建的Symbol会先转换成Graph,然后会有一些优化函数(optimize pass)对Graph进行优化,做一些InPlace, operator fusion之类的操作。对外最主要的成员是一个由NodeEntry组成的Vector output 内部为了优化方便,还有一个IndexedGraph
Symbol对外提供接口(包括python),帮助用户构建计算图,Graph内部使用,构建完的Symbol需要先转换成Graph,然后执行运算。 Symbol和Graph之间的构建关系,从下面可以看出,他们之间通过outputs进行沟通,而outputs都是std::vector类型的。
例如:
def test_infer_shape(): x = sym.Variable('x', shape=(4, 2)) y = sym.add(x, x, name='add1') y = sym.reshape(y, target=(2, 4), name="reshape1") g = graph.create(y) g._set_json_attr("shape_attr_key", "shape") g = g.apply('InferShape') jgraph = json.loads(g.apply('SaveJSON').json_attr('json')) jnodes = jgraph['nodes'] jnode_row_ptr = jgraph['node_row_ptr'] nindex = {n['name']: i for i, n in enumerate(jnodes)} assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4] assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]GraphExecutor初始化:
nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types);可见都是先构建Symbol,然后将基于Symbol Create一个Graph出来,再做后边的事情。
int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) { Graph* g = new Graph(); API_BEGIN(); g->outputs = static_cast<Symbol*>(symbol)->outputs; *graph = g; API_END_HANDLE_ERROR(delete g); }graph由Node组成:
* \brief Node represents an operation in a computation graph.Node中包含一个成员变量NodeAttrs,可以从NodeAttrs中获取到Op,而Op在注册的时候注册了一个AttriBute比如上边的AddKernel。于是和计算绑定在一起了。可以通过string ==> OpKernel<gpu> 找到这个对应的kernel。
Variable也是Node,特点是他的op为nullptr。
inline bool Node::is_variable() const { return this->op() == nullptr; }an entry that represents output data from a node
/*! \brief an entry that represents output data from a node */ struct NodeEntry { /*! \brief the source node of this data */ NodePtr node; /*! \brief index of output from the source. */ uint32_t index; /*! * \brief version of input Variable. * This field can only be nonzero when this->node is a Variable node. * version is increased by one each time a Variable get composed to a mutation Op. * This information can be helpful to decide order of operations when sequence of mutation happens. */ uint32_t version; };kernel也是作为一种通用attr注册进去的,
attr是一个string到any的映射,装了很多种东西:
std::unordered_map<std::string, std::unique_ptr<any> > attr;// registeration of oeprators // NOTE that the attr function can register any // additional attributes to the operator NNVM_REGISTER_OP(add) .describe("add two inputs together") .set_num_inputs(2) .set_attr<OpKernel>("OpKernel<gpu>", AddKernel) .include("ElementwiseOpAttr");set_attr实际上调用了UpdateAttrMap:
有一个全局的OpManager,负责管理key:attr对,attr的类型为
std::unordered_map<std::string, std::unique_ptr<any> > attr;// update attribute map void Op::UpdateAttrMap(const std::string& key, std::function<void(any*)> updater) { OpManager* mgr = OpManager::Global(); std::lock_guard<std::recursive_mutex>(mgr->mutex); std::unique_ptr<any>& value = mgr->attr[key]; if (value.get() == nullptr) value.reset(new any()); if (updater != nullptr) updater(value.get()); }