This project has retired. For details please refer to its
Attic page.
TestLongDoublePrimitiveMessageStores xref
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.giraph.comm.messages;
20
21 import java.io.IOException;
22 import java.util.Iterator;
23
24 import junit.framework.Assert;
25
26 import org.apache.giraph.bsp.CentralizedServiceWorker;
27 import org.apache.giraph.combiner.DoubleSumMessageCombiner;
28 import org.apache.giraph.comm.messages.primitives.IdByteArrayMessageStore;
29 import org.apache.giraph.comm.messages.primitives.LongDoubleMessageStore;
30 import org.apache.giraph.conf.GiraphConfiguration;
31 import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
32 import org.apache.giraph.factories.TestMessageValueFactory;
33 import org.apache.giraph.graph.BasicComputation;
34 import org.apache.giraph.graph.Vertex;
35 import org.apache.giraph.partition.Partition;
36 import org.apache.giraph.partition.PartitionStore;
37 import org.apache.giraph.utils.ByteArrayVertexIdMessages;
38 import org.apache.hadoop.io.DoubleWritable;
39 import org.apache.hadoop.io.LongWritable;
40 import org.apache.hadoop.io.NullWritable;
41 import org.apache.hadoop.io.Writable;
42 import org.junit.Before;
43 import org.junit.Test;
44 import org.mockito.Mockito;
45 import org.mockito.invocation.InvocationOnMock;
46 import org.mockito.stubbing.Answer;
47
48 import com.google.common.collect.Iterables;
49 import com.google.common.collect.Lists;
50
51 public class TestLongDoublePrimitiveMessageStores {
52 private static final int NUM_PARTITIONS = 2;
53 private static CentralizedServiceWorker<LongWritable, Writable, Writable>
54 service;
55
56 @Before
57 public void prepare() {
58 service = Mockito.mock(CentralizedServiceWorker.class);
59 Mockito.when(
60 service.getPartitionId(Mockito.any(LongWritable.class))).thenAnswer(
61 new Answer<Integer>() {
62 @Override
63 public Integer answer(InvocationOnMock invocation) {
64 LongWritable vertexId = (LongWritable) invocation.getArguments()[0];
65 return (int) (vertexId.get() % NUM_PARTITIONS);
66 }
67 }
68 );
69 PartitionStore partitionStore = Mockito.mock(PartitionStore.class);
70 Mockito.when(service.getPartitionStore()).thenReturn(partitionStore);
71 Mockito.when(service.getPartitionIds()).thenReturn(
72 Lists.newArrayList(0, 1));
73 Mockito.when(partitionStore.getPartitionIds()).thenReturn(
74 Lists.newArrayList(0, 1));
75 Partition partition = Mockito.mock(Partition.class);
76 Mockito.when(partition.getVertexCount()).thenReturn(Long.valueOf(1));
77 Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
78 Mockito.when(partitionStore.getNextPartition()).thenReturn(partition);
79 }
80
81 private static class LongDoubleNoOpComputation extends
82 BasicComputation<LongWritable, NullWritable, NullWritable,
83 DoubleWritable> {
84 @Override
85 public void compute(Vertex<LongWritable, NullWritable, NullWritable> vertex,
86 Iterable<DoubleWritable> messages) throws IOException {
87 }
88 }
89
90 private static ImmutableClassesGiraphConfiguration<LongWritable, Writable,
91 Writable> createLongDoubleConf() {
92
93 GiraphConfiguration initConf = new GiraphConfiguration();
94 initConf.setComputationClass(LongDoubleNoOpComputation.class);
95 return new ImmutableClassesGiraphConfiguration(initConf);
96 }
97
98 private static ByteArrayVertexIdMessages<LongWritable, DoubleWritable>
99 createLongDoubleMessages() {
100 ByteArrayVertexIdMessages<LongWritable, DoubleWritable> messages =
101 new ByteArrayVertexIdMessages<LongWritable, DoubleWritable>(
102 new TestMessageValueFactory<DoubleWritable>(DoubleWritable.class));
103 messages.setConf(createLongDoubleConf());
104 messages.initialize();
105 return messages;
106 }
107
108 private static void insertLongDoubleMessages(
109 MessageStore<LongWritable, DoubleWritable> messageStore) {
110 ByteArrayVertexIdMessages<LongWritable, DoubleWritable> messages =
111 createLongDoubleMessages();
112 messages.add(new LongWritable(0), new DoubleWritable(1));
113 messages.add(new LongWritable(2), new DoubleWritable(3));
114 messages.add(new LongWritable(0), new DoubleWritable(4));
115 messageStore.addPartitionMessages(0, messages);
116 messages = createLongDoubleMessages();
117 messages.add(new LongWritable(1), new DoubleWritable(1));
118 messages.add(new LongWritable(1), new DoubleWritable(3));
119 messages.add(new LongWritable(1), new DoubleWritable(4));
120 messageStore.addPartitionMessages(1, messages);
121 messages = createLongDoubleMessages();
122 messages.add(new LongWritable(0), new DoubleWritable(5));
123 messageStore.addPartitionMessages(0, messages);
124 }
125
126 @Test
127 public void testLongDoubleMessageStore() {
128 LongDoubleMessageStore messageStore =
129 new LongDoubleMessageStore(service, new DoubleSumMessageCombiner());
130 insertLongDoubleMessages(messageStore);
131
132 Iterable<DoubleWritable> m0 =
133 messageStore.getVertexMessages(new LongWritable(0));
134 Assert.assertEquals(1, Iterables.size(m0));
135 Assert.assertEquals(10.0, m0.iterator().next().get());
136 Iterable<DoubleWritable> m1 =
137 messageStore.getVertexMessages(new LongWritable(1));
138 Assert.assertEquals(1, Iterables.size(m1));
139 Assert.assertEquals(8.0, m1.iterator().next().get());
140 Iterable<DoubleWritable> m2 =
141 messageStore.getVertexMessages(new LongWritable(2));
142 Assert.assertEquals(1, Iterables.size(m2));
143 Assert.assertEquals(3.0, m2.iterator().next().get());
144 Assert.assertTrue(
145 Iterables.isEmpty(messageStore.getVertexMessages(new LongWritable(3))));
146 }
147
148 @Test
149 public void testLongByteArrayMessageStore() {
150 IdByteArrayMessageStore<LongWritable, DoubleWritable> messageStore =
151 new IdByteArrayMessageStore<>(
152 new TestMessageValueFactory<DoubleWritable>(DoubleWritable.class),
153 service, createLongDoubleConf());
154 insertLongDoubleMessages(messageStore);
155
156 Iterable<DoubleWritable> m0 =
157 messageStore.getVertexMessages(new LongWritable(0));
158 Assert.assertEquals(3, Iterables.size(m0));
159 Iterator<DoubleWritable> i0 = m0.iterator();
160 Assert.assertEquals(1.0, i0.next().get());
161 Assert.assertEquals(4.0, i0.next().get());
162 Assert.assertEquals(5.0, i0.next().get());
163 Iterable<DoubleWritable> m1 =
164 messageStore.getVertexMessages(new LongWritable(1));
165 Assert.assertEquals(3, Iterables.size(m1));
166 Iterator<DoubleWritable> i1 = m1.iterator();
167 Assert.assertEquals(1.0, i1.next().get());
168 Assert.assertEquals(3.0, i1.next().get());
169 Assert.assertEquals(4.0, i1.next().get());
170 Iterable<DoubleWritable> m2 =
171 messageStore.getVertexMessages(new LongWritable(2));
172 Assert.assertEquals(1, Iterables.size(m2));
173 Assert.assertEquals(3.0, m2.iterator().next().get());
174 Assert.assertTrue(
175 Iterables.isEmpty(messageStore.getVertexMessages(new LongWritable(3))));
176 }
177 }