Minimum Spanning Trees

Minimum Spanning Trees

  • Kruskal’s Algorithm
    • O(n + m lg(n))
  • Prim’s Algorithm
    • O(n lg(n) + m lg(n))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import heapq
from collections import namedtuple


class UnionFind(object):
def __init__(self, nodes):
self.parent = {v: -1 for v in nodes}
self.rank = {r: 0 for r in nodes}


def get_root(self, i):
if self.parent[i] == -1:
return i
self.parent[i] = self.get_root(self.parent[i])
return self.parent[i]


def union(self, i, j):
i_root = self.get_root(i)
j_root = self.get_root(j)

if i_root != j_root:
if self.rank[i_root] == self.rank[j_root]:
self.parent[i_root] = j_root
self.rank[j_root] += 1
elif self.rank[i_root] > self.rank[j_root]:
self.parent[j_root] = i_root
else:
self.parent[i_root] = j_root

def is_connected(self, i, j):
return self.get_root(i) == self.get_root(j)


class MyHeap(object):
def __init__(self, initial = None, key = None):
if not key:
self.key = lambda x: x.weight

if initial:
self._data = [(key(item), item) for item in initial]
heapq.heapify(self._data)
else:
self._data = []

def push(self, item):
# print(self._data)
heapq.heappush(self._data, (self.key(item), item))


def pop(self):
return heapq.heappop(self._data)[1]


def empty(self):
return False if self._data else True




# KruskalMST(G):
# DisjointSets forest
# foreach (Vertex v : G):
# forest.makeSet(v)

# PriorityQueue Q // min edge weight
# foreach (Edge e : G):
# Q.insert(e)

# Graph T = (V, {})
# while |T.edges()| < n-1:
# Edge (u, v) = Q.removeMin()
# if forest.find(u) != forest.find(v):
# T.addEdge(u, v)
# forest.union(forest.find(u)),
# return T



def kruskal(graph):

edge = namedtuple("edge",('u','v','weight'))
heap = MyHeap()
union_find = UnionFind(graph.keys())

path = []
weight = 0

for u in graph:
for v in graph[u]:
heap.push(edge(u,v,graph[u][v]))


while len(path) < len(graph)-1:
edge = heap.pop()
if not union_find.is_connected(edge.u, edge.v):
path.append(edge)
weight += edge.weight
union_find.union(edge.u,edge.v)

return weight,path


# PrimMST(G, s):
# foreach (Vertex v : G):
# d[v] = +inf
# p[v] = NULL
# d[s] = 0
# PriorityQueue Q // min distance, defined by d[v]
# Q.buildHeap(G.vertices())

# Graph T // "labeled set"

# repeat n times:
# Vertex m = Q.removeMin()
# T.add(m)
# foreach (Vertex v : neighbors of m not in T):
# if cost(v, m) < d[v]:
# d[v] = cost(v, m)
# p[v] = m


def prim(graph, root):

# Input: G, Graph;
# s, vertex in G, starting vertex
# Output: T, a minimum spanning tree (MST) of G

prev = None
path = []
total = 0 # Total cost of edges in tree
visited = set() # Set of vertices in tree
Node = namedtuple("Node",("v","weight"))
heap = MyHeap() # Unexplored edges ordered by cost
heap.push(Node(root,0))

while not heap.empty():
cur_node = heap.pop()
if cur_node.v not in visited:
visited.add(cur_node.v)
total += cur_node.weight
if prev:
path.append((prev,cur_node.v,cur_node.weight))
prev = cur_node.v
for neighbour in graph[cur_node.v]:
if neighbour not in visited:
heap.push(Node(neighbour, graph[cur_node.v][neighbour]))

return total,path







return weight,path

if __name__ == '__main__':
graph_dict = { "v1":{"v2": 32, "v4": 17},
"v2":{"v1":32, "v5": 45},
"v3":{"v7":5,"v4":18},
"v4":{"v3":18,"v1":17, "v5":10,"v8":3},
"v5":{"v4":10,"v2":45,"v9":25,"v6":28},
"v6":{"v5":28,"v10":6},
"v7":{"v3":5,"v8":59},
"v8":{"v4":3,"v7":59,"v9":4},
"v9":{"v8":4,"v5":25,"v10":12},
"v10":{"v9":12,"v6":6}
}

weight,path = prim(graph_dict, 'v1')
print(weight)
print(path)

weight,path = kruskal(graph_dict)
print(weight)
print(path)