Programming

[Algorithm] Minimum Spanning Tree(MST)

minigb 2021. 6. 22. 02:30

학교 과제로 MST를 STL 없이 짜야 했다.
그래서 이전에 짜놨던 min_heap을 사용해서 짰다.
그때 다형성을 사용하고, 메소드 이름을 STL이랑 동일하게 짜 놓은 덕분에 이번 과제는 금방 했다.
C++는 정말 재밌다.
다른 언어들도 배워야지.

#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <ctime>
#define kEndl '\n'
typedef long long ll;
typedef unsigned long long ull;
using namespace std;

template<typename T>
class Heap {
public:
	Heap<T>() {
		end = 1;
		tree.resize(1);
	}

	virtual void push(T value) = 0;
	virtual void pop() = 0;

	T top() {
		return tree[1];
	}

	inline bool empty() {
		return end == 1;
	}

	inline ull size() {
		return end;
	}

protected:
	vector<T> tree;
	ull end;
};

template<typename T>
class MinHeap : public Heap<T> {
public:
	virtual void push(T value) {
		if (this->end == this->tree.size()) {
			this->tree.resize(this->tree.size() * 2);
		}
		this->tree[this->end++] = value;

		for (ull child = this->end - 1; child >= 1; child /= 2) {
			ull parent = child / 2;
			if (parent >= 1 && this->tree[parent] > this->tree[child]) {
				swap(this->tree[parent], this->tree[child]);
			}
			else {
				break;
			}
		}
	}

	virtual void pop() {
		this->tree[1] = this->tree[this->end - 1];
		this->end--;

		for (ull parent = 1; parent * 2 < this->end; ) {
			ull smallerChild;
			if (parent * 2 + 1 < this->end) {
				smallerChild = parent * 2 + (this->tree[parent * 2] < this->tree[parent * 2 + 1] ? 0 : 1);
			}
			else {
				smallerChild = parent * 2;
			}

			if (this->tree[smallerChild] < this->tree[parent]) {
				swap(this->tree[parent], this->tree[smallerChild]);
				parent = smallerChild;
			}
			else {
				break;
			}
		}
	}
};

class UnionFind {
public:
	void Init(int n) {
		this->n = n;
		parent.resize(n, -1);
	}

	int GetParent(int a) {
		if (parent[a] == -1) {
			return a;
		}
		return parent[a] = GetParent(parent[a]);
	}

	void Merge(int a, int b) {
		a = GetParent(a);
		b = GetParent(b);
		if (a != b) {
			parent[a] = b;
		}
	}

protected:
	int n;
	vector<int> parent;
};

typedef struct edge {
	int node1, node2;
	ll cost;

	edge() {}
	edge(int node1, int node2, ll cost) {
		this->node1 = node1;
		this->node2 = node2;
		this->cost = cost;
	}

	inline bool operator< (edge comp) const {
		return cost < comp.cost;
	}
	inline bool operator> (edge comp) const {
		return cost > comp.cost;
	}
} Edge;

class MST {
public:
	MST(int n, int e) {
		this->n = n;
		this->e = e;
		uf.Init(n + 1);
	}

	void MakeEdge(int node1, int node2, ll cost) {
		min_heap.push(Edge(node1, node2, cost));
	}

	const vector<Edge>& GetMST() {
		if (count != 0) {
			return mst_edges;
		}

		while (!min_heap.empty() && count < n - 2) {
			Edge current = min_heap.top();
			min_heap.pop();

			if (uf.GetParent(current.node1) != uf.GetParent(current.node2)) {
				mst_edges.push_back(current);
				uf.Merge(current.node1, current.node2);
				total_cost += current.cost;
				++count;
			}
		}
		sort(mst_edges.begin(), mst_edges.end());
		connected = (count == n - 1);

		return mst_edges;
	}

	inline ll GetTotalCost() const {
		return total_cost;
	}

	inline bool IsConnected() const {
		return connected;
	}

protected:
	int n = 0, e = 0; //number of nodes, edges
	ll total_cost = 0;
	int count = 0;
	UnionFind uf; //for union-find
	MinHeap<Edge> min_heap; //for mst
	vector<Edge> mst_edges;
	bool connected;
};

int main() {
	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
	int n, e; cin >> n >> e;
	MST mst(n, e);
	for (int i = 0; i < e; ++i) {
		int node1, node2; cin >> node1 >> node2;
		ll cost; cin >> cost;
		mst.MakeEdge(node1, node2, cost);
	}

	mst.GetMST();
	cout << mst.GetTotalCost() << kEndl;
}