This project has retired. For details please refer to its Attic page.
AggregatorsTestComputation xref
View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.examples;
20  
21  import org.apache.giraph.aggregators.LongSumAggregator;
22  import org.apache.giraph.bsp.BspInputSplit;
23  import org.apache.giraph.edge.Edge;
24  import org.apache.giraph.edge.EdgeFactory;
25  import org.apache.giraph.graph.BasicComputation;
26  import org.apache.giraph.master.DefaultMasterCompute;
27  import org.apache.giraph.graph.Vertex;
28  import org.apache.giraph.io.EdgeInputFormat;
29  import org.apache.giraph.io.EdgeReader;
30  import org.apache.giraph.io.VertexReader;
31  import org.apache.giraph.io.formats.GeneratedVertexInputFormat;
32  import org.apache.hadoop.conf.Configuration;
33  import org.apache.hadoop.io.DoubleWritable;
34  import org.apache.hadoop.io.FloatWritable;
35  import org.apache.hadoop.io.LongWritable;
36  import org.apache.hadoop.mapreduce.InputSplit;
37  import org.apache.hadoop.mapreduce.JobContext;
38  import org.apache.hadoop.mapreduce.TaskAttemptContext;
39  import org.apache.log4j.Logger;
40  
41  import com.google.common.collect.Lists;
42  
43  import java.io.IOException;
44  import java.util.ArrayList;
45  import java.util.List;
46  
47  /** Computation which uses aggrergators. To be used for testing. */
48  public class AggregatorsTestComputation extends
49      BasicComputation<LongWritable, DoubleWritable, FloatWritable,
50          DoubleWritable> {
51  
52    /** Name of regular aggregator */
53    private static final String REGULAR_AGG = "regular";
54    /** Name of persistent aggregator */
55    private static final String PERSISTENT_AGG = "persistent";
56    /** Name of input super step persistent aggregator */
57    private static final String INPUT_VERTEX_PERSISTENT_AGG
58      = "input_super_step_vertex_agg";
59    /** Name of input super step persistent aggregator */
60    private static final String INPUT_EDGE_PERSISTENT_AGG
61      = "input_super_step_edge_agg";
62    /** Name of master overwriting aggregator */
63    private static final String MASTER_WRITE_AGG = "master";
64    /** Value which master compute will use */
65    private static final long MASTER_VALUE = 12345;
66    /** Prefix for name of aggregators in array */
67    private static final String ARRAY_PREFIX_AGG = "array";
68    /** Number of aggregators to use in array */
69    private static final int NUM_OF_AGGREGATORS_IN_ARRAY = 100;
70  
71    @Override
72    public void compute(
73        Vertex<LongWritable, DoubleWritable, FloatWritable> vertex,
74        Iterable<DoubleWritable> messages) throws IOException {
75      long superstep = getSuperstep();
76  
77      LongWritable myValue = new LongWritable(1L << superstep);
78      aggregate(REGULAR_AGG, myValue);
79      aggregate(PERSISTENT_AGG, myValue);
80  
81      long nv = getTotalNumVertices();
82      if (superstep > 0) {
83        assertEquals(nv * (1L << (superstep - 1)),
84            ((LongWritable) getAggregatedValue(REGULAR_AGG)).get());
85      } else {
86        assertEquals(0,
87            ((LongWritable) getAggregatedValue(REGULAR_AGG)).get());
88      }
89      assertEquals(nv * ((1L << superstep) - 1),
90          ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get());
91      assertEquals(MASTER_VALUE * (1L << superstep),
92          ((LongWritable) getAggregatedValue(MASTER_WRITE_AGG)).get());
93  
94      for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) {
95        aggregate(ARRAY_PREFIX_AGG + i, new LongWritable((superstep + 1) * i));
96        assertEquals(superstep * getTotalNumVertices() * i,
97            ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get());
98      }
99  
100     if (getSuperstep() == 10) {
101       vertex.voteToHalt();
102     }
103   }
104 
105   /** Master compute which uses aggregators. To be used for testing. */
106   public static class AggregatorsTestMasterCompute extends
107       DefaultMasterCompute {
108     @Override
109     public void compute() {
110       long superstep = getSuperstep();
111 
112       LongWritable myValue =
113           new LongWritable(MASTER_VALUE * (1L << superstep));
114       setAggregatedValue(MASTER_WRITE_AGG, myValue);
115 
116       long nv = getTotalNumVertices();
117       if (superstep >= 0) {
118         assertEquals(100, ((LongWritable)
119           getAggregatedValue(INPUT_VERTEX_PERSISTENT_AGG)).get());
120       }
121       if (superstep >= 0) {
122         assertEquals(4500, ((LongWritable)
123           getAggregatedValue(INPUT_EDGE_PERSISTENT_AGG)).get());
124       }
125       if (superstep > 0) {
126         assertEquals(nv * (1L << (superstep - 1)),
127             ((LongWritable) getAggregatedValue(REGULAR_AGG)).get());
128       } else {
129         assertEquals(0,
130             ((LongWritable) getAggregatedValue(REGULAR_AGG)).get());
131       }
132       assertEquals(nv * ((1L << superstep) - 1),
133           ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get());
134 
135       for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) {
136         assertEquals(superstep * getTotalNumVertices() * i,
137             ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get());
138       }
139     }
140 
141     @Override
142     public void initialize() throws InstantiationException,
143         IllegalAccessException {
144       registerPersistentAggregator(
145           INPUT_VERTEX_PERSISTENT_AGG, LongSumAggregator.class);
146       registerPersistentAggregator(
147           INPUT_EDGE_PERSISTENT_AGG, LongSumAggregator.class);
148       registerAggregator(REGULAR_AGG, LongSumAggregator.class);
149       registerPersistentAggregator(PERSISTENT_AGG,
150           LongSumAggregator.class);
151       registerAggregator(MASTER_WRITE_AGG, LongSumAggregator.class);
152 
153       for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) {
154         registerAggregator(ARRAY_PREFIX_AGG + i, LongSumAggregator.class);
155       }
156     }
157   }
158 
159   /**
160    * Throws exception if values are not equal.
161    *
162    * @param expected Expected value
163    * @param actual   Actual value
164    */
165   private static void assertEquals(long expected, long actual) {
166     if (expected != actual) {
167       throw new RuntimeException("expected: " + expected +
168           ", actual: " + actual);
169     }
170   }
171 
172   /**
173    * Simple VertexReader
174    */
175   public static class SimpleVertexReader extends
176       GeneratedVertexReader<LongWritable, DoubleWritable, FloatWritable> {
177     /** Class logger */
178     private static final Logger LOG =
179         Logger.getLogger(SimpleVertexReader.class);
180 
181     @Override
182     public boolean nextVertex() {
183       return totalRecords > recordsRead;
184     }
185 
186     @Override
187     public Vertex<LongWritable, DoubleWritable,
188         FloatWritable> getCurrentVertex() throws IOException {
189       Vertex<LongWritable, DoubleWritable, FloatWritable> vertex =
190           getConf().createVertex();
191       LongWritable vertexId = new LongWritable(
192           (inputSplit.getSplitIndex() * totalRecords) + recordsRead);
193       DoubleWritable vertexValue = new DoubleWritable(vertexId.get() * 10d);
194       long targetVertexId =
195           (vertexId.get() + 1) %
196           (inputSplit.getNumSplits() * totalRecords);
197       float edgeValue = vertexId.get() * 100f;
198       List<Edge<LongWritable, FloatWritable>> edges = Lists.newLinkedList();
199       edges.add(EdgeFactory.create(new LongWritable(targetVertexId),
200           new FloatWritable(edgeValue)));
201       vertex.initialize(vertexId, vertexValue, edges);
202       ++recordsRead;
203       if (LOG.isInfoEnabled()) {
204         LOG.info("next vertex: Return vertexId=" + vertex.getId().get() +
205             ", vertexValue=" + vertex.getValue() +
206             ", targetVertexId=" + targetVertexId + ", edgeValue=" + edgeValue);
207       }
208       aggregate(INPUT_VERTEX_PERSISTENT_AGG,
209         new LongWritable((long) vertex.getValue().get()));
210       return vertex;
211     }
212   }
213 
214   /**
215    * Simple VertexInputFormat
216    */
217   public static class SimpleVertexInputFormat extends
218     GeneratedVertexInputFormat<LongWritable, DoubleWritable, FloatWritable> {
219     @Override
220     public VertexReader<LongWritable, DoubleWritable,
221     FloatWritable> createVertexReader(InputSplit split,
222       TaskAttemptContext context)
223       throws IOException {
224       return new SimpleVertexReader();
225     }
226   }
227 
228   /**
229    * Simple Edge Reader
230    */
231   public static class SimpleEdgeReader extends
232     GeneratedEdgeReader<LongWritable, FloatWritable> {
233     /** Class logger */
234     private static final Logger LOG = Logger.getLogger(SimpleEdgeReader.class);
235 
236     @Override
237     public boolean nextEdge() {
238       return totalRecords > recordsRead;
239     }
240 
241     @Override
242     public Edge<LongWritable, FloatWritable> getCurrentEdge()
243       throws IOException {
244       LongWritable vertexId = new LongWritable(
245         (inputSplit.getSplitIndex() * totalRecords) + recordsRead);
246       long targetVertexId = (vertexId.get() + 1) %
247         (inputSplit.getNumSplits() * totalRecords);
248       float edgeValue = vertexId.get() * 100f;
249       Edge<LongWritable, FloatWritable> edge = EdgeFactory.create(
250         new LongWritable(targetVertexId), new FloatWritable(edgeValue));
251       ++recordsRead;
252       if (LOG.isInfoEnabled()) {
253         LOG.info("next edge: Return targetVertexId=" + targetVertexId +
254           ", edgeValue=" + edgeValue);
255       }
256       aggregate(INPUT_EDGE_PERSISTENT_AGG, new LongWritable((long) edge
257         .getValue().get()));
258       return edge;
259     }
260 
261     @Override
262     public LongWritable getCurrentSourceId() throws IOException,
263       InterruptedException {
264       LongWritable vertexId = new LongWritable(
265         (inputSplit.getSplitIndex() * totalRecords) + recordsRead);
266       return vertexId;
267     }
268   }
269 
270   /**
271    * Simple VertexInputFormat
272    */
273   public static class SimpleEdgeInputFormat extends
274     EdgeInputFormat<LongWritable, FloatWritable> {
275     @Override public void checkInputSpecs(Configuration conf) { }
276 
277     @Override
278     public EdgeReader<LongWritable, FloatWritable> createEdgeReader(
279       InputSplit split, TaskAttemptContext context) throws IOException {
280       return new SimpleEdgeReader();
281     }
282 
283     @Override
284     public List<InputSplit> getSplits(JobContext context, int minSplitCountHint)
285       throws IOException, InterruptedException {
286       List<InputSplit> inputSplitList = new ArrayList<InputSplit>();
287       for (int i = 0; i < minSplitCountHint; ++i) {
288         inputSplitList.add(new BspInputSplit(i, minSplitCountHint));
289       }
290       return inputSplitList;
291     }
292   }
293 }