GraphLab: Distributed Graph-Parallel API  2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Groups Pages
union_find.hpp
1 /*
2  * Copyright (c) 2009 Carnegie Mellon University.
3  * All rights reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing,
12  * software distributed under the License is distributed on an "AS
13  * IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
14  * express or implied. See the License for the specific language
15  * governing permissions and limitations under the License.
16  *
17  * For more about this software visit:
18  *
19  * http://www.graphlab.ml.cmu.edu
20  *
21  */
22 
23 
24 #ifndef GRAPHLAB_UTIL_UNION_FIND_HPP
25 #define GRAPHLAB_UTIL_UNION_FIND_HPP
26 #include <vector>
27 #include <utility>
28 #include <graphlab/parallel/atomic.hpp>
29 
30 namespace graphlab {
31 // IDType must be an integer type and its maximum
32 // value must be larger than the length of the sequence
33 template <typename IDType, typename RankType>
34 class union_find {
35  private:
36  std::vector<std::pair<IDType, RankType> > setid;
37 
38  bool is_root(IDType i) {
39  return setid[i].first == (IDType)i;
40  }
41 
42  public:
43  union_find() { }
44  void init(IDType s) {
45  setid.resize((size_t)s);
46  for (size_t i = 0; i < setid.size() ;++i) {
47  setid[i].first = (IDType)(i);
48  setid[i].second = 0;
49  }
50  }
51 
52  void merge(IDType i, IDType j) {
53  IDType iroot = find(i);
54  IDType jroot = find(j);
55  if (iroot == jroot) return;
56  else if (setid[iroot].second < setid[jroot].second) {
57  setid[iroot].first = jroot;
58  }
59  else if (setid[iroot].second > setid[jroot].second) {
60  setid[jroot].first = iroot;
61  }
62  else {
63  setid[jroot].first = iroot;
64  // make sure we don't overflow
65  if (setid[iroot].second + 1 > setid[iroot].second) {
66  setid[iroot].second = setid[iroot].second + 1;
67  }
68  }
69  }
70 
71  IDType find(IDType i) {
72  IDType root = i;
73  if (is_root(root)) return root;
74 
75  // get the id of the root element
76  while (!is_root(root)) { root = setid[root].first; }
77 
78  // update the parents and ranks all the way up
79  IDType cur = i;
80  while (!is_root(cur)) {
81  IDType parent = setid[cur].first;
82  setid[cur].first = root;
83  cur = parent;
84  }
85 
86  return setid[i].first;
87  }
88 };
89 
90 
91 class concurrent_union_find {
92  private:
93  union elem{
94  struct {
95  uint32_t next;
96  uint32_t rank;
97  } d;
98  uint64_t val;
99  };
100 
101  std::vector<elem> setid;
102 
103  bool is_root(uint32_t i) {
104  return setid[i].d.next == i;
105  }
106 
107  bool updateroot(uint32_t x, uint32_t oldrank,
108  uint32_t y, uint32_t newrank) {
109  elem old; old.d.next = x; old.d.rank = oldrank;
110  elem newval; newval.d.next = y; newval.d.rank = newrank;
111  return atomic_compare_and_swap(setid[x].val, old.val, newval.val);
112  }
113 
114  public:
115  concurrent_union_find() { }
116  void init(uint32_t s) {
117  setid.resize((size_t)s);
118  for (size_t i = 0; i < setid.size() ;++i) {
119  setid[i].d.next = (uint32_t)(i);
120  setid[i].d.rank = 0;
121  }
122  }
123 
124  void merge(uint32_t x, uint32_t y) {
125 
126  uint32_t xr, yr;
127  while(1) {
128  x = find(x);
129  y = find(y);
130  if (x == y) return;
131  xr = setid[x].d.rank;
132  yr = setid[y].d.rank;
133 
134  if (xr > yr || (xr == yr && x > y)) {
135  std::swap(x,y); std::swap(xr, yr);
136  }
137 
138  if (updateroot(x, xr, y, xr)) break;
139  }
140  if (xr == yr) {
141  __sync_add_and_fetch(&(setid[y].d.rank), 1);
142  }
143  }
144 
145  uint32_t find(uint32_t x) {
146  if (is_root(x)) return x;
147 
148  uint32_t y = x;
149  // get the id of the root element
150  while (!is_root(x)) { x = setid[x].d.next; }
151 
152  // update the parents and ranks all the way up
153  while (setid[y].d.rank < setid[x].d.rank) {
154  uint32_t t = setid[y].d.next;
155  atomic_compare_and_swap(setid[y].d.next, t, x);
156  y = setid[t].d.next;
157  }
158  return x;
159  }
160 };
161 }
162 #endif