View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm.messages;
20  
21  import java.io.IOException;
22  import java.util.Iterator;
23  
24  import junit.framework.Assert;
25  
26  import org.apache.giraph.bsp.CentralizedServiceWorker;
27  import org.apache.giraph.combiner.FloatSumMessageCombiner;
28  import org.apache.giraph.comm.messages.primitives.IntByteArrayMessageStore;
29  import org.apache.giraph.comm.messages.primitives.IntFloatMessageStore;
30  import org.apache.giraph.conf.GiraphConfiguration;
31  import org.apache.giraph.conf.GiraphConstants;
32  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
33  import org.apache.giraph.factories.TestMessageValueFactory;
34  import org.apache.giraph.graph.BasicComputation;
35  import org.apache.giraph.graph.Vertex;
36  import org.apache.giraph.partition.Partition;
37  import org.apache.giraph.partition.PartitionStore;
38  import org.apache.giraph.utils.ByteArrayVertexIdMessages;
39  import org.apache.hadoop.io.FloatWritable;
40  import org.apache.hadoop.io.IntWritable;
41  import org.apache.hadoop.io.NullWritable;
42  import org.apache.hadoop.io.Writable;
43  import org.junit.Before;
44  import org.junit.Test;
45  import org.mockito.Mockito;
46  import org.mockito.invocation.InvocationOnMock;
47  import org.mockito.stubbing.Answer;
48  
49  import com.google.common.collect.Iterables;
50  import com.google.common.collect.Lists;
51  
52  public class TestIntFloatPrimitiveMessageStores {
53    private static final int NUM_PARTITIONS = 2;
54    private static CentralizedServiceWorker<IntWritable, Writable, Writable>
55      service;
56    private static ImmutableClassesGiraphConfiguration<IntWritable, Writable,
57        Writable> conf;
58  
59    @Before
60    public void prepare() {
61      service = Mockito.mock(CentralizedServiceWorker.class);
62      Mockito.when(
63          service.getPartitionId(Mockito.any(IntWritable.class))).thenAnswer(
64          new Answer<Integer>() {
65            @Override
66            public Integer answer(InvocationOnMock invocation) {
67              IntWritable vertexId = (IntWritable) invocation.getArguments()[0];
68              return vertexId.get() % NUM_PARTITIONS;
69            }
70          }
71      );
72      PartitionStore partitionStore = Mockito.mock(PartitionStore.class);
73      Mockito.when(service.getPartitionStore()).thenReturn(partitionStore);
74      Mockito.when(partitionStore.getPartitionIds()).thenReturn(
75          Lists.newArrayList(0, 1));
76      Partition partition = Mockito.mock(Partition.class);
77      Mockito.when(partition.getVertexCount()).thenReturn(Long.valueOf(1));
78      Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
79      Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
80  
81      GiraphConfiguration initConf = new GiraphConfiguration();
82      initConf.setComputationClass(IntFloatNoOpComputation.class);
83      conf = new ImmutableClassesGiraphConfiguration(initConf);
84    }
85  
86    private static class IntFloatNoOpComputation extends
87        BasicComputation<IntWritable, NullWritable, NullWritable,
88            FloatWritable> {
89      @Override
90      public void compute(Vertex<IntWritable, NullWritable, NullWritable> vertex,
91          Iterable<FloatWritable> messages) throws IOException {
92      }
93    }
94  
95    private static ByteArrayVertexIdMessages<IntWritable, FloatWritable>
96    createIntFloatMessages() {
97      ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
98          new ByteArrayVertexIdMessages<IntWritable, FloatWritable>(
99              new TestMessageValueFactory<FloatWritable>(FloatWritable.class));
100     messages.setConf(conf);
101     messages.initialize();
102     return messages;
103   }
104 
105   private static void insertIntFloatMessages(
106       MessageStore<IntWritable, FloatWritable> messageStore) {
107     ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
108         createIntFloatMessages();
109     messages.add(new IntWritable(0), new FloatWritable(1));
110     messages.add(new IntWritable(2), new FloatWritable(3));
111     messages.add(new IntWritable(0), new FloatWritable(4));
112     messageStore.addPartitionMessages(0, messages);
113     messages = createIntFloatMessages();
114     messages.add(new IntWritable(1), new FloatWritable(1));
115     messages.add(new IntWritable(1), new FloatWritable(3));
116     messages.add(new IntWritable(1), new FloatWritable(4));
117     messageStore.addPartitionMessages(1, messages);
118     messages = createIntFloatMessages();
119     messages.add(new IntWritable(0), new FloatWritable(5));
120     messageStore.addPartitionMessages(0, messages);
121   }
122 
123   @Test
124   public void testIntFloatMessageStore() {
125     IntFloatMessageStore messageStore =
126         new IntFloatMessageStore(service, new FloatSumMessageCombiner());
127     insertIntFloatMessages(messageStore);
128 
129     Iterable<FloatWritable> m0 =
130         messageStore.getVertexMessages(new IntWritable(0));
131     Assert.assertEquals(1, Iterables.size(m0));
132     Assert.assertEquals((float) 10.0, m0.iterator().next().get());
133     Iterable<FloatWritable> m1 =
134         messageStore.getVertexMessages(new IntWritable(1));
135     Assert.assertEquals(1, Iterables.size(m1));
136     Assert.assertEquals((float) 8.0, m1.iterator().next().get());
137     Iterable<FloatWritable> m2 =
138         messageStore.getVertexMessages(new IntWritable(2));
139     Assert.assertEquals(1, Iterables.size(m2));
140     Assert.assertEquals((float) 3.0, m2.iterator().next().get());
141     Assert.assertTrue(
142         Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
143   }
144 
145   @Test
146   public void testIntByteArrayMessageStore() {
147     IntByteArrayMessageStore<FloatWritable> messageStore =
148         new IntByteArrayMessageStore<FloatWritable>(new
149             TestMessageValueFactory<FloatWritable>(FloatWritable.class),
150             service, conf);
151     insertIntFloatMessages(messageStore);
152 
153     Iterable<FloatWritable> m0 =
154         messageStore.getVertexMessages(new IntWritable(0));
155     Assert.assertEquals(3, Iterables.size(m0));
156     Iterator<FloatWritable> i0 = m0.iterator();
157     Assert.assertEquals((float) 1.0, i0.next().get());
158     Assert.assertEquals((float) 4.0, i0.next().get());
159     Assert.assertEquals((float) 5.0, i0.next().get());
160     Iterable<FloatWritable> m1 =
161         messageStore.getVertexMessages(new IntWritable(1));
162     Assert.assertEquals(3, Iterables.size(m1));
163     Iterator<FloatWritable> i1 = m1.iterator();
164     Assert.assertEquals((float) 1.0, i1.next().get());
165     Assert.assertEquals((float) 3.0, i1.next().get());
166     Assert.assertEquals((float) 4.0, i1.next().get());
167     Iterable<FloatWritable> m2 =
168         messageStore.getVertexMessages(new IntWritable(2));
169     Assert.assertEquals(1, Iterables.size(m2));
170     Assert.assertEquals((float) 3.0, m2.iterator().next().get());
171     Assert.assertTrue(
172         Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
173   }
174 
175   @Test
176   public void testIntByteArrayMessageStoreWithMessageEncoding() {
177     GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, true);
178     testIntByteArrayMessageStore();
179     GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, false);
180   }
181 }