- Notifications
You must be signed in to change notification settings - Fork 5.9k
mxnet::engine survey
mxnet::engine 主要包括如下实现:
- function 并行执行过程中的参数依赖问题
- 精确到 device 的多线程调度控制
除了具体实现之外,可以借鉴的设计思想:
- 每个 device 分配自己的任务队列和线程池,function 分配到具体 device 执行
- 便于更可控的性能调度
- 普通任务通过设置
device id分配到具体的 device 上执行 - 设立 high priority 专用线程池,不区分 device,所有 device 资源优先执行高优先任务
- CPU/GPU 间的拷贝操作单独拆开,用 IO 专用线程池专门负责,保证与计算任务间并发
- 每个 device 默认只设 1 个线程负责 IO,因为同一个 device 的 IO 无法支持高效并发
- 在实现一个复杂模块前,用一个 naive 的实现验证接口和基本功能
- 模块设立
profiler来追踪执行及性能情况,方便人工分析
这里完全参考官方文档[1]里的内容
mxnet::engine 的功能是,按照依赖关系执行多个 function,其执行有如下原则
- 有依赖关系的 function 必须依次执行
- 无依赖关系的 function 间并行执行
执行的主要 API 如下:
virtual void PushSync(Fn exec_fun, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutate_vars) = 0;-
threaded_engine.h/.cc, 基于线程池的 engine-
thread_engine_pooled所有 device 共用一个任务队列的实现 -
threaded_engine_perdevice.h/.cc每个 device 单独分配任务队列的实现,针对 CPU/GPU 性能方面的考虑
-
-
thread_pool.h, 一个简单线程池的实现
代码的逻辑是
-
engine.h做对外接口,其中提供一个单例static Engine* Engine::Get()来获取底层具体的 engine 实例; -
naive_engine.cc,threaded_engine_pooled.cc,threaded_engine_perdevice.cc三个文件实现了三种 engine; -
profiler.cc实现了class Profiler来追踪mxnet::engine运行中的信息,方便性能调优和 debug。
/*! * \brief Dependency engine that schedules operations. */ class MXNET_API Engine { public: /*! \brief callback on complete*/ typedef engine::CallbackOnComplete CallbackOnComplete; /*! \brief Synchronous operation to pass to engine. */ typedef std::function<void(RunContext)> SyncFn; /*! \brief Asynchronous operation to pass to engine. */ typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn; /*! \brief Variable pointer */ typedef engine::VarHandle VarHandle; /*! \brief Operator pointer */ typedef engine::OprHandle OprHandle; /*! * \brief Notify the engine about a shutdown, * This can help engine to print less messages into display. * * User do not have to call this function. * \return 0 when success, -1 when failure happens. */ virtual void NotifyShutdown() = 0; /*! * \brief Allocate a new variable, the variable can then * be used to schedule the operation concurrently via dependency * patterns. * \return The new variable allocated. */ virtual VarHandle NewVariable() = 0; /*! * \brief Create a new operator. The returned operator could be saved * externally so that it could be resued for scheduling. * \param fn The execution function. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param opr_name The operator name. * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, const char* opr_name = nullptr) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. * * The delete will not happen immediately, but will wait until all the * operations using this operator are completed. */ virtual void DeleteOperator(OprHandle op) = 0; /*! * \brief Push an operator to the engine. * \param op The operator to push. * \param exec_ctx Execution context. * \param priority Priority of the action, as hint to the engine. * \param profiling The variable indicate whether to profile this operator. */ virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter * on_complete that must be called when the execution * completes. * \param exec_ctx Execution context. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operator name. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) = 0; /*! * \brief Schedule the deletion of a variable. * * The delete will not happen immediately, but will wait until all the * operations depending on var are completed. * * \param delete_fn A function that will be called after the variable is * deleted. * \param exec_ctx Execution context. * \param var The variable to be deleted. */ virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) = 0; /*! * \brief Wait for a variable. * \param var The variable we should wait for. This function returns when the * variable is ready. */ virtual void WaitForVar(VarHandle var) = 0; /*! * \brief Wait until all the activity of engine finishes. */ virtual void WaitForAll() = 0; /*!\brief virtual destructor */ virtual ~Engine() noexcept(false) {} /*! * \return Engine singleton. */ static Engine* Get(); /*! * \brief Get shared pointer reference to engine singleton. * Most user should not call this function. * This function is called by another singleton X who requires * engine to be destructed after X. * * \return A shared pointer to Engine singleton. */ static std::shared_ptr<Engine> _GetSharedRef(); /*! * \brief Push an synchronous operation to the engine. * \param exec_fn Execution function that executes the operation. * \param exec_ctx Execution context. * \param const_vars The variables that current operation will use but not * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operator name. * \tparam SyncFn the synchronous function to be pushed. */ inline void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, const char* opr_name = nullptr) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); } /*! * \brief factory function to create OnComplete callback. * \param callback th static callback function. * \param param the paramter passed to callback. */ inline CallbackOnComplete CreateCallback( void (*callback)(Engine *, void *), void *param) { CallbackOnComplete ret; ret.callback_ = callback; ret.engine_ = this; ret.param_ = param; return ret; } }; // class Engine在 thread_engine.h 中包括了实现中的一些概念,比如
engine.h 中定义的 Var 用来管理依赖某个 variable 后多个 function 的先后操作关系。
class ThreadedVar final : public Var, public common::ObjectPoolAllocatable<ThreadedVar>其中, ThreadedVar 是一个 FIFO 链表 queue,链表中的每个节点是
/*! * \brief VersionedVarBlock that corresponding to a variable version. * This is a basic unit of LinkedList in the ThreadedVar. */ struct VersionedVarBlock : public common::ObjectPoolAllocatable<VersionedVarBlock> { /*! \brief next block in the LinkedList */ VersionedVarBlock* next{nullptr}; /*! \brief the operation this block triggers */ OprBlock* trigger{nullptr}; /*! \brief whether this operation is a write(mutate) operation. */ bool write{false}; /*! \brief define possible debug information */ DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock); }; // struct VersionedVarBlock每个 VersionedVarBlock 表示一个依赖该 Var 的 function, ThreadedVar 用一个链表表示 FIFO 队列,来管理所有的 VersionedVarBlock ,即依赖的 function。
链表的结构通过如下 member variable 表示:
VersionedVarBlock* pending_write_{nullptr};-
pending_write_指向链表队列中最前面(最旧)的 请求 Write 操作的VersionedVarBlock -
pedding_write_其实是链表的 HEAD,因为在 所有 Write 操作前的 Read 操作会直接调度执行, 并不会进入链表(参照 AppendReadDependency)
VersionedVarBlock* head_{nullptr};head_ 指向链表队列末尾的位置(名字太有迷惑性了。。), 当需要添加新的元素时只需要
head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block;ThreadedVar 调度 function 依赖关系的过程就是对 VersionedVarBlock 的链表的维护过程, 具体的管理过程包括如下 4 个 API:
-
void ThreadedVar::AppendReadDependency(OprBlock* opr_block);- 添加 Read 依赖
-
void ThreadedVar::AppendWriteDependency(OprBlock* opr_block);- 添加 Write 依赖
-
void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher)- Read 依赖完成
-
bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher)- Write 依赖完成
添加 Read 依赖的主要逻辑是
- 如果链表队列没有 padding 的 Write 操作依赖(
pending_write_ = nullptr)- 则根据规则 该 function 的 Read 依赖直接满足,通过
opr_block->decr_wait() - 该
opr_block无需加入到链表队列中
- 则根据规则 该 function 的 Read 依赖直接满足,通过
- 否则
- 乖乖 append 到队列末尾
inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { std::lock_guard<std::mutex> lock{m_}; if (pending_write_ == nullptr) { // invariant: is_ready_to_read() CHECK_GE(num_pending_reads_, 0); // STATE CHANGE ++num_pending_reads_; // decrease wait counter opr_block->decr_wait(); } else { auto&& new_var_block = VersionedVarBlock::New(); assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // append things to next. head_->next = new_var_block; head_->trigger = opr_block; head_ = new_var_block; } }其中, num_pedding_reads_ 只是一个 state,用于表示是否还有 Read 依赖,在判定能否删除该 Var 会用到。
添加 Write 依赖,由于 必然会产生 规则 中描述的 Read 和 Write 的问题, 因此必须要追加到队列末尾按顺序执行。
inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { auto&& new_var_block = VersionedVarBlock::New(); std::lock_guard<std::mutex> lock{m_}; // invariant. assert(head_->next == nullptr); assert(head_->trigger == nullptr); assert(head_->write == false); // attach to head. head_->next = new_var_block; head_->trigger = opr_block; head_->write = true; // check if it is ready to write if (pending_write_ == nullptr) { // invariant: is_ready_to_read() pending_write_ = head_; CHECK_GE(num_pending_reads_, 0); if (num_pending_reads_ == 0) { // STATE CHANGE opr_block->decr_wait(); num_pending_reads_ = kWriteTriggered; } } else { CHECK_NE(num_pending_reads_, 0); } head_ = new_var_block; }如果一个 Read 依赖完成,只需要修改 -- num_pending_reads 来确保 num_pending_reads 表示了最新的 pending 的 Read 依赖的操作的数目。
如果所有 pending 的 Read 操作均已满足,则接着开始满足下一个 Write 的依赖, 如果 Write 依赖对应的 function 所有的参数依赖都已经完毕( trigger->decr_wait() == 0 ) , 则将其 dispatch 到执行引擎中实际执行。
template <typename Dispatcher> inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { OprBlock *trigger = nullptr; { // this is lock scope std::lock_guard<std::mutex> lock{m_}; CHECK_GT(num_pending_reads_, 0); if (--num_pending_reads_ == 0) { if (pending_write_ != nullptr) { // STATE CHANGE trigger = pending_write_->trigger; num_pending_reads_ = kWriteTriggered; } } } if (trigger != nullptr && trigger->decr_wait() == 0) { dispatcher(trigger); } }由于 Write 依赖后面可能接了多个 Read 依赖,因此实现会复杂一些:
- 遍历链表知道找到下个 Write 依赖,用
end_of_read_chain表示 - 每发现一个 Read 依赖就将
num_pending_reads_ ++ - 旧的 Write 依赖用指针
old_pending_write表示, 两者之间全是 Read 依赖,while 循环并行满足其依赖
template <typename Dispatcher> inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; OprBlock* trigger_write = nullptr; { std::lock_guard<std::mutex> lock{m_}; // invariants assert(head_->next == nullptr); assert(pending_write_ != nullptr); CHECK_EQ(num_pending_reads_, kWriteTriggered); // 删掉当前 Write 依赖的 VersionedVarBlock,快速返回 if (to_delete_) { VersionedVarBlock *head = pending_write_->next; VersionedVarBlock::Delete(pending_write_); assert(head_ == head); VersionedVarBlock::Delete(head); return true; } // detach pending write old_pending_write = pending_write_; // search for chains to trigger end_of_read_chain = old_pending_write->next; // reset to 0 pending reads num_pending_reads_ = 0; while (end_of_read_chain != head_ && end_of_read_chain->write == false) { ++num_pending_reads_; end_of_read_chain = end_of_read_chain->next; } if (end_of_read_chain == head_) { pending_write_ = nullptr; } else { // check if there is pending reads, if not trigger write assert(end_of_read_chain->write == true); pending_write_ = end_of_read_chain; if (num_pending_reads_ == 0) { // mark write as already activated in this var num_pending_reads_ = kWriteTriggered; trigger_write = end_of_read_chain->trigger; } } } // This is outside of lock scope // Be very carful, pending_write_ and num_pending_reads_ // can change now, do not reply ont the two variables. // The linked list \in [old_pending_write, end_of_read_chain) // is already detached from this Var. // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); // dispatch all the events while (cur_head != end_of_read_chain) { if (cur_head->trigger->decr_wait() == 0) { dispatcher(cur_head->trigger); } auto prev = cur_head; cur_head = cur_head->next; assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } if (trigger_write != nullptr && trigger_write->decr_wait() == 0) { dispatcher(trigger_write); } return false; }首先给出存储 function 执行信息的 OprBlock,注意其中的 wait 字段表示,Opr 依赖的 Var 数目,当 wait==0 时, 表示所有的 Var 都可以满足了,此时对应的 function 就可以被 engine 真正执行了。
/*! * \brief Operation block in the scheduler. * Each OprBlock corresponds to an operation pushed to the engine. */ struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> { /*! * \brief wait number of pending tasks this OprBlock is waiting for. */ std::atomic<int> wait{0}; /*! \brief Pointer to information on performing real operation */ ThreadedOpr* opr{nullptr}; /*! \brief The context this operator */ Context ctx; /*! \brief priority of the function */ int priority; /*! \brief indicate whether to profile this operator */ bool profiling{false}; /*! \brief operator execution statistics */ OprExecStat *opr_stat; // define possible debug information DEFINE_ENGINE_DEBUG_INFO(OprBlock); /*! * \brief call this function to decrease the wait counter. * \return the wait counter after the decreasement. */ inline int decr_wait() { // chack invariant, avoid over trigger int ret = --wait; CHECK_GE(ret, 0); return ret; } }; // struct OprBlock总的调用接口,在 Push 一个 function 到 Engine 时
- 分析其 Var 的依赖关系,对
const_vars和mutate_vars分别调用AppendReadDependency和AppendWriteDependency构建依赖关系 -
opr_block->opr.wait记录依赖的参数数目 - 如果依赖直接满足,则执行之
- 否则将任务丢到 engine 的队列中,进入异步的等待
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; opr_block->wait.store(static_cast<int>( threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1)); opr_block->ctx = exec_ctx; opr_block->priority = priority; opr_block->profiling = profiling; ++pending_; // Add read dependencies. for (auto&& i : threaded_opr->const_vars) { i->AppendReadDependency(opr_block); } // Add write dependencies. for (auto&& i : threaded_opr->mutable_vars) { i->AppendWriteDependency(opr_block); } if (opr_block->decr_wait() == 0) { this->PushToExecute(opr_block, true); } }其中负责 function 的具体执行的是 PushToExecute 函数,其具体实现有两种:
-
threaded_engine_pooled.cc所有 device 共用一个 pool 的实现 -
threaded_engine_perdevice.cc区分 device 的 engine
这里的实现比 ThreadedEnginePerDevice 简单一些,大概逻辑是:
- 维护 2 个并发的任务队列,一个为 IO 任务, 一个为非 IO 任务
- 如果是
pusher_thread的 function,则立即执行,否则添加到对应的任务队列中
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { DoExecute(opr_block); } else { DoPushToQueue(opr_block); } }这里 pusher_thread ,如果为 true 则立即执行,否则添加到任务队列里,注意到 上小节中 engine 中 Push 中如此调用:
if (opr_block->decr_wait() == 0) { this->PushToExecute(opr_block, true); }就是对 Var 依赖的 opr_block 会首先被处理(check 依赖是否被满足啥的)。
mxnet 通过 engine.h 中定义的 FnProperty 将 function 分为以下 5 种
enum class FnProperty { /*! \brief Normal operation */ kNormal, /*! \brief Copy operation from GPU to other devices */ kCopyFromGPU, /*! \brief Copy operation from CPU to other devices */ kCopyToGPU, /*! \brief Prioritized sync operation on CPU */ kCPUPrioritized, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty不同的任务类型对计算/IO 资源的占用情况不同,会有不同的队列负责执行。
在 ThreadedEnginePooled 中安是否是 IO 任务将并发任务队列拆成:
-
io_task_queue, 负责 kCopyFromGPU, kCopyToGPU -
task_queue, 所有其他的类型
于是有 DoPushToQueue 中的实现:
/*! * \brief Push the operation to the queue. * \param opr_block The operator block. */ void DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { case FnProperty::kCopyFromGPU: case FnProperty::kCopyToGPU: { io_task_queue_.Push(opr_block); break; } default: { task_queue_.Push(opr_block); break; } }而两个任务队列的实现和线程池的细节如下:
dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_; dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_; ThreadPool thread_pool_; ThreadPool io_thread_pool_; void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) { OprBlock* opr_block; while (task_queue->Pop(&opr_block)) { DoExecute(opr_block); } }这里的线程池就是 engine/thread_pool.h 中的实现。
ThreadedEnginePerDevice 在 ThreadedEngine 的基础之上支持如下功能:
- 每个 device(GPU 卡/CPU 核?) 固定数目的线程数
- 对 IO 操作和高优先级操作分配不同的任务队列
- 针对 GPU,每个线程使用单独的 stream,互不影响
四个任务队列:
common::LazyAllocArray<ThreadWorkerBlock<kWorkerQueue> > cpu_normal_workers_; // cpu priority worker std::unique_ptr<ThreadWorkerBlock<kPriorityQueue> > cpu_priority_worker_; // workers doing normal works on GPU common::LazyAllocArray<ThreadWorkerBlock<kWorkerQueue> > gpu_normal_workers_; // workers doing copy works from/to GPU common::LazyAllocArray<ThreadWorkerBlock<kCopyQueue> > gpu_copy_workers_;这里 gpu_copy_workers_ 对应着 IO 操作的任务队列, cpu_normal_workers_ , gpu_normal_workers_ 和 gpu_copy_workers_ 均为每个 device 单独分配线程池。 cpu_priority_worker_ 不区分 device.
cpu_priority_worker_ 不区分 device 的目的是,利用所有的 CPU device 资源优先执行这些高优先的任务(类似常规的 CPU 多核并行程序), 而其他线程池区分 device 的目的是,各个 device 资源的追踪和充分利用,特别对于 GPU 这类。
其中类型 ThreadWorkerBlock 打包了 Queue 和 ThreadPool:
template<dmlc::ConcurrentQueueType type> struct ThreadWorkerBlock { // task queue on this task dmlc::ConcurrentBlockingQueue<OprBlock*, type> task_queue; // thread pool that works on this task std::unique_ptr<ThreadPool> pool; // destructor ~ThreadWorkerBlock() noexcept(false) { task_queue.SignalForKill(); } };主体接口 PushToExecute 和 ThreadedEngine 中的实现的逻辑类似:
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { const Context& ctx = opr_block->ctx; // pusher_thread 直接执行 if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(ctx.dev_id)); #endif } RunContext run_ctx; run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { // cpu 模式 if (ctx.dev_mask() == cpu::kDevMask) { // 如果是高优先级任务, 在 cpu_priority_worker_ 中执行 // 该队列不区分 device,在 CPU 多核上并发执行(空间 device 优先执行之) if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { // 否则乖乖仔 cpu_normal_workers_ 中分 device 执行 // 每个核会有自己的 thread pool ? int dev_id = ctx.dev_id; int nthread = cpu_worker_nthreads_; cpu_normal_workers_.Get(dev_id, [this, dev_id, nthread]() { auto blk = new ThreadWorkerBlock<kWorkerQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, blk] () { this->CPUWorker(blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } // GPU 模式 } else { CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); // GPU execution. FnProperty prop = opr_block->opr->prop; bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); int nthread = gpu_worker_nthreads_; int dev_id = ctx.dev_id; // IO 的 copy 操作,CPU <-> GPU 代价较大,需要单独线程异步去做 // 默认 1 个 device 上只分配 1 个 IO 线程,因为此处多线程拷贝也没效果 if (is_copy) { gpu_copy_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock<kCopyQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } else { // 是计算任务,则提交到 gpu 的计算队列中 gpu_normal_workers_.Get(dev_id, [this, dev_id, is_copy, nthread]() { auto blk = new ThreadWorkerBlock<kWorkerQueue>(); blk->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, blk] () { this->GPUWorker(dev_id, is_copy, blk); })); return blk; })->task_queue.Push(opr_block, opr_block->priority); } } } }其他实现基本跟 ThreadedEnginePooled 里的一致,最后给出 GPUWorker 的实现:
template<dmlc::ConcurrentQueueType type> inline void GPUWorker(int dev_id, bool is_copy_worker, ThreadWorkerBlock<type> *block) { #if MXNET_USE_CUDA // allocate stream mshadow::SetDevice<gpu>(dev_id); RunContext run_ctx; mshadow::Stream<gpu> *stream; // 每个 GPUWorker 会分配自己的 stream // 如果是 IO 的操作,直接分配显存 // 如果是正常的计算,则会按计算的方法分别分配 blas 活 cudnn 对应的显存 if (is_copy_worker) { stream = mshadow::NewStream<gpu>(false, false); } else { stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0); } run_ctx.stream = stream; // execute task OprBlock* opr_block; auto* task_queue = &(block->task_queue); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); } // Catch exception for CUDA driver shutdown MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(stream)); #endif }