View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm.netty;
20  
21  import org.apache.giraph.comm.flow_control.CreditBasedFlowControl;
22  import org.apache.giraph.comm.flow_control.FlowControl;
23  import org.apache.giraph.comm.flow_control.NoOpFlowControl;
24  import org.apache.giraph.comm.flow_control.StaticFlowControl;
25  import org.apache.giraph.comm.netty.handler.AckSignalFlag;
26  import org.apache.giraph.comm.netty.handler.TaskRequestIdGenerator;
27  import org.apache.giraph.comm.netty.handler.ClientRequestId;
28  import org.apache.giraph.comm.netty.handler.RequestEncoder;
29  import org.apache.giraph.comm.netty.handler.RequestInfo;
30  import org.apache.giraph.comm.netty.handler.RequestServerHandler;
31  import org.apache.giraph.comm.netty.handler.ResponseClientHandler;
32  /*if_not[HADOOP_NON_SECURE]*/
33  import org.apache.giraph.comm.netty.handler.SaslClientHandler;
34  import org.apache.giraph.comm.requests.RequestType;
35  import org.apache.giraph.comm.requests.SaslTokenMessageRequest;
36  /*end[HADOOP_NON_SECURE]*/
37  import org.apache.giraph.comm.requests.WritableRequest;
38  import org.apache.giraph.conf.BooleanConfOption;
39  import org.apache.giraph.conf.GiraphConstants;
40  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
41  import org.apache.giraph.function.Predicate;
42  import org.apache.giraph.graph.TaskInfo;
43  import org.apache.giraph.master.MasterInfo;
44  import org.apache.giraph.utils.PipelineUtils;
45  import org.apache.giraph.utils.ProgressableUtils;
46  import org.apache.giraph.utils.ThreadUtils;
47  import org.apache.giraph.utils.TimedLogger;
48  import org.apache.hadoop.mapreduce.Mapper;
49  import org.apache.log4j.Logger;
50  
51  import com.google.common.collect.Lists;
52  import com.google.common.collect.MapMaker;
53  import com.google.common.collect.Maps;
54  
55  /*if_not[HADOOP_NON_SECURE]*/
56  import java.io.IOException;
57  /*end[HADOOP_NON_SECURE]*/
58  import java.net.InetSocketAddress;
59  import java.util.Collection;
60  import java.util.Collections;
61  import java.util.Comparator;
62  import java.util.List;
63  import java.util.Map;
64  import java.util.concurrent.ConcurrentMap;
65  import java.util.concurrent.atomic.AtomicInteger;
66  import java.util.concurrent.atomic.AtomicLong;
67  
68  import io.netty.bootstrap.Bootstrap;
69  import io.netty.channel.Channel;
70  import io.netty.channel.ChannelFuture;
71  import io.netty.channel.ChannelFutureListener;
72  import io.netty.channel.ChannelHandlerContext;
73  import io.netty.channel.ChannelInitializer;
74  import io.netty.channel.ChannelOption;
75  import io.netty.channel.EventLoopGroup;
76  import io.netty.channel.nio.NioEventLoopGroup;
77  import io.netty.channel.socket.SocketChannel;
78  import io.netty.channel.socket.nio.NioSocketChannel;
79  import io.netty.handler.codec.FixedLengthFrameDecoder;
80  /*if_not[HADOOP_NON_SECURE]*/
81  import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
82  import io.netty.util.AttributeKey;
83  /*end[HADOOP_NON_SECURE]*/
84  import io.netty.util.concurrent.DefaultEventExecutorGroup;
85  import io.netty.util.concurrent.EventExecutorGroup;
86  
87  import static com.google.common.base.Preconditions.checkState;
88  import static org.apache.giraph.conf.GiraphConstants.CLIENT_RECEIVE_BUFFER_SIZE;
89  import static org.apache.giraph.conf.GiraphConstants.CLIENT_SEND_BUFFER_SIZE;
90  import static org.apache.giraph.conf.GiraphConstants.MAX_REQUEST_MILLISECONDS;
91  import static org.apache.giraph.conf.GiraphConstants.MAX_RESOLVE_ADDRESS_ATTEMPTS;
92  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_AFTER_HANDLER;
93  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_THREADS;
94  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_USE_EXECUTION_HANDLER;
95  import static org.apache.giraph.conf.GiraphConstants.NETTY_MAX_CONNECTION_FAILURES;
96  import static org.apache.giraph.conf.GiraphConstants.WAIT_TIME_BETWEEN_CONNECTION_RETRIES_MS;
97  import static org.apache.giraph.conf.GiraphConstants.WAITING_REQUEST_MSECS;
98  
99  /**
100  * Netty client for sending requests.  Thread-safe.
101  */
102 public class NettyClient {
103   /** Do we have a limit on number of open requests we can have */
104   public static final BooleanConfOption LIMIT_NUMBER_OF_OPEN_REQUESTS =
105       new BooleanConfOption("giraph.waitForRequestsConfirmation", false,
106           "Whether to have a limit on number of open requests or not");
107   /**
108    * Do we have a limit on number of open requests we can have for each worker.
109    * Note that if this option is enabled, Netty will not keep more than a
110    * certain number of requests open for each other worker in the job. If there
111    * are more requests generated for a worker, Netty will not actually send the
112    * surplus requests, instead, it caches the requests in a local buffer. The
113    * maximum number of these unsent requests in the cache is another
114    * user-defined parameter (MAX_NUM_OF_OPEN_REQUESTS_PER_WORKER).
115    */
116   public static final BooleanConfOption LIMIT_OPEN_REQUESTS_PER_WORKER =
117       new BooleanConfOption("giraph.waitForPerWorkerRequests", false,
118           "Whether to have a limit on number of open requests for each worker" +
119               "or not");
120   /** Maximum number of requests to list (for debugging) */
121   public static final int MAX_REQUESTS_TO_LIST = 10;
122   /**
123    * Maximum number of destination task ids with open requests to list
124    * (for debugging)
125    */
126   public static final int MAX_DESTINATION_TASK_IDS_TO_LIST = 10;
127   /** 30 seconds to connect by default */
128   public static final int MAX_CONNECTION_MILLISECONDS_DEFAULT = 30 * 1000;
129 /*if_not[HADOOP_NON_SECURE]*/
130   /** Used to authenticate with other workers acting as servers */
131   public static final AttributeKey<SaslNettyClient> SASL =
132       AttributeKey.valueOf("saslNettyClient");
133 /*end[HADOOP_NON_SECURE]*/
134   /** Class logger */
135   private static final Logger LOG = Logger.getLogger(NettyClient.class);
136   /** Context used to report progress */
137   private final Mapper<?, ?, ?, ?>.Context context;
138   /** Client bootstrap */
139   private final Bootstrap bootstrap;
140   /**
141    * Map of the peer connections, mapping from remote socket address to client
142    * meta data
143    */
144   private final ConcurrentMap<InetSocketAddress, ChannelRotater>
145   addressChannelMap = new MapMaker().makeMap();
146   /**
147    * Map from task id to address of its server
148    */
149   private final Map<Integer, InetSocketAddress> taskIdAddressMap =
150       new MapMaker().makeMap();
151   /**
152    * Request map of client request ids to request information.
153    */
154   private final ConcurrentMap<ClientRequestId, RequestInfo>
155   clientRequestIdRequestInfoMap;
156   /** Number of channels per server */
157   private final int channelsPerServer;
158   /** Inbound byte counter for this client */
159   private final InboundByteCounter inboundByteCounter = new
160       InboundByteCounter();
161   /** Outbound byte counter for this client */
162   private final OutboundByteCounter outboundByteCounter = new
163       OutboundByteCounter();
164   /** Send buffer size */
165   private final int sendBufferSize;
166   /** Receive buffer size */
167   private final int receiveBufferSize;
168   /** Warn if request size is bigger than the buffer size by this factor */
169   private final float requestSizeWarningThreshold;
170   /** Maximum number of connection failures */
171   private final int maxConnectionFailures;
172   /** How long to wait before trying to reconnect failed connections */
173   private final long waitTimeBetweenConnectionRetriesMs;
174   /** Maximum number of milliseconds for a request */
175   private final int maxRequestMilliseconds;
176   /** Waiting interval for checking outstanding requests msecs */
177   private final int waitingRequestMsecs;
178   /** Timed logger for printing request debugging */
179   private final TimedLogger requestLogger;
180   /** Worker executor group */
181   private final EventLoopGroup workerGroup;
182   /** Task request id generator */
183   private final TaskRequestIdGenerator taskRequestIdGenerator =
184       new TaskRequestIdGenerator();
185   /** Task info */
186   private final TaskInfo myTaskInfo;
187   /** Maximum thread pool size */
188   private final int maxPoolSize;
189   /** Maximum number of attempts to resolve an address*/
190   private final int maxResolveAddressAttempts;
191   /** Use execution handler? */
192   private final boolean useExecutionGroup;
193   /** EventExecutor Group (if used) */
194   private final EventExecutorGroup executionGroup;
195   /** Name of the handler to use execution group for (if used) */
196   private final String handlerToUseExecutionGroup;
197   /** When was the last time we checked if we should resend some requests */
198   private final AtomicLong lastTimeCheckedRequestsForProblems =
199       new AtomicLong(0);
200   /**
201    * Logger used to dump stack traces for every exception that happens
202    * in netty client threads.
203    */
204   private final LogOnErrorChannelFutureListener logErrorListener =
205       new LogOnErrorChannelFutureListener();
206   /** Flow control policy used */
207   private final FlowControl flowControl;
208 
209   /**
210    * Only constructor
211    *
212    * @param context Context for progress
213    * @param conf Configuration
214    * @param myTaskInfo Current task info
215    * @param exceptionHandler handler for uncaught exception. Will
216    *                         terminate job.
217    */
218   public NettyClient(Mapper<?, ?, ?, ?>.Context context,
219                      final ImmutableClassesGiraphConfiguration conf,
220                      TaskInfo myTaskInfo,
221                      final Thread.UncaughtExceptionHandler exceptionHandler) {
222     this.context = context;
223     this.myTaskInfo = myTaskInfo;
224     this.channelsPerServer = GiraphConstants.CHANNELS_PER_SERVER.get(conf);
225     sendBufferSize = CLIENT_SEND_BUFFER_SIZE.get(conf);
226     receiveBufferSize = CLIENT_RECEIVE_BUFFER_SIZE.get(conf);
227     this.requestSizeWarningThreshold =
228         GiraphConstants.REQUEST_SIZE_WARNING_THRESHOLD.get(conf);
229 
230     boolean limitNumberOfOpenRequests = LIMIT_NUMBER_OF_OPEN_REQUESTS.get(conf);
231     boolean limitOpenRequestsPerWorker =
232         LIMIT_OPEN_REQUESTS_PER_WORKER.get(conf);
233     checkState(!limitNumberOfOpenRequests || !limitOpenRequestsPerWorker,
234         "NettyClient: it is not allowed to have both limitations on the " +
235             "number of total open requests, and on the number of open " +
236             "requests per worker!");
237     if (limitNumberOfOpenRequests) {
238       flowControl = new StaticFlowControl(conf, this);
239     } else if (limitOpenRequestsPerWorker) {
240       flowControl = new CreditBasedFlowControl(conf, this, exceptionHandler);
241     } else {
242       flowControl = new NoOpFlowControl(this);
243     }
244 
245     maxRequestMilliseconds = MAX_REQUEST_MILLISECONDS.get(conf);
246     maxConnectionFailures = NETTY_MAX_CONNECTION_FAILURES.get(conf);
247     waitTimeBetweenConnectionRetriesMs =
248         WAIT_TIME_BETWEEN_CONNECTION_RETRIES_MS.get(conf);
249     waitingRequestMsecs = WAITING_REQUEST_MSECS.get(conf);
250     requestLogger = new TimedLogger(waitingRequestMsecs, LOG);
251     maxPoolSize = GiraphConstants.NETTY_CLIENT_THREADS.get(conf);
252     maxResolveAddressAttempts = MAX_RESOLVE_ADDRESS_ATTEMPTS.get(conf);
253 
254     clientRequestIdRequestInfoMap =
255         new MapMaker().concurrencyLevel(maxPoolSize).makeMap();
256 
257     handlerToUseExecutionGroup =
258         NETTY_CLIENT_EXECUTION_AFTER_HANDLER.get(conf);
259     useExecutionGroup = NETTY_CLIENT_USE_EXECUTION_HANDLER.get(conf);
260     if (useExecutionGroup) {
261       int executionThreads = NETTY_CLIENT_EXECUTION_THREADS.get(conf);
262       executionGroup = new DefaultEventExecutorGroup(executionThreads,
263           ThreadUtils.createThreadFactory(
264               "netty-client-exec-%d", exceptionHandler));
265       if (LOG.isInfoEnabled()) {
266         LOG.info("NettyClient: Using execution handler with " +
267             executionThreads + " threads after " +
268             handlerToUseExecutionGroup + ".");
269       }
270     } else {
271       executionGroup = null;
272     }
273 
274     workerGroup = new NioEventLoopGroup(maxPoolSize,
275         ThreadUtils.createThreadFactory(
276             "netty-client-worker-%d", exceptionHandler));
277 
278     bootstrap = new Bootstrap();
279     bootstrap.group(workerGroup)
280         .channel(NioSocketChannel.class)
281         .option(ChannelOption.CONNECT_TIMEOUT_MILLIS,
282             MAX_CONNECTION_MILLISECONDS_DEFAULT)
283         .option(ChannelOption.TCP_NODELAY, true)
284         .option(ChannelOption.SO_KEEPALIVE, true)
285         .option(ChannelOption.SO_SNDBUF, sendBufferSize)
286         .option(ChannelOption.SO_RCVBUF, receiveBufferSize)
287         .option(ChannelOption.ALLOCATOR, conf.getNettyAllocator())
288         .handler(new ChannelInitializer<SocketChannel>() {
289           @Override
290           protected void initChannel(SocketChannel ch) throws Exception {
291 /*if_not[HADOOP_NON_SECURE]*/
292             if (conf.authenticate()) {
293               LOG.info("Using Netty with authentication.");
294 
295               // Our pipeline starts with just byteCounter, and then we use
296               // addLast() to incrementally add pipeline elements, so that we
297               // can name them for identification for removal or replacement
298               // after client is authenticated by server.
299               // After authentication is complete, the pipeline's SASL-specific
300               // functionality is removed, restoring the pipeline to exactly the
301               // same configuration as it would be without authentication.
302               PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter",
303                   inboundByteCounter, handlerToUseExecutionGroup,
304                   executionGroup, ch);
305               if (conf.doCompression()) {
306                 PipelineUtils.addLastWithExecutorCheck("compressionDecoder",
307                     conf.getNettyCompressionDecoder(),
308                     handlerToUseExecutionGroup, executionGroup, ch);
309               }
310               PipelineUtils.addLastWithExecutorCheck(
311                   "clientOutboundByteCounter",
312                   outboundByteCounter, handlerToUseExecutionGroup,
313                   executionGroup, ch);
314               if (conf.doCompression()) {
315                 PipelineUtils.addLastWithExecutorCheck("compressionEncoder",
316                     conf.getNettyCompressionEncoder(),
317                     handlerToUseExecutionGroup, executionGroup, ch);
318               }
319               // The following pipeline component is needed to decode the
320               // server's SASL tokens. It is replaced with a
321               // FixedLengthFrameDecoder (same as used with the
322               // non-authenticated pipeline) after authentication
323               // completes (as in non-auth pipeline below).
324               PipelineUtils.addLastWithExecutorCheck(
325                   "length-field-based-frame-decoder",
326                   new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4),
327                   handlerToUseExecutionGroup, executionGroup, ch);
328               PipelineUtils.addLastWithExecutorCheck("request-encoder",
329                   new RequestEncoder(conf), handlerToUseExecutionGroup,
330                   executionGroup, ch);
331               // The following pipeline component responds to the server's SASL
332               // tokens with its own responses. Both client and server share the
333               // same Hadoop Job token, which is used to create the SASL
334               // tokens to authenticate with each other.
335               // After authentication finishes, this pipeline component
336               // is removed.
337               PipelineUtils.addLastWithExecutorCheck("sasl-client-handler",
338                   new SaslClientHandler(conf), handlerToUseExecutionGroup,
339                   executionGroup, ch);
340               PipelineUtils.addLastWithExecutorCheck("response-handler",
341                   new ResponseClientHandler(NettyClient.this, conf),
342                   handlerToUseExecutionGroup, executionGroup, ch);
343             } else {
344               LOG.info("Using Netty without authentication.");
345 /*end[HADOOP_NON_SECURE]*/
346               PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter",
347                   inboundByteCounter, handlerToUseExecutionGroup,
348                   executionGroup, ch);
349               if (conf.doCompression()) {
350                 PipelineUtils.addLastWithExecutorCheck("compressionDecoder",
351                     conf.getNettyCompressionDecoder(),
352                     handlerToUseExecutionGroup, executionGroup, ch);
353               }
354               PipelineUtils.addLastWithExecutorCheck(
355                   "clientOutboundByteCounter",
356                   outboundByteCounter, handlerToUseExecutionGroup,
357                   executionGroup, ch);
358               if (conf.doCompression()) {
359                 PipelineUtils.addLastWithExecutorCheck("compressionEncoder",
360                     conf.getNettyCompressionEncoder(),
361                     handlerToUseExecutionGroup, executionGroup, ch);
362               }
363               PipelineUtils.addLastWithExecutorCheck(
364                   "fixed-length-frame-decoder",
365                   new FixedLengthFrameDecoder(
366                       RequestServerHandler.RESPONSE_BYTES),
367                  handlerToUseExecutionGroup, executionGroup, ch);
368               PipelineUtils.addLastWithExecutorCheck("request-encoder",
369                     new RequestEncoder(conf), handlerToUseExecutionGroup,
370                   executionGroup, ch);
371               PipelineUtils.addLastWithExecutorCheck("response-handler",
372                   new ResponseClientHandler(NettyClient.this, conf),
373                   handlerToUseExecutionGroup, executionGroup, ch);
374 
375 /*if_not[HADOOP_NON_SECURE]*/
376             }
377 /*end[HADOOP_NON_SECURE]*/
378           }
379 
380           @Override
381           public void channelUnregistered(ChannelHandlerContext ctx) throws
382               Exception {
383             super.channelUnregistered(ctx);
384             LOG.error("Channel failed " + ctx.channel());
385             checkRequestsAfterChannelFailure(ctx.channel());
386           }
387         });
388 
389     // Start a thread which will observe if there are any problems with open
390     // requests
391     ThreadUtils.startThread(new Runnable() {
392       @Override
393       public void run() {
394         while (true) {
395           ThreadUtils.trySleep(waitingRequestMsecs);
396           checkRequestsForProblems();
397         }
398       }
399     }, "open-requests-observer");
400   }
401 
402   /**
403    * Whether master task is involved in the communication with a given client
404    *
405    * @param clientId id of the communication (on the end of the communication)
406    * @return true if master is on one end of the communication
407    */
408   public boolean masterInvolved(int clientId) {
409     return myTaskInfo.getTaskId() == MasterInfo.MASTER_TASK_ID ||
410         clientId == MasterInfo.MASTER_TASK_ID;
411   }
412 
413   /**
414    * Pair object for connectAllAddresses().
415    */
416   private static class ChannelFutureAddress {
417     /** Future object */
418     private final ChannelFuture future;
419     /** Address of the future */
420     private final InetSocketAddress address;
421     /** Task id */
422     private final Integer taskId;
423 
424     /**
425      * Constructor.
426      *
427      * @param future Immutable future
428      * @param address Immutable address
429      * @param taskId Immutable taskId
430      */
431     ChannelFutureAddress(
432         ChannelFuture future, InetSocketAddress address, Integer taskId) {
433       this.future = future;
434       this.address = address;
435       this.taskId = taskId;
436     }
437 
438     @Override
439     public String toString() {
440       return "(future=" + future + ",address=" + address + ",taskId=" +
441           taskId + ")";
442     }
443   }
444 
445   /**
446    * Connect to a collection of tasks servers
447    *
448    * @param tasks Tasks to connect to (if haven't already connected)
449    */
450   public void connectAllAddresses(Collection<? extends TaskInfo> tasks) {
451     List<ChannelFutureAddress> waitingConnectionList =
452         Lists.newArrayListWithCapacity(tasks.size() * channelsPerServer);
453     for (TaskInfo taskInfo : tasks) {
454       context.progress();
455       int taskId = taskInfo.getTaskId();
456       InetSocketAddress address = taskIdAddressMap.get(taskId);
457       if (address == null ||
458           !address.getHostName().equals(taskInfo.getHostname()) ||
459           address.getPort() != taskInfo.getPort()) {
460         address = resolveAddress(maxResolveAddressAttempts,
461             taskInfo.getHostOrIp(), taskInfo.getPort());
462         taskIdAddressMap.put(taskId, address);
463       }
464       if (address == null || address.getHostName() == null ||
465           address.getHostName().isEmpty()) {
466         throw new IllegalStateException("connectAllAddresses: Null address " +
467             "in addresses " + tasks);
468       }
469       if (address.isUnresolved()) {
470         throw new IllegalStateException("connectAllAddresses: Unresolved " +
471             "address " + address);
472       }
473 
474       if (addressChannelMap.containsKey(address)) {
475         continue;
476       }
477 
478       // Start connecting to the remote server up to n time
479       for (int i = 0; i < channelsPerServer; ++i) {
480         ChannelFuture connectionFuture = bootstrap.connect(address);
481 
482         waitingConnectionList.add(
483             new ChannelFutureAddress(
484                 connectionFuture, address, taskId));
485       }
486     }
487 
488     // Wait for all the connections to succeed up to n tries
489     int failures = 0;
490     int connected = 0;
491     while (failures < maxConnectionFailures) {
492       List<ChannelFutureAddress> nextCheckFutures = Lists.newArrayList();
493       boolean isFirstFailure = true;
494       for (ChannelFutureAddress waitingConnection : waitingConnectionList) {
495         context.progress();
496         ChannelFuture future = waitingConnection.future;
497         ProgressableUtils.awaitChannelFuture(future, context);
498         if (!future.isSuccess() || !future.channel().isOpen()) {
499           // Make a short pause before trying to reconnect failed addresses
500           // again, but to do it just once per iterating through channels
501           if (isFirstFailure) {
502             isFirstFailure = false;
503             try {
504               Thread.sleep(waitTimeBetweenConnectionRetriesMs);
505             } catch (InterruptedException e) {
506               throw new IllegalStateException(
507                   "connectAllAddresses: InterruptedException occurred", e);
508             }
509           }
510 
511           LOG.warn("connectAllAddresses: Future failed " +
512               "to connect with " + waitingConnection.address + " with " +
513               failures + " failures because of " + future.cause());
514 
515           ChannelFuture connectionFuture =
516               bootstrap.connect(waitingConnection.address);
517           nextCheckFutures.add(new ChannelFutureAddress(connectionFuture,
518               waitingConnection.address, waitingConnection.taskId));
519           ++failures;
520         } else {
521           Channel channel = future.channel();
522           if (LOG.isDebugEnabled()) {
523             LOG.debug("connectAllAddresses: Connected to " +
524                 channel.remoteAddress() + ", open = " + channel.isOpen());
525           }
526 
527           if (channel.remoteAddress() == null) {
528             throw new IllegalStateException(
529                 "connectAllAddresses: Null remote address!");
530           }
531 
532           ChannelRotater rotater =
533               addressChannelMap.get(waitingConnection.address);
534           if (rotater == null) {
535             ChannelRotater newRotater =
536                 new ChannelRotater(waitingConnection.taskId,
537                     waitingConnection.address);
538             rotater = addressChannelMap.putIfAbsent(
539                 waitingConnection.address, newRotater);
540             if (rotater == null) {
541               rotater = newRotater;
542             }
543           }
544           rotater.addChannel(future.channel());
545           ++connected;
546         }
547       }
548       LOG.info("connectAllAddresses: Successfully added " +
549           (waitingConnectionList.size() - nextCheckFutures.size()) +
550           " connections, (" + connected + " total connected) " +
551           nextCheckFutures.size() + " failed, " +
552           failures + " failures total.");
553       if (nextCheckFutures.isEmpty()) {
554         break;
555       }
556       waitingConnectionList = nextCheckFutures;
557     }
558     if (failures >= maxConnectionFailures) {
559       throw new IllegalStateException(
560           "connectAllAddresses: Too many failures (" + failures + ").");
561     }
562   }
563 
564 /*if_not[HADOOP_NON_SECURE]*/
565   /**
566    * Authenticate all servers in addressChannelMap.
567    */
568   public void authenticate() {
569     LOG.info("authenticate: NettyClient starting authentication with " +
570         "servers.");
571     for (InetSocketAddress address: addressChannelMap.keySet()) {
572       if (LOG.isDebugEnabled()) {
573         LOG.debug("authenticate: Authenticating with address:" + address);
574       }
575       ChannelRotater channelRotater = addressChannelMap.get(address);
576       for (Channel channel: channelRotater.getChannels()) {
577         if (LOG.isDebugEnabled()) {
578           LOG.debug("authenticate: Authenticating with server on channel: " +
579               channel);
580         }
581         authenticateOnChannel(channelRotater.getTaskId(), channel);
582       }
583     }
584     if (LOG.isInfoEnabled()) {
585       LOG.info("authenticate: NettyClient successfully authenticated with " +
586           addressChannelMap.size() + " server" +
587           ((addressChannelMap.size() != 1) ? "s" : "") +
588           " - continuing with normal work.");
589     }
590   }
591 
592   /**
593    * Authenticate with server connected at given channel.
594    *
595    * @param taskId Task id of the channel
596    * @param channel Connection to server to authenticate with.
597    */
598   private void authenticateOnChannel(Integer taskId, Channel channel) {
599     try {
600       SaslNettyClient saslNettyClient = channel.attr(SASL).get();
601       if (channel.attr(SASL).get() == null) {
602         if (LOG.isDebugEnabled()) {
603           LOG.debug("authenticateOnChannel: Creating saslNettyClient now " +
604               "for channel: " + channel);
605         }
606         saslNettyClient = new SaslNettyClient();
607         channel.attr(SASL).set(saslNettyClient);
608       }
609       if (!saslNettyClient.isComplete()) {
610         if (LOG.isDebugEnabled()) {
611           LOG.debug("authenticateOnChannel: Waiting for authentication " +
612               "to complete..");
613         }
614         SaslTokenMessageRequest saslTokenMessage = saslNettyClient.firstToken();
615         sendWritableRequest(taskId, saslTokenMessage);
616         // We now wait for Netty's thread pool to communicate over this
617         // channel to authenticate with another worker acting as a server.
618         try {
619           synchronized (saslNettyClient.getAuthenticated()) {
620             while (!saslNettyClient.isComplete()) {
621               saslNettyClient.getAuthenticated().wait();
622             }
623           }
624         } catch (InterruptedException e) {
625           LOG.error("authenticateOnChannel: Interrupted while waiting for " +
626               "authentication.");
627         }
628       }
629       if (LOG.isDebugEnabled()) {
630         LOG.debug("authenticateOnChannel: Authentication on channel: " +
631             channel + " has completed successfully.");
632       }
633     } catch (IOException e) {
634       LOG.error("authenticateOnChannel: Failed to authenticate with server " +
635           "due to error: " + e);
636     }
637     return;
638   }
639 /*end[HADOOP_NON_SECURE]*/
640 
641   /**
642    * Stop the client.
643    */
644   public void stop() {
645     if (LOG.isInfoEnabled()) {
646       LOG.info("stop: Halting netty client");
647     }
648     // Close connections asynchronously, in a Netty-approved
649     // way, without cleaning up thread pools until all channels
650     // in addressChannelMap are closed (success or failure)
651     int channelCount = 0;
652     for (ChannelRotater channelRotater : addressChannelMap.values()) {
653       channelCount += channelRotater.size();
654     }
655     final int done = channelCount;
656     final AtomicInteger count = new AtomicInteger(0);
657     for (ChannelRotater channelRotater : addressChannelMap.values()) {
658       channelRotater.closeChannels(new ChannelFutureListener() {
659         @Override
660         public void operationComplete(ChannelFuture cf) {
661           context.progress();
662           if (count.incrementAndGet() == done) {
663             if (LOG.isInfoEnabled()) {
664               LOG.info("stop: reached wait threshold, " +
665                   done + " connections closed, releasing " +
666                   "resources now.");
667             }
668             workerGroup.shutdownGracefully();
669             if (executionGroup != null) {
670               executionGroup.shutdownGracefully();
671             }
672           }
673         }
674       });
675     }
676     ProgressableUtils.awaitTerminationFuture(workerGroup, context);
677     if (executionGroup != null) {
678       ProgressableUtils.awaitTerminationFuture(executionGroup, context);
679     }
680     if (LOG.isInfoEnabled()) {
681       LOG.info("stop: Netty client halted");
682     }
683   }
684 
685   /**
686    * Get the next available channel, reconnecting if necessary
687    *
688    * @param remoteServer Remote server to get a channel for
689    * @return Available channel for this remote server
690    */
691   private Channel getNextChannel(InetSocketAddress remoteServer) {
692     Channel channel = addressChannelMap.get(remoteServer).nextChannel();
693     if (channel == null) {
694       throw new IllegalStateException(
695           "getNextChannel: No channel exists for " + remoteServer);
696     }
697 
698     // Return this channel if it is connected
699     if (channel.isActive()) {
700       return channel;
701     }
702 
703     // Get rid of the failed channel
704     if (addressChannelMap.get(remoteServer).removeChannel(channel)) {
705       LOG.warn("getNextChannel: Unlikely event that the channel " +
706           channel + " was already removed!");
707     }
708     if (LOG.isInfoEnabled()) {
709       LOG.info("getNextChannel: Fixing disconnected channel to " +
710           remoteServer + ", open = " + channel.isOpen() + ", " +
711           "bound = " + channel.isRegistered());
712     }
713     int reconnectFailures = 0;
714     while (reconnectFailures < maxConnectionFailures) {
715       ChannelFuture connectionFuture = bootstrap.connect(remoteServer);
716       ProgressableUtils.awaitChannelFuture(connectionFuture, context);
717       if (connectionFuture.isSuccess()) {
718         if (LOG.isInfoEnabled()) {
719           LOG.info("getNextChannel: Connected to " + remoteServer + "!");
720         }
721         addressChannelMap.get(remoteServer).addChannel(
722             connectionFuture.channel());
723         return connectionFuture.channel();
724       }
725       ++reconnectFailures;
726       LOG.warn("getNextChannel: Failed to reconnect to " +  remoteServer +
727           " on attempt " + reconnectFailures + " out of " +
728           maxConnectionFailures + " max attempts, sleeping for 5 secs",
729           connectionFuture.cause());
730       ThreadUtils.trySleep(5000);
731     }
732     throw new IllegalStateException("getNextChannel: Failed to connect " +
733         "to " + remoteServer + " in " + reconnectFailures +
734         " connect attempts");
735   }
736 
737   /**
738    * Send a request to a remote server honoring the flow control mechanism
739    * (should be already connected)
740    *
741    * @param destTaskId Destination task id
742    * @param request Request to send
743    */
744   public void sendWritableRequest(int destTaskId, WritableRequest request) {
745     flowControl.sendRequest(destTaskId, request);
746   }
747 
748   /**
749    * Actual send of a request.
750    *
751    * @param destTaskId destination to send the request to
752    * @param request request itself
753    * @return request id generated for sending the request
754    */
755   public Long doSend(int destTaskId, WritableRequest request) {
756     InetSocketAddress remoteServer = taskIdAddressMap.get(destTaskId);
757     if (clientRequestIdRequestInfoMap.isEmpty()) {
758       inboundByteCounter.resetAll();
759       outboundByteCounter.resetAll();
760     }
761     boolean registerRequest = true;
762     Long requestId = null;
763 /*if_not[HADOOP_NON_SECURE]*/
764     if (request.getType() == RequestType.SASL_TOKEN_MESSAGE_REQUEST) {
765       registerRequest = false;
766     }
767 /*end[HADOOP_NON_SECURE]*/
768 
769     RequestInfo newRequestInfo = new RequestInfo(remoteServer, request);
770     if (registerRequest) {
771       request.setClientId(myTaskInfo.getTaskId());
772       requestId = taskRequestIdGenerator.getNextRequestId(destTaskId);
773       request.setRequestId(requestId);
774       ClientRequestId clientRequestId =
775         new ClientRequestId(destTaskId, request.getRequestId());
776       RequestInfo oldRequestInfo = clientRequestIdRequestInfoMap.putIfAbsent(
777         clientRequestId, newRequestInfo);
778       if (oldRequestInfo != null) {
779         throw new IllegalStateException("sendWritableRequest: Impossible to " +
780           "have a previous request id = " + request.getRequestId() + ", " +
781           "request info of " + oldRequestInfo);
782       }
783     }
784     if (request.getSerializedSize() >
785         requestSizeWarningThreshold * sendBufferSize) {
786       LOG.warn("Creating large request of type " + request.getClass() +
787         ", size " + request.getSerializedSize() +
788         " bytes. Check netty buffer size.");
789     }
790     writeRequestToChannel(newRequestInfo);
791     return requestId;
792   }
793 
794   /**
795    * Write request to a channel for its destination
796    *
797    * @param requestInfo Request info
798    */
799   private void writeRequestToChannel(RequestInfo requestInfo) {
800     Channel channel = getNextChannel(requestInfo.getDestinationAddress());
801     ChannelFuture writeFuture = channel.write(requestInfo.getRequest());
802     requestInfo.setWriteFuture(writeFuture);
803     writeFuture.addListener(logErrorListener);
804   }
805 
806   /**
807    * Handle receipt of a message. Called by response handler.
808    *
809    * @param senderId Id of sender of the message
810    * @param requestId Id of the request
811    * @param response Actual response
812    * @param shouldDrop Drop the message?
813    */
814   public void messageReceived(int senderId, long requestId, int response,
815       boolean shouldDrop) {
816     if (shouldDrop) {
817       synchronized (clientRequestIdRequestInfoMap) {
818         clientRequestIdRequestInfoMap.notifyAll();
819       }
820       return;
821     }
822     AckSignalFlag responseFlag = flowControl.getAckSignalFlag(response);
823     if (responseFlag == AckSignalFlag.DUPLICATE_REQUEST) {
824       LOG.info("messageReceived: Already completed request (taskId = " +
825           senderId + ", requestId = " + requestId + ")");
826     } else if (responseFlag != AckSignalFlag.NEW_REQUEST) {
827       throw new IllegalStateException(
828           "messageReceived: Got illegal response " + response);
829     }
830     RequestInfo requestInfo = clientRequestIdRequestInfoMap
831         .remove(new ClientRequestId(senderId, requestId));
832     if (requestInfo == null) {
833       LOG.info("messageReceived: Already received response for (taskId = " +
834           senderId + ", requestId = " + requestId + ")");
835     } else {
836       if (LOG.isDebugEnabled()) {
837         LOG.debug("messageReceived: Completed (taskId = " + senderId + ")" +
838             requestInfo + ".  Waiting on " +
839             clientRequestIdRequestInfoMap.size() + " requests");
840       }
841       flowControl.messageAckReceived(senderId, requestId, response);
842       // Help #waitAllRequests() to finish faster
843       synchronized (clientRequestIdRequestInfoMap) {
844         clientRequestIdRequestInfoMap.notifyAll();
845       }
846     }
847   }
848 
849   /**
850    * Ensure all the request sent so far are complete. Periodically check the
851    * state of current open requests. If there is an issue in any of them,
852    * re-send the request.
853    */
854   public void waitAllRequests() {
855     flowControl.waitAllRequests();
856     checkState(flowControl.getNumberOfUnsentRequests() == 0);
857     while (clientRequestIdRequestInfoMap.size() > 0) {
858       // Wait for requests to complete for some time
859       synchronized (clientRequestIdRequestInfoMap) {
860         if (clientRequestIdRequestInfoMap.size() == 0) {
861           break;
862         }
863         try {
864           clientRequestIdRequestInfoMap.wait(waitingRequestMsecs);
865         } catch (InterruptedException e) {
866           throw new IllegalStateException("waitAllRequests: Got unexpected " +
867               "InterruptedException", e);
868         }
869       }
870       logAndSanityCheck();
871     }
872     if (LOG.isInfoEnabled()) {
873       LOG.info("waitAllRequests: Finished all requests. " +
874           inboundByteCounter.getMetrics() + "\n" + outboundByteCounter
875           .getMetrics());
876     }
877   }
878 
879   /**
880    * Log information about the requests and check for problems in requests
881    */
882   public void logAndSanityCheck() {
883     logInfoAboutOpenRequests();
884     // Make sure that waiting doesn't kill the job
885     context.progress();
886   }
887 
888   /**
889    * Log the status of open requests.
890    */
891   private void logInfoAboutOpenRequests() {
892     if (LOG.isInfoEnabled() && requestLogger.isPrintable()) {
893       LOG.info("logInfoAboutOpenRequests: Waiting interval of " +
894           waitingRequestMsecs + " msecs, " +
895           clientRequestIdRequestInfoMap.size() +
896           " open requests, " + inboundByteCounter.getMetrics() + "\n" +
897           outboundByteCounter.getMetrics());
898 
899       if (clientRequestIdRequestInfoMap.size() < MAX_REQUESTS_TO_LIST) {
900         for (Map.Entry<ClientRequestId, RequestInfo> entry :
901             clientRequestIdRequestInfoMap.entrySet()) {
902           LOG.info("logInfoAboutOpenRequests: Waiting for request " +
903               entry.getKey() + " - " + entry.getValue());
904         }
905       }
906 
907       // Count how many open requests each task has
908       Map<Integer, Integer> openRequestCounts = Maps.newHashMap();
909       for (ClientRequestId clientRequestId :
910           clientRequestIdRequestInfoMap.keySet()) {
911         int taskId = clientRequestId.getDestinationTaskId();
912         Integer currentCount = openRequestCounts.get(taskId);
913         openRequestCounts.put(taskId,
914             (currentCount == null ? 0 : currentCount) + 1);
915       }
916       // Sort it in decreasing order of number of open requests
917       List<Map.Entry<Integer, Integer>> sorted =
918           Lists.newArrayList(openRequestCounts.entrySet());
919       Collections.sort(sorted, new Comparator<Map.Entry<Integer, Integer>>() {
920         @Override
921         public int compare(Map.Entry<Integer, Integer> entry1,
922             Map.Entry<Integer, Integer> entry2) {
923           int value1 = entry1.getValue();
924           int value2 = entry2.getValue();
925           return (value1 < value2) ? 1 : ((value1 == value2) ? 0 : -1);
926         }
927       });
928       // Print task ids which have the most open requests
929       StringBuilder message = new StringBuilder();
930       message.append("logInfoAboutOpenRequests: ");
931       int itemsToPrint =
932           Math.min(MAX_DESTINATION_TASK_IDS_TO_LIST, sorted.size());
933       for (int i = 0; i < itemsToPrint; i++) {
934         message.append(sorted.get(i).getValue())
935             .append(" requests for taskId=")
936             .append(sorted.get(i).getKey())
937             .append(", ");
938       }
939       LOG.info(message);
940       flowControl.logInfo();
941     }
942   }
943 
944   /**
945    * Check if there are some open requests which have been sent a long time
946    * ago, and if so resend them.
947    */
948   private void checkRequestsForProblems() {
949     long lastTimeChecked = lastTimeCheckedRequestsForProblems.get();
950     // If not enough time passed from the previous check, return
951     if (System.currentTimeMillis() < lastTimeChecked + waitingRequestMsecs) {
952       return;
953     }
954     // If another thread did the check already, return
955     if (!lastTimeCheckedRequestsForProblems.compareAndSet(lastTimeChecked,
956         System.currentTimeMillis())) {
957       return;
958     }
959     resendRequestsWhenNeeded(new Predicate<RequestInfo>() {
960       @Override
961       public boolean apply(RequestInfo requestInfo) {
962         ChannelFuture writeFuture = requestInfo.getWriteFuture();
963         // If not connected anymore, request failed, or the request is taking
964         // too long, re-establish and resend
965         return (writeFuture != null && (!writeFuture.channel().isActive() ||
966             (writeFuture.isDone() && !writeFuture.isSuccess()))) ||
967             (requestInfo.getElapsedMsecs() > maxRequestMilliseconds);
968       }
969     });
970   }
971 
972   /**
973    * Resend requests which satisfy predicate
974    *
975    * @param shouldResendRequestPredicate Predicate to use to check whether
976    *                                     request should be resent
977    */
978   private void resendRequestsWhenNeeded(
979       Predicate<RequestInfo> shouldResendRequestPredicate) {
980     // Check if there are open requests which have been sent a long time ago,
981     // and if so, resend them.
982     List<ClientRequestId> addedRequestIds = Lists.newArrayList();
983     List<RequestInfo> addedRequestInfos = Lists.newArrayList();
984     // Check all the requests for problems
985     for (Map.Entry<ClientRequestId, RequestInfo> entry :
986         clientRequestIdRequestInfoMap.entrySet()) {
987       RequestInfo requestInfo = entry.getValue();
988       // If request should be resent
989       if (shouldResendRequestPredicate.apply(requestInfo)) {
990         ChannelFuture writeFuture = requestInfo.getWriteFuture();
991         String logMessage;
992         if (writeFuture == null) {
993           logMessage = "wasn't sent successfully";
994         } else {
995           logMessage = "connected = " +
996               writeFuture.channel().isActive() +
997               ", future done = " + writeFuture.isDone() + ", " +
998               "success = " + writeFuture.isSuccess() + ", " +
999               "cause = " + writeFuture.cause();
1000         }
1001         LOG.warn("checkRequestsForProblems: Problem with request id " +
1002             entry.getKey() + ", " + logMessage + ", " +
1003             "elapsed time = " + requestInfo.getElapsedMsecs() + ", " +
1004             "destination = " + requestInfo.getDestinationAddress() +
1005             " " + requestInfo);
1006         addedRequestIds.add(entry.getKey());
1007         addedRequestInfos.add(new RequestInfo(
1008             requestInfo.getDestinationAddress(), requestInfo.getRequest()));
1009       }
1010     }
1011 
1012     // Add any new requests to the system, connect if necessary, and re-send
1013     for (int i = 0; i < addedRequestIds.size(); ++i) {
1014       ClientRequestId requestId = addedRequestIds.get(i);
1015       RequestInfo requestInfo = addedRequestInfos.get(i);
1016 
1017       if (clientRequestIdRequestInfoMap.put(requestId, requestInfo) == null) {
1018         LOG.warn("checkRequestsForProblems: Request " + requestId +
1019             " completed prior to sending the next request");
1020         clientRequestIdRequestInfoMap.remove(requestId);
1021       }
1022       if (LOG.isInfoEnabled()) {
1023         LOG.info("checkRequestsForProblems: Re-issuing request " + requestInfo);
1024       }
1025       writeRequestToChannel(requestInfo);
1026     }
1027     addedRequestIds.clear();
1028     addedRequestInfos.clear();
1029   }
1030 
1031   /**
1032    * Utility method for resolving addresses
1033    *
1034    * @param maxResolveAddressAttempts Maximum number of attempts to resolve the
1035    *        address
1036    * @param hostOrIp Known IP or host name
1037    * @param port Target port number
1038    * @return The successfully resolved address.
1039    * @throws IllegalStateException if the address is not resolved
1040    *         in <code>maxResolveAddressAttempts</code> tries.
1041    */
1042   private static InetSocketAddress resolveAddress(
1043       int maxResolveAddressAttempts, String hostOrIp, int port) {
1044     int resolveAttempts = 0;
1045     InetSocketAddress address = new InetSocketAddress(hostOrIp, port);
1046     while (address.isUnresolved() &&
1047         resolveAttempts < maxResolveAddressAttempts) {
1048       ++resolveAttempts;
1049       LOG.warn("resolveAddress: Failed to resolve " + address +
1050           " on attempt " + resolveAttempts + " of " +
1051           maxResolveAddressAttempts + " attempts, sleeping for 5 seconds");
1052       ThreadUtils.trySleep(5000);
1053       address = new InetSocketAddress(hostOrIp,
1054           address.getPort());
1055     }
1056     if (resolveAttempts >= maxResolveAddressAttempts) {
1057       throw new IllegalStateException("resolveAddress: Couldn't " +
1058           "resolve " + address + " in " +  resolveAttempts + " tries.");
1059     }
1060     return address;
1061   }
1062 
1063   public FlowControl getFlowControl() {
1064     return flowControl;
1065   }
1066 
1067   /**
1068    * Generate and get the next request id to be used for a given worker
1069    *
1070    * @param taskId id of the worker to generate the next request id
1071    * @return request id
1072    */
1073   public Long getNextRequestId(int taskId) {
1074     return taskRequestIdGenerator.getNextRequestId(taskId);
1075   }
1076 
1077   /**
1078    * @return number of open requests
1079    */
1080   public int getNumberOfOpenRequests() {
1081     return clientRequestIdRequestInfoMap.size();
1082   }
1083 
1084   /**
1085    * Resend requests related to channel which failed
1086    *
1087    * @param channel Channel which failed
1088    */
1089   private void checkRequestsAfterChannelFailure(final Channel channel) {
1090     resendRequestsWhenNeeded(new Predicate<RequestInfo>() {
1091       @Override
1092       public boolean apply(RequestInfo requestInfo) {
1093         return requestInfo.getDestinationAddress().equals(
1094             channel.remoteAddress());
1095       }
1096     });
1097   }
1098 
1099   /**
1100    * This listener class just dumps exception stack traces if
1101    * something happens.
1102    */
1103   private class LogOnErrorChannelFutureListener
1104       implements ChannelFutureListener {
1105 
1106     @Override
1107     public void operationComplete(ChannelFuture future) throws Exception {
1108       if (future.isDone() && !future.isSuccess()) {
1109         LOG.error("Request failed", future.cause());
1110         checkRequestsAfterChannelFailure(future.channel());
1111       }
1112     }
1113   }
1114 }