プリムのアルゴリズム
ここのお話のやや続き。id:naoya さんがクラスカルのアルゴリズムを書かれていたのに触発されてプリムのアルゴリズム書いてみました。前提のグラフは同じです。私は優先度付き min キューの実装は、Python の heap queue にのっかってラップする感じでの手抜きではありますが。。。
ちょっとスパゲッティ感漂っているのはご容赦頂きたい所なのですが、こうして実装して比較してみてみますと、クラスカルが Edge を単位として見ているのに対して、プリムはどちらかというと Vertice を単位としてみているような雰囲気があって、その違いも面白いものです。
# -*- coding: utf-8 -*- import sys from heapq import heappush,heappop class minheap: def __init__(self,elems): self.heap = [] [heappush(self.heap,e) for e in elems] def extract_min(self): return heappop(self.heap) def decrease_key(self,i,key): if key > self.heap[i] : raise Exception, "new key is larger than current key" self.heap[i] = key while i > 0 and self.heap[i/2] > self.heap[i] : pindex = i/2 self.heap[i], self.heap[pindex] = self.heap[pindex], self.heap[i] i = pindex def index(self,key): return self.heap.index(key) def is_empty(self): return len(self.heap) == 0 class MinPriorityQueue: def __init__(self,elems): keys = map(lambda x : x.key, elems) self.minheap = minheap(keys) self.elem_key_handles = dict([(e,e.key) for e in elems]) self.key_elem_handles = dict([(k,[]) for k in keys]) [self._update_key_elem_handles(e) for e in elems] def is_phi(self): return self.minheap.is_empty() def extract_min(self): min_key = self.minheap.extract_min() elem = self.key_elem_handles[min_key].pop(0) del self.elem_key_handles[elem] return elem def contains(self,elem): return self.elem_key_handles.has_key(elem) def update(self,elem,key): oldkey = elem.key i = self.minheap.index(oldkey) self.minheap.decrease_key(i, key) self.key_elem_handles[oldkey].remove(elem) elem.key = key self._update_key_elem_handles(elem) def _update_key_elem_handles(self,elem): key = elem.key if not self.key_elem_handles.has_key(key) : self.key_elem_handles[key] = [] self.key_elem_handles[key].append(elem) class Vertice: def __init__(self,name): self.name = name self.adjs = dict() def add(self,vertice,weight,stop=False): if not self.adjs.has_key(vertice) : self.adjs[vertice] = weight if not stop : vertice.add(self,weight,True) return self def w(self,vertice): return self.adjs.get(vertice,sys.maxint) def adj(self): return self.adjs.keys() def __repr__(self): return self.name __str__ = __repr__ # vertices a,b,c,d,e,f,g,h,i = Vertice('a'), Vertice('b'),Vertice('c'),Vertice('d'),Vertice('e'),Vertice('f'),Vertice('g'),Vertice('h'),Vertice('i') vertices = [a,b,c,d,e,f,g,h,i] # edges a.add(b,4).add(h,8) b.add(c,8).add(h,11) c.add(d,7).add(f,4).add(i,2) d.add(e,9).add(f,14) e.add(f,10) f.add(g,2) g.add(h,1).add(i,6) h.add(i,7) # MST-PRIM for u in vertices : u.key = sys.maxint u.pi = None r = vertices[0] r.key = 0 q = MinPriorityQueue(vertices) while not q.is_phi(): u = q.extract_min() for v in u.adj() : if q.contains(v) and u.w(v) < v.key : v.pi = u q.update(v, u.w(v)) print sum(map(lambda x : x.key, vertices))