diff --git a/data_structures/disjoint_set/disjoint_set.py b/data_structures/disjoint_set/disjoint_set.py new file mode 100644 index 000000000..a93b89621 --- /dev/null +++ b/data_structures/disjoint_set/disjoint_set.py @@ -0,0 +1,79 @@ +""" + disjoint set + Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure +""" + + +class Node: + def __init__(self, data): + self.data = data + + +def make_set(x): + """ + make x as a set. + """ + # rank is the distance from x to its' parent + # root's rank is 0 + x.rank = 0 + x.parent = x + + +def union_set(x, y): + """ + union two sets. + set with bigger rank should be parent, so that the + disjoint set tree will be more flat. + """ + x, y = find_set(x), find_set(y) + if x.rank > y.rank: + y.parent = x + else: + x.parent = y + if x.rank == y.rank: + y.rank += 1 + + +def find_set(x): + """ + return the parent of x + """ + if x != x.parent: + x.parent = find_set(x.parent) + return x.parent + + +def find_python_set(node: Node) -> set: + """ + Return a Python Standard Library set that contains i. + """ + sets = ({0, 1, 2}, {3, 4, 5}) + for s in sets: + if node.data in s: + return s + raise ValueError(f"{node.data} is not in {sets}") + + +def test_disjoint_set(): + """ + >>> test_disjoint_set() + """ + vertex = [Node(i) for i in range(6)] + for v in vertex: + make_set(v) + + union_set(vertex[0], vertex[1]) + union_set(vertex[1], vertex[2]) + union_set(vertex[3], vertex[4]) + union_set(vertex[3], vertex[5]) + + for node0 in vertex: + for node1 in vertex: + if find_python_set(node0).isdisjoint(find_python_set(node1)): + assert find_set(node0) != find_set(node1) + else: + assert find_set(node0) == find_set(node1) + + +if __name__ == "__main__": + test_disjoint_set()