View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm.netty;
20  
21  import org.apache.giraph.comm.flow_control.CreditBasedFlowControl;
22  import org.apache.giraph.comm.flow_control.FlowControl;
23  import org.apache.giraph.comm.flow_control.NoOpFlowControl;
24  import org.apache.giraph.comm.flow_control.StaticFlowControl;
25  import org.apache.giraph.comm.netty.handler.AckSignalFlag;
26  import org.apache.giraph.comm.netty.handler.TaskRequestIdGenerator;
27  import org.apache.giraph.comm.netty.handler.ClientRequestId;
28  import org.apache.giraph.comm.netty.handler.RequestEncoder;
29  import org.apache.giraph.comm.netty.handler.RequestInfo;
30  import org.apache.giraph.comm.netty.handler.RequestServerHandler;
31  import org.apache.giraph.comm.netty.handler.ResponseClientHandler;
32  /*if_not[HADOOP_NON_SECURE]*/
33  import org.apache.giraph.comm.netty.handler.SaslClientHandler;
34  import org.apache.giraph.comm.requests.RequestType;
35  import org.apache.giraph.comm.requests.SaslTokenMessageRequest;
36  /*end[HADOOP_NON_SECURE]*/
37  import org.apache.giraph.comm.requests.WritableRequest;
38  import org.apache.giraph.conf.BooleanConfOption;
39  import org.apache.giraph.conf.GiraphConstants;
40  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
41  import org.apache.giraph.function.Predicate;
42  import org.apache.giraph.graph.TaskInfo;
43  import org.apache.giraph.master.MasterInfo;
44  import org.apache.giraph.utils.PipelineUtils;
45  import org.apache.giraph.utils.ProgressableUtils;
46  import org.apache.giraph.utils.ThreadUtils;
47  import org.apache.giraph.utils.TimedLogger;
48  import org.apache.hadoop.mapreduce.Mapper;
49  import org.apache.log4j.Logger;
50  
51  import com.google.common.collect.Lists;
52  import com.google.common.collect.MapMaker;
53  import com.google.common.collect.Maps;
54  
55  /*if_not[HADOOP_NON_SECURE]*/
56  import java.io.IOException;
57  /*end[HADOOP_NON_SECURE]*/
58  import java.net.InetSocketAddress;
59  import java.util.Collection;
60  import java.util.Collections;
61  import java.util.Comparator;
62  import java.util.List;
63  import java.util.Map;
64  import java.util.concurrent.ConcurrentMap;
65  import java.util.concurrent.atomic.AtomicInteger;
66  import java.util.concurrent.atomic.AtomicLong;
67  
68  import io.netty.bootstrap.Bootstrap;
69  import io.netty.channel.Channel;
70  import io.netty.channel.ChannelFuture;
71  import io.netty.channel.ChannelFutureListener;
72  import io.netty.channel.ChannelInitializer;
73  import io.netty.channel.ChannelOption;
74  import io.netty.channel.EventLoopGroup;
75  import io.netty.channel.nio.NioEventLoopGroup;
76  import io.netty.channel.socket.SocketChannel;
77  import io.netty.channel.socket.nio.NioSocketChannel;
78  import io.netty.handler.codec.FixedLengthFrameDecoder;
79  /*if_not[HADOOP_NON_SECURE]*/
80  import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
81  import io.netty.util.AttributeKey;
82  /*end[HADOOP_NON_SECURE]*/
83  import io.netty.util.concurrent.DefaultEventExecutorGroup;
84  import io.netty.util.concurrent.EventExecutorGroup;
85  
86  import static com.google.common.base.Preconditions.checkState;
87  import static org.apache.giraph.conf.GiraphConstants.CLIENT_RECEIVE_BUFFER_SIZE;
88  import static org.apache.giraph.conf.GiraphConstants.CLIENT_SEND_BUFFER_SIZE;
89  import static org.apache.giraph.conf.GiraphConstants.MAX_REQUEST_MILLISECONDS;
90  import static org.apache.giraph.conf.GiraphConstants.MAX_RESOLVE_ADDRESS_ATTEMPTS;
91  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_AFTER_HANDLER;
92  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_EXECUTION_THREADS;
93  import static org.apache.giraph.conf.GiraphConstants.NETTY_CLIENT_USE_EXECUTION_HANDLER;
94  import static org.apache.giraph.conf.GiraphConstants.NETTY_MAX_CONNECTION_FAILURES;
95  import static org.apache.giraph.conf.GiraphConstants.WAIT_TIME_BETWEEN_CONNECTION_RETRIES_MS;
96  import static org.apache.giraph.conf.GiraphConstants.WAITING_REQUEST_MSECS;
97  
98  /**
99   * Netty client for sending requests.  Thread-safe.
100  */
101 public class NettyClient {
102   /** Do we have a limit on number of open requests we can have */
103   public static final BooleanConfOption LIMIT_NUMBER_OF_OPEN_REQUESTS =
104       new BooleanConfOption("giraph.waitForRequestsConfirmation", false,
105           "Whether to have a limit on number of open requests or not");
106   /**
107    * Do we have a limit on number of open requests we can have for each worker.
108    * Note that if this option is enabled, Netty will not keep more than a
109    * certain number of requests open for each other worker in the job. If there
110    * are more requests generated for a worker, Netty will not actually send the
111    * surplus requests, instead, it caches the requests in a local buffer. The
112    * maximum number of these unsent requests in the cache is another
113    * user-defined parameter (MAX_NUM_OF_OPEN_REQUESTS_PER_WORKER).
114    */
115   public static final BooleanConfOption LIMIT_OPEN_REQUESTS_PER_WORKER =
116       new BooleanConfOption("giraph.waitForPerWorkerRequests", false,
117           "Whether to have a limit on number of open requests for each worker" +
118               "or not");
119   /** Maximum number of requests to list (for debugging) */
120   public static final int MAX_REQUESTS_TO_LIST = 10;
121   /**
122    * Maximum number of destination task ids with open requests to list
123    * (for debugging)
124    */
125   public static final int MAX_DESTINATION_TASK_IDS_TO_LIST = 10;
126   /** 30 seconds to connect by default */
127   public static final int MAX_CONNECTION_MILLISECONDS_DEFAULT = 30 * 1000;
128 /*if_not[HADOOP_NON_SECURE]*/
129   /** Used to authenticate with other workers acting as servers */
130   public static final AttributeKey<SaslNettyClient> SASL =
131       AttributeKey.valueOf("saslNettyClient");
132 /*end[HADOOP_NON_SECURE]*/
133   /** Class logger */
134   private static final Logger LOG = Logger.getLogger(NettyClient.class);
135   /** Context used to report progress */
136   private final Mapper<?, ?, ?, ?>.Context context;
137   /** Client bootstrap */
138   private final Bootstrap bootstrap;
139   /**
140    * Map of the peer connections, mapping from remote socket address to client
141    * meta data
142    */
143   private final ConcurrentMap<InetSocketAddress, ChannelRotater>
144   addressChannelMap = new MapMaker().makeMap();
145   /**
146    * Map from task id to address of its server
147    */
148   private final Map<Integer, InetSocketAddress> taskIdAddressMap =
149       new MapMaker().makeMap();
150   /**
151    * Request map of client request ids to request information.
152    */
153   private final ConcurrentMap<ClientRequestId, RequestInfo>
154   clientRequestIdRequestInfoMap;
155   /** Number of channels per server */
156   private final int channelsPerServer;
157   /** Inbound byte counter for this client */
158   private final InboundByteCounter inboundByteCounter = new
159       InboundByteCounter();
160   /** Outbound byte counter for this client */
161   private final OutboundByteCounter outboundByteCounter = new
162       OutboundByteCounter();
163   /** Send buffer size */
164   private final int sendBufferSize;
165   /** Receive buffer size */
166   private final int receiveBufferSize;
167   /** Warn if request size is bigger than the buffer size by this factor */
168   private final float requestSizeWarningThreshold;
169   /** Maximum number of connection failures */
170   private final int maxConnectionFailures;
171   /** How long to wait before trying to reconnect failed connections */
172   private final long waitTimeBetweenConnectionRetriesMs;
173   /** Maximum number of milliseconds for a request */
174   private final int maxRequestMilliseconds;
175   /** Waiting interval for checking outstanding requests msecs */
176   private final int waitingRequestMsecs;
177   /** Timed logger for printing request debugging */
178   private final TimedLogger requestLogger;
179   /** Worker executor group */
180   private final EventLoopGroup workerGroup;
181   /** Task request id generator */
182   private final TaskRequestIdGenerator taskRequestIdGenerator =
183       new TaskRequestIdGenerator();
184   /** Task info */
185   private final TaskInfo myTaskInfo;
186   /** Maximum thread pool size */
187   private final int maxPoolSize;
188   /** Maximum number of attempts to resolve an address*/
189   private final int maxResolveAddressAttempts;
190   /** Use execution handler? */
191   private final boolean useExecutionGroup;
192   /** EventExecutor Group (if used) */
193   private final EventExecutorGroup executionGroup;
194   /** Name of the handler to use execution group for (if used) */
195   private final String handlerToUseExecutionGroup;
196   /** When was the last time we checked if we should resend some requests */
197   private final AtomicLong lastTimeCheckedRequestsForProblems =
198       new AtomicLong(0);
199   /**
200    * Logger used to dump stack traces for every exception that happens
201    * in netty client threads.
202    */
203   private final LogOnErrorChannelFutureListener logErrorListener =
204       new LogOnErrorChannelFutureListener();
205   /** Flow control policy used */
206   private final FlowControl flowControl;
207 
208   /**
209    * Only constructor
210    *
211    * @param context Context for progress
212    * @param conf Configuration
213    * @param myTaskInfo Current task info
214    * @param exceptionHandler handler for uncaught exception. Will
215    *                         terminate job.
216    */
217   public NettyClient(Mapper<?, ?, ?, ?>.Context context,
218                      final ImmutableClassesGiraphConfiguration conf,
219                      TaskInfo myTaskInfo,
220                      final Thread.UncaughtExceptionHandler exceptionHandler) {
221     this.context = context;
222     this.myTaskInfo = myTaskInfo;
223     this.channelsPerServer = GiraphConstants.CHANNELS_PER_SERVER.get(conf);
224     sendBufferSize = CLIENT_SEND_BUFFER_SIZE.get(conf);
225     receiveBufferSize = CLIENT_RECEIVE_BUFFER_SIZE.get(conf);
226     this.requestSizeWarningThreshold =
227         GiraphConstants.REQUEST_SIZE_WARNING_THRESHOLD.get(conf);
228 
229     boolean limitNumberOfOpenRequests = LIMIT_NUMBER_OF_OPEN_REQUESTS.get(conf);
230     boolean limitOpenRequestsPerWorker =
231         LIMIT_OPEN_REQUESTS_PER_WORKER.get(conf);
232     checkState(!limitNumberOfOpenRequests || !limitOpenRequestsPerWorker,
233         "NettyClient: it is not allowed to have both limitations on the " +
234             "number of total open requests, and on the number of open " +
235             "requests per worker!");
236     if (limitNumberOfOpenRequests) {
237       flowControl = new StaticFlowControl(conf, this);
238     } else if (limitOpenRequestsPerWorker) {
239       flowControl = new CreditBasedFlowControl(conf, this);
240     } else {
241       flowControl = new NoOpFlowControl(this);
242     }
243 
244     maxRequestMilliseconds = MAX_REQUEST_MILLISECONDS.get(conf);
245     maxConnectionFailures = NETTY_MAX_CONNECTION_FAILURES.get(conf);
246     waitTimeBetweenConnectionRetriesMs =
247         WAIT_TIME_BETWEEN_CONNECTION_RETRIES_MS.get(conf);
248     waitingRequestMsecs = WAITING_REQUEST_MSECS.get(conf);
249     requestLogger = new TimedLogger(waitingRequestMsecs, LOG);
250     maxPoolSize = GiraphConstants.NETTY_CLIENT_THREADS.get(conf);
251     maxResolveAddressAttempts = MAX_RESOLVE_ADDRESS_ATTEMPTS.get(conf);
252 
253     clientRequestIdRequestInfoMap =
254         new MapMaker().concurrencyLevel(maxPoolSize).makeMap();
255 
256     handlerToUseExecutionGroup =
257         NETTY_CLIENT_EXECUTION_AFTER_HANDLER.get(conf);
258     useExecutionGroup = NETTY_CLIENT_USE_EXECUTION_HANDLER.get(conf);
259     if (useExecutionGroup) {
260       int executionThreads = NETTY_CLIENT_EXECUTION_THREADS.get(conf);
261       executionGroup = new DefaultEventExecutorGroup(executionThreads,
262           ThreadUtils.createThreadFactory(
263               "netty-client-exec-%d", exceptionHandler));
264       if (LOG.isInfoEnabled()) {
265         LOG.info("NettyClient: Using execution handler with " +
266             executionThreads + " threads after " +
267             handlerToUseExecutionGroup + ".");
268       }
269     } else {
270       executionGroup = null;
271     }
272 
273     workerGroup = new NioEventLoopGroup(maxPoolSize,
274         ThreadUtils.createThreadFactory(
275             "netty-client-worker-%d", exceptionHandler));
276 
277     bootstrap = new Bootstrap();
278     bootstrap.group(workerGroup)
279         .channel(NioSocketChannel.class)
280         .option(ChannelOption.CONNECT_TIMEOUT_MILLIS,
281             MAX_CONNECTION_MILLISECONDS_DEFAULT)
282         .option(ChannelOption.TCP_NODELAY, true)
283         .option(ChannelOption.SO_KEEPALIVE, true)
284         .option(ChannelOption.SO_SNDBUF, sendBufferSize)
285         .option(ChannelOption.SO_RCVBUF, receiveBufferSize)
286         .option(ChannelOption.ALLOCATOR, conf.getNettyAllocator())
287         .handler(new ChannelInitializer<SocketChannel>() {
288           @Override
289           protected void initChannel(SocketChannel ch) throws Exception {
290 /*if_not[HADOOP_NON_SECURE]*/
291             if (conf.authenticate()) {
292               LOG.info("Using Netty with authentication.");
293 
294               // Our pipeline starts with just byteCounter, and then we use
295               // addLast() to incrementally add pipeline elements, so that we
296               // can name them for identification for removal or replacement
297               // after client is authenticated by server.
298               // After authentication is complete, the pipeline's SASL-specific
299               // functionality is removed, restoring the pipeline to exactly the
300               // same configuration as it would be without authentication.
301               PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter",
302                   inboundByteCounter, handlerToUseExecutionGroup,
303                   executionGroup, ch);
304               if (conf.doCompression()) {
305                 PipelineUtils.addLastWithExecutorCheck("compressionDecoder",
306                     conf.getNettyCompressionDecoder(),
307                     handlerToUseExecutionGroup, executionGroup, ch);
308               }
309               PipelineUtils.addLastWithExecutorCheck(
310                   "clientOutboundByteCounter",
311                   outboundByteCounter, handlerToUseExecutionGroup,
312                   executionGroup, ch);
313               if (conf.doCompression()) {
314                 PipelineUtils.addLastWithExecutorCheck("compressionEncoder",
315                     conf.getNettyCompressionEncoder(),
316                     handlerToUseExecutionGroup, executionGroup, ch);
317               }
318               // The following pipeline component is needed to decode the
319               // server's SASL tokens. It is replaced with a
320               // FixedLengthFrameDecoder (same as used with the
321               // non-authenticated pipeline) after authentication
322               // completes (as in non-auth pipeline below).
323               PipelineUtils.addLastWithExecutorCheck(
324                   "length-field-based-frame-decoder",
325                   new LengthFieldBasedFrameDecoder(1024, 0, 4, 0, 4),
326                   handlerToUseExecutionGroup, executionGroup, ch);
327               PipelineUtils.addLastWithExecutorCheck("request-encoder",
328                   new RequestEncoder(conf), handlerToUseExecutionGroup,
329                   executionGroup, ch);
330               // The following pipeline component responds to the server's SASL
331               // tokens with its own responses. Both client and server share the
332               // same Hadoop Job token, which is used to create the SASL
333               // tokens to authenticate with each other.
334               // After authentication finishes, this pipeline component
335               // is removed.
336               PipelineUtils.addLastWithExecutorCheck("sasl-client-handler",
337                   new SaslClientHandler(conf), handlerToUseExecutionGroup,
338                   executionGroup, ch);
339               PipelineUtils.addLastWithExecutorCheck("response-handler",
340                   new ResponseClientHandler(NettyClient.this, conf),
341                   handlerToUseExecutionGroup, executionGroup, ch);
342             } else {
343               LOG.info("Using Netty without authentication.");
344 /*end[HADOOP_NON_SECURE]*/
345               PipelineUtils.addLastWithExecutorCheck("clientInboundByteCounter",
346                   inboundByteCounter, handlerToUseExecutionGroup,
347                   executionGroup, ch);
348               if (conf.doCompression()) {
349                 PipelineUtils.addLastWithExecutorCheck("compressionDecoder",
350                     conf.getNettyCompressionDecoder(),
351                     handlerToUseExecutionGroup, executionGroup, ch);
352               }
353               PipelineUtils.addLastWithExecutorCheck(
354                   "clientOutboundByteCounter",
355                   outboundByteCounter, handlerToUseExecutionGroup,
356                   executionGroup, ch);
357               if (conf.doCompression()) {
358                 PipelineUtils.addLastWithExecutorCheck("compressionEncoder",
359                     conf.getNettyCompressionEncoder(),
360                     handlerToUseExecutionGroup, executionGroup, ch);
361               }
362               PipelineUtils.addLastWithExecutorCheck(
363                   "fixed-length-frame-decoder",
364                   new FixedLengthFrameDecoder(
365                       RequestServerHandler.RESPONSE_BYTES),
366                  handlerToUseExecutionGroup, executionGroup, ch);
367               PipelineUtils.addLastWithExecutorCheck("request-encoder",
368                     new RequestEncoder(conf), handlerToUseExecutionGroup,
369                   executionGroup, ch);
370               PipelineUtils.addLastWithExecutorCheck("response-handler",
371                   new ResponseClientHandler(NettyClient.this, conf),
372                   handlerToUseExecutionGroup, executionGroup, ch);
373 
374 /*if_not[HADOOP_NON_SECURE]*/
375             }
376 /*end[HADOOP_NON_SECURE]*/
377           }
378         });
379   }
380 
381   /**
382    * Whether master task is involved in the communication with a given client
383    *
384    * @param clientId id of the communication (on the end of the communication)
385    * @return true if master is on one end of the communication
386    */
387   public boolean masterInvolved(int clientId) {
388     return myTaskInfo.getTaskId() == MasterInfo.MASTER_TASK_ID ||
389         clientId == MasterInfo.MASTER_TASK_ID;
390   }
391 
392   /**
393    * Pair object for connectAllAddresses().
394    */
395   private static class ChannelFutureAddress {
396     /** Future object */
397     private final ChannelFuture future;
398     /** Address of the future */
399     private final InetSocketAddress address;
400     /** Task id */
401     private final Integer taskId;
402 
403     /**
404      * Constructor.
405      *
406      * @param future Immutable future
407      * @param address Immutable address
408      * @param taskId Immutable taskId
409      */
410     ChannelFutureAddress(
411         ChannelFuture future, InetSocketAddress address, Integer taskId) {
412       this.future = future;
413       this.address = address;
414       this.taskId = taskId;
415     }
416 
417     @Override
418     public String toString() {
419       return "(future=" + future + ",address=" + address + ",taskId=" +
420           taskId + ")";
421     }
422   }
423 
424   /**
425    * Connect to a collection of tasks servers
426    *
427    * @param tasks Tasks to connect to (if haven't already connected)
428    */
429   public void connectAllAddresses(Collection<? extends TaskInfo> tasks) {
430     List<ChannelFutureAddress> waitingConnectionList =
431         Lists.newArrayListWithCapacity(tasks.size() * channelsPerServer);
432     for (TaskInfo taskInfo : tasks) {
433       context.progress();
434       int taskId = taskInfo.getTaskId();
435       InetSocketAddress address = taskIdAddressMap.get(taskId);
436       if (address == null ||
437           !address.getHostName().equals(taskInfo.getHostname()) ||
438           address.getPort() != taskInfo.getPort()) {
439         address = resolveAddress(maxResolveAddressAttempts,
440             taskInfo.getHostOrIp(), taskInfo.getPort());
441         taskIdAddressMap.put(taskId, address);
442       }
443       if (address == null || address.getHostName() == null ||
444           address.getHostName().isEmpty()) {
445         throw new IllegalStateException("connectAllAddresses: Null address " +
446             "in addresses " + tasks);
447       }
448       if (address.isUnresolved()) {
449         throw new IllegalStateException("connectAllAddresses: Unresolved " +
450             "address " + address);
451       }
452 
453       if (addressChannelMap.containsKey(address)) {
454         continue;
455       }
456 
457       // Start connecting to the remote server up to n time
458       for (int i = 0; i < channelsPerServer; ++i) {
459         ChannelFuture connectionFuture = bootstrap.connect(address);
460 
461         waitingConnectionList.add(
462             new ChannelFutureAddress(
463                 connectionFuture, address, taskId));
464       }
465     }
466 
467     // Wait for all the connections to succeed up to n tries
468     int failures = 0;
469     int connected = 0;
470     while (failures < maxConnectionFailures) {
471       List<ChannelFutureAddress> nextCheckFutures = Lists.newArrayList();
472       boolean isFirstFailure = true;
473       for (ChannelFutureAddress waitingConnection : waitingConnectionList) {
474         context.progress();
475         ChannelFuture future = waitingConnection.future;
476         ProgressableUtils.awaitChannelFuture(future, context);
477         if (!future.isSuccess() || !future.channel().isOpen()) {
478           // Make a short pause before trying to reconnect failed addresses
479           // again, but to do it just once per iterating through channels
480           if (isFirstFailure) {
481             isFirstFailure = false;
482             try {
483               Thread.sleep(waitTimeBetweenConnectionRetriesMs);
484             } catch (InterruptedException e) {
485               throw new IllegalStateException(
486                   "connectAllAddresses: InterruptedException occurred", e);
487             }
488           }
489 
490           LOG.warn("connectAllAddresses: Future failed " +
491               "to connect with " + waitingConnection.address + " with " +
492               failures + " failures because of " + future.cause());
493 
494           ChannelFuture connectionFuture =
495               bootstrap.connect(waitingConnection.address);
496           nextCheckFutures.add(new ChannelFutureAddress(connectionFuture,
497               waitingConnection.address, waitingConnection.taskId));
498           ++failures;
499         } else {
500           Channel channel = future.channel();
501           if (LOG.isDebugEnabled()) {
502             LOG.debug("connectAllAddresses: Connected to " +
503                 channel.remoteAddress() + ", open = " + channel.isOpen());
504           }
505 
506           if (channel.remoteAddress() == null) {
507             throw new IllegalStateException(
508                 "connectAllAddresses: Null remote address!");
509           }
510 
511           ChannelRotater rotater =
512               addressChannelMap.get(waitingConnection.address);
513           if (rotater == null) {
514             ChannelRotater newRotater =
515                 new ChannelRotater(waitingConnection.taskId);
516             rotater = addressChannelMap.putIfAbsent(
517                 waitingConnection.address, newRotater);
518             if (rotater == null) {
519               rotater = newRotater;
520             }
521           }
522           rotater.addChannel(future.channel());
523           ++connected;
524         }
525       }
526       LOG.info("connectAllAddresses: Successfully added " +
527           (waitingConnectionList.size() - nextCheckFutures.size()) +
528           " connections, (" + connected + " total connected) " +
529           nextCheckFutures.size() + " failed, " +
530           failures + " failures total.");
531       if (nextCheckFutures.isEmpty()) {
532         break;
533       }
534       waitingConnectionList = nextCheckFutures;
535     }
536     if (failures >= maxConnectionFailures) {
537       throw new IllegalStateException(
538           "connectAllAddresses: Too many failures (" + failures + ").");
539     }
540   }
541 
542 /*if_not[HADOOP_NON_SECURE]*/
543   /**
544    * Authenticate all servers in addressChannelMap.
545    */
546   public void authenticate() {
547     LOG.info("authenticate: NettyClient starting authentication with " +
548         "servers.");
549     for (InetSocketAddress address: addressChannelMap.keySet()) {
550       if (LOG.isDebugEnabled()) {
551         LOG.debug("authenticate: Authenticating with address:" + address);
552       }
553       ChannelRotater channelRotater = addressChannelMap.get(address);
554       for (Channel channel: channelRotater.getChannels()) {
555         if (LOG.isDebugEnabled()) {
556           LOG.debug("authenticate: Authenticating with server on channel: " +
557               channel);
558         }
559         authenticateOnChannel(channelRotater.getTaskId(), channel);
560       }
561     }
562     if (LOG.isInfoEnabled()) {
563       LOG.info("authenticate: NettyClient successfully authenticated with " +
564           addressChannelMap.size() + " server" +
565           ((addressChannelMap.size() != 1) ? "s" : "") +
566           " - continuing with normal work.");
567     }
568   }
569 
570   /**
571    * Authenticate with server connected at given channel.
572    *
573    * @param taskId Task id of the channel
574    * @param channel Connection to server to authenticate with.
575    */
576   private void authenticateOnChannel(Integer taskId, Channel channel) {
577     try {
578       SaslNettyClient saslNettyClient = channel.attr(SASL).get();
579       if (channel.attr(SASL).get() == null) {
580         if (LOG.isDebugEnabled()) {
581           LOG.debug("authenticateOnChannel: Creating saslNettyClient now " +
582               "for channel: " + channel);
583         }
584         saslNettyClient = new SaslNettyClient();
585         channel.attr(SASL).set(saslNettyClient);
586       }
587       if (!saslNettyClient.isComplete()) {
588         if (LOG.isDebugEnabled()) {
589           LOG.debug("authenticateOnChannel: Waiting for authentication " +
590               "to complete..");
591         }
592         SaslTokenMessageRequest saslTokenMessage = saslNettyClient.firstToken();
593         sendWritableRequest(taskId, saslTokenMessage);
594         // We now wait for Netty's thread pool to communicate over this
595         // channel to authenticate with another worker acting as a server.
596         try {
597           synchronized (saslNettyClient.getAuthenticated()) {
598             while (!saslNettyClient.isComplete()) {
599               saslNettyClient.getAuthenticated().wait();
600             }
601           }
602         } catch (InterruptedException e) {
603           LOG.error("authenticateOnChannel: Interrupted while waiting for " +
604               "authentication.");
605         }
606       }
607       if (LOG.isDebugEnabled()) {
608         LOG.debug("authenticateOnChannel: Authentication on channel: " +
609             channel + " has completed successfully.");
610       }
611     } catch (IOException e) {
612       LOG.error("authenticateOnChannel: Failed to authenticate with server " +
613           "due to error: " + e);
614     }
615     return;
616   }
617 /*end[HADOOP_NON_SECURE]*/
618 
619   /**
620    * Stop the client.
621    */
622   public void stop() {
623     if (LOG.isInfoEnabled()) {
624       LOG.info("stop: Halting netty client");
625     }
626     flowControl.shutdown();
627     // Close connections asynchronously, in a Netty-approved
628     // way, without cleaning up thread pools until all channels
629     // in addressChannelMap are closed (success or failure)
630     int channelCount = 0;
631     for (ChannelRotater channelRotater : addressChannelMap.values()) {
632       channelCount += channelRotater.size();
633     }
634     final int done = channelCount;
635     final AtomicInteger count = new AtomicInteger(0);
636     for (ChannelRotater channelRotater : addressChannelMap.values()) {
637       channelRotater.closeChannels(new ChannelFutureListener() {
638         @Override
639         public void operationComplete(ChannelFuture cf) {
640           context.progress();
641           if (count.incrementAndGet() == done) {
642             if (LOG.isInfoEnabled()) {
643               LOG.info("stop: reached wait threshold, " +
644                   done + " connections closed, releasing " +
645                   "resources now.");
646             }
647             workerGroup.shutdownGracefully();
648             if (executionGroup != null) {
649               executionGroup.shutdownGracefully();
650             }
651           }
652         }
653       });
654     }
655     ProgressableUtils.awaitTerminationFuture(workerGroup, context);
656     if (executionGroup != null) {
657       ProgressableUtils.awaitTerminationFuture(executionGroup, context);
658     }
659     if (LOG.isInfoEnabled()) {
660       LOG.info("stop: Netty client halted");
661     }
662   }
663 
664   /**
665    * Get the next available channel, reconnecting if necessary
666    *
667    * @param remoteServer Remote server to get a channel for
668    * @return Available channel for this remote server
669    */
670   private Channel getNextChannel(InetSocketAddress remoteServer) {
671     Channel channel = addressChannelMap.get(remoteServer).nextChannel();
672     if (channel == null) {
673       throw new IllegalStateException(
674           "getNextChannel: No channel exists for " + remoteServer);
675     }
676 
677     // Return this channel if it is connected
678     if (channel.isActive()) {
679       return channel;
680     }
681 
682     // Get rid of the failed channel
683     if (addressChannelMap.get(remoteServer).removeChannel(channel)) {
684       LOG.warn("getNextChannel: Unlikely event that the channel " +
685           channel + " was already removed!");
686     }
687     if (LOG.isInfoEnabled()) {
688       LOG.info("getNextChannel: Fixing disconnected channel to " +
689           remoteServer + ", open = " + channel.isOpen() + ", " +
690           "bound = " + channel.isRegistered());
691     }
692     int reconnectFailures = 0;
693     while (reconnectFailures < maxConnectionFailures) {
694       ChannelFuture connectionFuture = bootstrap.connect(remoteServer);
695       ProgressableUtils.awaitChannelFuture(connectionFuture, context);
696       if (connectionFuture.isSuccess()) {
697         if (LOG.isInfoEnabled()) {
698           LOG.info("getNextChannel: Connected to " + remoteServer + "!");
699         }
700         addressChannelMap.get(remoteServer).addChannel(
701             connectionFuture.channel());
702         return connectionFuture.channel();
703       }
704       ++reconnectFailures;
705       LOG.warn("getNextChannel: Failed to reconnect to " +  remoteServer +
706           " on attempt " + reconnectFailures + " out of " +
707           maxConnectionFailures + " max attempts, sleeping for 5 secs",
708           connectionFuture.cause());
709       try {
710         Thread.sleep(5000);
711       } catch (InterruptedException e) {
712         LOG.warn("getNextChannel: Unexpected interrupted exception", e);
713       }
714     }
715     throw new IllegalStateException("getNextChannel: Failed to connect " +
716         "to " + remoteServer + " in " + reconnectFailures +
717         " connect attempts");
718   }
719 
720   /**
721    * Send a request to a remote server honoring the flow control mechanism
722    * (should be already connected)
723    *
724    * @param destTaskId Destination task id
725    * @param request Request to send
726    */
727   public void sendWritableRequest(int destTaskId, WritableRequest request) {
728     flowControl.sendRequest(destTaskId, request);
729   }
730 
731   /**
732    * Actual send of a request.
733    *
734    * @param destTaskId destination to send the request to
735    * @param request request itself
736    * @return request id generated for sending the request
737    */
738   public Long doSend(int destTaskId, WritableRequest request) {
739     InetSocketAddress remoteServer = taskIdAddressMap.get(destTaskId);
740     if (clientRequestIdRequestInfoMap.isEmpty()) {
741       inboundByteCounter.resetAll();
742       outboundByteCounter.resetAll();
743     }
744     boolean registerRequest = true;
745     Long requestId = null;
746 /*if_not[HADOOP_NON_SECURE]*/
747     if (request.getType() == RequestType.SASL_TOKEN_MESSAGE_REQUEST) {
748       registerRequest = false;
749     }
750 /*end[HADOOP_NON_SECURE]*/
751 
752     Channel channel = getNextChannel(remoteServer);
753     RequestInfo newRequestInfo = new RequestInfo(remoteServer, request);
754     if (registerRequest) {
755       request.setClientId(myTaskInfo.getTaskId());
756       requestId = taskRequestIdGenerator.getNextRequestId(destTaskId);
757       request.setRequestId(requestId);
758       ClientRequestId clientRequestId =
759         new ClientRequestId(destTaskId, request.getRequestId());
760       RequestInfo oldRequestInfo = clientRequestIdRequestInfoMap.putIfAbsent(
761         clientRequestId, newRequestInfo);
762       if (oldRequestInfo != null) {
763         throw new IllegalStateException("sendWritableRequest: Impossible to " +
764           "have a previous request id = " + request.getRequestId() + ", " +
765           "request info of " + oldRequestInfo);
766       }
767     }
768     if (request.getSerializedSize() >
769         requestSizeWarningThreshold * sendBufferSize) {
770       LOG.warn("Creating large request of type " + request.getClass() +
771         ", size " + request.getSerializedSize() +
772         " bytes. Check netty buffer size.");
773     }
774     ChannelFuture writeFuture = channel.write(request);
775     newRequestInfo.setWriteFuture(writeFuture);
776     writeFuture.addListener(logErrorListener);
777     return requestId;
778   }
779 
780   /**
781    * Handle receipt of a message. Called by response handler.
782    *
783    * @param senderId Id of sender of the message
784    * @param requestId Id of the request
785    * @param response Actual response
786    * @param shouldDrop Drop the message?
787    */
788   public void messageReceived(int senderId, long requestId, int response,
789       boolean shouldDrop) {
790     if (shouldDrop) {
791       synchronized (clientRequestIdRequestInfoMap) {
792         clientRequestIdRequestInfoMap.notifyAll();
793       }
794       return;
795     }
796     AckSignalFlag responseFlag = flowControl.getAckSignalFlag(response);
797     if (responseFlag == AckSignalFlag.DUPLICATE_REQUEST) {
798       LOG.info("messageReceived: Already completed request (taskId = " +
799           senderId + ", requestId = " + requestId + ")");
800     } else if (responseFlag != AckSignalFlag.NEW_REQUEST) {
801       throw new IllegalStateException(
802           "messageReceived: Got illegal response " + response);
803     }
804     RequestInfo requestInfo = clientRequestIdRequestInfoMap
805         .remove(new ClientRequestId(senderId, requestId));
806     if (requestInfo == null) {
807       LOG.info("messageReceived: Already received response for (taskId = " +
808           senderId + ", requestId = " + requestId + ")");
809     } else {
810       if (LOG.isDebugEnabled()) {
811         LOG.debug("messageReceived: Completed (taskId = " + senderId + ")" +
812             requestInfo + ".  Waiting on " +
813             clientRequestIdRequestInfoMap.size() + " requests");
814       }
815       flowControl.messageAckReceived(senderId, requestId, response);
816       // Help #waitAllRequests() to finish faster
817       synchronized (clientRequestIdRequestInfoMap) {
818         clientRequestIdRequestInfoMap.notifyAll();
819       }
820     }
821   }
822 
823   /**
824    * Ensure all the request sent so far are complete. Periodically check the
825    * state of current open requests. If there is an issue in any of them,
826    * re-send the request.
827    */
828   public void waitAllRequests() {
829     flowControl.waitAllRequests();
830     checkState(flowControl.getNumberOfUnsentRequests() == 0);
831     while (clientRequestIdRequestInfoMap.size() > 0) {
832       // Wait for requests to complete for some time
833       synchronized (clientRequestIdRequestInfoMap) {
834         if (clientRequestIdRequestInfoMap.size() == 0) {
835           break;
836         }
837         try {
838           clientRequestIdRequestInfoMap.wait(waitingRequestMsecs);
839         } catch (InterruptedException e) {
840           throw new IllegalStateException("waitAllRequests: Got unexpected " +
841               "InterruptedException", e);
842         }
843       }
844       logAndSanityCheck();
845     }
846     if (LOG.isInfoEnabled()) {
847       LOG.info("waitAllRequests: Finished all requests. " +
848           inboundByteCounter.getMetrics() + "\n" + outboundByteCounter
849           .getMetrics());
850     }
851   }
852 
853   /**
854    * Log information about the requests and check for problems in requests
855    */
856   public void logAndSanityCheck() {
857     logInfoAboutOpenRequests();
858     // Make sure that waiting doesn't kill the job
859     context.progress();
860     checkRequestsForProblems();
861   }
862 
863   /**
864    * Log the status of open requests.
865    */
866   private void logInfoAboutOpenRequests() {
867     if (LOG.isInfoEnabled() && requestLogger.isPrintable()) {
868       LOG.info("logInfoAboutOpenRequests: Waiting interval of " +
869           waitingRequestMsecs + " msecs, " +
870           clientRequestIdRequestInfoMap.size() +
871           " open requests, " + inboundByteCounter.getMetrics() + "\n" +
872           outboundByteCounter.getMetrics());
873 
874       if (clientRequestIdRequestInfoMap.size() < MAX_REQUESTS_TO_LIST) {
875         for (Map.Entry<ClientRequestId, RequestInfo> entry :
876             clientRequestIdRequestInfoMap.entrySet()) {
877           LOG.info("logInfoAboutOpenRequests: Waiting for request " +
878               entry.getKey() + " - " + entry.getValue());
879         }
880       }
881 
882       // Count how many open requests each task has
883       Map<Integer, Integer> openRequestCounts = Maps.newHashMap();
884       for (ClientRequestId clientRequestId :
885           clientRequestIdRequestInfoMap.keySet()) {
886         int taskId = clientRequestId.getDestinationTaskId();
887         Integer currentCount = openRequestCounts.get(taskId);
888         openRequestCounts.put(taskId,
889             (currentCount == null ? 0 : currentCount) + 1);
890       }
891       // Sort it in decreasing order of number of open requests
892       List<Map.Entry<Integer, Integer>> sorted =
893           Lists.newArrayList(openRequestCounts.entrySet());
894       Collections.sort(sorted, new Comparator<Map.Entry<Integer, Integer>>() {
895         @Override
896         public int compare(Map.Entry<Integer, Integer> entry1,
897             Map.Entry<Integer, Integer> entry2) {
898           int value1 = entry1.getValue();
899           int value2 = entry2.getValue();
900           return (value1 < value2) ? 1 : ((value1 == value2) ? 0 : -1);
901         }
902       });
903       // Print task ids which have the most open requests
904       StringBuilder message = new StringBuilder();
905       message.append("logInfoAboutOpenRequests: ");
906       int itemsToPrint =
907           Math.min(MAX_DESTINATION_TASK_IDS_TO_LIST, sorted.size());
908       for (int i = 0; i < itemsToPrint; i++) {
909         message.append(sorted.get(i).getValue())
910             .append(" requests for taskId=")
911             .append(sorted.get(i).getKey())
912             .append(", ");
913       }
914       LOG.info(message);
915       flowControl.logInfo();
916     }
917   }
918 
919   /**
920    * Check if there are some open requests which have been sent a long time
921    * ago, and if so resend them.
922    */
923   private void checkRequestsForProblems() {
924     long lastTimeChecked = lastTimeCheckedRequestsForProblems.get();
925     // If not enough time passed from the previous check, return
926     if (System.currentTimeMillis() < lastTimeChecked + waitingRequestMsecs) {
927       return;
928     }
929     // If another thread did the check already, return
930     if (!lastTimeCheckedRequestsForProblems.compareAndSet(lastTimeChecked,
931         System.currentTimeMillis())) {
932       return;
933     }
934     resendRequestsWhenNeeded(new Predicate<RequestInfo>() {
935       @Override
936       public boolean apply(RequestInfo requestInfo) {
937         ChannelFuture writeFuture = requestInfo.getWriteFuture();
938         // If not connected anymore, request failed, or the request is taking
939         // too long, re-establish and resend
940         return !writeFuture.channel().isActive() ||
941           (writeFuture.isDone() && !writeFuture.isSuccess()) ||
942           (requestInfo.getElapsedMsecs() > maxRequestMilliseconds);
943       }
944     });
945   }
946 
947   /**
948    * Resend requests which satisfy predicate
949    *
950    * @param shouldResendRequestPredicate Predicate to use to check whether
951    *                                     request should be resent
952    */
953   private void resendRequestsWhenNeeded(
954       Predicate<RequestInfo> shouldResendRequestPredicate) {
955     // Check if there are open requests which have been sent a long time ago,
956     // and if so, resend them.
957     List<ClientRequestId> addedRequestIds = Lists.newArrayList();
958     List<RequestInfo> addedRequestInfos = Lists.newArrayList();
959     // Check all the requests for problems
960     for (Map.Entry<ClientRequestId, RequestInfo> entry :
961         clientRequestIdRequestInfoMap.entrySet()) {
962       RequestInfo requestInfo = entry.getValue();
963       ChannelFuture writeFuture = requestInfo.getWriteFuture();
964       // Request wasn't sent yet
965       if (writeFuture == null) {
966         continue;
967       }
968       // If request should be resent
969       if (shouldResendRequestPredicate.apply(requestInfo)) {
970         LOG.warn("checkRequestsForProblems: Problem with request id " +
971             entry.getKey() + " connected = " +
972             writeFuture.channel().isActive() +
973             ", future done = " + writeFuture.isDone() + ", " +
974             "success = " + writeFuture.isSuccess() + ", " +
975             "cause = " + writeFuture.cause() + ", " +
976             "elapsed time = " + requestInfo.getElapsedMsecs() + ", " +
977             "destination = " + writeFuture.channel().remoteAddress() +
978             " " + requestInfo);
979         addedRequestIds.add(entry.getKey());
980         addedRequestInfos.add(new RequestInfo(
981             requestInfo.getDestinationAddress(), requestInfo.getRequest()));
982       }
983     }
984 
985     // Add any new requests to the system, connect if necessary, and re-send
986     for (int i = 0; i < addedRequestIds.size(); ++i) {
987       ClientRequestId requestId = addedRequestIds.get(i);
988       RequestInfo requestInfo = addedRequestInfos.get(i);
989 
990       if (clientRequestIdRequestInfoMap.put(requestId, requestInfo) == null) {
991         LOG.warn("checkRequestsForProblems: Request " + requestId +
992             " completed prior to sending the next request");
993         clientRequestIdRequestInfoMap.remove(requestId);
994       }
995       InetSocketAddress remoteServer = requestInfo.getDestinationAddress();
996       Channel channel = getNextChannel(remoteServer);
997       if (LOG.isInfoEnabled()) {
998         LOG.info("checkRequestsForProblems: Re-issuing request " + requestInfo);
999       }
1000       ChannelFuture writeFuture = channel.write(requestInfo.getRequest());
1001       requestInfo.setWriteFuture(writeFuture);
1002       writeFuture.addListener(logErrorListener);
1003     }
1004     addedRequestIds.clear();
1005     addedRequestInfos.clear();
1006   }
1007 
1008   /**
1009    * Utility method for resolving addresses
1010    *
1011    * @param maxResolveAddressAttempts Maximum number of attempts to resolve the
1012    *        address
1013    * @param hostOrIp Known IP or host name
1014    * @param port Target port number
1015    * @return The successfully resolved address.
1016    * @throws IllegalStateException if the address is not resolved
1017    *         in <code>maxResolveAddressAttempts</code> tries.
1018    */
1019   private static InetSocketAddress resolveAddress(
1020       int maxResolveAddressAttempts, String hostOrIp, int port) {
1021     int resolveAttempts = 0;
1022     InetSocketAddress address = new InetSocketAddress(hostOrIp, port);
1023     while (address.isUnresolved() &&
1024         resolveAttempts < maxResolveAddressAttempts) {
1025       ++resolveAttempts;
1026       LOG.warn("resolveAddress: Failed to resolve " + address +
1027           " on attempt " + resolveAttempts + " of " +
1028           maxResolveAddressAttempts + " attempts, sleeping for 5 seconds");
1029       try {
1030         Thread.sleep(5000);
1031       } catch (InterruptedException e) {
1032         LOG.warn("resolveAddress: Interrupted.", e);
1033       }
1034       address = new InetSocketAddress(hostOrIp,
1035           address.getPort());
1036     }
1037     if (resolveAttempts >= maxResolveAddressAttempts) {
1038       throw new IllegalStateException("resolveAddress: Couldn't " +
1039           "resolve " + address + " in " +  resolveAttempts + " tries.");
1040     }
1041     return address;
1042   }
1043 
1044   public FlowControl getFlowControl() {
1045     return flowControl;
1046   }
1047 
1048   /**
1049    * Generate and get the next request id to be used for a given worker
1050    *
1051    * @param taskId id of the worker to generate the next request id
1052    * @return request id
1053    */
1054   public Long getNextRequestId(int taskId) {
1055     return taskRequestIdGenerator.getNextRequestId(taskId);
1056   }
1057 
1058   /**
1059    * @return number of open requests
1060    */
1061   public int getNumberOfOpenRequests() {
1062     return clientRequestIdRequestInfoMap.size();
1063   }
1064 
1065   /**
1066    * Resend requests related to channel which failed
1067    *
1068    * @param future ChannelFuture of the failed channel
1069    */
1070   private void checkRequestsAfterChannelFailure(final ChannelFuture future) {
1071     resendRequestsWhenNeeded(new Predicate<RequestInfo>() {
1072       @Override
1073       public boolean apply(RequestInfo requestInfo) {
1074         return requestInfo.getWriteFuture() == future;
1075       }
1076     });
1077   }
1078 
1079   /**
1080    * This listener class just dumps exception stack traces if
1081    * something happens.
1082    */
1083   private class LogOnErrorChannelFutureListener
1084       implements ChannelFutureListener {
1085 
1086     @Override
1087     public void operationComplete(ChannelFuture future) throws Exception {
1088       if (future.isDone() && !future.isSuccess()) {
1089         LOG.error("Request failed", future.cause());
1090         checkRequestsAfterChannelFailure(future);
1091       }
1092     }
1093   }
1094 }