View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm;
20  
21  import java.io.IOException;
22  import java.util.Arrays;
23  import java.util.Iterator;
24  
25  import javax.annotation.concurrent.NotThreadSafe;
26  
27  import org.apache.giraph.bsp.CentralizedServiceWorker;
28  import org.apache.giraph.comm.netty.NettyWorkerClientRequestProcessor;
29  import org.apache.giraph.comm.requests.SendWorkerMessagesRequest;
30  import org.apache.giraph.comm.requests.SendWorkerOneMessageToManyRequest;
31  import org.apache.giraph.comm.requests.WritableRequest;
32  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
33  import org.apache.giraph.partition.PartitionOwner;
34  import org.apache.giraph.utils.ByteArrayOneMessageToManyIds;
35  import org.apache.giraph.utils.ExtendedDataOutput;
36  import org.apache.giraph.utils.PairList;
37  import org.apache.giraph.utils.VertexIdMessages;
38  import org.apache.giraph.worker.WorkerInfo;
39  import org.apache.hadoop.io.Writable;
40  import org.apache.hadoop.io.WritableComparable;
41  import org.apache.log4j.Logger;
42  
43  /**
44   * Aggregates the messages to be sent to workers so they can be sent
45   * in bulk.
46   *
47   * @param <I> Vertex id
48   * @param <M> Message data
49   */
50  @NotThreadSafe
51  @SuppressWarnings("unchecked")
52  public class SendOneMessageToManyCache<I extends WritableComparable,
53    M extends Writable> extends SendMessageCache<I, M> {
54    /** Class logger */
55    private static final Logger LOG =
56        Logger.getLogger(SendOneMessageToManyCache.class);
57    /** Cache serialized one to many messages for each worker */
58    private final ByteArrayOneMessageToManyIds<I, M>[] msgVidsCache;
59    /** Tracking message-vertexIds sizes for each worker */
60    private final int[] msgVidsSizes;
61    /** Reused byte array to serialize target ids on each worker */
62    private final ExtendedDataOutput[] idSerializer;
63    /** Reused int array to count target id distribution */
64    private final int[] idCounter;
65    /**
66     * Reused int array to record the partition id
67     * of the first target vertex id found on the worker.
68     */
69    private final int[] firstPartitionMap;
70    /** The WorkerInfo list */
71    private final WorkerInfo[] workerInfoList;
72  
73    /**
74     * Constructor
75     *
76     * @param conf Giraph configuration
77     * @param serviceWorker Service worker
78     * @param processor NettyWorkerClientRequestProcessor
79     * @param maxMsgSize Max message size sent to a worker
80     */
81    public SendOneMessageToManyCache(ImmutableClassesGiraphConfiguration conf,
82      CentralizedServiceWorker<?, ?, ?> serviceWorker,
83      NettyWorkerClientRequestProcessor<I, ?, ?> processor,
84      int maxMsgSize) {
85      super(conf, serviceWorker, processor, maxMsgSize);
86      int numWorkers = getNumWorkers();
87      msgVidsCache = new ByteArrayOneMessageToManyIds[numWorkers];
88      msgVidsSizes = new int[numWorkers];
89      idSerializer = new ExtendedDataOutput[numWorkers];
90      // InitialBufferSizes is alo initialized based on the number of workers.
91      // As a result, initialBufferSizes is the same as idSerializer in length
92      int initialBufferSize = 0;
93      for (int i = 0; i < this.idSerializer.length; i++) {
94        initialBufferSize = getSendWorkerInitialBufferSize(i);
95        if (initialBufferSize > 0) {
96          // InitialBufferSizes is from super class.
97          // Each element is for one worker.
98          idSerializer[i] = conf.createExtendedDataOutput(initialBufferSize);
99        }
100     }
101     idCounter = new int[numWorkers];
102     firstPartitionMap = new int[numWorkers];
103     // Get worker info list.
104     workerInfoList = new WorkerInfo[numWorkers];
105     // Remember there could be null in the array.
106     for (WorkerInfo workerInfo : serviceWorker.getWorkerInfoList()) {
107       workerInfoList[workerInfo.getTaskId()] = workerInfo;
108     }
109   }
110 
111   /**
112    * Reset ExtendedDataOutput array for id serialization
113    * in next message-Vids encoding
114    */
115   private void resetIdSerializers() {
116     for (int i = 0; i < this.idSerializer.length; i++) {
117       if (idSerializer[i] != null) {
118         idSerializer[i].reset();
119       }
120     }
121   }
122 
123   /**
124    * Reset id counter for next message-vertexIds encoding
125    */
126   private void resetIdCounter() {
127     Arrays.fill(idCounter, 0);
128   }
129 
130   /**
131    * Add message with multiple target ids to message cache.
132    *
133    * @param workerInfo The remote worker destination
134    * @param ids A byte array to hold serialized vertex ids
135    * @param idPos The end position of ids
136    *              information in the byte array above
137    * @param count The number of target ids
138    * @param message Message to send to remote worker
139    * @return The size of messages for the worker.
140    */
141   private int addOneToManyMessage(
142     WorkerInfo workerInfo, byte[] ids, int idPos, int count, M message) {
143     // Get the data collection
144     ByteArrayOneMessageToManyIds<I, M> workerData =
145       msgVidsCache[workerInfo.getTaskId()];
146     if (workerData == null) {
147       workerData = new ByteArrayOneMessageToManyIds<I, M>(
148           getConf().<M>createOutgoingMessageValueFactory());
149       workerData.setConf(getConf());
150       workerData.initialize(getSendWorkerInitialBufferSize(
151         workerInfo.getTaskId()));
152       msgVidsCache[workerInfo.getTaskId()] = workerData;
153     }
154     workerData.add(ids, idPos, count, message);
155     // Update the size of cached, outgoing data per worker
156     msgVidsSizes[workerInfo.getTaskId()] =
157       workerData.getSize();
158     return msgVidsSizes[workerInfo.getTaskId()];
159   }
160 
161   /**
162    * Gets the messages + vertexIds for a worker and removes it from the cache.
163    * Here the {@link org.apache.giraph.utils.ByteArrayOneMessageToManyIds}
164    * returned could be null.But when invoking this method, we also check if
165    * the data size sent to this worker is above the threshold.
166    * Therefore, it doesn't matter if the result is null or not.
167    *
168    * @param workerInfo Target worker to which one messages - many ids are sent
169    * @return {@link org.apache.giraph.utils.ByteArrayOneMessageToManyIds}
170    *         that belong to the workerInfo
171    */
172   private ByteArrayOneMessageToManyIds<I, M>
173   removeWorkerMsgVids(WorkerInfo workerInfo) {
174     ByteArrayOneMessageToManyIds<I, M> workerData =
175       msgVidsCache[workerInfo.getTaskId()];
176     if (workerData != null) {
177       msgVidsCache[workerInfo.getTaskId()] = null;
178       msgVidsSizes[workerInfo.getTaskId()] = 0;
179     }
180     return workerData;
181   }
182 
183   /**
184    * Gets all messages - vertexIds and removes them from the cache.
185    *
186    * @return All vertex messages for all workers
187    */
188   private PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
189   removeAllMsgVids() {
190     PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>> allData =
191       new PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>();
192     allData.initialize(msgVidsCache.length);
193     for (WorkerInfo workerInfo : getWorkerPartitions().keySet()) {
194       ByteArrayOneMessageToManyIds<I, M> workerData =
195         removeWorkerMsgVids(workerInfo);
196       if (workerData != null && !workerData.isEmpty()) {
197         allData.add(workerInfo, workerData);
198       }
199     }
200     return allData;
201   }
202 
203   @Override
204   public void sendMessageToAllRequest(Iterator<I> vertexIdIterator, M message) {
205     // This is going to be reused through every message sending
206     resetIdSerializers();
207     resetIdCounter();
208     // Count messages
209     int currentMachineId = 0;
210     PartitionOwner owner = null;
211     WorkerInfo workerInfo = null;
212     I vertexId = null;
213     while (vertexIdIterator.hasNext()) {
214       vertexId = vertexIdIterator.next();
215       owner = getServiceWorker().getVertexPartitionOwner(vertexId);
216       workerInfo = owner.getWorkerInfo();
217       currentMachineId = workerInfo.getTaskId();
218       // Serialize this target vertex id
219       try {
220         vertexId.write(idSerializer[currentMachineId]);
221       } catch (IOException e) {
222         throw new IllegalStateException(
223           "Failed to serialize the target vertex id.");
224       }
225       idCounter[currentMachineId]++;
226       // Record the first partition id in the worker which message send to.
227       // If idCounter shows there is only one target on this worker
228       // then this is the partition number of the target vertex.
229       if (idCounter[currentMachineId] == 1) {
230         firstPartitionMap[currentMachineId] = owner.getPartitionId();
231       }
232     }
233     // Add the message to the cache
234     int idSerializerPos = 0;
235     int workerMessageSize = 0;
236     byte[] serializedId  = null;
237     WritableRequest writableRequest = null;
238     for (int i = 0; i < idCounter.length; i++) {
239       if (idCounter[i] == 1) {
240         serializedId = idSerializer[i].getByteArray();
241         idSerializerPos = idSerializer[i].getPos();
242         // Add the message to the cache
243         workerMessageSize = addMessage(workerInfoList[i],
244           firstPartitionMap[i], serializedId, idSerializerPos, message);
245 
246         if (LOG.isTraceEnabled()) {
247           LOG.trace("sendMessageToAllRequest: Send bytes (" +
248             message.toString() + ") to one target in  worker " +
249             workerInfoList[i]);
250         }
251         ++totalMsgsSentInSuperstep;
252         if (workerMessageSize >= maxMessagesSizePerWorker) {
253           PairList<Integer, VertexIdMessages<I, M>>
254             workerMessages = removeWorkerMessages(workerInfoList[i]);
255           writableRequest = new SendWorkerMessagesRequest<>(workerMessages);
256           totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
257           clientProcessor.doRequest(workerInfoList[i], writableRequest);
258           // Notify sending
259           getServiceWorker().getGraphTaskManager().notifySentMessages();
260         }
261       } else if (idCounter[i] > 1) {
262         serializedId = idSerializer[i].getByteArray();
263         idSerializerPos = idSerializer[i].getPos();
264         workerMessageSize = addOneToManyMessage(
265             workerInfoList[i], serializedId, idSerializerPos, idCounter[i],
266             message);
267 
268         if (LOG.isTraceEnabled()) {
269           LOG.trace("sendMessageToAllRequest: Send bytes (" +
270             message.toString() + ") to all targets in worker" +
271             workerInfoList[i]);
272         }
273         totalMsgsSentInSuperstep += idCounter[i];
274         if (workerMessageSize >= maxMessagesSizePerWorker) {
275           ByteArrayOneMessageToManyIds<I, M> workerMsgVids =
276             removeWorkerMsgVids(workerInfoList[i]);
277           writableRequest =  new SendWorkerOneMessageToManyRequest<>(
278             workerMsgVids, getConf());
279           totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
280           clientProcessor.doRequest(workerInfoList[i], writableRequest);
281           // Notify sending
282           getServiceWorker().getGraphTaskManager().notifySentMessages();
283         }
284       }
285     }
286   }
287 
288   @Override
289   public void flush() {
290     super.flush();
291     PairList<WorkerInfo, ByteArrayOneMessageToManyIds<I, M>>
292     remainingMsgVidsCache = removeAllMsgVids();
293     PairList<WorkerInfo,
294         ByteArrayOneMessageToManyIds<I, M>>.Iterator
295     msgIdsIterator = remainingMsgVidsCache.getIterator();
296     while (msgIdsIterator.hasNext()) {
297       msgIdsIterator.next();
298       WritableRequest writableRequest =
299         new SendWorkerOneMessageToManyRequest<>(
300             msgIdsIterator.getCurrentSecond(), getConf());
301       totalMsgBytesSentInSuperstep += writableRequest.getSerializedSize();
302       clientProcessor.doRequest(
303         msgIdsIterator.getCurrentFirst(), writableRequest);
304     }
305   }
306 }