本文主要是介绍NCCL源码解析: proxy 线程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 前言
- 概括
- 详解
- 1. 用到的变量
- 2. proxy 线程创建
- 2.1 ncclProxyService()
- 2.2 proxyServiceInitOp()
- 2.2 proxyProgressAsync()
- 4. ncclProxyConnect()
- 4.1 ncclProxyCallBlocking()
- 4.2 ncclPollProxyResponse()
前言
NCCL 源码解析总目录
我尽量在每个函数之前介绍每个函数的作用,建议先不要投入到函数内部实现,先把函数作用搞清楚,有了整体框架,再回归到细节。
习惯: 我的笔记习惯:为了便于快速理解,函数调用关系通过缩进表示,也可能是函数展开,根据情况而定。
如下
// 调用 proxyConnInit
NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));
// 对函数 proxyConnInit 进行展开,可方便看参数
static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct
如有问题,请留言指正。
图后面再补;
有些遗漏之处,还没涉及,后面补;
闲话后面再补。
概括
每个GPU对应一个管理线程或者进程,在卡与卡之间建立通信的时候,会额外创建一个代理线程去完成这件事,代理线程是被动的,该做什么事还是由GPU对应的管理线程去通过TCP下发。
代理线程的主要工作有:
- 监听TCP端口
- 调用 ncclTransportComm 的 proxySharedInit, proxySetup,proxyConnect
- 关闭TCP链接
详解
1. 用到的变量
主要关注 comm->proxyState
的初始化,后面会作为理线程参数代使用,用到的时候再来看也行。
// 初始化
commAlloc()NCCLCHECK(ncclCalloc(&sharedRes, 1));
bootstrapInit()// proxy is aborted through a message; don't set abortFlag// 申请内存NCCLCHECK(ncclCalloc(&proxySocket, 1));// 建立 socket -> proxySocketNCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag));// Listen 状态NCCLCHECK(ncclSocketListen(proxySocket));// 获取地址保存在 state->peerProxyAddresses + rank , IP + PortNCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank));struct bootstrapState* state;comm->bootstrap = state;// 所有节点聚合, state->peerProxyAddresses 保存全部地址NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress)));// 申请内存初始化 comm->proxyStateNCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses));NCCLCHECK(ncclCalloc(&comm->sharedRes->proxyState, 1));comm->proxyState = comm->sharedRes->proxyState;comm->proxyState->refCount = 1;comm->proxyState->listenSock = proxySocket;comm->proxyState->peerAddresses = state->peerProxyAddresses;
2. proxy 线程创建
主要通过 ncclProxyCreate()
进行 proxyState 对象属性初始化,NCCL 初始化时会创建线程ncclProxyService
。
ncclProxyCreate(comm)
{// proxyState 来自 comm->proxyStatestruct ncclProxyState* proxyState = comm->proxyState;// 属性初始化,每个属性什么用,用到的时候介绍proxyState->tpRank = comm->rank;proxyState->tpnRanks = comm->nRanks;proxyState->tpLocalnRanks = comm->localRanks;proxyState->cudaDev = comm->cudaDev;proxyState->abortFlag = comm->abortFlag;proxyState->p2pnChannels = comm->p2pnChannels;proxyState->p2pChunkSize = comm->p2pChunkSize;proxyState->nChannels = comm->nChannels;proxyState->allocP2pNetLLBuffers = comm->allocP2pNetLLBuffers;proxyState->dmaBufSupport = comm->dmaBufSupport;proxyState->ncclNet = comm->ncclNet;proxyState->ncclCollNet = comm->ncclCollNet;memcpy(proxyState->buffSizes, comm->buffSizes, sizeof(comm->buffSizes));// 创建线程pthread_create(&comm->proxyState->thread, NULL, ncclProxyService, comm->proxyState);
}
2.1 ncclProxyService()
proxy 服务线程代码, 一个设备起一个 proxy
线程,线程名为 NCCL Service %rank
。
线程主要做三件事:
- 建立TCP连接
- 根据每个卡的线程客户端命令
type
做事 - 关闭TCP连接
type 定义如下:
enum ncclProxyMsgType {ncclProxyMsgInit = 1, // 建立 tcp 连接ncclProxyMsgSharedInit = 2, // 代理线程调用 ncclTransportComm 的 proxySharedInitncclProxyMsgSetup = 3, // 代理线程调用 ncclTransportComm 的 proxySetupncclProxyMsgConnect = 4, // 代理线程调用 ncclTransportComm 的 proxyConnectncclProxyMsgStart = 5, // 还没用ncclProxyMsgClose = 6, // 关闭 TCP 链接ncclProxyMsgAbort = 7, // 还没用ncclProxyMsgStop = 8, // 停用链接,如果所有链接都停用了,代理线程才退出ncclProxyMsgConvertFd = 9, // cuMem API support (UDS)
};
线程中主要的处理函数是 proxyServiceInitOp()
// 参数
args = comm->proxyState
void* ncclProxyService(void* _args) {struct ncclProxyState* proxyState = (struct ncclProxyState*) _args;// Prepare poll descriptorstruct ncclProxyConnectionPool connectionPool;connectionPool.pools = NULL;connectionPool.banks = 0;connectionPool.offset = NCCL_PROXY_CONN_POOL_SIZE;struct pollfd pollfds[NCCL_MAX_LOCAL_RANKS+1];struct ncclProxyLocalPeer peers[NCCL_MAX_LOCAL_RANKS];memset(&peers, 0, sizeof(struct ncclProxyLocalPeer)*NCCL_MAX_LOCAL_RANKS);for (int s=0; s<NCCL_MAX_LOCAL_RANKS; s++) {pollfds[s].fd = -1;pollfds[s].events = POLLHUP|POLLIN;}if (ncclSocketGetFd(proxyState->listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) {WARN("[Proxy Service] Get listenSock fd fails");return NULL;};// 监听输入pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN;int maxnpeers = 0;int npeers = 0;int stop = 0;int asyncOpCount = 0;while (stop == 0 || (stop == 1 && npeers > 0)) {/* Even if local comm aborts, we cannot let proxy thread exit if we still have peer* connections. Need to wait until all other related comms call abort and safely exit* together, or we could face segmentation fault. */// 本地退出,不能推出线程,需要等其他 comms 也停止才能一起退出if (*proxyState->abortFlag != 0) stop = 1;/* never let proxy service thread blocks in poll, or it cannot receive abortFlag. */int ret;do {ret = poll(pollfds, NCCL_MAX_LOCAL_RANKS+1, asyncOpCount ? 0 : 500);} while (ret < 0 && errno == EINTR);if (ret < 0) {WARN("[Proxy Service] Poll failed: %s", strerror(errno));return NULL;}if (pollfds[NCCL_MAX_LOCAL_RANKS].revents) {int s = 0;while (s < NCCL_MAX_LOCAL_RANKS && pollfds[s].fd >= 0) s++;if (s == NCCL_MAX_LOCAL_RANKS) {WARN("[Proxy service] Too many connections (%d max)", NCCL_MAX_LOCAL_RANKS);return NULL;}if (maxnpeers < s+1) maxnpeers = s+1;// 初始化 socketif (ncclSocketInit(&peers[s].sock) != ncclSuccess) {WARN("[Service thread] Initialize peers[%d].sock fails", s);return NULL;}// acceptif (ncclSocketAccept(&peers[s].sock, proxyState->listenSock) != ncclSuccess) {WARN("[Service thread] Accept failed %s", strerror(errno));} else {// 监听 fd 到 pollfdsif (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) {WARN("[Service thread] Get peers[%d].sock fd fails", s);return NULL;}npeers++;peers[s].tpLocalRank = -1;}}for (int s=0; s<maxnpeers; s++) {struct ncclProxyLocalPeer* peer = peers+s;struct ncclSocket* sock = &peer->sock;int closeConn = 0;int type = 0;ncclResult_t res = ncclSuccess;if (pollfds[s].fd == -1) continue;// Progress all ops for this ncclProxyLocalPeerncclProxyAsyncOp* op = peer->asyncOps;while (op != nullptr) {ncclProxyAsyncOp* opnext = op->next; /* in case op is freed in proxyProgressAsync */type = op->type;res = proxyProgressAsync(op, proxyState, &asyncOpCount, peer, &connectionPool);if (res == ncclSuccess || res == ncclInProgress) {op = opnext;} else {// Res is a bad resultcloseConn = 1;WARN("[Service thread] Error encountered progressing operation=%s, res=%d, closing connection", ncclProxyMsgTypeStr[type], res);break;}}// Check for additional ops coming in// 检查是否有输入if (pollfds[s].revents & POLLIN) {int closed;// 先接收 Typeres = ncclSocketTryRecv(sock, &type, sizeof(int), &closed, false /*blocking*/);if (res != ncclSuccess && res != ncclInProgress) {WARN("[Service thread] Could not receive type from localRank %d, res=%u, closed=%d", peer->tpLocalRank, res, closed);closeConn = 1;} else if (closed) {INFO(NCCL_INIT|NCCL_NET|NCCL_PROXY, "[Service thread] Connection closed by localRank %d", peer->tpLocalRank);closeConn = 1;} else if (res == ncclSuccess) { // We received something from the sock// 接收到数据,根据 type 做不同的动作if (type == ncclProxyMsgStop) {// 关闭连接stop = 1;closeConn = 1;} else if (type == ncclProxyMsgClose) {// 关闭连接closeConn = 1;} else if (proxyMatchOpType(type)) {// 处理客户端即设备的请求,根据 type 进行不同的处理res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);} else {// 关闭连接WARN("[Service thread] Unknown command %d from localRank %d", type, peer->tpLocalRank);closeConn = 1;}INFO(NCCL_PROXY, "Received and initiated operation=%s res=%d", ncclProxyMsgTypeStr[type], res);}} else if (pollfds[s].revents & POLLHUP) {// 关闭连接closeConn = 1;}if (res != ncclSuccess && res != ncclInProgress) {// 关闭连接WARN("[Proxy Service %d] Failed to execute operation %s from rank %d, retcode %d", proxyState->tpRank, ncclProxyMsgTypeStr[type], peer->tpRank, res);closeConn = 1;}if (closeConn) {// 关闭连接ncclSocketClose(sock);if (op != nullptr) {asyncProxyOpDequeue(peer, op);asyncOpCount--;}pollfds[s].fd = -1;npeers--;}}}// 退出操作// Wait for all operations to complete and stop progress thread before freeing any resourceif (ncclProxyProgressDestroy(proxyState) != ncclSuccess) {WARN("[Proxy Service] proxyDestroy failed");}for (int s=0; s<maxnpeers; s++) {ncclSocketClose(&peers[s].sock);}ncclProxyFreeConnections(&connectionPool, proxyState);ncclSocketClose(proxyState->listenSock);free(proxyState->listenSock);proxyOpsFree(proxyState);return NULL;
}
2.2 proxyServiceInitOp()
线程中的主要处理函数,因为客户端发送数据的时候是先发什么后发什么的顺序,所以接收也先按一定的顺序接收数据,然后调用 proxyProgressAsync
进行处理;
// 本地 rank 的 proxyState
// peers 是保存在服务端的数据,数据保存的是客户端的信息
// peer 抽象的是客户端对象
res = proxyServiceInitOp(type, peers+s, &connectionPool, proxyState, &asyncOpCount);
static ncclResult_t proxyServiceInitOp(int type, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, int* asyncOpCount) {// 服务端 sockstruct ncclSocket* sock = &peer->sock;// 申请内存struct ncclProxyAsyncOp* asyncOp;NCCLCHECK(ncclCalloc(&asyncOp, 1));asyncOp->type = type;// 按照客户端发送的顺序,接收各个字段// 接收 connection, 指向发送端 connection 对象的首地址NCCLCHECK(ncclSocketRecv(sock, &asyncOp->connection, sizeof(void*)));// 获取发送长度NCCLCHECK(ncclSocketRecv(sock, &asyncOp->reqSize, sizeof(int)));// 获取接收缓冲区大小NCCLCHECK(ncclSocketRecv(sock, &asyncOp->respSize, sizeof(int)));if (asyncOp->reqSize) {// 如果发送长度大于0,发送端会发送数据,接收端要接收数据// 先申请内存再接收数据NCCLCHECK(ncclCalloc(&asyncOp->reqBuff, asyncOp->reqSize));NCCLCHECK(ncclSocketRecv(sock, asyncOp->reqBuff, asyncOp->reqSize));}// Store opId for completion response// 接收发送端 opId 的首地址NCCLCHECK(ncclSocketRecv(sock, &asyncOp->opId, sizeof(asyncOp->opId)));// 如果发送端要接收数据,则接收数据大小大于0,服务端要申请发送缓冲区内存if (asyncOp->respSize) NCCLCHECK(ncclCalloc(&asyncOp->respBuff, asyncOp->respSize));// 请求 asyncOp 加入peer 对象链表中 peer->asyncOpsasyncProxyOpEnqueue(peer, asyncOp);(*asyncOpCount)++;// 处理请求NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));return ncclSuccess;
}
2.2 proxyProgressAsync()
处理请求函数,根据参数 type
进行不同的逻辑处理,然后按照一定的顺序返回数据
NCCLCHECK(proxyProgressAsync(asyncOp, proxyState, asyncOpCount, peer, connectionPool));
static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclProxyState* proxyState, int* asyncOpCount, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool) {int done = 1;if (op->type == ncclProxyMsgSetup) {// 调用 proxy proxySetup APITRACE(NCCL_PROXY, "proxyProgressAsync::proxySetup() opId=%p", op->opId);NCCLCHECK(op->connection->tcomm->proxySetup(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));} else if (op->type == ncclProxyMsgConnect) {// 调用 proxy proxyConnect APITRACE(NCCL_PROXY, "proxyProgressAsync::proxyConnect() opId=%p op.reqBuff=%p", op->opId, op->reqBuff);NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, proxyState, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done));} else if (op->type == ncclProxyMsgSharedInit) {int nChannels = (int) *op->reqBuff;// 调用 proxy proxySharedInit APITRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgSharedInit opId=%p op.reqBuff=%p nChannels=%d", op->opId, op->reqBuff, nChannels);if (op->connection->tcomm->proxySharedInit) NCCLCHECK(op->connection->tcomm->proxySharedInit(op->connection, proxyState, nChannels));__atomic_store_n(&op->connection->state, connSharedInitialized, __ATOMIC_RELEASE);} else if (op->type == ncclProxyMsgConvertFd) {int fd = *(int *)op->reqBuff;TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgConvertFd opId=%p op.reqBuff=%p fd=%d", op->opId, op->reqBuff, fd);NCCLCHECK(proxyConvertFd(peer, op->opId, proxyState, fd)); // cuMem API support} else if (op->type == ncclProxyMsgInit) {// TRACE(NCCL_PROXY, "proxyProgressAsync::ncclProxyMsgInit opId=%p op.reqBuff=%p", op->opId, op->reqBuff);NCCLCHECK(proxyConnInit(peer, connectionPool, proxyState, (ncclProxyInitReq*) op->reqBuff, (ncclProxyInitResp*) op->respBuff, &op->connection));static ncclResult_t proxyConnInit(struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclProxyState* proxyState, ncclProxyInitReq* req, ncclProxyInitResp* resp, struct ncclProxyConnection** connection) {int id;// 为 connectionPool-> pools 分配空间,// connectionPool->offset++// id = ((pool->banks-1) << NCCL_PROXY_CONN_POOL_SIZE_POW2) + pool->offset;// offset 为 (1 << 7) 个,为一个 bankNCCLCHECK(ncclProxyNewConnection(connectionPool, &id));// 根据 id 获取 bank 与 offset// 根据 bank与 offset 获取 ncclProxyConnection 首地址 connectionNCCLCHECK(ncclProxyGetConnection(connectionPool, id, connection));// 填充 connection(*connection)->sock = &peer->sock;(*connection)->transport = req->transport;(*connection)->send = req->send;(*connection)->tpLocalRank = req->tpLocalRank;(*connection)->sameProcess = req->sameProcess;peer->tpLocalRank = req->tpLocalRank;peer->tpRank = req->tpRank;// connection 首地址给 resp->connection, 要告诉客户端resp->connection = *connection;(*connection)->tcomm = (*connection)->send ? &ncclTransports[(*connection)->transport]->send : &ncclTransports[(*connection)->transport]->recv;// If we need proxy progress, let's allocate ops and start the threadif ((*connection)->tcomm->proxyProgress) {NCCLCHECK(proxyProgressInit(proxyState));struct ncclProxyProgressState* state = &proxyState->progressState;strncpy(resp->devShmPath, state->opsPoolShmSuffix, sizeof(resp->devShmPath));}INFO(NCCL_NET|NCCL_PROXY, "New proxy %s connection %d from local rank %d, transport %d", (*connection)->send ? "send":"recv", id, (*connection)->tpLocalRank, (*connection)->transport);__atomic_store_n(&(*connection)->state, connInitialized, __ATOMIC_RELEASE);return ncclSuccess;}} else return ncclInternalError;if (done) {INFO(NCCL_PROXY, "proxyProgressAsync opId=%p op.type=%d op.reqBuff=%p op.respSize=%d done", op->opId, op->type, op->reqBuff, op->respSize);if (op->type == ncclProxyMsgSetup)__atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE);else if (op->type == ncclProxyMsgConnect)__atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE);/* if setup or connect is done, we should not return any error at this point since* ncclSocketSend might already send the respBuff to the requester. If we still choose* to abort and close the connection, it can cause segfault if the requester is using* the respBuff. */// Send the opId for referencing async operation// 发送 opIdNCCLCHECK(ncclSocketSend(op->connection->sock, &op->opId, sizeof(op->opId)));// Send the response size// 发送接收大小NCCLCHECK(ncclSocketSend(op->connection->sock, &op->respSize, sizeof(op->respSize)));if (op->respSize) {// Send the response// 发送响应NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize));}// op 移出链表asyncProxyOpDequeue(peer, op);(*asyncOpCount)--;return ncclSuccess;} else if (*proxyState->abortFlag != 0) {return ncclInternalError;}return ncclInProgress;
}
4. ncclProxyConnect()
以其中链接为例 :如果要使用代理,那么首先要先连接,通过 type 为 ncclProxyMsgInit 告诉代理,我要链接,代理线程会 accept 建立 socket, 返回连接的 ncclProxyConnection connection 对象的首地址
链接流程如下,主要关注数据传输,有的传数据,有的传首地址:
// p2p send connector
// rank GPU 设备连接 proxy TCP 服务端,服务端建立保存连接,申请通信所需的内存资源
struct ncclConnector* send
NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, tpProxyRank, &send->proxyConn));
ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, int tpProxyRank, struct ncclProxyConnector* proxyConn) {struct ncclSocket* sock;int ready, proxyRank = -1;struct ncclProxyState* sharedProxyState = comm->proxyState;// Keep one connection per mlocal rankfor (int i = 0; i < comm->localRanks; ++i) {/* find the proxy rank in comm. */if (comm->topParentRanks[comm->localRankToRank[i]] == tpProxyRank) {proxyRank = comm->localRankToRank[i];break;}}proxyConn->sameProcess = comm->peerInfo[proxyRank].pidHash == comm->peerInfo[comm->rank].pidHash ? 1 : 0;// Keep one connection per local rankproxyConn->connection = NULL;proxyConn->tpRank = tpProxyRank;// peerSocks 初始化if (sharedProxyState->peerSocks == NULL) {NCCLCHECK(ncclCalloc(&sharedProxyState->peerSocks, comm->sharedRes->tpNLocalRanks));NCCLCHECK(ncclCalloc(&sharedProxyState->proxyOps, comm->sharedRes->tpNLocalRanks));NCCLCHECK(ncclCalloc(&sharedProxyState->sharedDevMems, comm->sharedRes->tpNLocalRanks));for (int i = 0; i < comm->sharedRes->tpNLocalRanks; ++i) {NCCLCHECK(ncclSocketSetFd(-1, &sharedProxyState->peerSocks[i]));}}proxyConn->tpLocalRank = comm->sharedRes->tpRankToLocalRank[proxyConn->tpRank];sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;NCCLCHECK(ncclSocketReady(sock, &ready));if (!ready) {// scoket 初始化 socketNCCLCHECK(ncclSocketInit(sock, sharedProxyState->peerAddresses+proxyConn->tpRank, comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag));// 连接 proxy 服务线程中监听的端口NCCLCHECK(ncclSocketConnect(sock));}struct ncclProxyInitReq req = {0};req.transport = transport;req.send = send;req.tpLocalRank = comm->topParentLocalRanks[comm->localRank];req.tpRank = comm->topParentRanks[comm->rank];req.sameProcess = proxyConn->sameProcess;struct ncclProxyInitResp resp = {0};// This usually sends proxyConn->connection to identify which connection this is.// However, this is part of the response and therefore is ignored// 收发消息初始化,proxy 服务端申请内存,建立连接NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));// resp.connection 为服务端的 connection 对象的首地址proxyConn->connection = resp.connection;// If we need proxy progress, map progress opsstruct ncclTransportComm* tcomm = send ? &ncclTransports[transport]->send : &ncclTransports[transport]->recv;if (tcomm->proxyProgress) {char poolPath[] = "/dev/shm/nccl-XXXXXX";strncpy(poolPath+sizeof("/dev/shm/nccl-")-1, resp.devShmPath, sizeof("XXXXXX")-1);struct ncclProxyOps* proxyOps = sharedProxyState->proxyOps + proxyConn->tpLocalRank;if (proxyOps->pool == NULL) {NCCLCHECK(ncclShmOpen(poolPath, sizeof(struct ncclProxyOpsPool), (void**)(&proxyOps->pool), NULL, 0, &proxyOps->handle));proxyOps->nextOps = proxyOps->nextOpsEnd = proxyOps->freeOp = -1;}}INFO(NCCL_NET|NCCL_PROXY, "Connection to proxy localRank %d -> connection %p", proxyConn->tpLocalRank, proxyConn->connection);return ncclSuccess;
}
4.1 ncclProxyCallBlocking()
调用代理线程接口,即开始发送命令,接收返回。
// 客户端通知 proxy 服务端调用响应接口,服务端根据 type 做不同的处理
// ncclProxyMsgInit 表示服务端进行通信初始化
NCCLCHECK(ncclProxyCallBlocking(comm, proxyConn, ncclProxyMsgInit, &req, sizeof(req), &resp, sizeof(resp)));
ncclResult_t ncclProxyCallBlocking(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) {// Alloc some memory to act as a handlencclResult_t res = ncclSuccess;void* opId = malloc(1);// ncclProxyCallAsync()// 首先发送 type// 再发送 proxyConn->connection 的首地址// 发送 reqSize// 发送 respSize// 如果 reqSize 大于0,说明有发送数据,即发送数据// 发送 opId 的首地址NCCLCHECKGOTO(ncclProxyCallAsync(comm, proxyConn, type, reqBuff, reqSize, respSize, opId), res, fail);struct ncclProxyState* sharedProxyState = comm->proxyState;sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;// 将当前 请求放入 state 的链表中 state->expectedResponses;NCCLCHECK(expectedProxyResponseEnqueue(sharedProxyState, opId, respSize));{struct ncclExpectedProxyResponse* ex;NCCLCHECK(ncclCalloc(&ex, 1));ex->opId = opId;// Pre-alloc response bufferex->respBuff = malloc(respSize);ex->respSize = respSize;ex->done = false;struct ncclExpectedProxyResponse* list = state->expectedResponses;if (list == NULL) {state->expectedResponses = ex;return ncclSuccess;}while (list->next) list = list->next;list->next = ex;}do {res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);{int found = 0;// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));}} while (res == ncclInProgress);exit:free(opId);return res;
fail:goto exit;
}
4.2 ncclPollProxyResponse()
发送的时候有 opId
作为此次通信的标识,代理线程返回数据时也会把这个opId
带回来
所以接受的时候要比较 opId
, 如果与本次发送的 opId
一样,那么就接收成功;
如果不一样,那么把接受的数据放入缓冲区,继续接收
// 轮询等待 opId 的返回数据
res = ncclPollProxyResponse(comm, proxyConn, respBuff, opId);
ncclResult_t ncclPollProxyResponse(struct ncclComm* comm, struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) {struct ncclProxyState* sharedProxyState = comm->proxyState;// Receive the connection pointer from the Proxy// 检查停止字段if (*comm->abortFlag) {WARN("Comm %p is in abort state", comm);return ncclInternalError;}if (sharedProxyState->peerSocks == NULL) return ncclInternalError;// Check response queueint found = 0;// 如果 opId 在链表中找到,且 done 字段已被置为 True, 则拷贝数据到 respBuff, found 置 1NCCLCHECK(expectedProxyResponseDequeue(sharedProxyState, opId, respBuff, &found));if (found == 0) {// 发送完之后,还没收到回复,虽然有 opId, 但是 done 字段仍为 False, 所以 found == 0// Attempt to read in a new response header from the proxy thread// 对于没有父节点的 comm来说,tpLocalRank 就是 comm->localrank// 获取发送端的 socketstruct ncclSocket* sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank;void* recvOpId;int offset = 0;// 接收数据,先接受 opIdif (ncclSuccess != ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)) {WARN("Socket recv failed while polling for opId=%p", opId);return ncclInternalError;}// 确保接收全部数据, offset == 0 返回 ncclInProgress 继续接收数据if (offset == 0) {return ncclInProgress;// If we've returned a partial response, block to receive the rest of it} else if (offset < sizeof(recvOpId)) {while (offset < sizeof(recvOpId))NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset));}INFO(NCCL_PROXY, "ncclPollProxyResponse Received new opId=%p", recvOpId);// Now do a blocking recv of the response sizeint respSize = 0;// 接收返回数据的大小NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(respSize)));// If there's a respSize to recvif (respSize > 0) {// 有返回数据if (recvOpId != opId) {// Unexpected response, need to buffer the socket data// 对于意想不到的 opId, 申请内存保存数据respBuff = malloc(respSize);}assert(respBuff != NULL);// 接收返回的数据NCCLCHECK(ncclSocketRecv(sock, respBuff, respSize));}if (recvOpId == opId) {// 如果已经接收了 opId 的数据,则在 state->expectedResponses 链表中移除 opId 相对应的项INFO(NCCL_PROXY, "recvOpId=%p matches expected opId=%p", recvOpId, opId);NCCLCHECK(expectedProxyResponseRemove(sharedProxyState, recvOpId));// 返回成功return ncclSuccess;} else {INFO(NCCL_PROXY, "Queuing opId=%p respBuff=%p respSize=%d", recvOpId, respBuff, respSize);// Store the result and mark response as completed// 如果接收的是其他 opId 的数据,则拷贝数据到缓冲区,并置 elem->done 为 TrueNCCLCHECK(expectedProxyResponseStore(sharedProxyState, recvOpId, respBuff, respSize));// 返回,继续处理接收数据return ncclInProgress;}} else {INFO(NCCL_PROXY, "ncclPollProxyResponse Dequeued cached opId=%p", opId);}return ncclSuccess;
}
这篇关于NCCL源码解析: proxy 线程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!