4. spark源码分析(基于yarn cluster模式)-YARN contaienr启动-CoarseGrainedExecutorBackend

本系列基于spark-2.4.6

通过上一节的分析,我们确定,container中启动的为org.apache.spark.executor.CoarseGrainedExecutorBackend,接下来,我们看下其实现。

def main(args: Array[String]) {
    
    
    var driverUrl: String = null
    var executorId: String = null
    var hostname: String = null
    var cores: Int = 0
    var appId: String = null
    var workerUrl: Option[String] = None
    val userClassPath = new mutable.ListBuffer[URL]()

    var argv = args.toList
    while (!argv.isEmpty) {
    
    
      argv match {
    
    
        case ("--driver-url") :: value :: tail =>
          driverUrl = value
          argv = tail
        case ("--executor-id") :: value :: tail =>
          executorId = value
          argv = tail
        case ("--hostname") :: value :: tail =>
          hostname = value
          argv = tail
        case ("--cores") :: value :: tail =>
          cores = value.toInt
          argv = tail
        case ("--app-id") :: value :: tail =>
          appId = value
          argv = tail
        case ("--worker-url") :: value :: tail =>
          // Worker url is used in spark standalone mode to enforce fate-sharing with worker
          workerUrl = Some(value)
          argv = tail
        case ("--user-class-path") :: value :: tail =>
          userClassPath += new URL(value)
          argv = tail
        case Nil =>
        case tail =>
          // scalastyle:off println
          System.err.println(s"Unrecognized options: ${tail.mkString(" ")}")
          // scalastyle:on println
          printUsageAndExit()
      }
    }

    if (hostname == null) {
    
    
      hostname = Utils.localHostName()
      log.info(s"Executor hostname is not provided, will use '$hostname' to advertise itself")
    }

    if (driverUrl == null || executorId == null || cores <= 0 || appId == null) {
    
    
      printUsageAndExit()
    }

    run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
    System.exit(0)
  }

main方法中准备好相关参数后,直接调用run方法:


  private def run(
      driverUrl: String,
      executorId: String,
      hostname: String,
      cores: Int,
      appId: String,
      workerUrl: Option[String],
      userClassPath: Seq[URL]) {
    
    

    Utils.initDaemon(log)

    SparkHadoopUtil.get.runAsSparkUser {
    
     () =>
      // Debug code
      Utils.checkHost(hostname)

      // Bootstrap to fetch the driver's Spark properties.
      val executorConf = new SparkConf
      val fetcher = RpcEnv.create(
        "driverPropsFetcher",
        hostname,
        -1,
        executorConf,
        new SecurityManager(executorConf),
        clientMode = true)
      val driver = fetcher.setupEndpointRefByURI(driverUrl)
      val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig)
      val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
      fetcher.shutdown()

      // Create SparkEnv using properties we fetched from the driver.
      val driverConf = new SparkConf()
      for ((key, value) <- props) {
    
    
        // this is required for SSL in standalone mode
        if (SparkConf.isExecutorStartupConf(key)) {
    
    
          driverConf.setIfMissing(key, value)
        } else {
    
    
          driverConf.set(key, value)
        }
      }

      cfg.hadoopDelegationCreds.foreach {
    
     tokens =>
        SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf)
      }

      val env = SparkEnv.createExecutorEnv(
        driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false)

      env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
        env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
      workerUrl.foreach {
    
     url =>
        env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
      }
      env.rpcEnv.awaitTermination()
    }
  }

这块首先会和Driver建立通信,通过RPCEndpoint向Driver发送RetrieveSparkAppConfig获取Spark的配置,然偶创建SparkEnv,设置MemoryManager,ShuffleManager等。

private[spark] def createExecutorEnv(
      conf: SparkConf,
      executorId: String,
      hostname: String,
      numCores: Int,
      ioEncryptionKey: Option[Array[Byte]],
      isLocal: Boolean): SparkEnv = {
    
    
    val env = create(
      conf,
      executorId,
      hostname,
      hostname,
      None,
      isLocal,
      numCores,
      ioEncryptionKey
    )
    SparkEnv.set(env)
    env
  }

在run方法中设置了RPCEndpoint端点和监控:

 env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
        env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
      workerUrl.foreach {
    
     url =>
        env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
      }

调用NettyRpcEnv实现:

 override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    
    
    dispatcher.registerRpcEndpoint(name, endpoint)
  }

而这里的dispatcher则是分发worker节点收到的消息,通过Dispather实现:

private val threadpool: ThreadPoolExecutor = {
    
    
    val availableCores =
      if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
      math.max(2, availableCores))
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
    for (i <- 0 until numThreads) {
    
    
      pool.execute(new MessageLoop)
    }
    pool
  }
  private class MessageLoop extends Runnable {
    
    
    override def run(): Unit = {
    
    
      try {
    
    
        while (true) {
    
    
          try {
    
    
            val data = receivers.take()
            if (data == PoisonPill) {
    
    
              receivers.offer(PoisonPill)
              return
            }
            data.inbox.process(Dispatcher.this)
          } catch {
    
    
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
    
    
        case _: InterruptedException => // exit
        case t: Throwable =>
          try {
    
    
            threadpool.execute(new MessageLoop)
          } finally {
    
    
            throw t
          }
      }
    }
  }

最后消息处理在Inbox中进行:

def process(dispatcher: Dispatcher): Unit = {
    
    
    var message: InboxMessage = null
    inbox.synchronized {
    
    
      if (!enableConcurrent && numActiveThreads != 0) {
    
    
        return
      }
      message = messages.poll()
      if (message != null) {
    
    
        numActiveThreads += 1
      } else {
    
    
        return
      }
    }
    while (true) {
    
    
      safelyCall(endpoint) {
    
    
        message match {
    
    
          case RpcMessage(_sender, content, context) =>
            try {
    
    
              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, {
    
     msg =>
                throw new SparkException(s"Unsupported message $message from ${_sender}")
              })
            } catch {
    
    
              case e: Throwable =>
            }

          case OneWayMessage(_sender, content) =>
            endpoint.receive.applyOrElse[Any, Unit](content, {
    
     msg =>
              throw new SparkException(s"Unsupported message $message from ${_sender}")
            })

          case OnStart =>
            endpoint.onStart()
            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
    
    
              inbox.synchronized {
    
    
                if (!stopped) {
    
    
                  enableConcurrent = true
                }
              }
            }

          case OnStop =>
            val activeThreads = inbox.synchronized {
    
     inbox.numActiveThreads }
            assert(activeThreads == 1,
              s"There should be only a single active thread but found $activeThreads threads.")
            dispatcher.removeRpcEndpointRef(endpoint)
            endpoint.onStop()
            assert(isEmpty, "OnStop should be the last message")

          case RemoteProcessConnected(remoteAddress) =>
            endpoint.onConnected(remoteAddress)

          case RemoteProcessDisconnected(remoteAddress) =>
            endpoint.onDisconnected(remoteAddress)

          case RemoteProcessConnectionError(cause, remoteAddress) =>
            endpoint.onNetworkError(cause, remoteAddress)
        }
      }

      inbox.synchronized {
    
    
        if (!enableConcurrent && numActiveThreads != 1) {
    
    
          numActiveThreads -= 1
          return
        }
        message = messages.poll()
        if (message == null) {
    
    
          numActiveThreads -= 1
          return
        }
      }
    }
  }

这里就会处理各种消息。
另外,在创建SparkEnv的时候,创建了RpcEnv,这个版本中创建的为NettyRpcEnv,创建方法如下:

private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
    
    

  def create(config: RpcEnvConfig): RpcEnv = {
    
    
    val sparkConf = config.conf
    // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
    // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
    val javaSerializerInstance =
      new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
    val nettyEnv =
      new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
        config.securityManager, config.numUsableCores)
    if (!config.clientMode) {
    
    
      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = {
    
     actualPort =>
        nettyEnv.startServer(config.bindAddress, actualPort)
        (nettyEnv, nettyEnv.address.port)
      }
      try {
    
    
        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
      } catch {
    
    
        case NonFatal(e) =>
          nettyEnv.shutdown()
          throw e
      }
    }
    nettyEnv
  }
}

这里在 Executor中clientMode=true不会startServer.在NettyRpcEnv中,有一个变量:

  private val transportContext = new TransportContext(transportConf,
    new NettyRpcHandler(dispatcher, this, streamManager))

我们在启动的时候,会设置Endpoint:

env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
        env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
      workerUrl.foreach {
    
     url =>
        env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
      }
ef setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    
    
    dispatcher.registerRpcEndpoint(name, endpoint)
  }
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
    
    
    val addr = RpcEndpointAddress(nettyEnv.address, name)
    val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
    synchronized {
    
    
      if (stopped) {
    
    
        throw new IllegalStateException("RpcEnv has been stopped")
      }
      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
    
    
        throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
      }
      val data = endpoints.get(name)
      endpointRefs.put(data.endpoint, data.ref)
      receivers.offer(data)  // for the OnStart message
    }
    endpointRef
  }

可以看到,在注册的时候,同时发送了一个EndpointData消息。

private class EndpointData(
      val name: String,
      val endpoint: RpcEndpoint,
      val ref: NettyRpcEndpointRef) {
    
    
    val inbox = new Inbox(ref, endpoint)
  }
   inbox.synchronized {
    
    
    messages.add(OnStart)
  } 

而这里会发送一个OnStart消息,这个消息实际上就是个自己处理了,最后在CoarseGrainedExecutorBackend处理该消息,实际上就是向Driver发送了一个RegisterExecutor的消息,注册Executor。而通过发送这个消息,会将nettEnv的客户端启动,连接到Driver的NettyServer上。

def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
    
    
    nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
  }
  //-----------------------------------
private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
    
    
    val promise = Promise[Any]()
    val remoteAddr = message.receiver.address

    def onFailure(e: Throwable): Unit = {
    
    
      if (!promise.tryFailure(e)) {
    
    
        e match {
    
    
          case e : RpcEnvStoppedException => logDebug (s"Ignored failure: $e")
          case _ => logWarning(s"Ignored failure: $e")
        }
      }
    }

    def onSuccess(reply: Any): Unit = reply match {
    
    
      case RpcFailure(e) => onFailure(e)
      case rpcReply =>
        if (!promise.trySuccess(rpcReply)) {
    
    
          logWarning(s"Ignored message: $reply")
        }
    }

    try {
    
    
      if (remoteAddr == address) {
    
    
        val p = Promise[Any]()
        p.future.onComplete {
    
    
          case Success(response) => onSuccess(response)
          case Failure(e) => onFailure(e)
        }(ThreadUtils.sameThread)
        dispatcher.postLocalMessage(message, p)
      } else {
    
    
        val rpcMessage = RpcOutboxMessage(message.serialize(this),
          onFailure,
          (client, response) => onSuccess(deserialize[Any](client, response)))
        postToOutbox(message.receiver, rpcMessage)
        promise.future.failed.foreach {
    
    
          case _: TimeoutException => rpcMessage.onTimeout()
          case _ =>
        }(ThreadUtils.sameThread)
      }

      val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
    
    
        override def run(): Unit = {
    
    
          onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
            s"in ${timeout.duration}"))
        }
      }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
      promise.future.onComplete {
    
     v =>
        timeoutCancelable.cancel(true)
      }(ThreadUtils.sameThread)
    } catch {
    
    
      case NonFatal(e) =>
        onFailure(e)
    }
    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
  }
//-----------------------
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
    
    
    if (receiver.client != null) {
    
    
      message.sendWith(receiver.client)
    } else {
    
    
      require(receiver.address != null,
        "Cannot send message to client endpoint with no listen address.")
      val targetOutbox = {
    
    
        val outbox = outboxes.get(receiver.address)
        if (outbox == null) {
    
    
          val newOutbox = new Outbox(this, receiver.address)
          val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
          if (oldOutbox == null) {
    
    
            newOutbox
          } else {
    
    
            oldOutbox
          }
        } else {
    
    
          outbox
        }
      }
      if (stopped.get) {
    
    
        // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
        outboxes.remove(receiver.address)
        targetOutbox.stop()
      } else {
    
    
        targetOutbox.send(message)
      }
    }
  }
  //------------------------------
  def send(message: OutboxMessage): Unit = {
    
    
    val dropped = synchronized {
    
    
      if (stopped) {
    
    
        true
      } else {
    
    
        messages.add(message)
        false
      }
    }
    if (dropped) {
    
    
      message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
    } else {
    
    
      drainOutbox()
    }
  }
  //----------------------------
  private def drainOutbox(): Unit = {
    
    
    var message: OutboxMessage = null
    synchronized {
    
    
      if (stopped) {
    
    
        return
      }
      if (connectFuture != null) {
    
    
        // We are connecting to the remote address, so just exit
        return
      }
      if (client == null) {
    
    
        // There is no connect task but client is null, so we need to launch the connect task.
        launchConnectTask()
        return
      }
      if (draining) {
    
    
        // There is some thread draining, so just exit
        return
      }
      message = messages.poll()
      if (message == null) {
    
    
        return
      }
      draining = true
    }
    while (true) {
    
    
      try {
    
    
        val _client = synchronized {
    
     client }
        if (_client != null) {
    
    
          message.sendWith(_client)
        } else {
    
    
          assert(stopped == true)
        }
      } catch {
    
    
        case NonFatal(e) =>
          handleNetworkFailure(e)
          return
      }
      synchronized {
    
    
        if (stopped) {
    
    
          return
        }
        message = messages.poll()
        if (message == null) {
    
    
          draining = false
          return
        }
      }
    }
  }
//-----------------
private def launchConnectTask(): Unit = {
    
    
    connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {
    
    

      override def call(): Unit = {
    
    
        try {
    
    
          val _client = nettyEnv.createClient(address)
          outbox.synchronized {
    
    
            client = _client
            if (stopped) {
    
    
              closeClient()
            }
          }
        } catch {
    
    
          case ie: InterruptedException =>
            // exit
            return
          case NonFatal(e) =>
            outbox.synchronized {
    
     connectFuture = null }
            handleNetworkFailure(e)
            return
        }
        outbox.synchronized {
    
     connectFuture = null }
        // It's possible that no thread is draining now. If we don't drain here, we cannot send the
        // messages until the next message arrives.
        drainOutbox()
      }
    })
  }
//-------------
 private[netty] def createClient(address: RpcAddress): TransportClient = {
    
    
    clientFactory.createClient(address.host, address.port)
  }
//-------------
public TransportClient createClient(String remoteHost, int remotePort)
      throws IOException, InterruptedException {
    
    
    // Get connection from the connection pool first.
    // If it is not found or not active, create a new one.
    // Use unresolved address here to avoid DNS resolution each time we creates a client.
    final InetSocketAddress unresolvedAddress =
      InetSocketAddress.createUnresolved(remoteHost, remotePort);

    // Create the ClientPool if we don't have it yet.
    ClientPool clientPool = connectionPool.get(unresolvedAddress);
    if (clientPool == null) {
    
    
      connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
      clientPool = connectionPool.get(unresolvedAddress);
    }

    int clientIndex = rand.nextInt(numConnectionsPerPeer);
    TransportClient cachedClient = clientPool.clients[clientIndex];

    if (cachedClient != null && cachedClient.isActive()) {
    
    
      // Make sure that the channel will not timeout by updating the last use time of the
      // handler. Then check that the client is still alive, in case it timed out before
      // this code was able to update things.
      TransportChannelHandler handler = cachedClient.getChannel().pipeline()
        .get(TransportChannelHandler.class);
      synchronized (handler) {
    
    
        handler.getResponseHandler().updateTimeOfLastRequest();
      }

      if (cachedClient.isActive()) {
    
    
        logger.trace("Returning cached connection to {}: {}",
          cachedClient.getSocketAddress(), cachedClient);
        return cachedClient;
      }
    }

    // If we reach here, we don't have an existing connection open. Let's create a new one.
    // Multiple threads might race here to create new connections. Keep only one of them active.
    final long preResolveHost = System.nanoTime();
    final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
    final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
    if (hostResolveTimeMs > 2000) {
    
    
      logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    } else {
    
    
      logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
    }

    synchronized (clientPool.locks[clientIndex]) {
    
    
      cachedClient = clientPool.clients[clientIndex];

      if (cachedClient != null) {
    
    
        if (cachedClient.isActive()) {
    
    
          logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
          return cachedClient;
        } else {
    
    
          logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
        }
      }
      clientPool.clients[clientIndex] = createClient(resolvedAddress);
      return clientPool.clients[clientIndex];
    }
  }
 public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
      throws IOException, InterruptedException {
    
    
    final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
    return createClient(address);
  }

  /** Create a completely new {@link TransportClient} to the remote address. */
  private TransportClient createClient(InetSocketAddress address)
      throws IOException, InterruptedException {
    
    
    logger.debug("Creating new connection to {}", address);

    Bootstrap bootstrap = new Bootstrap();
    bootstrap.group(workerGroup)
      .channel(socketChannelClass)
      // Disable Nagle's Algorithm since we don't want packets to wait
      .option(ChannelOption.TCP_NODELAY, true)
      .option(ChannelOption.SO_KEEPALIVE, true)
      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
      .option(ChannelOption.ALLOCATOR, pooledAllocator);

    if (conf.receiveBuf() > 0) {
    
    
      bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
    }

    if (conf.sendBuf() > 0) {
    
    
      bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
    }

    final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
    final AtomicReference<Channel> channelRef = new AtomicReference<>();

    bootstrap.handler(new ChannelInitializer<SocketChannel>() {
    
    
      @Override
      public void initChannel(SocketChannel ch) {
    
    
        TransportChannelHandler clientHandler = context.initializePipeline(ch);
        clientRef.set(clientHandler.getClient());
        channelRef.set(ch);
      }
    });

    // Connect to the remote server
    long preConnect = System.nanoTime();
    ChannelFuture cf = bootstrap.connect(address);
    if (!cf.await(conf.connectionTimeoutMs())) {
    
    
      throw new IOException(
        String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
    } else if (cf.cause() != null) {
    
    
      throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
    }

    TransportClient client = clientRef.get();
    Channel channel = channelRef.get();
    assert client != null : "Channel future completed successfully with null client";

    // Execute any client bootstraps synchronously before marking the Client as successful.
    long preBootstrap = System.nanoTime();
    logger.debug("Connection to {} successful, running bootstraps...", address);
    try {
    
    
      for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
    
    
        clientBootstrap.doBootstrap(client, channel);
      }
    } catch (Exception e) {
    
     // catch non-RuntimeExceptions too as bootstrap may be written in Scala
      long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
      logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
      client.close();
      throw Throwables.propagate(e);
    }
    long postBootstrap = System.nanoTime();

    logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
      address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);

    return client;
  }

看过之前Netty的分析,这里就是建立Netty的连接,另外,这里可以发现,Excutor与Driver可以指定建立多少通道,通过io.numConnectionsPerPeer,默认是1

关于Executor是怎么启动的我们先分析到这里,接下来我们分析SparkContext的启动。

猜你喜欢

转载自blog.csdn.net/LeoHan163/article/details/120871024