This project has retired. For details please refer to its Attic page.
RequestTest xref
View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm;
20  
21  import org.apache.giraph.comm.netty.NettyClient;
22  import org.apache.giraph.comm.netty.NettyServer;
23  import org.apache.giraph.comm.netty.handler.AckSignalFlag;
24  import org.apache.giraph.comm.netty.handler.WorkerRequestServerHandler;
25  import org.apache.giraph.comm.requests.SendPartitionMutationsRequest;
26  import org.apache.giraph.comm.requests.SendVertexRequest;
27  import org.apache.giraph.comm.requests.SendWorkerMessagesRequest;
28  import org.apache.giraph.comm.requests.SendWorkerOneMessageToManyRequest;
29  import org.apache.giraph.conf.GiraphConfiguration;
30  import org.apache.giraph.conf.GiraphConstants;
31  import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
32  import org.apache.giraph.edge.Edge;
33  import org.apache.giraph.edge.EdgeFactory;
34  import org.apache.giraph.factories.TestMessageValueFactory;
35  import org.apache.giraph.graph.Vertex;
36  import org.apache.giraph.graph.VertexMutations;
37  import org.apache.giraph.metrics.GiraphMetrics;
38  import org.apache.giraph.partition.Partition;
39  import org.apache.giraph.partition.PartitionStore;
40  import org.apache.giraph.utils.ByteArrayOneMessageToManyIds;
41  import org.apache.giraph.utils.VertexIdMessages;
42  import org.apache.giraph.utils.ByteArrayVertexIdMessages;
43  import org.apache.giraph.utils.ExtendedDataOutput;
44  import org.apache.giraph.utils.IntNoOpComputation;
45  import org.apache.giraph.utils.MockUtils;
46  import org.apache.giraph.utils.PairList;
47  import org.apache.giraph.worker.WorkerInfo;
48  import org.apache.hadoop.io.IntWritable;
49  import org.apache.hadoop.mapreduce.Mapper.Context;
50  import org.junit.Before;
51  import org.junit.Test;
52  
53  import com.google.common.collect.Lists;
54  import com.google.common.collect.Maps;
55  
56  import java.io.IOException;
57  import java.util.Map;
58  import java.util.Map.Entry;
59  import java.util.concurrent.ConcurrentMap;
60  
61  import static org.junit.Assert.assertEquals;
62  import static org.junit.Assert.assertTrue;
63  import static org.mockito.Mockito.mock;
64  import static org.mockito.Mockito.when;
65  
66  /**
67   * Test all the different netty requests.
68   */
69  @SuppressWarnings("unchecked")
70  public class RequestTest {
71    /** Configuration */
72    private ImmutableClassesGiraphConfiguration conf;
73    /** Server data */
74    private ServerData<IntWritable, IntWritable, IntWritable> serverData;
75    /** Server */
76    private NettyServer server;
77    /** Client */
78    private NettyClient client;
79    /** Worker info */
80    private WorkerInfo workerInfo;
81  
82    @Before
83    public void setUp() {
84      // Setup the conf
85      GiraphConfiguration tmpConf = new GiraphConfiguration();
86      GiraphConstants.COMPUTATION_CLASS.set(tmpConf, IntNoOpComputation.class);
87      conf = new ImmutableClassesGiraphConfiguration(tmpConf);
88  
89      @SuppressWarnings("rawtypes")
90      Context context = mock(Context.class);
91      when(context.getConfiguration()).thenReturn(conf);
92  
93      // Start the service
94      serverData = MockUtils.createNewServerData(conf, context);
95      serverData.prepareSuperstep();
96      workerInfo = new WorkerInfo();
97      server = new NettyServer(conf,
98          new WorkerRequestServerHandler.Factory(serverData), workerInfo,
99              context, new MockExceptionHandler());
100     server.start();
101 
102     workerInfo.setInetSocketAddress(server.getMyAddress(), server.getLocalHostOrIp());
103     client = new NettyClient(context, conf, new WorkerInfo(),
104         new MockExceptionHandler());
105     server.setFlowControl(client.getFlowControl());
106     client.connectAllAddresses(
107         Lists.<WorkerInfo>newArrayList(workerInfo));
108   }
109 
110   @Test
111   public void sendVertexPartition() {
112     // Data to send
113     int partitionId = 13;
114     Partition<IntWritable, IntWritable, IntWritable> partition =
115         conf.createPartition(partitionId, null);
116     for (int i = 0; i < 10; ++i) {
117       Vertex vertex = conf.createVertex();
118       vertex.initialize(new IntWritable(i), new IntWritable(i));
119       partition.putVertex(vertex);
120     }
121 
122     // Send the request
123     SendVertexRequest<IntWritable, IntWritable, IntWritable> request =
124       new SendVertexRequest<IntWritable, IntWritable, IntWritable>(partition);
125     client.sendWritableRequest(workerInfo.getTaskId(), request);
126     client.waitAllRequests();
127 
128     // Stop the service
129     client.stop();
130     server.stop();
131 
132     // Check the output
133     PartitionStore<IntWritable, IntWritable, IntWritable> partitionStore =
134         serverData.getPartitionStore();
135     assertTrue(partitionStore.hasPartition(partitionId));
136     int total = 0;
137     Partition<IntWritable, IntWritable, IntWritable> partition2 =
138         partitionStore.removePartition(partitionId);
139     for (Vertex<IntWritable, IntWritable, IntWritable> vertex : partition2) {
140       total += vertex.getId().get();
141     }
142     partitionStore.putPartition(partition2);
143     assertEquals(total, 45);
144     partitionStore.shutdown();
145   }
146 
147   @Test
148   public void sendWorkerMessagesRequest() {
149     // Data to send
150     PairList<Integer, VertexIdMessages<IntWritable,
151             IntWritable>>
152         dataToSend = new PairList<>();
153     dataToSend.initialize();
154     int partitionId = 0;
155     ByteArrayVertexIdMessages<IntWritable,
156             IntWritable> vertexIdMessages =
157         new ByteArrayVertexIdMessages<>(
158             new TestMessageValueFactory<>(IntWritable.class));
159     vertexIdMessages.setConf(conf);
160     vertexIdMessages.initialize();
161     dataToSend.add(partitionId, vertexIdMessages);
162     for (int i = 1; i < 7; ++i) {
163       IntWritable vertexId = new IntWritable(i);
164       for (int j = 0; j < i; ++j) {
165         vertexIdMessages.add(vertexId, new IntWritable(j));
166       }
167     }
168 
169     // Send the request
170     SendWorkerMessagesRequest<IntWritable, IntWritable> request =
171       new SendWorkerMessagesRequest<>(dataToSend);
172     request.setConf(conf);
173 
174     client.sendWritableRequest(workerInfo.getTaskId(), request);
175     client.waitAllRequests();
176 
177     // Stop the service
178     client.stop();
179     server.stop();
180 
181     // Check the output
182     Iterable<IntWritable> vertices =
183         serverData.getIncomingMessageStore().getPartitionDestinationVertices(0);
184     int keySum = 0;
185     int messageSum = 0;
186     for (IntWritable vertexId : vertices) {
187       keySum += vertexId.get();
188       Iterable<IntWritable> messages =
189           serverData.<IntWritable>getIncomingMessageStore().getVertexMessages(
190               vertexId);
191       synchronized (messages) {
192         for (IntWritable message : messages) {
193           messageSum += message.get();
194         }
195       }
196     }
197     assertEquals(21, keySum);
198     assertEquals(35, messageSum);
199   }
200 
201   @Test
202   public void sendWorkerIndividualMessagesRequest() throws IOException {
203     // Data to send
204     ByteArrayOneMessageToManyIds<IntWritable, IntWritable>
205         dataToSend = new ByteArrayOneMessageToManyIds<>(new
206         TestMessageValueFactory<>(IntWritable.class));
207     dataToSend.setConf(conf);
208     dataToSend.initialize();
209     ExtendedDataOutput output = conf.createExtendedDataOutput();
210     for (int i = 1; i <= 7; ++i) {
211       IntWritable vertexId = new IntWritable(i);
212       vertexId.write(output);
213     }
214     dataToSend.add(output.getByteArray(), output.getPos(), 7, new IntWritable(1));
215 
216     // Send the request
217     SendWorkerOneMessageToManyRequest<IntWritable, IntWritable> request =
218       new SendWorkerOneMessageToManyRequest<>(dataToSend, conf);
219     client.sendWritableRequest(workerInfo.getTaskId(), request);
220     client.waitAllRequests();
221 
222     // Stop the service
223     client.stop();
224     server.stop();
225 
226     // Check the output
227     Iterable<IntWritable> vertices =
228         serverData.getIncomingMessageStore().getPartitionDestinationVertices(0);
229     int keySum = 0;
230     int messageSum = 0;
231     for (IntWritable vertexId : vertices) {
232       keySum += vertexId.get();
233       Iterable<IntWritable> messages =
234           serverData.<IntWritable>getIncomingMessageStore().getVertexMessages(
235               vertexId);
236       synchronized (messages) {
237         for (IntWritable message : messages) {
238           messageSum += message.get();
239         }
240       }
241     }
242     assertEquals(28, keySum);
243     assertEquals(7, messageSum);
244   }
245 
246   @Test
247   public void sendPartitionMutationsRequest() {
248     // Data to send
249     int partitionId = 19;
250     Map<IntWritable, VertexMutations<IntWritable, IntWritable,
251     IntWritable>> vertexIdMutations =
252         Maps.newHashMap();
253     for (int i = 0; i < 11; ++i) {
254       VertexMutations<IntWritable, IntWritable, IntWritable> mutations =
255           new VertexMutations<IntWritable, IntWritable, IntWritable>();
256       for (int j = 0; j < 3; ++j) {
257         Vertex vertex = conf.createVertex();
258         vertex.initialize(new IntWritable(i), new IntWritable(j));
259         mutations.addVertex(vertex);
260       }
261       for (int j = 0; j < 2; ++j) {
262         mutations.removeVertex();
263       }
264       for (int j = 0; j < 5; ++j) {
265         Edge<IntWritable, IntWritable> edge =
266             EdgeFactory.create(new IntWritable(i), new IntWritable(2 * j));
267         mutations.addEdge(edge);
268       }
269       for (int j = 0; j < 7; ++j) {
270         mutations.removeEdge(new IntWritable(j));
271       }
272       vertexIdMutations.put(new IntWritable(i), mutations);
273     }
274 
275     // Send the request
276     SendPartitionMutationsRequest<IntWritable, IntWritable, IntWritable>
277         request = new SendPartitionMutationsRequest<IntWritable, IntWritable,
278         IntWritable>(partitionId,
279         vertexIdMutations);
280     GiraphMetrics.init(conf);
281     client.sendWritableRequest(workerInfo.getTaskId(), request);
282     client.waitAllRequests();
283 
284     // Stop the service
285     client.stop();
286     server.stop();
287 
288     // Check the output
289     ConcurrentMap<IntWritable,
290         VertexMutations<IntWritable, IntWritable, IntWritable>>
291         inVertexIdMutations =
292         serverData.getPartitionMutations().get(partitionId);
293     int keySum = 0;
294     for (Entry<IntWritable,
295         VertexMutations<IntWritable, IntWritable, IntWritable>> entry :
296         inVertexIdMutations
297         .entrySet()) {
298       synchronized (entry.getValue()) {
299         keySum += entry.getKey().get();
300         int vertexValueSum = 0;
301         for (Vertex<IntWritable, IntWritable, IntWritable> vertex : entry
302             .getValue().getAddedVertexList()) {
303           vertexValueSum += vertex.getValue().get();
304         }
305         assertEquals(3, vertexValueSum);
306         assertEquals(2, entry.getValue().getRemovedVertexCount());
307         int removeEdgeValueSum = 0;
308         for (Edge<IntWritable, IntWritable> edge : entry.getValue()
309             .getAddedEdgeList()) {
310           removeEdgeValueSum += edge.getValue().get();
311         }
312         assertEquals(20, removeEdgeValueSum);
313       }
314     }
315     assertEquals(55, keySum);
316   }
317 }