View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *     http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   */
18  
19  package org.apache.giraph.comm.messages.primitives;
20  
21  import it.unimi.dsi.fastutil.ints.Int2FloatMap;
22  import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
23  import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
24  import it.unimi.dsi.fastutil.ints.IntIterator;
25  import it.unimi.dsi.fastutil.objects.ObjectIterator;
26  
27  import java.io.DataInput;
28  import java.io.DataOutput;
29  import java.io.IOException;
30  import java.util.Collections;
31  import java.util.List;
32  
33  import org.apache.giraph.bsp.CentralizedServiceWorker;
34  import org.apache.giraph.combiner.MessageCombiner;
35  import org.apache.giraph.comm.messages.MessageStore;
36  import org.apache.giraph.utils.EmptyIterable;
37  import org.apache.giraph.utils.VertexIdMessageIterator;
38  import org.apache.giraph.utils.VertexIdMessages;
39  import org.apache.hadoop.io.FloatWritable;
40  import org.apache.hadoop.io.IntWritable;
41  import org.apache.hadoop.io.Writable;
42  
43  import com.google.common.collect.Lists;
44  
45  /**
46   * Special message store to be used when ids are IntWritable and messages
47   * are FloatWritable and messageCombiner is used.
48   * Uses fastutil primitive maps in order to decrease number of objects and
49   * get better performance.
50   */
51  public class IntFloatMessageStore
52      implements MessageStore<IntWritable, FloatWritable> {
53    /** Map from partition id to map from vertex id to message */
54    private final Int2ObjectOpenHashMap<Int2FloatOpenHashMap> map;
55    /** Message messageCombiner */
56    private final
57    MessageCombiner<? super IntWritable, FloatWritable> messageCombiner;
58    /** Service worker */
59    private final CentralizedServiceWorker<IntWritable, ?, ?> service;
60  
61    /**
62     * Constructor
63     *
64     * @param service Service worker
65     * @param messageCombiner Message messageCombiner
66     */
67    public IntFloatMessageStore(
68        CentralizedServiceWorker<IntWritable, Writable, Writable> service,
69        MessageCombiner<? super IntWritable, FloatWritable> messageCombiner) {
70      this.service = service;
71      this.messageCombiner = messageCombiner;
72  
73      map = new Int2ObjectOpenHashMap<Int2FloatOpenHashMap>();
74      for (int partitionId : service.getPartitionStore().getPartitionIds()) {
75        Int2FloatOpenHashMap partitionMap = new Int2FloatOpenHashMap(
76            (int) service.getPartitionStore()
77                .getPartitionVertexCount(partitionId));
78        map.put(partitionId, partitionMap);
79      }
80    }
81  
82    @Override
83    public boolean isPointerListEncoding() {
84      return false;
85    }
86  
87    /**
88     * Get map which holds messages for partition which vertex belongs to.
89     *
90     * @param vertexId Id of the vertex
91     * @return Map which holds messages for partition which vertex belongs to.
92     */
93    private Int2FloatOpenHashMap getPartitionMap(IntWritable vertexId) {
94      return map.get(service.getPartitionId(vertexId));
95    }
96  
97    @Override
98    public void addPartitionMessages(int partitionId,
99        VertexIdMessages<IntWritable, FloatWritable> messages) {
100     IntWritable reusableVertexId = new IntWritable();
101     FloatWritable reusableMessage = new FloatWritable();
102     FloatWritable reusableCurrentMessage = new FloatWritable();
103 
104     Int2FloatOpenHashMap partitionMap = map.get(partitionId);
105     synchronized (partitionMap) {
106       VertexIdMessageIterator<IntWritable, FloatWritable>
107           iterator = messages.getVertexIdMessageIterator();
108       while (iterator.hasNext()) {
109         iterator.next();
110         int vertexId = iterator.getCurrentVertexId().get();
111         float message = iterator.getCurrentMessage().get();
112         if (partitionMap.containsKey(vertexId)) {
113           reusableVertexId.set(vertexId);
114           reusableMessage.set(message);
115           reusableCurrentMessage.set(partitionMap.get(vertexId));
116           messageCombiner.combine(reusableVertexId, reusableCurrentMessage,
117               reusableMessage);
118           message = reusableCurrentMessage.get();
119         }
120         partitionMap.put(vertexId, message);
121       }
122     }
123   }
124 
125   @Override
126   public void finalizeStore() {
127   }
128 
129   @Override
130   public void clearPartition(int partitionId) {
131     map.get(partitionId).clear();
132   }
133 
134   @Override
135   public boolean hasMessagesForVertex(IntWritable vertexId) {
136     return getPartitionMap(vertexId).containsKey(vertexId.get());
137   }
138 
139   @Override
140   public boolean hasMessagesForPartition(int partitionId) {
141     Int2FloatOpenHashMap partitionMessages = map.get(partitionId);
142     return partitionMessages != null && !partitionMessages.isEmpty();
143   }
144 
145   @Override
146   public Iterable<FloatWritable> getVertexMessages(
147       IntWritable vertexId) {
148     Int2FloatOpenHashMap partitionMap = getPartitionMap(vertexId);
149     if (!partitionMap.containsKey(vertexId.get())) {
150       return EmptyIterable.get();
151     } else {
152       return Collections.singleton(
153           new FloatWritable(partitionMap.get(vertexId.get())));
154     }
155   }
156 
157   @Override
158   public void clearVertexMessages(IntWritable vertexId) {
159     getPartitionMap(vertexId).remove(vertexId.get());
160   }
161 
162   @Override
163   public void clearAll() {
164     map.clear();
165   }
166 
167   @Override
168   public Iterable<IntWritable> getPartitionDestinationVertices(
169       int partitionId) {
170     Int2FloatOpenHashMap partitionMap = map.get(partitionId);
171     List<IntWritable> vertices =
172         Lists.newArrayListWithCapacity(partitionMap.size());
173     IntIterator iterator = partitionMap.keySet().iterator();
174     while (iterator.hasNext()) {
175       vertices.add(new IntWritable(iterator.nextInt()));
176     }
177     return vertices;
178   }
179 
180   @Override
181   public void writePartition(DataOutput out,
182       int partitionId) throws IOException {
183     Int2FloatOpenHashMap partitionMap = map.get(partitionId);
184     out.writeInt(partitionMap.size());
185     ObjectIterator<Int2FloatMap.Entry> iterator =
186         partitionMap.int2FloatEntrySet().fastIterator();
187     while (iterator.hasNext()) {
188       Int2FloatMap.Entry entry = iterator.next();
189       out.writeInt(entry.getIntKey());
190       out.writeFloat(entry.getFloatValue());
191     }
192   }
193 
194   @Override
195   public void readFieldsForPartition(DataInput in,
196       int partitionId) throws IOException {
197     int size = in.readInt();
198     Int2FloatOpenHashMap partitionMap = new Int2FloatOpenHashMap(size);
199     while (size-- > 0) {
200       int vertexId = in.readInt();
201       float message = in.readFloat();
202       partitionMap.put(vertexId, message);
203     }
204     synchronized (map) {
205       map.put(partitionId, partitionMap);
206     }
207   }
208 }