This project has retired. For details please refer to its Attic page.
TestIntFloatPrimitiveMessageStores xref
View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm.messages;
20  
21  import java.io.IOException;
22  import java.util.Iterator;
23  
24  import junit.framework.Assert;
25  
26  import org.apache.giraph.bsp.CentralizedServiceWorker;
27  import org.apache.giraph.combiner.FloatSumMessageCombiner;
28  import org.apache.giraph.comm.messages.primitives.IdByteArrayMessageStore;
29  import org.apache.giraph.comm.messages.primitives.IntFloatMessageStore;
30  import org.apache.giraph.conf.GiraphConfiguration;
31  import org.apache.giraph.conf.GiraphConstants;
32  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
33  import org.apache.giraph.factories.TestMessageValueFactory;
34  import org.apache.giraph.graph.BasicComputation;
35  import org.apache.giraph.graph.Vertex;
36  import org.apache.giraph.partition.Partition;
37  import org.apache.giraph.partition.PartitionStore;
38  import org.apache.giraph.utils.ByteArrayVertexIdMessages;
39  import org.apache.hadoop.io.FloatWritable;
40  import org.apache.hadoop.io.IntWritable;
41  import org.apache.hadoop.io.NullWritable;
42  import org.apache.hadoop.io.Writable;
43  import org.junit.Before;
44  import org.junit.Test;
45  import org.mockito.Mockito;
46  import org.mockito.invocation.InvocationOnMock;
47  import org.mockito.stubbing.Answer;
48  
49  import com.google.common.collect.Iterables;
50  import com.google.common.collect.Lists;
51  
52  public class TestIntFloatPrimitiveMessageStores {
53    private static final int NUM_PARTITIONS = 2;
54    private static CentralizedServiceWorker<IntWritable, Writable, Writable>
55      service;
56    private static ImmutableClassesGiraphConfiguration<IntWritable, Writable,
57        Writable> conf;
58  
59    @Before
60    public void prepare() {
61      service = Mockito.mock(CentralizedServiceWorker.class);
62      Mockito.when(
63          service.getPartitionId(Mockito.any(IntWritable.class))).thenAnswer(
64          new Answer<Integer>() {
65            @Override
66            public Integer answer(InvocationOnMock invocation) {
67              IntWritable vertexId = (IntWritable) invocation.getArguments()[0];
68              return vertexId.get() % NUM_PARTITIONS;
69            }
70          }
71      );
72      PartitionStore partitionStore = Mockito.mock(PartitionStore.class);
73      Mockito.when(service.getPartitionStore()).thenReturn(partitionStore);
74      Mockito.when(service.getPartitionIds()).thenReturn(
75        Lists.newArrayList(0, 1));
76      Mockito.when(partitionStore.getPartitionIds()).thenReturn(
77        Lists.newArrayList(0, 1));
78      Partition partition = Mockito.mock(Partition.class);
79      Mockito.when(partition.getVertexCount()).thenReturn(Long.valueOf(1));
80      Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
81      Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
82  
83      GiraphConfiguration initConf = new GiraphConfiguration();
84      initConf.setComputationClass(IntFloatNoOpComputation.class);
85      conf = new ImmutableClassesGiraphConfiguration(initConf);
86    }
87  
88    private static class IntFloatNoOpComputation extends
89        BasicComputation<IntWritable, NullWritable, NullWritable,
90            FloatWritable> {
91      @Override
92      public void compute(Vertex<IntWritable, NullWritable, NullWritable> vertex,
93          Iterable<FloatWritable> messages) throws IOException {
94      }
95    }
96  
97    private static ByteArrayVertexIdMessages<IntWritable, FloatWritable>
98    createIntFloatMessages() {
99      ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
100         new ByteArrayVertexIdMessages<IntWritable, FloatWritable>(
101             new TestMessageValueFactory<FloatWritable>(FloatWritable.class));
102     messages.setConf(conf);
103     messages.initialize();
104     return messages;
105   }
106 
107   private static void insertIntFloatMessages(
108       MessageStore<IntWritable, FloatWritable> messageStore) {
109     ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
110         createIntFloatMessages();
111     messages.add(new IntWritable(0), new FloatWritable(1));
112     messages.add(new IntWritable(2), new FloatWritable(3));
113     messages.add(new IntWritable(0), new FloatWritable(4));
114     messageStore.addPartitionMessages(0, messages);
115     messages = createIntFloatMessages();
116     messages.add(new IntWritable(1), new FloatWritable(1));
117     messages.add(new IntWritable(1), new FloatWritable(3));
118     messages.add(new IntWritable(1), new FloatWritable(4));
119     messageStore.addPartitionMessages(1, messages);
120     messages = createIntFloatMessages();
121     messages.add(new IntWritable(0), new FloatWritable(5));
122     messageStore.addPartitionMessages(0, messages);
123   }
124 
125   @Test
126   public void testIntFloatMessageStore() {
127     IntFloatMessageStore messageStore =
128         new IntFloatMessageStore(service, new FloatSumMessageCombiner());
129     insertIntFloatMessages(messageStore);
130 
131     Iterable<FloatWritable> m0 =
132         messageStore.getVertexMessages(new IntWritable(0));
133     Assert.assertEquals(1, Iterables.size(m0));
134     Assert.assertEquals((float) 10.0, m0.iterator().next().get());
135     Iterable<FloatWritable> m1 =
136         messageStore.getVertexMessages(new IntWritable(1));
137     Assert.assertEquals(1, Iterables.size(m1));
138     Assert.assertEquals((float) 8.0, m1.iterator().next().get());
139     Iterable<FloatWritable> m2 =
140         messageStore.getVertexMessages(new IntWritable(2));
141     Assert.assertEquals(1, Iterables.size(m2));
142     Assert.assertEquals((float) 3.0, m2.iterator().next().get());
143     Assert.assertTrue(
144         Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
145   }
146 
147   @Test
148   public void testIntByteArrayMessageStore() {
149     IdByteArrayMessageStore<IntWritable, FloatWritable> messageStore =
150         new IdByteArrayMessageStore<>(new
151             TestMessageValueFactory<FloatWritable>(FloatWritable.class),
152             service, conf);
153     insertIntFloatMessages(messageStore);
154 
155     Iterable<FloatWritable> m0 =
156         messageStore.getVertexMessages(new IntWritable(0));
157     Assert.assertEquals(3, Iterables.size(m0));
158     Iterator<FloatWritable> i0 = m0.iterator();
159     Assert.assertEquals((float) 1.0, i0.next().get());
160     Assert.assertEquals((float) 4.0, i0.next().get());
161     Assert.assertEquals((float) 5.0, i0.next().get());
162     Iterable<FloatWritable> m1 =
163         messageStore.getVertexMessages(new IntWritable(1));
164     Assert.assertEquals(3, Iterables.size(m1));
165     Iterator<FloatWritable> i1 = m1.iterator();
166     Assert.assertEquals((float) 1.0, i1.next().get());
167     Assert.assertEquals((float) 3.0, i1.next().get());
168     Assert.assertEquals((float) 4.0, i1.next().get());
169     Iterable<FloatWritable> m2 =
170         messageStore.getVertexMessages(new IntWritable(2));
171     Assert.assertEquals(1, Iterables.size(m2));
172     Assert.assertEquals((float) 3.0, m2.iterator().next().get());
173     Assert.assertTrue(
174         Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
175   }
176 
177   @Test
178   public void testIntByteArrayMessageStoreWithMessageEncoding() {
179     GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, true);
180     testIntByteArrayMessageStore();
181     GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, false);
182   }
183 }