GraphLab: Distributed Graph-Parallel API  2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups Pages
sample_sort.hpp
1 #ifndef GRAPHLAB_RPC_SAMPLE_SORT_HPP
2 #define GRAPHLAB_RPC_SAMPLE_SORT_HPP
3 
4 #include <vector>
5 #include <algorithm>
6 #include <utility>
7 #include <graphlab/rpc/dc_dist_object.hpp>
8 #include <graphlab/rpc/buffered_exchange.hpp>
9 #include <graphlab/logger/assertions.hpp>
10 namespace graphlab {
11 
12 namespace sample_sort_impl {
13  template <typename Key, typename Value>
14  struct pair_key_comparator {
15  bool operator()(const std::pair<Key,Value>& k1,
16  const std::pair<Key,Value>& k2) {
17  return k1.first < k2.first;
18  }
19  };
20 }
21 
22 template <typename Key, typename Value>
23 class sample_sort {
24  private:
25  dc_dist_object<sample_sort<Key, Value> > rmi;
26 
27  typedef buffered_exchange<std::pair<Key, Value> > key_exchange_type;
28 
29  key_exchange_type key_exchange;
30  std::vector<std::pair<Key, Value> > key_values;
31  public:
32  sample_sort(distributed_control& dc): rmi(dc, this), key_exchange(dc) { }
33 
34  template <typename KeyIterator, typename ValueIterator>
35  void sort(KeyIterator kstart, KeyIterator kend,
36  ValueIterator vstart, ValueIterator vend) {
37  rmi.barrier();
38 
39  size_t num_entries = std::distance(kstart, kend);
40  ASSERT_EQ(num_entries, std::distance(vstart, vend));
41 
42  // we will sample k * p entries
43  std::vector<std::vector<Key> > sampled_keys(rmi.numprocs());
44  for (size_t i = 0;i < 100 * rmi.numprocs(); ++i) {
45  size_t idx = (rand() % num_entries);
46  sampled_keys[rmi.procid()].push_back(*(kstart + idx));
47  }
48 
49  rmi.all_gather(sampled_keys);
50  // collapse into a single array and sort
51  std::vector<Key> all_sampled_keys;
52  for (size_t i = 0;i < sampled_keys.size(); ++i) {
53  std::copy(sampled_keys[i].begin(), sampled_keys[i].end(),
54  std::inserter(all_sampled_keys, all_sampled_keys.end()));
55  }
56  // sort the sampled keys and extract the ranges
57  std::sort(all_sampled_keys.begin(), all_sampled_keys.end());
58  std::vector<Key> ranges(rmi.numprocs());
59  ranges[0] = Key();
60  for(size_t i = 1; i < rmi.numprocs(); ++i) {
61  ranges[i] = all_sampled_keys[sampled_keys[0].size() * i];
62  }
63 
64  // begin shuffle
65  KeyIterator kiter = kstart;
66  ValueIterator viter = vstart;
67  if (rmi.numprocs() < 8) {
68  while(kiter != kend) {
69  procid_t target_machine = 0;
70  while (target_machine < rmi.numprocs() - 1 &&
71  ranges[target_machine + 1] < *kiter) ++target_machine;
72  key_exchange.send(target_machine, std::make_pair(*kiter, *viter));
73  ++kiter; ++viter;
74  }
75  }
76  else {
77  while(kiter != kend) {
78  procid_t target_machine =
79  std::upper_bound(ranges.begin(), ranges.end(), *kiter)
80  - ranges.begin() - 1;
81  key_exchange.send(target_machine, std::make_pair(*kiter, *viter));
82  ++kiter; ++viter;
83  }
84  }
85  key_exchange.flush();
86 
87  // read from key exchange
88  procid_t recvid;
89  typename key_exchange_type::buffer_type buffer;
90  while(key_exchange.recv(recvid, buffer)) {
91  std::copy(buffer.begin(), buffer.end(),
92  std::inserter(key_values, key_values.end()));
93  }
94  std::sort(key_values.begin(), key_values.end(),
95  sample_sort_impl::pair_key_comparator<Key,Value>());
96 
97  rmi.barrier();
98  }
99 
100  std::vector<std::pair<Key, Value> >& result() {
101  return key_values;
102  }
103 };
104 
105 
106 } // namespace graphlab
107 
108 #endif