알고리즘/문제풀이

[백준 10021번] Watering the Fields

Ohnim · 오님 2024. 5. 20. 23:30

문제 링크

https://www.acmicpc.net/problem/10021

 

문제 요약

농부 John은 가뭄 때문에 N개의 농지에 물을 공급하기 위한 수로를 설치하려고 한다. 각 농지는 2차원 평면 위의 점 \((x_i, y_i)\)로 나타내며, 서로 다른 두 농지 \(i\)와 \(j\)를 연결하는 수로를 건설하는 비용은 \((x_i - x_j)^2 + (y_i - y_j)^2\)으로 계산한다.

 

이 문제는 모든 농지가 서로 물을 공급받을 수 있도록 수로를 설치하는 최소 비용을 계산하는것이 목표다. 제약 사항이 하나 있다면 수로를 건설하는 업자가 수로 건설 비용이 C 이상이 되는 것만 설치한다는 것이다.

 

문제 풀이

이 문제에서 농지는 그래프의 노드로, 수로는 간선으로 모델링 할 수 있다. 이렇게 그래프로 모델링하고 나면 우리가 구해야 하는 것은 비용이 C 이상인 간선들을 이용해 모든 노드들을 연결하는 미니멈 스패닝 트리를 구하는 문제로 바뀐다.(만약 미니멈 스패닝 트리를 모른다면 이 문제를 푸는 것은 어려울 것이라 생각한다)

 

미니멈 스패닝 트리를 구하는 알고리즘은 크게 크루스칼 알고리즘과 프림 알고리즘 두 개가 있는데, 크루스칼 알고리즘을 이용해 미니멈 스패닝 트리를 구했다.(보통 크루스칼 알고리즘을 이용하는데 왠지모르게 디스조인트 셋을 구현하는 것이 기분이 좋다.)

 

답을 출력할 때에는 실제로 스패닝 트리가 만들어졌는지 확인할 필요가 있다. 0번 노드의 집합 번호를 찾은 뒤 나머지 노드가 전부 같은 집합에 속하는지 확인하는 방법으로 간단하게 확인할 수 있다.

 

소스 코드

더보기
#include <stdio.h>
#include <vector>
#include <algorithm>

using namespace std;

struct Point {
    int x, y;
    
    Point(int _x, int _y): x(_x), y(_y) {}
};

struct Edge {
    int u, v, w;
    
    Edge(int _u, int _v, int _w): u(_u), v(_v), w(_w) {}
    
    bool operator < (const Edge& r) const {
        return w < r.w;
    }
};

struct DisjointSet {
    vector<int> parents, ranks;
    
    DisjointSet(int n): parents(n), ranks(n, 1) {
        for (int i = 0 ; i < n ; ++i) {
            parents[i] = i;
        }
    }
    
    int find(int u) {
        if (u == parents[u]) {
            return u;
        }
        
        return parents[u] = find(parents[u]);
    }
    
    void merge(int u, int v) {
        u = find(u);
        v = find(v);
        
        if (u == v) {
            return;
        }
        
        if (ranks[u] > ranks[v]) {
            swap(u, v);
        }
        
        parents[u] = v;
        
        if (ranks[u] == ranks[v]) {
            ++ranks[v];
        }
    }
};

int getDist(const Point& a, const Point& b) {
    return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
}

int n, c;
vector<Point> points;
vector<Edge> edges;

int main() {
    scanf("%d%d", &n, &c);

    for (int i = 0 ; i < n ; ++i) {
        int x, y;
        scanf("%d%d", &x, &y);
        points.push_back(Point(x, y));
    }
    
    for (int i = 0 ; i < n ; ++i) {
        for (int j = i + 1 ; j < n ; ++j) {
            int dist = getDist(points[i], points[j]);
            
            if (dist >= c) {
                edges.push_back(Edge(i, j, dist));
            }
        }
    }
    
    sort(edges.begin(), edges.end());
    
    DisjointSet disjointSet(n);
    
    long long ans = 0;
    
    for (const auto& edge: edges) {
        if (disjointSet.find(edge.u) == disjointSet.find(edge.v)) {
            continue;
        }
        
        disjointSet.merge(edge.u, edge.v);
        ans += edge.w;
    }
    
    int root = disjointSet.find(0);
    
    for (int i = 1 ; i < n ; ++i) {
        if (root != disjointSet.find(i)) {
            puts("-1");
            
            return 0;
        }
    }
    
    printf("%lld\n", ans);
}