View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.master;
20  
21  import java.io.DataInput;
22  import java.io.DataOutput;
23  import java.io.IOException;
24  import java.util.ArrayList;
25  import java.util.HashSet;
26  
27  import junit.framework.Assert;
28  
29  import org.apache.giraph.combiner.MessageCombiner;
30  import org.apache.giraph.conf.GiraphConfiguration;
31  import org.apache.giraph.graph.AbstractComputation;
32  import org.apache.giraph.graph.Vertex;
33  import org.apache.giraph.utils.InternalVertexRunner;
34  import org.apache.giraph.utils.TestGraph;
35  import org.apache.hadoop.io.DoubleWritable;
36  import org.apache.hadoop.io.IntWritable;
37  import org.apache.hadoop.io.Writable;
38  import org.junit.Test;
39  
40  import com.google.common.collect.Lists;
41  import com.google.common.collect.Sets;
42  
43  /** Test switching Computation and MessageCombiner class during application */
44  public class TestSwitchClasses {
45    @Test
46    public void testSwitchingClasses() throws Exception {
47      GiraphConfiguration conf = new GiraphConfiguration();
48      conf.setComputationClass(Computation3.class);
49      conf.setMasterComputeClass(SwitchingClassesMasterCompute.class);
50  
51      TestGraph<IntWritable, StatusValue, IntWritable> graph =
52          new TestGraph<IntWritable, StatusValue, IntWritable>(conf);
53      IntWritable id1 = new IntWritable(1);
54      graph.addVertex(id1, new StatusValue());
55      IntWritable id2 = new IntWritable(2);
56      graph.addVertex(id2, new StatusValue());
57      graph = InternalVertexRunner.runWithInMemoryOutput(conf, graph);
58  
59      Assert.assertEquals(2, graph.getVertices().size());
60    }
61  
62    private static void checkVerticesOnFinalSuperstep(
63        Vertex<IntWritable, StatusValue, IntWritable> vertex) {
64      // Check that computations were performed in expected order
65      final ArrayList<Integer> expectedComputations =
66          Lists.newArrayList(1, 1, 2, 3, 1);
67      checkComputations(expectedComputations, vertex.getValue().computations);
68      // Check that messages were sent in the correct superstep,
69      // and combined when needed
70      switch (vertex.getId().get()) {
71        case 1:
72          ArrayList<HashSet<Double>> messages1 =
73              Lists.newArrayList(
74                  Sets.<Double>newHashSet(),
75                  Sets.<Double>newHashSet(11d),
76                  Sets.<Double>newHashSet(11d),
77                  Sets.<Double>newHashSet(101.5, 201.5),
78                  Sets.<Double>newHashSet(3002d));
79          checkMessages(messages1, vertex.getValue().messagesReceived);
80          break;
81        case 2:
82          ArrayList<HashSet<Double>> messages2 =
83              Lists.newArrayList(
84                  Sets.<Double>newHashSet(),
85                  Sets.<Double>newHashSet(12d),
86                  Sets.<Double>newHashSet(12d),
87                  Sets.<Double>newHashSet(102.5, 202.5),
88                  Sets.<Double>newHashSet(3004d));
89          checkMessages(messages2, vertex.getValue().messagesReceived);
90          break;
91        default:
92          throw new IllegalStateException("checkVertices: Illegal vertex " +
93              vertex);
94      }
95    }
96  
97    private static void checkComputations(ArrayList<Integer> expected,
98        ArrayList<Integer> actual) {
99      Assert.assertEquals("Incorrect number of supersteps",
100         expected.size(), actual.size());
101     for (int i = 0; i < expected.size(); i++) {
102       Assert.assertEquals("Incorrect computation on superstep " + i,
103           (int) expected.get(i), (int) actual.get(i));
104     }
105   }
106 
107   private static void checkMessages(ArrayList<HashSet<Double>> expected,
108       ArrayList<HashSet<Double>> actual) {
109     Assert.assertEquals(expected.size(), actual.size());
110     for (int i = 0; i < expected.size(); i++) {
111       Assert.assertEquals(expected.get(i).size(), actual.get(i).size());
112       for (Double value : expected.get(i)) {
113         Assert.assertTrue(actual.get(i).contains(value));
114       }
115     }
116   }
117 
118   public static class SwitchingClassesMasterCompute
119       extends DefaultMasterCompute {
120     @Override
121     public void compute() {
122       switch ((int) getSuperstep()) {
123         case 0:
124           setComputation(Computation1.class);
125           setMessageCombiner(MinimumMessageCombiner.class);
126           break;
127         case 1:
128           // test classes don't change
129           break;
130         case 2:
131           setComputation(Computation2.class);
132           // test combiner removed
133           setMessageCombiner(null);
134           break;
135         case 3:
136           setComputation(Computation3.class);
137           setMessageCombiner(SumMessageCombiner.class);
138           setIncomingMessage(DoubleWritable.class);
139           setOutgoingMessage(IntWritable.class);
140           break;
141         case 4:
142           setComputation(Computation1.class);
143           break;
144         default:
145           haltComputation();
146       }
147     }
148   }
149 
150   public static class Computation1 extends AbstractComputation<IntWritable,
151         StatusValue, IntWritable, IntWritable, IntWritable> {
152     @Override
153     public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex,
154         Iterable<IntWritable> messages) throws IOException {
155       vertex.getValue().computations.add(1);
156       vertex.getValue().addIntMessages(messages);
157 
158       IntWritable otherId = new IntWritable(3 - vertex.getId().get());
159       sendMessage(otherId, new IntWritable(otherId.get() + 10));
160       sendMessage(otherId, new IntWritable(otherId.get() + 20));
161       // Check the vertices on the final superstep
162       if (getSuperstep() == 4) {
163         checkVerticesOnFinalSuperstep(vertex);
164       }
165     }
166   }
167 
168   public static class Computation2 extends AbstractComputation<IntWritable,
169         StatusValue, IntWritable, IntWritable, DoubleWritable> {
170     @Override
171     public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex,
172         Iterable<IntWritable> messages) throws IOException {
173       vertex.getValue().computations.add(2);
174       vertex.getValue().addIntMessages(messages);
175 
176       IntWritable otherId = new IntWritable(3 - vertex.getId().get());
177       sendMessage(otherId, new DoubleWritable(otherId.get() + 100.5));
178       sendMessage(otherId, new DoubleWritable(otherId.get() + 200.5));
179     }
180   }
181 
182   public static class Computation3 extends AbstractComputation<IntWritable,
183         StatusValue, IntWritable, Writable, Writable> {
184     @Override
185     public void compute(
186         Vertex<IntWritable, StatusValue, IntWritable> vertex,
187         Iterable<Writable> messages) throws IOException {
188       vertex.getValue().computations.add(3);
189       vertex.getValue().addDoubleMessages(messages);
190 
191       IntWritable otherId = new IntWritable(3 - vertex.getId().get());
192       sendMessage(otherId, new IntWritable(otherId.get() + 1000));
193       sendMessage(otherId, new IntWritable(otherId.get() + 2000));
194     }
195   }
196 
197   public static class MinimumMessageCombiner
198       implements MessageCombiner<IntWritable,
199                   IntWritable> {
200     @Override
201     public void combine(IntWritable vertexIndex, IntWritable originalMessage,
202         IntWritable messageToCombine) {
203       originalMessage.set(
204           Math.min(originalMessage.get(), messageToCombine.get()));
205     }
206 
207     @Override
208     public IntWritable createInitialMessage() {
209       return new IntWritable(Integer.MAX_VALUE);
210     }
211   }
212 
213   public static class SumMessageCombiner
214       implements MessageCombiner<IntWritable, IntWritable> {
215     @Override
216     public void combine(IntWritable vertexIndex, IntWritable originalMessage,
217         IntWritable messageToCombine) {
218       originalMessage.set(originalMessage.get() + messageToCombine.get());
219     }
220 
221     @Override
222     public IntWritable createInitialMessage() {
223       return new IntWritable(0);
224     }
225   }
226 
227   public static class StatusValue implements Writable {
228     private ArrayList<Integer> computations = new ArrayList<Integer>();
229     private ArrayList<HashSet<Double>> messagesReceived =
230         new ArrayList<HashSet<Double>>();
231 
232     public StatusValue() {
233     }
234 
235     public void addIntMessages(Iterable<IntWritable> messages) {
236       HashSet<Double> messagesList = new HashSet<Double>();
237       for (IntWritable message : messages) {
238         messagesList.add((double) message.get());
239       }
240       messagesReceived.add(messagesList);
241     }
242 
243     public void addDoubleMessages(Iterable<Writable> messages) {
244       HashSet<Double> messagesList = new HashSet<Double>();
245       for (Writable message : messages) {
246         messagesList.add(((DoubleWritable)message).get());
247       }
248       messagesReceived.add(messagesList);
249     }
250 
251     @Override
252     public String toString() {
253       return "(computations=" + computations +
254           ",messagesReceived=" + messagesReceived + ")";
255     }
256 
257     @Override
258     public void write(DataOutput dataOutput) throws IOException {
259       dataOutput.writeInt(computations.size());
260       for (Integer computation : computations) {
261         dataOutput.writeInt(computation);
262       }
263       dataOutput.writeInt(messagesReceived.size());
264       for (HashSet<Double> messages : messagesReceived) {
265         dataOutput.writeInt(messages.size());
266         for (Double message : messages) {
267           dataOutput.writeDouble(message);
268         }
269       }
270     }
271 
272     @Override
273     public void readFields(DataInput dataInput) throws IOException {
274       int size = dataInput.readInt();
275       computations = new ArrayList<Integer>(size);
276       for (int i = 0; i < size; i++) {
277         computations.add(dataInput.readInt());
278       }
279       size = dataInput.readInt();
280       messagesReceived = new ArrayList<HashSet<Double>>(size);
281       for (int i = 0; i < size; i++) {
282         int size2 = dataInput.readInt();
283         HashSet<Double> messages = new HashSet<Double>(size2);
284         for (int j = 0; j < size2; j++) {
285           messages.add(dataInput.readDouble());
286         }
287         messagesReceived.add(messages);
288       }
289     }
290   }
291 }