This project has retired. For details please refer to its
Attic page.
TestSwitchClasses xref
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.master;
20
21 import java.io.DataInput;
22 import java.io.DataOutput;
23 import java.io.IOException;
24 import java.util.ArrayList;
25 import java.util.HashSet;
26
27 import junit.framework.Assert;
28
29 import org.apache.giraph.combiner.MessageCombiner;
30 import org.apache.giraph.conf.GiraphConfiguration;
31 import org.apache.giraph.graph.AbstractComputation;
32 import org.apache.giraph.graph.Vertex;
33 import org.apache.giraph.utils.InternalVertexRunner;
34 import org.apache.giraph.utils.TestGraph;
35 import org.apache.hadoop.io.DoubleWritable;
36 import org.apache.hadoop.io.IntWritable;
37 import org.apache.hadoop.io.Writable;
38 import org.junit.Test;
39
40 import com.google.common.collect.Lists;
41 import com.google.common.collect.Sets;
42
43
44 public class TestSwitchClasses {
45 @Test
46 public void testSwitchingClasses() throws Exception {
47 GiraphConfiguration conf = new GiraphConfiguration();
48 conf.setComputationClass(Computation3.class);
49 conf.setMasterComputeClass(SwitchingClassesMasterCompute.class);
50
51 TestGraph<IntWritable, StatusValue, IntWritable> graph =
52 new TestGraph<IntWritable, StatusValue, IntWritable>(conf);
53 IntWritable id1 = new IntWritable(1);
54 graph.addVertex(id1, new StatusValue());
55 IntWritable id2 = new IntWritable(2);
56 graph.addVertex(id2, new StatusValue());
57 graph = InternalVertexRunner.runWithInMemoryOutput(conf, graph);
58
59 Assert.assertEquals(2, graph.getVertexCount());
60 }
61
62 private static void checkVerticesOnFinalSuperstep(
63 Vertex<IntWritable, StatusValue, IntWritable> vertex) {
64
65 final ArrayList<Integer> expectedComputations =
66 Lists.newArrayList(1, 1, 2, 3, 1);
67 checkComputations(expectedComputations, vertex.getValue().computations);
68
69
70 switch (vertex.getId().get()) {
71 case 1:
72 ArrayList<HashSet<Double>> messages1 =
73 Lists.newArrayList(
74 Sets.<Double>newHashSet(),
75 Sets.<Double>newHashSet(11d),
76 Sets.<Double>newHashSet(11d),
77 Sets.<Double>newHashSet(101.5, 201.5),
78 Sets.<Double>newHashSet(3002d));
79 checkMessages(messages1, vertex.getValue().messagesReceived);
80 break;
81 case 2:
82 ArrayList<HashSet<Double>> messages2 =
83 Lists.newArrayList(
84 Sets.<Double>newHashSet(),
85 Sets.<Double>newHashSet(12d),
86 Sets.<Double>newHashSet(12d),
87 Sets.<Double>newHashSet(102.5, 202.5),
88 Sets.<Double>newHashSet(3004d));
89 checkMessages(messages2, vertex.getValue().messagesReceived);
90 break;
91 default:
92 throw new IllegalStateException("checkVertices: Illegal vertex " +
93 vertex);
94 }
95 }
96
97 private static void checkComputations(ArrayList<Integer> expected,
98 ArrayList<Integer> actual) {
99 Assert.assertEquals("Incorrect number of supersteps",
100 expected.size(), actual.size());
101 for (int i = 0; i < expected.size(); i++) {
102 Assert.assertEquals("Incorrect computation on superstep " + i,
103 (int) expected.get(i), (int) actual.get(i));
104 }
105 }
106
107 private static void checkMessages(ArrayList<HashSet<Double>> expected,
108 ArrayList<HashSet<Double>> actual) {
109 Assert.assertEquals(expected.size(), actual.size());
110 for (int i = 0; i < expected.size(); i++) {
111 Assert.assertEquals(expected.get(i).size(), actual.get(i).size());
112 for (Double value : expected.get(i)) {
113 Assert.assertTrue(actual.get(i).contains(value));
114 }
115 }
116 }
117
118 public static class SwitchingClassesMasterCompute
119 extends DefaultMasterCompute {
120 @Override
121 public void compute() {
122 switch ((int) getSuperstep()) {
123 case 0:
124 setComputation(Computation1.class);
125 setMessageCombiner(MinimumMessageCombiner.class);
126 break;
127 case 1:
128
129 break;
130 case 2:
131 setComputation(Computation2.class);
132
133 setMessageCombiner(null);
134 break;
135 case 3:
136 setComputation(Computation3.class);
137 setMessageCombiner(SumMessageCombiner.class);
138 setIncomingMessage(DoubleWritable.class);
139 setOutgoingMessage(IntWritable.class);
140 break;
141 case 4:
142 setComputation(Computation1.class);
143 break;
144 default:
145 haltComputation();
146 }
147 }
148 }
149
150 public static class Computation1 extends AbstractComputation<IntWritable,
151 StatusValue, IntWritable, IntWritable, IntWritable> {
152 @Override
153 public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex,
154 Iterable<IntWritable> messages) throws IOException {
155 vertex.getValue().computations.add(1);
156 vertex.getValue().addIntMessages(messages);
157
158 IntWritable otherId = new IntWritable(3 - vertex.getId().get());
159 sendMessage(otherId, new IntWritable(otherId.get() + 10));
160 sendMessage(otherId, new IntWritable(otherId.get() + 20));
161
162 if (getSuperstep() == 4) {
163 checkVerticesOnFinalSuperstep(vertex);
164 }
165 }
166 }
167
168 public static class Computation2 extends AbstractComputation<IntWritable,
169 StatusValue, IntWritable, IntWritable, DoubleWritable> {
170 @Override
171 public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex,
172 Iterable<IntWritable> messages) throws IOException {
173 vertex.getValue().computations.add(2);
174 vertex.getValue().addIntMessages(messages);
175
176 IntWritable otherId = new IntWritable(3 - vertex.getId().get());
177 sendMessage(otherId, new DoubleWritable(otherId.get() + 100.5));
178 sendMessage(otherId, new DoubleWritable(otherId.get() + 200.5));
179 }
180 }
181
182 public static class Computation3 extends AbstractComputation<IntWritable,
183 StatusValue, IntWritable, Writable, Writable> {
184 @Override
185 public void compute(
186 Vertex<IntWritable, StatusValue, IntWritable> vertex,
187 Iterable<Writable> messages) throws IOException {
188 vertex.getValue().computations.add(3);
189 vertex.getValue().addDoubleMessages(messages);
190
191 IntWritable otherId = new IntWritable(3 - vertex.getId().get());
192 sendMessage(otherId, new IntWritable(otherId.get() + 1000));
193 sendMessage(otherId, new IntWritable(otherId.get() + 2000));
194 }
195 }
196
197 public static class MinimumMessageCombiner
198 implements MessageCombiner<IntWritable,
199 IntWritable> {
200 @Override
201 public void combine(IntWritable vertexIndex, IntWritable originalMessage,
202 IntWritable messageToCombine) {
203 originalMessage.set(
204 Math.min(originalMessage.get(), messageToCombine.get()));
205 }
206
207 @Override
208 public IntWritable createInitialMessage() {
209 return new IntWritable(Integer.MAX_VALUE);
210 }
211 }
212
213 public static class SumMessageCombiner
214 implements MessageCombiner<IntWritable, IntWritable> {
215 @Override
216 public void combine(IntWritable vertexIndex, IntWritable originalMessage,
217 IntWritable messageToCombine) {
218 originalMessage.set(originalMessage.get() + messageToCombine.get());
219 }
220
221 @Override
222 public IntWritable createInitialMessage() {
223 return new IntWritable(0);
224 }
225 }
226
227 public static class StatusValue implements Writable {
228 private ArrayList<Integer> computations = new ArrayList<Integer>();
229 private ArrayList<HashSet<Double>> messagesReceived =
230 new ArrayList<HashSet<Double>>();
231
232 public StatusValue() {
233 }
234
235 public void addIntMessages(Iterable<IntWritable> messages) {
236 HashSet<Double> messagesList = new HashSet<Double>();
237 for (IntWritable message : messages) {
238 messagesList.add((double) message.get());
239 }
240 messagesReceived.add(messagesList);
241 }
242
243 public void addDoubleMessages(Iterable<Writable> messages) {
244 HashSet<Double> messagesList = new HashSet<Double>();
245 for (Writable message : messages) {
246 messagesList.add(((DoubleWritable)message).get());
247 }
248 messagesReceived.add(messagesList);
249 }
250
251 @Override
252 public String toString() {
253 return "(computations=" + computations +
254 ",messagesReceived=" + messagesReceived + ")";
255 }
256
257 @Override
258 public void write(DataOutput dataOutput) throws IOException {
259 dataOutput.writeInt(computations.size());
260 for (Integer computation : computations) {
261 dataOutput.writeInt(computation);
262 }
263 dataOutput.writeInt(messagesReceived.size());
264 for (HashSet<Double> messages : messagesReceived) {
265 dataOutput.writeInt(messages.size());
266 for (Double message : messages) {
267 dataOutput.writeDouble(message);
268 }
269 }
270 }
271
272 @Override
273 public void readFields(DataInput dataInput) throws IOException {
274 int size = dataInput.readInt();
275 computations = new ArrayList<Integer>(size);
276 for (int i = 0; i < size; i++) {
277 computations.add(dataInput.readInt());
278 }
279 size = dataInput.readInt();
280 messagesReceived = new ArrayList<HashSet<Double>>(size);
281 for (int i = 0; i < size; i++) {
282 int size2 = dataInput.readInt();
283 HashSet<Double> messages = new HashSet<Double>(size2);
284 for (int j = 0; j < size2; j++) {
285 messages.add(dataInput.readDouble());
286 }
287 messagesReceived.add(messages);
288 }
289 }
290 }
291 }