This project has retired. For details please refer to its
Attic page.
TestIntFloatPrimitiveMessageStores xref
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.comm.messages;
20
21 import java.io.IOException;
22 import java.util.Iterator;
23
24 import junit.framework.Assert;
25
26 import org.apache.giraph.bsp.CentralizedServiceWorker;
27 import org.apache.giraph.combiner.FloatSumMessageCombiner;
28 import org.apache.giraph.comm.messages.primitives.IdByteArrayMessageStore;
29 import org.apache.giraph.comm.messages.primitives.IntFloatMessageStore;
30 import org.apache.giraph.conf.GiraphConfiguration;
31 import org.apache.giraph.conf.GiraphConstants;
32 import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
33 import org.apache.giraph.factories.TestMessageValueFactory;
34 import org.apache.giraph.graph.BasicComputation;
35 import org.apache.giraph.graph.Vertex;
36 import org.apache.giraph.partition.Partition;
37 import org.apache.giraph.partition.PartitionStore;
38 import org.apache.giraph.utils.ByteArrayVertexIdMessages;
39 import org.apache.hadoop.io.FloatWritable;
40 import org.apache.hadoop.io.IntWritable;
41 import org.apache.hadoop.io.NullWritable;
42 import org.apache.hadoop.io.Writable;
43 import org.junit.Before;
44 import org.junit.Test;
45 import org.mockito.Mockito;
46 import org.mockito.invocation.InvocationOnMock;
47 import org.mockito.stubbing.Answer;
48
49 import com.google.common.collect.Iterables;
50 import com.google.common.collect.Lists;
51
52 public class TestIntFloatPrimitiveMessageStores {
53 private static final int NUM_PARTITIONS = 2;
54 private static CentralizedServiceWorker<IntWritable, Writable, Writable>
55 service;
56 private static ImmutableClassesGiraphConfiguration<IntWritable, Writable,
57 Writable> conf;
58
59 @Before
60 public void prepare() {
61 service = Mockito.mock(CentralizedServiceWorker.class);
62 Mockito.when(
63 service.getPartitionId(Mockito.any(IntWritable.class))).thenAnswer(
64 new Answer<Integer>() {
65 @Override
66 public Integer answer(InvocationOnMock invocation) {
67 IntWritable vertexId = (IntWritable) invocation.getArguments()[0];
68 return vertexId.get() % NUM_PARTITIONS;
69 }
70 }
71 );
72 PartitionStore partitionStore = Mockito.mock(PartitionStore.class);
73 Mockito.when(service.getPartitionStore()).thenReturn(partitionStore);
74 Mockito.when(service.getPartitionIds()).thenReturn(
75 Lists.newArrayList(0, 1));
76 Mockito.when(partitionStore.getPartitionIds()).thenReturn(
77 Lists.newArrayList(0, 1));
78 Partition partition = Mockito.mock(Partition.class);
79 Mockito.when(partition.getVertexCount()).thenReturn(Long.valueOf(1));
80 Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
81 Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
82
83 GiraphConfiguration initConf = new GiraphConfiguration();
84 initConf.setComputationClass(IntFloatNoOpComputation.class);
85 conf = new ImmutableClassesGiraphConfiguration(initConf);
86 }
87
88 private static class IntFloatNoOpComputation extends
89 BasicComputation<IntWritable, NullWritable, NullWritable,
90 FloatWritable> {
91 @Override
92 public void compute(Vertex<IntWritable, NullWritable, NullWritable> vertex,
93 Iterable<FloatWritable> messages) throws IOException {
94 }
95 }
96
97 private static ByteArrayVertexIdMessages<IntWritable, FloatWritable>
98 createIntFloatMessages() {
99 ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
100 new ByteArrayVertexIdMessages<IntWritable, FloatWritable>(
101 new TestMessageValueFactory<FloatWritable>(FloatWritable.class));
102 messages.setConf(conf);
103 messages.initialize();
104 return messages;
105 }
106
107 private static void insertIntFloatMessages(
108 MessageStore<IntWritable, FloatWritable> messageStore) {
109 ByteArrayVertexIdMessages<IntWritable, FloatWritable> messages =
110 createIntFloatMessages();
111 messages.add(new IntWritable(0), new FloatWritable(1));
112 messages.add(new IntWritable(2), new FloatWritable(3));
113 messages.add(new IntWritable(0), new FloatWritable(4));
114 messageStore.addPartitionMessages(0, messages);
115 messages = createIntFloatMessages();
116 messages.add(new IntWritable(1), new FloatWritable(1));
117 messages.add(new IntWritable(1), new FloatWritable(3));
118 messages.add(new IntWritable(1), new FloatWritable(4));
119 messageStore.addPartitionMessages(1, messages);
120 messages = createIntFloatMessages();
121 messages.add(new IntWritable(0), new FloatWritable(5));
122 messageStore.addPartitionMessages(0, messages);
123 }
124
125 @Test
126 public void testIntFloatMessageStore() {
127 IntFloatMessageStore messageStore =
128 new IntFloatMessageStore(service, new FloatSumMessageCombiner());
129 insertIntFloatMessages(messageStore);
130
131 Iterable<FloatWritable> m0 =
132 messageStore.getVertexMessages(new IntWritable(0));
133 Assert.assertEquals(1, Iterables.size(m0));
134 Assert.assertEquals((float) 10.0, m0.iterator().next().get());
135 Iterable<FloatWritable> m1 =
136 messageStore.getVertexMessages(new IntWritable(1));
137 Assert.assertEquals(1, Iterables.size(m1));
138 Assert.assertEquals((float) 8.0, m1.iterator().next().get());
139 Iterable<FloatWritable> m2 =
140 messageStore.getVertexMessages(new IntWritable(2));
141 Assert.assertEquals(1, Iterables.size(m2));
142 Assert.assertEquals((float) 3.0, m2.iterator().next().get());
143 Assert.assertTrue(
144 Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
145 }
146
147 @Test
148 public void testIntByteArrayMessageStore() {
149 IdByteArrayMessageStore<IntWritable, FloatWritable> messageStore =
150 new IdByteArrayMessageStore<>(new
151 TestMessageValueFactory<FloatWritable>(FloatWritable.class),
152 service, conf);
153 insertIntFloatMessages(messageStore);
154
155 Iterable<FloatWritable> m0 =
156 messageStore.getVertexMessages(new IntWritable(0));
157 Assert.assertEquals(3, Iterables.size(m0));
158 Iterator<FloatWritable> i0 = m0.iterator();
159 Assert.assertEquals((float) 1.0, i0.next().get());
160 Assert.assertEquals((float) 4.0, i0.next().get());
161 Assert.assertEquals((float) 5.0, i0.next().get());
162 Iterable<FloatWritable> m1 =
163 messageStore.getVertexMessages(new IntWritable(1));
164 Assert.assertEquals(3, Iterables.size(m1));
165 Iterator<FloatWritable> i1 = m1.iterator();
166 Assert.assertEquals((float) 1.0, i1.next().get());
167 Assert.assertEquals((float) 3.0, i1.next().get());
168 Assert.assertEquals((float) 4.0, i1.next().get());
169 Iterable<FloatWritable> m2 =
170 messageStore.getVertexMessages(new IntWritable(2));
171 Assert.assertEquals(1, Iterables.size(m2));
172 Assert.assertEquals((float) 3.0, m2.iterator().next().get());
173 Assert.assertTrue(
174 Iterables.isEmpty(messageStore.getVertexMessages(new IntWritable(3))));
175 }
176
177 @Test
178 public void testIntByteArrayMessageStoreWithMessageEncoding() {
179 GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, true);
180 testIntByteArrayMessageStore();
181 GiraphConstants.USE_MESSAGE_SIZE_ENCODING.set(conf, false);
182 }
183 }