[源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

在具体介绍 TensorFlow 分布式的各种 Strategy 之前,我们首先需要看看分布式的基础:分布式环境。只有把基础打扎实了,才能在以后的分析工作之中最大程度的扫清障碍,事半功倍。本篇介绍 Worker(一系列相关概念) 的静态架构。

[源码解析] TensorFlow 分布式环境(3)— Worker 静态逻辑

在具体介绍 TensorFlow 分布式的各种 Strategy 之前,我们首先需要看看分布式的基础:分布式环境。只有把基础打扎实了,才能在以后的分析工作之中最大程度的扫清障碍,事半功倍。本篇介绍 Worker(一系列相关概念) 的静态架构。

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 “TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems”

[翻译] TensorFlow 分布式之论文篇 “Implementation of Control Flow in TensorFlow”

[源码解析] TensorFlow 分布式环境(1) — 总体架构

[源码解析] TensorFlow 分布式环境(2)—Master 静态逻辑

1. 继承关系

1.1 角色概念

TensorFlow Worker 类是执行计算的实体,其主要功能是:

  • 接收 Master的请求。
  • 管理 WorkerSession。
  • 处理注册的子图,比如按照自己节点上的设备情况来对子图进行二次分裂。
  • 在每个设备上运行注册的子图。
  • 支持 worker-to-worker 的张量传输等等。具体如何处理依据 worker 和 worker 的位置关系来决定,比如 CPU 和 GPU 之间使用 cudaMemcpyAsync,本地 GPU 之间通过 DMA,远端 worker 通过 gRPC 或者 RDMA。
  • 执行完毕之后,从计算图的终止节点 sink 中取出结果。

可以参见 protobuf/worker_service.proto 以了解关于每个方法的更多细节。

1.2 接口

对于 WorkerService 的访问是通过 WorkerInterface 完成的。WorkerInterface 是 worker 的接口类,其是与 TensorFlow Worker service 交互的接口,主要是:

  • 定义了一些异步虚函数,比如 CreateWorkerSessionAsync,派生类将实现它们,这些虚函数和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一对应,也和 Protobuf 的配置一一对应。
  • 定义了一些同步函数,比如 CreateWorkerSession,其会通过类似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 来调用到具体异步虚函数。

1.3 WorkerInterface 派生类

如下图所示,WorkerInterface 有三种实现。

  • Worker : 这个类可以被子类化,以便为不同的传输机制提供特定方法的专门实现。例如,GrpcWorker 专门实现了 RecvTensorAsync() 方法,以支持更有效的 gRPC 数据结构来处理大型二进制数据。
  • GrpcWorker : 从 Worker 再次派生,是本地模式下的 Worker 角色。如果 Master/Worker 都是在本地,则可以直接调用,不需要 RPC 的网络传输。
  • GrpcRemoteWorker :分布式模式下,Worker 位于远端,本地需要使用 GrpcRemoteWorker 来访问远端 Worker。
    • GrpcRemoteWorker 是 gRPC 客户端,其通过 stub 来访问远端 Worker 之上的 GrpcWorkerService 服务。
    • GrpcWorkerService 实现了 WorkerService 定义的所有接口,但是实际业务是转发给本地 GrpcWorker 完成。

具体示例如下:

图 1 Worker 逻辑关系

2. GrpcRemoteWorker

GrpcRemoteWorker 相当于是远端 Worker 的一个本地代理。

  • 本地 Master 将计算图进行分区,然后依据分区是不在本地还是远端,分别调用本地 Worker 或者 GrpcRemoteWorker 来执行分区的子计算图。
  • 本地 GrpcRemoteWorker 生成是在 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc 的GetOrCreateWorker 之中。
  • GrpcRemoteWorker 会通过 IssueRequest 向远端发送 grpc 请求。
  • 远程 GrpcWorkerService 守护进程收到请求后,调用本地 Worker 处理请求,完成后返回结果。

2.1 定义

具体 GrpcRemoteWorker 代码如下,我们省略了部分代码,比如 DeleteWorkerSessionAsync 方法的实现等。

class GrpcRemoteWorker : public WorkerInterface {
 public:
  explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                            ::grpc::CompletionQueue* completion_queue,
                            thread::ThreadPool* callback_threadpool,
                            WorkerCacheLogger* logger, const string& target)
      : channel_(std::move(channel)),
        stub_(channel_),
        cq_(completion_queue),
        callback_threadpool_(callback_threadpool),
        getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
        createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
        deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
        registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
        deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
        rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
        cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
        cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
        recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
        recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
        logging_(Method(GrpcWorkerMethod::kLogging)),
        tracing_(Method(GrpcWorkerMethod::kTracing)),
        completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
        instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
        getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
        markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
        logger_(logger),
        target_(target) {}

  ~GrpcRemoteWorker() override {}

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

  void RegisterGraphAsync(const RegisterGraphRequest* request,
                          RegisterGraphResponse* response,
                          StatusCallback done) override {
    IssueRequest(request, response, registergraph_, std::move(done));
  }

  void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                     RunGraphResponse* response, StatusCallback done) override {
    IssueRequest(request, response, rungraph_, std::move(done), call_opts);
  }
  void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
                     MutableRunGraphResponseWrapper* response,
                     StatusCallback done) override {
    IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
                 rungraph_, std::move(done), call_opts);
  }

 private:
  // Utility method for issuing a generic asynchronous request. The
  // given callback, done, will be called when the RPC completes.
  void IssueRequest(const protobuf::Message* request,
                    protobuf::Message* response, const ::grpc::string& method,
                    StatusCallback done, CallOptions* call_opts = nullptr,
                    bool fail_fast = true) {
    new RPCState<protobuf::Message>(
        &stub_, cq_, method, *request, response, std::move(done), call_opts,
        callback_threadpool_, MaxRetries(), fail_fast, &target_);
  }

  void IssueRequest(const protobuf::Message* request, TensorResponse* response,
                    const ::grpc::string& method, StatusCallback done,
                    CallOptions* call_opts = nullptr) {
    new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                 std::move(done), call_opts,
                                 callback_threadpool_, MaxRetries(),
                                 /*fail_fast=*/true, &target_);
  }

  // Helper function for initializing the RpcMethod objects below.
  const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }

  // Helper function for configuring max GRPC retries. Defaults to 0 (no
  // retries).
  const int64_t MaxRetries() {
    int64_t max_retries = -1;
    TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
    return max_retries;
  }

  SharedGrpcChannelPtr channel_;
  ::grpc::GenericStub stub_;
  ::grpc::CompletionQueue* cq_;
  thread::ThreadPool* callback_threadpool_;

  const ::grpc::string getstatus_;
  const ::grpc::string createworkersession_;
  const ::grpc::string deleteworkersession_;
  const ::grpc::string registergraph_;
  const ::grpc::string deregistergraph_;
  const ::grpc::string rungraph_;
  const ::grpc::string cleanupgraph_;
  const ::grpc::string cleanupall_;
  const ::grpc::string recvtensor_;
  const ::grpc::string recvbuf_;
  const ::grpc::string logging_;
  const ::grpc::string tracing_;
  const ::grpc::string completegroup_;
  const ::grpc::string instancesource_;
  const ::grpc::string getstepsequence_;
  const ::grpc::string markrecvfinished_;

  // Support for logging.
  WorkerCacheLogger* logger_;
  const string target_;

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};

2.2 生成

生成代码如下:

WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                     ::grpc::CompletionQueue* completion_queue,
                                     thread::ThreadPool* callback_threadpool,
                                     WorkerCacheLogger* logger,
                                     const string& target) {
  return new GrpcRemoteWorker(std::move(channel), completion_queue,
                              callback_threadpool, logger, target);
}

具体调用是在缓存之中,代码位于:tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc,其会依据参数决定生成何种 Worker。

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if (!channel) {
      return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target);
  }
}

2.3 发送请求

我们接下看看如何发送请求。CreateWorkerSessionAsync 实际发送的就是 createworkersession_ 这个字符串对应的请求。

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

IssueRequest 在上面定义之中有, 重新列出如下,可以看到调用的是 method 这个远端方法,对于我们这里就是 createworkersession_。

void IssueRequest(const protobuf::Message* request,
                  protobuf::Message* response, const ::grpc::string& method,
                  StatusCallback done, CallOptions* call_opts = nullptr,
                  bool fail_fast = true) {
  new RPCState<protobuf::Message>(
      &stub_, cq_, method, *request, response, std::move(done), call_opts,
      callback_threadpool_, MaxRetries(), fail_fast, &target_);
}

createworkersession_ 是在构建函数之中配置。

explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                          ::grpc::CompletionQueue* completion_queue,
                          thread::ThreadPool* callback_threadpool,
                          WorkerCacheLogger* logger, const string& target)
    : channel_(std::move(channel)),
      createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), // 配置

GrpcWorkerMethodName 定义在 tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc 之中,这里是具体的字符串,也就是远端 GrpcWorker 的方法名字,可以看到,CreateWorkerSessionAsync 实际上调用的是 “/tensorflow.WorkerService/CreateWorkerSession”

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
  return "invalid id";
}

3. Worker Service

WorkerService是一个 gRPC 服务,其定义了一个 TensorFlow 服务。WorkerService 代表MasterService在一组本地设备上执行数据流图。 一个 WorkerService 会跟踪多个 “注册的计算图”。每个注册图是客户计算图的一个子图,只对应那些应该在这个工作者上执行的节点(以及使用 RecvTensor 方法进行进程间通信之中所需的任何额外节点)。

Master 会依据 ClusterSpec 内容在集群之中寻找其他的 Server 实例,找到之后把这些 Server 实例作为 Worker 角色。Master 接着把子图分发给这些 Worker 节点,然后安排这些 Worker 完成具体子图的计算过程。Worker 之间如果存在数据依赖,则通过进程间通信进行交互。无论是 Master 调用 Worker,还是 Worker 之间互相访问,都要遵循 WorkerService 定义的接口规范。WorkerService 的所有接口定义在 worker_service.proto 文件中。

service WorkerService {
  // See worker.proto for details.
  rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);

  // See worker.proto for details.
  rpc CreateWorkerSession(CreateWorkerSessionRequest)
      returns (CreateWorkerSessionResponse);

  // See worker.proto for details.
  rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
      returns (DeleteWorkerSessionResponse);

  // See worker.proto for details.
  rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);

  // See worker.proto for details.
  rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);

  // See worker.proto for details.
  rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);

  // See worker.proto for details.
  rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);

  // See worker.proto for details.
  rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);

  // See worker.proto for details.
  rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
    // RecvTensor Method
  }

  // See worker.proto for details.
  rpc Logging(LoggingRequest) returns (LoggingResponse);

  // See worker.proto for details.
  rpc Tracing(TracingRequest) returns (TracingResponse);

  // See worker.proto for details.
  rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}

  // See worker.proto for details.
  rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);

  // See worker.proto for details.
  rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);

  // See worker.proto for details.
  rpc CompleteInstance(CompleteInstanceRequest)
      returns (CompleteInstanceResponse);
}

3.3.1 WorkerInterface

与 MasterService 类似,对于 WorkerService 的访问是通过 WorkerInterface 完成的。WorkerInterface 是 worker 的接口类,其是与 TensorFlow Worker service 交互的接口,主要是:

  • 定义了一些异步虚函数,比如 CreateWorkerSessionAsync,派生类将实现它们,这些虚函数和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一对应,也和 Protobuf 的配置一一对应。
  • 定义了一些同步函数,比如 CreateWorkerSession,其会通过类似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 的方法来调用到具体异步虚函数。

我们首先列出其异步接口如下。

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:
  virtual void GetStatusAsync(CallOptions* opts,
                              const GetStatusRequest* request,
                              GetStatusResponse* response, bool fail_fast,
                              StatusCallback done) = 0;

  virtual void CreateWorkerSessionAsync(
      const CreateWorkerSessionRequest* request,
      CreateWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void DeleteWorkerSessionAsync(
      CallOptions* opts, const DeleteWorkerSessionRequest* request,
      DeleteWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
                                  RegisterGraphResponse* response,
                                  StatusCallback done) = 0;

  virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                    DeregisterGraphResponse* response,
                                    StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                             MutableRunGraphResponseWrapper* response,
                             StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
                             RunGraphResponse* response, StatusCallback done) {
    RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
    MutableRunGraphResponseWrapper* wrapped_response =
        new NonOwnedProtoRunGraphResponse(response);
    RunGraphAsync(opts, wrapped_request, wrapped_response,
                  [wrapped_request, wrapped_response,
                   done = std::move(done)](const Status& s) {
                    done(s);
                    delete wrapped_request;
                    delete wrapped_response;
                  });
  }

  virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
                                 CleanupGraphResponse* response,
                                 StatusCallback done) = 0;

  virtual void CleanupAllAsync(const CleanupAllRequest* request,
                               CleanupAllResponse* response,
                               StatusCallback done) = 0;

  virtual void RecvTensorAsync(CallOptions* opts,
                               const RecvTensorRequest* request,
                               TensorResponse* response,
                               StatusCallback done) = 0;

  virtual void LoggingAsync(const LoggingRequest* request,
                            LoggingResponse* response, StatusCallback done) = 0;

  virtual void TracingAsync(const TracingRequest* request,
                            TracingResponse* response, StatusCallback done) = 0;

  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                            RecvBufResponse* response, StatusCallback done) = 0;

  virtual void CompleteGroupAsync(CallOptions* opts,
                                  const CompleteGroupRequest* request,
                                  CompleteGroupResponse* response,
                                  StatusCallback done) = 0;

  virtual void CompleteInstanceAsync(CallOptions* ops,
                                     const CompleteInstanceRequest* request,
                                     CompleteInstanceResponse* response,
                                     StatusCallback done) = 0;

  virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
                                    GetStepSequenceResponse* response,
                                    StatusCallback done) = 0;
}

WorkerInterface 也提供给了同步接口,这样 Master 或者 Worker 就可以像调用本地函数一样调用远端 WorkerService 的方法。同步接口是在异步接口之上实现的,通过使用 CallAndWait 适配器来完成对异步的封装。 另外,为了避免外部代码非法删除 WorkerInterface 实例,也做了一些限制,比如其析构函数是 protected,让 WorkerCacheInterface 成为友元,并且由 WorkerCacheInterface::ReleaseWorker 负责删除 WorkerInterface 实例。下面是同步接口和一些基础函数,成员变量。

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:

  virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
    return new MutableProtoRunGraphRequest;
  }

  virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
    return new OwnedProtoRunGraphResponse;
  }

  Status GetStatus(const GetStatusRequest* request,
                   GetStatusResponse* response) {
    Status ret;
    Notification n;
    GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
                   [&ret, &n](const Status& s) {
                     ret = s;
                     n.Notify();
                   });
    n.WaitForNotification();
    return ret;
  }

  Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                             CreateWorkerSessionResponse* response) {
    return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
  }

  Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
                             DeleteWorkerSessionResponse* response) {
    return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
                                  response);
  }

  Status RegisterGraph(const RegisterGraphRequest* request,
                       RegisterGraphResponse* response) {
    return CallAndWait(&ME::RegisterGraphAsync, request, response);
  }

  Status DeregisterGraph(const DeregisterGraphRequest* request,
                         DeregisterGraphResponse* response) {
    return CallAndWait(&ME::DeregisterGraphAsync, request, response);
  }

  Status CleanupGraph(const CleanupGraphRequest* request,
                      CleanupGraphResponse* response) {
    return CallAndWait(&ME::CleanupGraphAsync, request, response);
  }

  Status CleanupAll(const CleanupAllRequest* request,
                    CleanupAllResponse* response) {
    return CallAndWait(&ME::CleanupAllAsync, request, response);
  }

  Status Logging(const LoggingRequest* request, LoggingResponse* response) {
    return CallAndWait(&ME::LoggingAsync, request, response);
  }

  Status Tracing(const TracingRequest* request, TracingResponse* response) {
    return CallAndWait(&ME::TracingAsync, request, response);
  }

  Status GetStepSequence(const GetStepSequenceRequest* request,
                         GetStepSequenceResponse* response) {
    return CallAndWait(&ME::GetStepSequenceAsync, request, response);
  }

 protected:
  // Instances of WorkerInterface must be deleted by a call to
  // WorkerCacheInterface::ReleaseWorker().
  virtual ~WorkerInterface() {}
  friend class WorkerCacheInterface;

  // NOTE: This should only be called by implementations of this
  // interface whose CreateRunGraphResponse() method returns a
  // proto-based wrappers for the RunGraphResponse message.
  RunGraphResponse* get_proto_from_wrapper(
      MutableRunGraphResponseWrapper* wrapper) {
    return wrapper->get_proto();
  }

 private:
  typedef WorkerInterface ME;

  template <typename Method, typename Req, typename Resp>
  Status CallAndWait(Method func, const Req* req, Resp* resp) {
    Status ret;
    Notification n;
    (this->*func)(req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }

  template <typename Method, typename Req, typename Resp>
  Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
    CallOptions call_opts;
    Status ret;
    Notification n;
    (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify();
    });
    n.WaitForNotification();
    return ret;
  }
};

3.3.2 概念梳理

WorkerService 接口之中牵扯到众多概念,我们需要仔细梳理一下。

前面提到了,Client 和 Master 之间是通过 session_handle / MasterSession 对 来进行合作,Master 和 Worker 之间就是通过 MasterSession 和 WorkerSession 来完成合作的,MasterSession 会统一管理多个隶属的 WorkerSession。这里需要理清楚几个概念之间的关系:

  • session_handle :目的是为了让 MasterSession 统一管理其下面的多个 WorkerSession。与 MasterSession 一一对应,在创建 MasterSession 时候生成。通过 CreateSessionResponse 返回给 Client,通过 CreateWorkerSessionRequest 发送给 Worker,这样从 Client 到 Master,再到 Worker 这一条链路就是由 session_handle 唯一标示。
  • graph_handle :注册子图时候,由 GraphMgr::Register 生成,通过 RegisterGraphResponse 返回给 Master。子图就被该 graph_handle 所标识。在集群内部则是 (session_handle, graph_handle) 二元组来唯一标识某一个子图。
  • step_id :因为 Master 会让多个 Worker 并发执行计算,所以会广播通知大家执行 RunGraph,为了区别不同的 Step,Master 为每次 RunStep 生成全局唯一的标识 step_id,通过 RunGraphRequest 消息把 step_id 携带给 Worker。

我们梳理一下 graph_handle。GraphMgr::Register 之中会生成 graph_handle。

Status GraphMgr::Register(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  Item* item = new Item;
  Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                      config_proto, collective_graph_key, cluster_flr, item);
  // Inserts one item into table_.
  {
    mutex_lock l(mu_);
    *graph_handle =
        strings::Printf("%016llx", static_cast<long long>(++next_id_));
    item->handle = *graph_handle;
    CHECK(table_.insert({*graph_handle, item}).second);
  }
  return Status::OK();
}

RegisterGraphResponse 之中会返回 graph_handle 给 Master。

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

分割的子图里有 graph_handle。

// Graph partitioned into per-location subgraphs.
struct Part {
  // Worker name.
  string name;

  // Maps feed names to rendezvous keys. Empty most of the time.
  std::unordered_map<string, string> feed_key;

  // Maps rendezvous keys to fetch names. Empty most of the time.
  std::unordered_map<string, string> key_fetch;

  // The interface to the worker. Owned.
  WorkerInterface* worker = nullptr;

  // After registration with the worker, graph_handle identifies
  // this partition on the worker.
  string graph_handle;

  Part() : feed_key(3), key_fetch(3) {}
};

注册返回时候会给子图设定 graph_handle。

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    const PartitionOptions& popts,
    std::unordered_map<string, GraphDef> graph_partitions) {
  partitions_.reserve(graph_partitions.size());
  Status s;
  for (auto& name_def : graph_partitions) {
    partitions_.emplace_back();
    Part* part = &partitions_.back();
    part->name = name_def.first;
    TrackFeedsAndFetches(part, name_def.second, popts);
    part->worker = worker_cache_->GetOrCreateWorker(part->name);
    if (part->worker == nullptr) {
      s = errors::NotFound("worker ", part->name);
      break;
    }
  }
  if (!s.ok()) {
    for (Part& part : partitions_) {
      worker_cache_->ReleaseWorker(part.name, part.worker);
      part.worker = nullptr;
    }
    return s;
  }
  struct Call {
    RegisterGraphRequest req;
    RegisterGraphResponse resp;
    Status status;
  };
  const int num = partitions_.size();
  gtl::InlinedVector<Call, 4> calls(num);
  BlockingCounter done(num);
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    Call* c = &calls[i];
    c->req.set_session_handle(session_handle_);
    c->req.set_create_worker_session_called(!should_deregister_);
    c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
    StripDefaultAttributes(*OpRegistry::Global(),
                           c->req.mutable_graph_def()->mutable_node());
    *c->req.mutable_config_proto() = session_opts_.config;
    *c->req.mutable_graph_options() = session_opts_.config.graph_options();
    *c->req.mutable_debug_options() =
        callable_opts_.run_options().debug_options();
    c->req.set_collective_graph_key(collective_graph_key_);

    auto cb = [c, &done](const Status& s) {
      c->status = s;
      done.DecrementCount();
    };
    part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
  }
  done.Wait();
  for (int i = 0; i < num; ++i) {
    Call* c = &calls[i];
    s.Update(c->status);
    partitions_[i].graph_handle = c->resp.graph_handle();
  }
  return s;
}

使用时候会用 graph_handle 来唯一确定一个子图。

// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
  struct Call {
    DeregisterGraphRequest req;
    DeregisterGraphResponse resp;
  };
  for (Part& part : partitions_) {
    // The graph handle may be empty if we failed during partition registration.
    if (!part.graph_handle.empty()) {
      Call* c = new Call;
      c->req.set_session_handle(session_handle_);
      c->req.set_create_worker_session_called(!should_deregister_);
      c->req.set_graph_handle(part.graph_handle);
      // NOTE(mrry): We must capture worker_cache_ since this
      // could be deleted before the callback is called.
      WorkerCacheInterface* worker_cache = worker_cache_;
      const string name = part.name;
      WorkerInterface* w = part.worker;
      CHECK_NOTNULL(w);
      auto cb = [worker_cache, c, name, w](const Status& s) {
         delete c;
        worker_cache->ReleaseWorker(name, w);
      };
      w->DeregisterGraphAsync(&c->req, &c->resp, cb);
    }
  }
}

3.3.4 WorkerInterface 派生类

如下图所示,WorkerInterface 有两种实现。

  • GrpcWorker : 本地模式下的Worker 角色,如果 Master/Worker都是在本地,则可以直接调用,不需要 RPC 的网络传输。
  • GrpcRemoteWorker :分布式模式下,Worker 位于远端,本地需要使用 GrpcRemoteWorker 来访问远端 Worker。
    • GrpcRemoteWorker 是 gRPC 客户端,其通过 stub 来访问远端 Worker 之上的 GrpcWorkerService 服务。
    • GrpcWorkerService 实现了 WorkerService 定义的所有接口,但是实际业务是转发给本地 GrpcWorker 完成。

具体示例如下:

图 1 WorkerInterface 派生类

3.3.5 使用

Server 初始化时候,用如下代码建立Worker Service。

  // 创建 GrpcWorker 以及对应的 GrpcWorkerService
  worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
                                         opts.worker_service_options)


具体就是返回 GrpcWorkerService。

// Returns an implementation of WorkerService rpc service.
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
    GrpcWorker* worker, ::grpc::ServerBuilder* builder,
    GrpcWorkerServiceOptions options) {
  return std::unique_ptr<AsyncServiceInterface>(
      new GrpcWorkerService(worker, builder, options));
}


GrpcServer 之中,使用 worker_thread_ 线程来执行 GrpcWorkerService 的 HandleRPCsLoop 方法。

worker_thread_.reset(
    env_->StartThread(ThreadOptions(), "TF_worker_service",
                      [this] { worker_service_->HandleRPCsLoop(); }));


3.3.6 定义

GrpcWorkerService 定义如下,因为其需要作为守护进程处理传入的 gRPC 请求,所以在构造函数之中会建立若干线程,用来响应请求,然后在 HandleRPCsLoop 之中会启动这些线程,然后做 Join。

class GrpcWorkerService : public AsyncServiceInterface {
 public:
  GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
                    GrpcWorkerServiceOptions options)
      : is_shutdown_(false) {
    builder->RegisterService(&worker_service_);

    for (int i = 0; i < options.num_serving_threads; i++) {
      threads_.emplace_back(
          new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
                                      cache_.get(), &worker_service_));
    }
  }

  // This method blocks forever handling requests from the completion queue.
  void HandleRPCsLoop() override {
    for (auto& worker_thread : threads_) {
      worker_thread->Start();
    }
    for (auto& worker_thread : threads_) {
      worker_thread->Join();
    }
  }

 private:
  grpc::WorkerService::AsyncService worker_service_;
  std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;

  std::unique_ptr<GrpcResponseCache> cache_;
  mutex service_shutdown_mu_;
  bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};

3.3.7 线程

具体循环和响应请求其实是在线程之中完成的,cq_ 则是 grpc 的完成队列。

// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
// requests.  Each thread operates on an independent completion queue.
class GrpcWorkerServiceThread {
 public:
  explicit GrpcWorkerServiceThread(
      GrpcWorker* worker, ::grpc::ServerBuilder* builder,
      std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
      grpc::WorkerService::AsyncService* worker_service)
      : worker_(worker),
        queue_depth_(queue_depth),
        cache_(cache),
        worker_service_(worker_service),
        is_shutdown_(false) {
    cq_ = builder->AddCompletionQueue();
  }

  void Start() {
    thread_.reset(
        worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
                                         [this]() { HandleRPCsLoop(); }));
  }
}

主循环

GrpcWorkerServiceThread::HandleRPCsLoop 是线程主循环,和 master service 类似。这里先准备好一些 gRPC 调用的等待队列,这些调用请求与后面的 GrpcWorkerMethod 一一对应,每个方法对应的处理过程的代码会在后面提到。

// Add one or more completion queue entries for each worker method, then
// begin servicing requests from the completion queue.
void GrpcWorkerServiceThread::HandleRPCsLoop() {
  // TODO(ncteisen): This may require performance engineering. We can
  // change the number of threads, the number of handlers per thread,
  // or even decide to specialize certain threads to certain methods.
  SETUP_FOR_REQUEST(GetStatus, 1, false);
  SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
  SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
  SETUP_FOR_REQUEST(CleanupAll, 1, false);
  SETUP_FOR_REQUEST(RegisterGraph, 1, false);
  SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
  SETUP_FOR_REQUEST(Logging, 1, false);
  SETUP_FOR_REQUEST(Tracing, 1, false);
  SETUP_FOR_REQUEST(CompleteGroup, 10, true);
  SETUP_FOR_REQUEST(CompleteInstance, 10, true);
  SETUP_FOR_REQUEST(GetStepSequence, 10, true);
  SETUP_FOR_REQUEST(RecvBuf, 500, true);
  SETUP_FOR_REQUEST(RunGraph, 100, true);
  SETUP_FOR_REQUEST(CleanupGraph, 100, false);
  SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);

  // TODO(ncteisen): Determine a better policy for enqueuing the
  // appropriate number of each request type.
  for (int i = 0;
       i < gtl::FindWithDefault(
               queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
               1000);
       ++i) {
    EnqueueRecvTensorRequestRaw();
  }

  void* tag;
  bool ok;

  while (cq_->Next(&tag, &ok)) {
    UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
        static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
    CHECK(callback_tag);
    callback_tag->OnCompleted(this, ok);
  }
}
grpc request

对于 request 的处理与 master 类似。每个 request 会调用到一个业务 handler,如下面宏定义的 GrpcWorkerServiceThread::method##Handler。

#define ENQUEUE_REQUEST(method, supports_cancel)                             \
  do {                                                                       \
    mutex_lock l(shutdown_mu_);                                              \
    if (!is_shutdown_) {                                                     \
      Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
           method##Request, method##Response>::                              \
          EnqueueRequestForMethod(                                           \
              worker_service_, cq_.get(),                                    \
              static_cast<int>(GrpcWorkerMethod::k##method),                 \
              &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
    }                                                                        \
  } while (0)

#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
  for (int i = 0;                                                              \
       i < gtl::FindWithDefault(queue_depth_,                                  \
                                static_cast<int>(GrpcWorkerMethod::k##method), \
                                default_depth);                                \
       ++i) {                                                                  \
    ENQUEUE_REQUEST(method, supports_cancel);                                  \
  }

这里需要把每个 RPC 服务注册为异步服务,这使用 gRPC 自带的 AddMethod 接口和 MarkMethodAsync 接口来完成。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}
Handler & 线程池

具体 Handler 是通过宏来配置的,具体如下,这里调用了 Call,其会依据配置来决定是否使用线程池 compute_pool->Schedule 来进行计算。这里就用到了 worker env 里面集成的模块。

  // Handle all non-cancellable simple methods with a standard wrapper.
  // The boolean may_block_on_compute_pool indicates whether or not the
  // operation may block on activities (such as op execution) that run on the
  // compute pool.
#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

#undef HANDLE_CALL
消息&方法

GrpcWorkerMethod 定义了 worker 具体有哪些方法。

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

具体这些消息名字对应哪些方法,就是由 GrpcWorkerMethodName 完成。

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  return "invalid id";
}

在 AsyncService 之中会调用 GrpcWorkerMethodName 完成给 grpc 注册。

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i);
  }
}

业务处理

具体业务处理则是调用了 Worker 完成的。

void GetStepSequenceHandler(
    WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
  Schedule([this, call]() {
    worker_->GetStepSequenceAsync(
        &call->request, &call->response, [call](const Status& s) {
          call->SendResponse(ToGrpcStatus(s));
        });
  });
  ENQUEUE_REQUEST(GetStepSequence, true);
}

目前从线程角度看,逻辑如下,这里假定有三个线程。Server 的 worker_thread_ 启动了 GrpcWorkerService::HandleRPCsLoop(),其作用就是启动两个 GrpcWorkerServiceThread,每个 GrpcWorkerServiceThread 在 GrpcWorkerServiceThread::HandleRPCsLoop 之中会响应 gRPC 请求,进行业务处理。这里需要注意,GrpcWorkerService 和 GrpcWorkerServiceThread 都有 HandleRPCsLoop 这个方法。

图 2 线程角度

3.3.8 业务逻辑

CreateWorkerSession

CreateWorkerSessionRequest 消息之中会传递 MasterSession对应的 session_handle,Worker 接收消息之后,生成一个 WorkerSession。在一个集群之内,当 MasterSession 建立 WorkerSession 时候,都会把自己对应的 session_handle 传过去,这样,WorkerSession 就可以通过 session_handle 知道自己属于哪个 MasterSession。MasterSession 实例也可以统一管理隶属于它的所有 WorkerSession。

GrpcWorker 通过 SessionMgr 来具体完成对 WorkerSession 的管理,既可以通过 master task name 来确定 WorkerSession,也可以通过 session_handle 来确定。

class SessionMgr {

  WorkerEnv* const worker_env_;  // Not owned.
  std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
  std::shared_ptr<WorkerSession> legacy_session_;
  const WorkerCacheFactory worker_cache_factory_;

  // A map from session identifier to internal session structure.
  std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);

  // Incarnation and WorkerSession handle associated with a master task.
  struct MasterAssociatedSession {
    const int64_t master_incarnation;
    const string session_handle;
  };
  // A map from master task name to its associated worker sessions.
  std::unordered_multimap<string, MasterAssociatedSession>
      master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};

具体消息如下,注意,CreateWorkerSessionResponse 没有返回任何东西:

message CreateWorkerSessionRequest {
  // Sessions are identified by a given handle.
  string session_handle = 1;

  // Defines the configuration of a TensorFlow worker.
  ServerDef server_def = 2;

  // If true, any resources such as Variables used in the session will not be
  // shared with other sessions.
  bool isolate_session_state = 3;

  // The device attributes of all the devices in the cluster.
  repeated DeviceAttributes cluster_device_attributes = 4;

  // The master task name from which the request is sent.
  string master_task = 5;

  // The incarnation ID of the master task local CPU device.
  // If the target worker already has a WorkerSession created previously with
  // the same master task name but a different incarnation, it usually indicates
  // that the previous master failed before deleting the WorkerSession on the
  // worker. To prevent memory leaks, the worker should garbage collect the old
  // WorkerSessions.
  int64 master_incarnation = 6;
}

message CreateWorkerSessionResponse {}

图 3 CreateWorkerSession

如前所述,GrpcWorker 这些消息都是用宏来生成的。

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if (!s.ok()) {                                                          \
        VLOG(3) << "Bad response from " << #method << ": " << s;              \
      }                                                                       \
      call->SendResponse(ToGrpcStatus(s));                                    \
    };                                                                        \
    if ((may_block_on_compute_pool)) {                                        \
      worker_->env()->env->SchedClosure(std::move(closure));                  \
    } else {                                                                  \
      worker_->env()->compute_pool->Schedule(std::move(closure));             \
    }                                                                         \
    ENQUEUE_REQUEST(method, false);                                           \
  }

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

RegisterGraph

RegisterGraphRequest 消息会发送 MasterSession 对应的 session_handle,子图 graph_def。当 Worker 接收消息,完成子图注册/初始化后,会返回该子图的 graph_handle 给 Master。

对于每个会话,在 master 将每个节点放在一个设备上之后,它将整个图分割成许多子图。一个子图中的所有节点都在同一个 worker 中,但可能在该 worker 拥有的许多设备上(例如cpu0,加上gpu0、gpu1、…、gpu7)。在运行任何step之前,master 为 worker 注册了子图。成功的注册会返回一个图的句柄,以便在以后的 RunGraph请求中使用。

////////////////////////////////////////////////////////////////////////////////
//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
// it partitions the whole graph into many subgraphs. All the nodes in
// a subgraph were in the same worker, but potentially on many devices
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
// master registers subgraphs for a worker before running any steps. A
// successful registration returns a graph handle to be used in latter
// RunGraph requests.
//
////////////////////////////////////////////////////////////////////////////////

message RegisterGraphRequest {
  // Subgraphs are scoped within one session.
  string session_handle = 1;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 6;

  // "graph_def" has the subgraph of nodes for this worker, with each node
  // having its device_name filled in.
  GraphDef graph_def = 2;

  // True iff the graph (before partitioning) contains control flow nodes.
  //
  // As of 01/11/2015, this is no longer set by clients.
  bool has_control_flow = 3 [deprecated = true];

  // Configuration options for the session in which this graph was created.
  GraphOptions graph_options = 4;

  // Field(s) used by TensorFlow Debugger (tfdbg).
  DebugOptions debug_options = 5;

  // If graph_def contains any collective ops this must be a positive
  // integer used to coordinate execution with other graphs.  All
  // graphs in a distributed execution with the same
  // collective_graph_key will coordinate to use the same step_id
  // concurrently so that BufRendezvous entries will make the correct
  // values accessible.
  int64 collective_graph_key = 7;

  // ConfigProto from the session in which this graph was created.
  // Contains additional parameters beyond graph_options, including
  // the name of the requested executor.
  ConfigProto config_proto = 8;
}

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

图 4 RegisterGraph

DeregisterGraph

当不再需要计算图时(例如,整个计算图图被重新调度,图内节点被重新编排),Master 会利用该图对应的 graph_handle来取消注册。在 Master 重启情况下,Worker 根据以 TTL 为基础的策略自动取消对应 graph_handle 的注册。

////////////////////////////////////////////////////////////////////////////////
//
// DeregisterGraph method request/response messages
//
// The master deregisters the given graph_handle when the graph is no
// longer needed (e.g., the overall graph is re-scheduled and nodes
// are re-placed).
//
// The worker deregisters a graph_handle automatically according to on
// a TTL-base policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////

message DeregisterGraphRequest {
  // The session_handle used when registering the graph. If session_handle is
  // empty, a single global namespace is used.
  string session_handle = 2;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 3;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;
}

message DeregisterGraphResponse {
  // TODO(mrry): Optionally add summary stats for the graph.
}

图 5 DeregisterGraph

RunGraph

Master 用 RunGraphRequest 来执行在 graph_handle下注册的所有子图。

Master 会生成一个全局唯一的 step_id 来区分图计算的不同运行 step。子图之间可以使用 step_id 进行彼此通信(例如,发送/转发操作),以区分不同运行产生的张量。

RunGraphRequest 消息的 send 表示子图输入的张量,recv_key 指明子图输出的张量。RunGraphResponse 会返回 recv_key 对应的 Tensor 列表。

图 6 RunGraph

////////////////////////////////////////////////////////////////////////////////
//
// RunGraph request / response messages
//
// The worker executes all subgraphs registered under graph_handle.
// RunGraph returns after the execution finishes or an error is
// encountered.
// A sequence of RunGraphRequests with is_partial may be sent to RunGraph for
// partial graph execution.
//
////////////////////////////////////////////////////////////////////////////////

// Options specific to the execution of a single step.
message ExecutorOpts {
  bool record_costs = 1;
  bool record_timeline = 3;
  bool record_partition_graphs = 4;
  bool report_tensor_allocations_upon_oom = 5;
}

message RunGraphRequest {
  // session_handle is the master-generated unique id for this session.
  // If session_handle is non-empty, it must be the same as used when
  // registering the graph. If it is empty, a single global namespace is used to
  // search for the graph_handle.
  string session_handle = 8;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 10;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;

  // A unique ID to distinguish different runs of the same graph.
  //
  // The master generates a global unique step_id to distinguish
  // different runs of the graph computation. Subgraphs communicate
  // (e.g., send/recv ops) with each other using step_id to
  // distinguish tensors generated by different runs.
  int64 step_id = 2;

  // Options for this step.
  ExecutorOpts exec_opts = 5;

  // Runs the graph.
  //
  // Sends the tensors in "send" into the graph before the run and
  // fetches the keys into RunGraphResponse.recv after the run.
  repeated NamedTensorProto send = 3;
  repeated string recv_key = 4;

  // True if the RunGraphRequest is a partial run request.
  bool is_partial = 6;
  // True if this is the last partial run request in a sequence of requests.
  bool is_last_partial_run = 7;

  // If true then some errors, e.g., execution errors that have long
  // error messages, may return an OK RunGraphResponse with the actual
  // error saved in the status_code/status_error_message fields of the
  // response body. This is a workaround since the RPC subsystem may
  // truncate long metadata messages.
  bool store_errors_in_response_body = 9;

  // Unique identifier for this request. Every RunGraphRequest must have a
  // unique request_id, and retried RunGraphRequests must have the same
  // request_id. If request_id is zero, retry detection is disabled.
  //
  // Retried RunGraphRequests are problematic because they may issue a
  // RecvTensor that will have no corresponding sender and will wait forever.
  // Workers use request_ids to reject retried RunGraph requests instead of
  // waiting forever.
  int64 request_id = 11;

  // Next: 12
}

message RunGraphResponse {
  // A list of tensors corresponding to those requested by
  // RunGraphRequest.recv_key.
  repeated NamedTensorProto recv = 1;

  // If the request asked for execution stats, the cost graph, or the partition
  // graphs, these are returned here.
  // TODO(suharshs): Package these in a RunMetadata instead.
  StepStats step_stats = 2;
  CostGraphDef cost_graph = 3;
  repeated GraphDef partition_graph = 4;

  // If store_errors_in_response_body is true in the request, then
  // optionally the server may return an OK status for the RPC and
  // fill the true status into the fields below, to allow for messages
  // that are too long to fit in metadata.
  error.Code status_code = 5;
  string status_error_message = 6;
}

RecvTensor

在具体运行之中,两个 Worker 之间可能会交换数据,此时生产者只是把准备好的张量放入 rendezvous,消费者会主动发起 RecvTensorRequest 请求,RecvTensorRequest 里面 step_id 标识是哪次 step,rendezvous_key 标识要接收张量的通道(channel)。

一个 RecvTensor 请求从通道中获取一个张量,也可以通过多个 RecvTensor 请求在同一个通道中发送和接收多个张量。最终生产者的张量会通过 RecvTensorResponse 返回给消费者。

图 7 RecvTensor

////////////////////////////////////////////////////////////////////////////////
//
// RecvTensor method request/response messages
//
////////////////////////////////////////////////////////////////////////////////

message RecvTensorRequest {
  // The step in which the tensor will be produced.
  //
  // REQUIRED: This must eventually correspond to the step_id passed
  // into a RunGraph call on the same WorkerService.
  int64 step_id = 1;

  // A key identifying the channel to receive tensors from. A RecvTensor request
  // retrieves one tensor from the channel, but multiple tensors can be sent and
  // received over the same channel with multiple RecvTensor requests. See
  // rendezvous.h for details.
  string rendezvous_key = 2;

  // If true, use an out-of-band DMA mechanism to transfer the
  // received tensor.
  bool dma_ok = 3;

  // Optional information on client-side device locality.
  DeviceLocality client_locality = 4;

  // Optional information on server-side device locality.
  DeviceLocality server_locality = 5;

  // Optional information needed by the RPC subsystem.
  google.protobuf.Any transport_options = 6;

  // Unique identifier for this request. Every RecvTensorRequest must have a
  // unique request_id, and retried RecvTensorRequests must have the same
  // request_id. If request_id is zero, retry detection and response cache
  // are disabled.
  //
  // Retried RecvTensorRequests are problematic because a RecvTensor with no
  // corresponding sender will wait forever, and the tensor may have been
  // delivered to a previous retry. Workers use request_ids to reject retried
  // RecvTensor requests instead of waiting forever.
  int64 request_id = 7;
}

message RecvTensorResponse {
  // The tensor as a proto.
  TensorProto tensor = 1;

  // If true, this tensor was the output of a dead node, and the
  // content is invalid.
  bool is_dead = 2;

  // The time at which tensor was available and started to be returned.
  int64 send_start_micros = 3;

  // Optional additional information about how to receive the tensor,
  // e.g. in the event that RecvTensorRequest.dma_ok was true.
  google.protobuf.Any transport_options = 4;

  // Whether the receiver should send a MarkRecvFinishedRequest to the sender
  // to ack the message.
  bool require_ack = 5;
}

4. Worker

Worker 类主要是提供了 WorkerEnv 和 PartialRunMgr,其可以被子类化,以便为不同的传输机制提供特定方法的专门实现。例如,GrpcWorker 专门实现了 RecvTensorAsync 方法,以支持更有效的 gRPC 数据结构来处理大型二进制数据。

class Worker : public WorkerInterface {
 protected:
  WorkerEnv* const env_;  // Not owned.
  RecentRequestIds recent_request_ids_;

 private:
  PartialRunMgr partial_run_mgr_;

  CancellationManager cancellation_manager_;

  TF_DISALLOW_COPY_AND_ASSIGN(Worker);
};

我们举出一个方法看看,具体其他方法我们后面遇到了会说。

void Worker::CleanupAllAsync(const CleanupAllRequest* request,
                             CleanupAllResponse* response,
                             StatusCallback done) {
  std::vector<string> containers;
  for (const auto& c : request->container()) containers.push_back(c);
  env_->device_mgr->ClearContainers(containers);
  done(Status::OK());
}

5. GrpcWorker

GrpcWorker 是 GrpcRemoteWorker 对应的远端 Worker。也是 GrpcWorkerService 调用的对象,其实现了业务逻辑。其定义如下,我们可以看到其实现了几个方法。

class GrpcWorker : public Worker {
 public:
  GrpcWorker(WorkerEnv* env, const ConfigProto& config);

  // Specialized version of RecvTensor for gRPC, which avoids a copy.
  virtual void GrpcRecvTensorAsync(CallOptions* opts,
                                   const RecvTensorRequest* request,
                                   ::grpc::ByteBuffer* response,
                                   StatusCallback done);

  void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
                    StatusCallback done) override;

  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                    RecvBufResponse* response, StatusCallback done) override;

  void CleanupGraphAsync(const CleanupGraphRequest* request,
                         CleanupGraphResponse* response,
                         StatusCallback done) override;

  WorkerEnv* env();

  void EnableResponseCache();

  void RemoveCacheEntryForId(int64 request_id);

 private:
  std::unique_ptr<GrpcResponseCache> response_cache_;
  const int32 recv_buf_max_chunk_;
};

至此,Worker 的静态结构我们已经介绍完毕,具体 Worker 功能我们将在后文 Session 部分进行具体介绍。

0xFF 参考

TensorFlow Internals

TensorFlow架构与设计:概述

TensorFlow内核剖析

TensorFlow架构与设计:OP本质论

[译] TensorFlow 白皮书

2017TensorFlow开发者峰会

https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

『深度长文』Tensorflow代码解析(五)

什么是in-graph replication和between-graph replication?

[腾讯机智] TensorFlow源码解析(1): 创建会话

05tensorflow分布式会话

第八节,配置分布式TensorFlow

TensorFlow 分布式(Distributed TensorFlow)

tensorflow源码解析之distributed_runtime

Distributed TensorFlow: A Gentle Introduction

一文说清楚Tensorflow分布式训练必备知识

TensorFlow中的Placement启发式算法模块——Placer

TensorFlow的图切割模块——Graph Partitioner

TensorFlow中的通信机制——Rendezvous(一)本地传输

TensorFlow分布式采坑记

TensorFlow技术内幕(九):模型优化之分布式执行

Tensorflow架构流程]

页面下部广告