Skip to content

mxnet::engine survey

Tao Luo edited this page Dec 9, 2019 · 1 revision

总结

mxnet::engine 主要包括如下实现:

  • function 并行执行过程中的参数依赖问题
  • 精确到 device 的多线程调度控制

除了具体实现之外,可以借鉴的设计思想:

  • 每个 device 分配自己的任务队列和线程池,function 分配到具体 device 执行
    • 便于更可控的性能调度
  • 普通任务通过设置 device id 分配到具体的 device 上执行
  • 设立 high priority 专用线程池,不区分 device,所有 device 资源优先执行高优先任务
  • CPU/GPU 间的拷贝操作单独拆开,用 IO 专用线程池专门负责,保证与计算任务间并发
    • 每个 device 默认只设 1 个线程负责 IO,因为同一个 device 的 IO 无法支持高效并发
  • 在实现一个复杂模块前,用一个 naive 的实现验证接口和基本功能
  • 模块设立 profiler 来追踪执行及性能情况,方便人工分析

API 及功能

这里完全参考官方文档[1]里的内容

mxnet::engine 的功能是,按照依赖关系执行多个 function,其执行有如下原则

  • 有依赖关系的 function 必须依次执行
  • 无依赖关系的 function 间并行执行

执行的主要 API 如下:

virtual void PushSync(Fn exec_fun, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutate_vars) = 0;
  • threaded_engine.h/.cc , 基于线程池的 engine
    • thread_engine_pooled 所有 device 共用一个任务队列的实现
    • threaded_engine_perdevice.h/.cc 每个 device 单独分配任务队列的实现,针对 CPU/GPU 性能方面的考虑
  • thread_pool.h , 一个简单线程池的实现

代码的逻辑是

  • engine.h 做对外接口,其中提供一个单例 static Engine* Engine::Get() 来获取底层具体的 engine 实例;
  • naive_engine.cc , threaded_engine_pooled.cc , threaded_engine_perdevice.cc 三个文件实现了三种 engine;
  • profiler.cc 实现了 class Profiler 来追踪 mxnet::engine 运行中的信息,方便性能调优和 debug。

engine 接口

/*!  * \brief Dependency engine that schedules operations. */ class MXNET_API Engine { public: /*! \brief callback on complete*/ typedef engine::CallbackOnComplete CallbackOnComplete; /*! \brief Synchronous operation to pass to engine. */ typedef std::function<void(RunContext)> SyncFn; /*! \brief Asynchronous operation to pass to engine. */ typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn; /*! \brief Variable pointer */ typedef engine::VarHandle VarHandle; /*! \brief Operator pointer */ typedef engine::OprHandle OprHandle; /*!  * \brief Notify the engine about a shutdown,  * This can help engine to print less messages into display.  *  * User do not have to call this function.  * \return 0 when success, -1 when failure happens.  */ virtual void NotifyShutdown() = 0; /*!  * \brief Allocate a new variable, the variable can then  * be used to schedule the operation concurrently via dependency  * patterns.  * \return The new variable allocated.  */ virtual VarHandle NewVariable() = 0; /*!  * \brief Create a new operator. The returned operator could be saved  * externally so that it could be resued for scheduling.  * \param fn The execution function.  * \param const_vars The variables that current operation will use but not  * mutate.  * \param mutable_vars The variables that current operation will mutate.  * \param prop Property of the function.  * \param opr_name The operator name.  * \return The new operator allocated.  */ virtual OprHandle NewOperator(AsyncFn fn, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, const char* opr_name = nullptr) = 0; /*!  * \brief Delete the given operator.  * \param op The operator to delete.  *  * The delete will not happen immediately, but will wait until all the  * operations using this operator are completed.  */ virtual void DeleteOperator(OprHandle op) = 0; /*!  * \brief Push an operator to the engine.  * \param op The operator to push.  * \param exec_ctx Execution context.  * \param priority Priority of the action, as hint to the engine.  * \param profiling The variable indicate whether to profile this operator.  */ virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0; /*!  * \brief Push an asynchronous operation to the engine.  * \param exec_fun Execution function, this function takes a parameter  * on_complete that must be called when the execution  * completes.  * \param exec_ctx Execution context.  * \param const_vars The variables that current operation will use but not  * mutate.  * \param mutable_vars The variables that current operation will mutate.  * \param prop Property of the function.  * \param priority Priority of the action, as hint to the engine.  * \param opr_name The operator name.  */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) = 0; /*!  * \brief Schedule the deletion of a variable.  *  * The delete will not happen immediately, but will wait until all the  * operations depending on var are completed.  *  * \param delete_fn A function that will be called after the variable is  * deleted.  * \param exec_ctx Execution context.  * \param var The variable to be deleted.  */ virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) = 0; /*!  * \brief Wait for a variable.  * \param var The variable we should wait for. This function returns when the  * variable is ready.  */ virtual void WaitForVar(VarHandle var) = 0; /*!  * \brief Wait until all the activity of engine finishes.  */ virtual void WaitForAll() = 0; /*!\brief virtual destructor */ virtual ~Engine() noexcept(false) {} /*!  * \return Engine singleton.  */ static Engine* Get(); /*!  * \brief Get shared pointer reference to engine singleton.  * Most user should not call this function.  * This function is called by another singleton X who requires  * engine to be destructed after X.  *  * \return A shared pointer to Engine singleton.  */ static std::shared_ptr<Engine> _GetSharedRef(); /*!  * \brief Push an synchronous operation to the engine.  * \param exec_fn Execution function that executes the operation.  * \param exec_ctx Execution context.  * \param const_vars The variables that current operation will use but not  * mutate.  * \param mutable_vars The variables that current operation will mutate.  * \param prop Property of the function.  * \param priority Priority of the action, as hint to the engine.  * \param opr_name The operator name.  * \tparam SyncFn the synchronous function to be pushed.  */ inline void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); } /*!  * \brief factory function to create OnComplete callback.  * \param callback th static callback function.  * \param param the paramter passed to callback.  */ inline CallbackOnComplete CreateCallback( void (*callback)(Engine *, void *), void *param) { CallbackOnComplete ret; ret.callback_ = callback; ret.engine_ = this; ret.param_ = param; return ret; } }; // class Engine

threaded_engine 实现

thread_engine.h 中包括了实现中的一些概念,比如

Var

engine.h 中定义的 Var 用来管理依赖某个 variable 后多个 function 的先后操作关系。

class ThreadedVar final : public Var, public common::ObjectPoolAllocatable<ThreadedVar>

其中, ThreadedVar 是一个 FIFO 链表 queue,链表中的每个节点是

/*!  * \brief VersionedVarBlock that corresponding to a variable version.  * This is a basic unit of LinkedList in the ThreadedVar.  */ struct VersionedVarBlock : public common::ObjectPoolAllocatable<VersionedVarBlock> { /*! \brief next block in the LinkedList */ VersionedVarBlock* next{nullptr}; /*! \brief the operation this block triggers */ OprBlock* trigger{nullptr}; /*! \brief whether this operation is a write(mutate) operation. */ bool write{false}; /*! \brief define possible debug information */ DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock); }; // struct VersionedVarBlock

每个 VersionedVarBlock 表示一个依赖该 Var 的 function, ThreadedVar 用一个链表表示 FIFO 队列,来管理所有的 VersionedVarBlock ,即依赖的 function。

链表的结构通过如下 member variable 表示:

VersionedVarBlock* pending_write_{nullptr};
  • pending_write_ 指向链表队列中最前面(最旧)的 请求 Write 操作的 VersionedVarBlock
  • pedding_write_ 其实是链表的 HEAD,因为在 所有 Write 操作前的 Read 操作会直接调度执行, 并不会进入链表(参照 AppendReadDependency
VersionedVarBlock* head_{nullptr};

head_ 指向链表队列末尾的位置(名字太有迷惑性了。。), 当需要添加新的元素时只需要

head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block;

ThreadedVar 调度 function 依赖关系的过程就是对 VersionedVarBlock 的链表的维护过程, 具体的管理过程包括如下 4 个 API:

  • void ThreadedVar::AppendReadDependency(OprBlock* opr_block);
    • 添加 Read 依赖
  • void ThreadedVar::AppendWriteDependency(OprBlock* opr_block);
    • 添加 Write 依赖
  • void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher)
    • Read 依赖完成
  • bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher)
    • Write 依赖完成

AppendReadDependency

添加 Read 依赖的主要逻辑是

  • 如果链表队列没有 padding 的 Write 操作依赖( pending_write_ = nullptr
    • 则根据规则 该 function 的 Read 依赖直接满足,通过 opr_block->decr_wait()
    • opr_block 无需加入到链表队列中
  • 否则
    • 乖乖 append 到队列末尾
inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { std::lock_guard<std::mutex> lock{m_}; if (pending_write_ == nullptr) { // invariant: is_ready_to_read() CHECK_GE(num_pending_reads_, 0); // STATE CHANGE ++num_pending_reads_; // decrease wait counter opr_block->decr_wait(); } else { auto&& new_var_block = VersionedVarBlock::New(); assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // append things to next. head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block; } }

其中, num_pedding_reads_ 只是一个 state,用于表示是否还有 Read 依赖,在判定能否删除该 Var 会用到。

AppendWriteDependency

添加 Write 依赖,由于 必然会产生 规则 中描述的 Read 和 Write 的问题, 因此必须要追加到队列末尾按顺序执行。

inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { auto&& new_var_block = VersionedVarBlock::New(); std::lock_guard<std::mutex> lock{m_}; // invariant. assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // attach to head. head_->next = new_var_block; head_->trigger = opr_block; head_->write = true; // check if it is ready to write if (pending_write_ == nullptr) { // invariant: is_ready_to_read() pending_write_ = head_; CHECK_GE(num_pending_reads_, 0); if (num_pending_reads_ == 0) { // STATE CHANGE opr_block->decr_wait(); num_pending_reads_ = kWriteTriggered; } } else { CHECK_NE(num_pending_reads_, 0); } head_ = new_var_block; }

CompleteReadDependency

如果一个 Read 依赖完成,只需要修改 -- num_pending_reads 来确保 num_pending_reads 表示了最新的 pending 的 Read 依赖的操作的数目。

如果所有 pending 的 Read 操作均已满足,则接着开始满足下一个 Write 的依赖, 如果 Write 依赖对应的 function 所有的参数依赖都已经完毕( trigger->decr_wait() == 0 ) , 则将其 dispatch 到执行引擎中实际执行。

template <typename Dispatcher> inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { OprBlock *trigger = nullptr; { // this is lock scope std::lock_guard<std::mutex> lock{m_}; CHECK_GT(num_pending_reads_, 0); if (--num_pending_reads_ == 0) { if (pending_write_ != nullptr) { // STATE CHANGE trigger = pending_write_->trigger; num_pending_reads_ = kWriteTriggered; } } } if (trigger != nullptr && trigger->decr_wait() == 0) { dispatcher(trigger); } }

CompleteWriteDependency

由于 Write 依赖后面可能接了多个 Read 依赖,因此实现会复杂一些:

  • 遍历链表知道找到下个 Write 依赖,用 end_of_read_chain 表示
  • 每发现一个 Read 依赖就将 num_pending_reads_ ++
  • 旧的 Write 依赖用指针 old_pending_write 表示, 两者之间全是 Read 依赖,while 循环并行满足其依赖
template <typename Dispatcher> inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; OprBlock* trigger_write = nullptr; { std::lock_guard<std::mutex> lock{m_}; // invariants assert(head_->next == nullptr); assert(pending_write_ != nullptr); CHECK_EQ(num_pending_reads_, kWriteTriggered); // 删掉当前 Write 依赖的 VersionedVarBlock,快速返回 if (to_delete_) { VersionedVarBlock *head = pending_write_->next; VersionedVarBlock::Delete(pending_write_); assert(head_ == head); VersionedVarBlock::Delete(head); return true; } // detach pending write old_pending_write = pending_write_; // search for chains to trigger end_of_read_chain = old_pending_write->next; // reset to 0 pending reads num_pending_reads_ = 0; while (end_of_read_chain != head_ && end_of_read_chain->write == false) { ++num_pending_reads_; end_of_read_chain = end_of_read_chain->next; } if (end_of_read_chain == head_) { pending_write_ = nullptr; } else { // check if there is pending reads, if not trigger write assert(end_of_read_chain->write == true); pending_write_ = end_of_read_chain; if (num_pending_reads_ == 0) { // mark write as already activated in this var num_pending_reads_ = kWriteTriggered; trigger_write = end_of_read_chain->trigger; } } } // This is outside of lock scope // Be very carful, pending_write_ and num_pending_reads_ // can change now, do not reply ont the two variables. // The linked list \in [old_pending_write, end_of_read_chain) // is already detached from this Var. // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); // dispatch all the events while (cur_head != end_of_read_chain) { if (cur_head->trigger->decr_wait() == 0) { dispatcher(cur_head->trigger); } auto prev = cur_head; cur_head = cur_head->next; assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } if (trigger_write != nullptr && trigger_write->decr_wait() == 0) { dispatcher(trigger_write); } return false; }

Engine 总接口

首先给出存储 function 执行信息的 OprBlock,注意其中的 wait 字段表示,Opr 依赖的 Var 数目,当 wait==0 时, 表示所有的 Var 都可以满足了,此时对应的 function 就可以被 engine 真正执行了。

/*!  * \brief Operation block in the scheduler.  * Each OprBlock corresponds to an operation pushed to the engine.  */ struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> { /*!  * \brief wait number of pending tasks this OprBlock is waiting for.  */ std::atomic<int> wait{0}; /*! \brief Pointer to information on performing real operation */ ThreadedOpr* opr{nullptr}; /*! \brief The context this operator */ Context ctx; /*! \brief priority of the function */ int priority; /*! \brief indicate whether to profile this operator */ bool profiling{false}; /*! \brief operator execution statistics */ OprExecStat *opr_stat; // define possible debug information DEFINE_ENGINE_DEBUG_INFO(OprBlock); /*!  * \brief call this function to decrease the wait counter.  * \return the wait counter after the decreasement.  */ inline int decr_wait() { // chack invariant, avoid over trigger int ret = --wait; CHECK_GE(ret, 0); return ret; } }; // struct OprBlock

总的调用接口,在 Push 一个 function 到 Engine 时

  • 分析其 Var 的依赖关系,对 const_varsmutate_vars 分别调用 AppendReadDependencyAppendWriteDependency 构建依赖关系
  • opr_block->opr.wait 记录依赖的参数数目
  • 如果依赖直接满足,则执行之
    • 否则将任务丢到 engine 的队列中,进入异步的等待
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; opr_block->wait.store(static_cast<int>( threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1)); opr_block->ctx = exec_ctx; opr_block->priority = priority; opr_block->profiling = profiling; ++pending_; // Add read dependencies. for (auto&& i : threaded_opr->const_vars) { i->AppendReadDependency(opr_block); } // Add write dependencies. for (auto&& i : threaded_opr->mutable_vars) { i->AppendWriteDependency(opr_block); } if (opr_block->decr_wait() == 0) { this->PushToExecute(opr_block, true); } }

其中负责 function 的具体执行的是 PushToExecute 函数,其具体实现有两种:

  • threaded_engine_pooled.cc 所有 device 共用一个 pool 的实现
  • threaded_engine_perdevice.cc 区分 device 的 engine

ThreadedEnginePooled

这里的实现比 ThreadedEnginePerDevice 简单一些,大概逻辑是:

  • 维护 2 个并发的任务队列,一个为 IO 任务, 一个为非 IO 任务
  • 如果是 pusher_thread 的 function,则立即执行,否则添加到对应的任务队列中
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { DoExecute(opr_block); } else { DoPushToQueue(opr_block); } }

这里 pusher_thread ,如果为 true 则立即执行,否则添加到任务队列里,注意到 上小节中 engine 中 Push 中如此调用:

if (opr_block->decr_wait() == 0) { this->PushToExecute(opr_block, true); }

就是对 Var 依赖的 opr_block 会首先被处理(check 依赖是否被满足啥的)。

mxnet 通过 engine.h 中定义的 FnProperty 将 function 分为以下 5 种

enum class FnProperty { /*! \brief Normal operation */ kNormal, /*! \brief Copy operation from GPU to other devices */ kCopyFromGPU, /*! \brief Copy operation from CPU to other devices */ kCopyToGPU, /*! \brief Prioritized sync operation on CPU */ kCPUPrioritized, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty

不同的任务类型对计算/IO 资源的占用情况不同,会有不同的队列负责执行。

ThreadedEnginePooled 中安是否是 IO 任务将并发任务队列拆成:

  1. io_task_queue , 负责 kCopyFromGPU, kCopyToGPU
  2. task_queue , 所有其他的类型

于是有 DoPushToQueue 中的实现:

/*!  * \brief Push the operation to the queue.  * \param opr_block The operator block.  */ void DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { case FnProperty::kCopyFromGPU: case FnProperty::kCopyToGPU: { io_task_queue_.Push(opr_block); break; } default: { task_queue_.Push(opr_block); break; } }

而两个任务队列的实现和线程池的细节如下:

dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_; dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_; ThreadPool thread_pool_; ThreadPool io_thread_pool_; void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) { OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { DoExecute(opr_block); } }

这里的线程池就是 engine/thread_pool.h 中的实现。

ThreadedEnginePerDevice

ThreadedEnginePerDeviceThreadedEngine 的基础之上支持如下功能:

  • 每个 device(GPU 卡/CPU 核?) 固定数目的线程数
  • 对 IO 操作和高优先级操作分配不同的任务队列
  • 针对 GPU,每个线程使用单独的 stream,互不影响

四个任务队列:

common::LazyAllocArray<ThreadWorkerBlock<kWorkerQueue> > cpu_normal_workers_; // cpu priority worker std::unique_ptr<ThreadWorkerBlock<kPriorityQueue> > cpu_priority_worker_; // workers doing normal works on GPU common::LazyAllocArray<ThreadWorkerBlock<kWorkerQueue> > gpu_normal_workers_; // workers doing copy works from/to GPU common::LazyAllocArray<ThreadWorkerBlock<kCopyQueue> > gpu_copy_workers_;

这里 gpu_copy_workers_ 对应着 IO 操作的任务队列, cpu_normal_workers_ , gpu_normal_workers_gpu_copy_workers_ 均为每个 device 单独分配线程池。 cpu_priority_worker_ 不区分 device.

cpu_priority_worker_ 不区分 device 的目的是,利用所有的 CPU device 资源优先执行这些高优先的任务(类似常规的 CPU 多核并行程序), 而其他线程池区分 device 的目的是,各个 device 资源的追踪和充分利用,特别对于 GPU 这类。

其中类型 ThreadWorkerBlock 打包了 Queue 和 ThreadPool:

template<dmlc::ConcurrentQueueType type> struct ThreadWorkerBlock { // task queue on this task dmlc::ConcurrentBlockingQueue<OprBlock*, type> task_queue; // thread pool that works on this task std::unique_ptr<ThreadPool> pool; // destructor ~ThreadWorkerBlock() noexcept(false) { task_queue.SignalForKill(); } };

主体接口 PushToExecuteThreadedEngine 中的实现的逻辑类似:

void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { const Context& ctx = opr_block->ctx; // pusher_thread 直接执行 if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(ctx.dev_id)); #endif } RunContext run_ctx; run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { // cpu 模式 if (ctx.dev_mask() == cpu::kDevMask) { // 如果是高优先级任务, 在 cpu_priority_worker_ 中执行 // 该队列不区分 device,在 CPU 多核上并发执行(空间 device 优先执行之) if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { // 否则乖乖仔 cpu_normal_workers_ 中分 device 执行 // 每个核会有自己的 thread pool ? int dev_id = ctx.dev_id; int nthread = cpu_worker_nthreads_; cpu_normal_workers_.Get(dev_id, [this, dev_id, nthread]() { auto blk = new ThreadWorkerBlock<kWorkerQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, blk] () { this->CPUWorker(blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } // GPU 模式 } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); // GPU execution. FnProperty prop = opr_block->opr->prop; bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); int nthread = gpu_worker_nthreads_; int dev_id = ctx.dev_id; // IO 的 copy 操作,CPU <-> GPU 代价较大,需要单独线程异步去做 // 默认 1 个 device 上只分配 1 个 IO 线程,因为此处多线程拷贝也没效果 if (is_copy) { gpu_copy_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock<kCopyQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } else { // 是计算任务,则提交到 gpu 的计算队列中 gpu_normal_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock<kWorkerQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } } } }

其他实现基本跟 ThreadedEnginePooled 里的一致,最后给出 GPUWorker 的实现:

template<dmlc::ConcurrentQueueType type> inline void GPUWorker(int dev_id, bool is_copy_worker, ThreadWorkerBlock<type> *block) { #if MXNET_USE_CUDA // allocate stream mshadow::SetDevice<gpu>(dev_id); RunContext run_ctx; mshadow::Stream<gpu> *stream; // 每个 GPUWorker 会分配自己的 stream // 如果是 IO 的操作,直接分配显存 // 如果是正常的计算,则会按计算的方法分别分配 blas 活 cudnn 对应的显存 if (is_copy_worker) { stream = mshadow::NewStream<gpu>(false, false); } else { stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0); } run_ctx.stream = stream; // execute task OprBlock* opr_block; auto* task_queue = &(block->task_queue); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); } // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(stream)); #endif }

参考文献

  1. Dependency Engine for Deep Learning
  2. mxnet dep engine implemention
Clone this wiki locally