プリムのアルゴリズム

ここのお話のやや続き。id:naoya さんがクラスカルアルゴリズムを書かれていたのに触発されてプリムのアルゴリズム書いてみました。前提のグラフは同じです。私は優先度付き min キューの実装は、Python の heap queue にのっかってラップする感じでの手抜きではありますが。。。

ちょっとスパゲッティ感漂っているのはご容赦頂きたい所なのですが、こうして実装して比較してみてみますと、クラスカルが Edge を単位として見ているのに対して、プリムはどちらかというと Vertice を単位としてみているような雰囲気があって、その違いも面白いものです。

# -*- coding: utf-8 -*-

import sys
from heapq import heappush,heappop

class minheap:    
    
    def __init__(self,elems):
        self.heap = []
        [heappush(self.heap,e) for e in elems]
        
    def extract_min(self):
        return heappop(self.heap)
    
    def decrease_key(self,i,key):
        if key > self.heap[i] :
            raise Exception, "new key is larger than current key"
        self.heap[i] = key
        while i > 0 and self.heap[i/2] > self.heap[i] :
            pindex = i/2
            self.heap[i], self.heap[pindex] = self.heap[pindex], self.heap[i]
            i = pindex
            
    def index(self,key):
        return self.heap.index(key)
    
    def is_empty(self):
        return len(self.heap) == 0
    
class MinPriorityQueue:
    
    def __init__(self,elems):
        keys = map(lambda x : x.key, elems)
        self.minheap = minheap(keys)
        self.elem_key_handles = dict([(e,e.key) for e in elems])
        self.key_elem_handles = dict([(k,[]) for k in keys])
        [self._update_key_elem_handles(e) for e in elems]
        
    def is_phi(self):
        return self.minheap.is_empty()
    
    def extract_min(self):
        min_key = self.minheap.extract_min()    
        elem = self.key_elem_handles[min_key].pop(0)
        del self.elem_key_handles[elem]
        return elem
    
    def contains(self,elem):
        return self.elem_key_handles.has_key(elem)
    
    def update(self,elem,key):        
        oldkey = elem.key
        i = self.minheap.index(oldkey)
        self.minheap.decrease_key(i, key)
        self.key_elem_handles[oldkey].remove(elem)
        elem.key = key
        self._update_key_elem_handles(elem)
    
    def _update_key_elem_handles(self,elem):
        key = elem.key
        if not self.key_elem_handles.has_key(key) :
            self.key_elem_handles[key] = []
        self.key_elem_handles[key].append(elem)

class Vertice:    
    
    def __init__(self,name):
        self.name = name
        self.adjs = dict()        
        
    def add(self,vertice,weight,stop=False):
        if not self.adjs.has_key(vertice) :
            self.adjs[vertice] = weight
        if not stop :
            vertice.add(self,weight,True)
        return self
    
    def w(self,vertice):
        return self.adjs.get(vertice,sys.maxint)
    
    def adj(self):
        return self.adjs.keys()    

    def __repr__(self):
        return self.name
    
    __str__ = __repr__    
                    
# vertices
a,b,c,d,e,f,g,h,i = Vertice('a'), Vertice('b'),Vertice('c'),Vertice('d'),Vertice('e'),Vertice('f'),Vertice('g'),Vertice('h'),Vertice('i')
vertices = [a,b,c,d,e,f,g,h,i]

# edges
a.add(b,4).add(h,8)
b.add(c,8).add(h,11)
c.add(d,7).add(f,4).add(i,2)
d.add(e,9).add(f,14)
e.add(f,10)
f.add(g,2)
g.add(h,1).add(i,6)
h.add(i,7)

# MST-PRIM
for u in vertices :
    u.key = sys.maxint
    u.pi = None        
r = vertices[0]
r.key = 0
q = MinPriorityQueue(vertices)
while not q.is_phi():
    u = q.extract_min()
    for v in u.adj() :
        if q.contains(v) and u.w(v) < v.key :
            v.pi = u
            q.update(v, u.w(v))

print sum(map(lambda x : x.key, vertices))