本文主要基于MXNet1.6.0版本进行分析。

上一篇文章中,我们分析了MXNet中KVStore的进程内通信机制。在这篇文章中,我们主要分析KVStore如何进行多节点分布式通信。

在KVStore的实现中,KVStoreDistKVStoreDistServer分别对应参数服务器中的worker节点与server节点。KVStoreDist继承自KVStoreLocal,通过封装PS-Lite中的KVWorker实现了PushPull等接口,从而向server发送各类请求;而KVStoreDistServer则封装了PS-Lite中的KVServer,用来处理并响应worker发来的各类请求。

worker端执行逻辑

worker创建

KVStoreDist的构造函数为每个worker节点创建一个ps::KVWorker<char>类型的对象。如果当前worker节点不是一个recovery的节点,那么就阻塞到所有的worker和server启动。

explicit KVStoreDist(bool use_device_comm)
: KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
if (IsWorkerNode()) {
int new_customer_id = GetNewCustomerId();
ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
ps::StartAsync(new_customer_id, "mxnet\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
new_customer_id,
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

worker的初始化过程

在初始化时,每个worker首先检查key的唯一性,随后调用comm_->Init为每个key初始化进行本地通信的资源。本地初始化完成后,worker0把自己本地的权重发送给所有的server。worker0在其push操作完成后,会将数据写入到comm_buf_compr_buf_这两个缓冲区中。

void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
CheckUnique(keys);
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
}
if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) {
Push_(keys, values, 0, false);
// wait until the push is finished
for (const int key : keys) {
comm_buf_[key].WaitToWrite();
compr_buf_[key].WaitToWrite();
}
} else {
// do nothing
}
if (!ps::Postoffice::Get()->is_recovery()) {
Barrier();
}
}

worker发送控制消息

worker端通过SendCommandToServers函数向server端发送控制消息。例如,在KVStoreDist的析构函数中有如下代码,用来从worker0节点向所有server节点发送一个终止的命令。

if (get_rank() == 0 && ps_worker_->get_customer()->customer_id() == 0) {
// stop the executor at servers
SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
}

worker发送数据消息

worker会调用Push_函数向server发送数据请求,它的核心逻辑如下所示(省略部分代码)。与之前提到的本地通信类似,在向server节点发送数据之前,会先调用GroupPairsPush把具有相同key的value汇总到一个vector中。对于每个key,先在本地进行一次Reduce操作聚合所有设备上的梯度,并将结果存放到comm_buf中。随后,通过EncodeDefaultKey把key和value编码成PS-Lite支持的数据结构,再调用PushDefault把对应的数据发送出去。

void KVStoreDist::Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority,
bool do_merge) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray>> grouped_val;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_val, false); for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const auto& vals = grouped_vals[i];
NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
// Start of a push doesn't guarantee that the previous pushes are completed.
// This shouldn't affect training of networks though because training involves
// a sequence of push, pull, then push. This imposes ordering that the
// second push happens after the first pull, and the pull happens after first push.
comm_buf = merged; // avoid memory copy
} else {
if (comm_buf.is_none()) {
comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
}
CopyFromTo(merged, &comm_buf);
}
const int dtype = merged.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), num_bytes);
PushDefault(key, comm_buf, pskv, priority);
}
}

PushDefault会调用ps_worker_->ZPush来完成梯度的发送,梯度发送以及发送之前的一些准备操作都被封装到一个lambda表达式中,这个lambda表达式随后被压入到MXNet后端的依赖引擎中等待执行。

void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
auto push_to_servers =
[this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
const int dtype = send_buf.dtype();
// convert to ps keys
const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
char* data = static_cast<char *>(send_buf.data().dptr_);
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
};
Engine::Get()->PushAsync(
push_to_servers,
pinned_ctx_,
{send_buf.var()},
{},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultPush");
}

Pull操作的过程如下所示。在准备工作完成后,调用ps_server_->ZPull完成权重的拉取,最后在本地执行Broadcast把从server端拉回的权重广播到所有设备上。

void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority, bool ignore_sparse) override {
CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true); for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
// use the same array for merging to guarantee that pull always happens
// after the previous push on this key
auto& recv_buf = comm_buf_[key];
const auto storage_type = grouped_vals[i][0]->storage_type();
CHECK_EQ(storage_type, kDefaultStorage)
<< "Expected stype of value to be kDefaultStorage";
if (recv_buf.is_none()) {
// it may happen for the first time a no-rank-0 worker pull the weight.
recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
true, grouped_vals[i][0]->dtype());
}
auto pull_from_servers = [this, key, recv_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
const int dtype = recv_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = EncodeDefaultKey(key, size, num_bytes) :
char* data = static_cast<char*> (recv_buf.data().dptr_);
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
RequestType mode = RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
}; CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull"); comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
}
}

server端执行逻辑

server的创建以及启动

首先在KVStoreDistServer的构造函数中为ps_server_绑定处理命令请求的CommandHandle以及处理数据请求的DataHandleEx。注意到在绑定CommandHandle时,ps_server_被向上转型成ps::SimpleApp*类型。这是因为ps::SimpleApp中实现的set_request_handle只能接收包含两个形参的函数对象,而ps::KVServer继承了ps::SimpleApp并且重载了set_request_handle,使之可以接收包含三个形参的函数对象。这样一来,就完成了对控制请求和数据请求的分开处理。

KVStoreDistServer() {
using namespace std::placeholders;
ps_server_ = new ps::KVServer<char>(0);
static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
ps_server_->set_request_handle(
std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
sync_mode_ = false;
gradient_compression_ = std::make_shared<GradientCompression>();
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

处理控制请求

server接收到worker0发来的命令后,会根据命令的类型,执行不同的操作。例如,当worker发来StopServer的命令后,server就会被停止。相应的命令执行完毕后,server会发送一个响应给worker0。注意这里负责发送响应的不是ps::KVWorker<char>类型的对象,而是ps::SimpleApp类型的对象。

void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
CommandType recved_type = static_cast<CommandType>(recved.head);
switch (recved_type) {
case CommandType::kStopServer:
exec_.Stop();
break;
case CommandType::kSyncMode:
sync_mode_ = true;
break;
case CommandType::kSetGradientCompression:
gradient_compression_->DecodeParams(recved.body);
break;
case CommandType::kSetProfilerParams:
// last char is the type of profiler command
ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
(recved.body.back() - '0'),
recved.body);
break;
case CommandType::kSetMultiPrecision:
// uses value 1 for message id from frontend
if (!multi_precision_) {
multi_precision_ = true;
CreateMultiPrecisionCopies();
}
break;
case CommandType::kController:
// this uses value 0 for message id from frontend
// let the main thread to execute ctrl, which is necessary for python
exec_.Exec([this, recved]() {
CHECK(controller_);
controller_(recved.head, recved.body);
});
break;
}
app->Response(recved);
}

处理数据请求

前面提到,DataHandleEx被注册为处理数据请求的函数,它会根据数据请求类型去调用不同的处理函数。默认情况下会调用DataHandleDefalut,该函数会对worker发来的push和pull请求分开处理。当worker节点push梯度到server时,如果某个key是第一次被push,那么server会为相应的key申请内存空间;否则会根据sync_mode_的值分别进行处理。在sync_mode_ == true(即同步训练模式)的情况下,所有worker上的梯度会被聚合到update_buf_[key].merged中;而在异步训练模式下,server把从某个worker接收的梯度放在update_buf_[key].temp_array中。随后,worker发来的push请求信息会被记录到update_buf_[key].request中。待上面的工作完成后,会调用ApplyUpdates函数去更新key对应的模型参数。当worker节点向server节点发送pull请求时,server会直接调用DefaultStorageResponse把server节点最新的模型参数发送给worker。

void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
const ps::KVPairs<char>& req_data, ps::KVServer<char>* server) {
int key = DecodeKey(req_data.keys[0]);
auto& stored = store_[key];
if (req_meta.push) { // push operation
size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)};
mxnet::TShape dshape(ds, ds + 1);
TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
})
NDArray recved = NDArray(recv_blob, 0);
if (stored.is_none()) { // the first push request
// initialization
stored = NDArray(dshape, Context(), false, type.dtype);
CopyFromTo(recved, &stored, 0);
server->Response(req_meta);
stored.WaitToRead();
} else {
auto& updates = update_buf_[key];
if (sync_mode_ && updates.merged.is_none() {
updates.merged = NDArray(dshape, Context(), false, type.dtype);
}
if (updates.request.empty()) { // the first
if (sync_mode_) {
CopyFromTo(recvd, updates.merged);
} else { // async training
updates.temp_array = recved;
}
} else {
updates.merged += recved;
}
updates.request.push_back(req_meta);
ApplyUpdates(type, key, req_data, &updates, server);
} else { // pull operation
DefaultStorageResponse(type, key, req_meta, req_data, server);
}
}

函数ApplyUpdates实现了模型权重更新的核心逻辑。如果是异步训练模式,或者当前的update_buf中的push请求数量等于worker的数量(意味着server收到了所有worker上的梯度),那么就会执行参数的更新过程;否则就不进行更新,直接调用server->Response给worker发一个不带任何数据的响应消息,表示收到了相关的数据。如果server端设置了更新器updater_,那么就会在server端执行更新操作;否则,server只对梯度进行聚合。如下代码的7~16行描述了这一过程,更新或聚合的结果会被存放到store_[key]中。由于update_buf_[key].request中保存的请求既有可能是push,也有可能是pushpull(唯独不可能是pull,因为我们只在req_meta.push==true时才把req_meta加入到update_buf_[key].request中),因此我们还要额外处理pushpull这类请求。对于update_buf_[key].request中的每个请求,如果该请求req.pull==true,那么就调用DefaultStorageResponse把模型权重传输给worker。在更新过程完成后,update_buf_[key].request就会被清空,以等待下一次更新。

inline void ApplyUpdates(const DataHandleType type, const int key,
const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
ps::KVServer<char>* server) {
if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
// let the main thread to execute updater_, which is necessary for python
auto& stored = store_[key];
auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array;
if (updater_) { // update_on_kvstore == True
exec_.Exec([this, key, &update, &stored](){
CHECK(updater_);
updater_(key, update, &stored);
});
} else { // update_on_kvstore == False, only support for sync mode
CHECK(sync_mode_) << "Updater needs to be set for async mode";
// if no updater, just copy
CopyFromTo(update_buf->merged, &stored);
}
/**
* Request can be for either push or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
bool has_pull = false;
for (const auto& req : update_buf->request) {
has_pull = has_pull || req.pull;
}
if (has_pull) {
// if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
stored.WaitToRead();
for (const auto& req : update_buf->request) {
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
}
}
update_buf->request.clear();
} else {
// otherwise, send response directly
for (const auto& req : update_buf->request) {
server->Response(req);
}
update_buf->request.clear();
stored.WaitToRead();
}
} else { // donot perform update operation
update_buf->merged.WaitToRead();
}
}

DefaultStorageResponse会根据传入的req_metareq_data这两个参数针对worker的push请求构建出对应的带数据的响应消息。响应是一个ps::KVPairs<char>类型的对象,其中的数据部分拷贝自store_[key]。响应对象构建完成后,同样会调用server->Response将消息发回对应的worker。

void DefaultStorageResponse(const DataHandleType type,
const int key,
const ps::KVMeta& req_meta,
const ps::KVPairs<char> &req_data,
ps::KVServer<char>* server) {
ps::KVPairs<char> response;
const NDArray& stored = store_[key];
CHECK(!stored.is_none()) << "init " << key << " first"; auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
response.keys = req_data.keys;
response.lens = {len};
// TODO(mli) try to remove this CopyFrom
response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
server->Response(req_meta, response);
}

最新文章

  1. Android 急速发布项目到 JitPack
  2. 【学习笔记】JAva编程思想之多态
  3. 【腾讯GAD暑期训练营游戏程序班】游戏中的特效系统作业说明文档
  4. sas编程-日期相差计算函数 intnx
  5. SpringMVC学习笔记(三)
  6. 关于ListView中notifyDataSetChanged()刷新数据不更新原因
  7. 关于HTML5代码总结。
  8. spring3.2.8+quartz2.2.0(比较全,对比quartz1.x的配置)
  9. 一起学习 微服务(MicroServices)-笔记
  10. NK 1137: 石子合并问题
  11. RHEL7虚拟机中不重启的情况下加新硬盘及扩展根分区容量
  12. 接上一篇博客(解决-Dmaven.multiModuleProjectDirectory system property is not set. Check $M2_HOME environment variable and mvn script match. )
  13. MySQL 修改最大连接数
  14. RE:考勤系统的复盘
  15. Python全栈之路----常用模块----logging模块
  16. spring boot整合 springmvc+mybatis
  17. db mysql / mysql cluster 5.7.19 / my.cnf / thread_pool_stall_limit
  18. TCP/IP 笔记 - 地址解析协议
  19. c# Point不能输入小数
  20. [实战演练]Intel面试题目 - 进栈出栈顺序问题

热门文章

  1. SGU140. Integer Sequences
  2. 使用.NET 6开发TodoList应用(25)——实现RefreshToken
  3. 缓存一致性性协议MESI笔记
  4. php伪协议总结
  5. 常用Cron表达式范例
  6. day1 三位数各个位上的数字和
  7. java基础04-数据类型扩展及面试题
  8. 《剑指offer》面试题39. 数组中出现次数超过一半的数字
  9. 《剑指offer》面试题58 - I. 翻转单词顺序
  10. 《剑指offer》面试题64. 求1+2+…+n