1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 package org.apache.giraph.comm.aggregators;
20
21 import java.io.IOException;
22 import java.util.Map;
23
24 import org.apache.giraph.comm.GlobalCommType;
25 import org.apache.hadoop.io.LongWritable;
26 import org.apache.hadoop.io.Writable;
27
28 import com.google.common.collect.Maps;
29
30 /**
31 * Takes and serializes global communication values and keeps them grouped by
32 * owner partition id, to be sent in bulk.
33 * Includes broadcast messages, reducer registrations and special count.
34 */
35 public class SendGlobalCommCache extends CountingCache {
36 /** Map from worker partition id to global communication output stream */
37 private final Map<Integer, GlobalCommValueOutputStream> globalCommMap =
38 Maps.newHashMap();
39
40 /** whether to write Class object for values into the stream */
41 private final boolean writeClass;
42
43 /**
44 * Constructor
45 *
46 * @param writeClass boolean whether to write Class object for values
47 */
48 public SendGlobalCommCache(boolean writeClass) {
49 this.writeClass = writeClass;
50 }
51
52 /**
53 * Add global communication value to the cache
54 *
55 * @param taskId Task id of worker which owns the value
56 * @param name Name
57 * @param type Global communication type
58 * @param value Value
59 * @return Number of bytes in serialized data for this worker
60 * @throws IOException
61 */
62 public int addValue(Integer taskId, String name,
63 GlobalCommType type, Writable value) throws IOException {
64 GlobalCommValueOutputStream out = globalCommMap.get(taskId);
65 if (out == null) {
66 out = new GlobalCommValueOutputStream(writeClass);
67 globalCommMap.put(taskId, out);
68 }
69 return out.addValue(name, type, value);
70 }
71
72 /**
73 * Remove and get values for certain worker
74 *
75 * @param taskId Partition id of worker owner
76 * @return Serialized global communication data for this worker
77 */
78 public byte[] removeSerialized(Integer taskId) {
79 incrementCounter(taskId);
80 GlobalCommValueOutputStream out = globalCommMap.remove(taskId);
81 if (out == null) {
82 return new byte[4];
83 } else {
84 return out.flush();
85 }
86 }
87
88 /**
89 * Creates special value which will hold the total number of global
90 * communication requests for worker with selected task id. This should be
91 * called after all values for the worker have been added to the cache.
92 *
93 * @param taskId Destination worker's task id
94 * @throws IOException
95 */
96 public void addSpecialCount(Integer taskId) throws IOException {
97 // current number of requests, plus one for the last flush
98 long totalCount = getCount(taskId) + 1;
99 addValue(taskId, GlobalCommType.SPECIAL_COUNT.name(),
100 GlobalCommType.SPECIAL_COUNT, new LongWritable(totalCount));
101 }
102 }