이 문제에서 달라진 점은 각 행렬에 들어갈 수 있는 값이 0또는 1이 아니라 내가 정할 수 있다는 점이다. 만약 내가 원본 행렬에 존재하는 최댓값을 \(x\)라고 정했다고 하자. 그럼 각 행들을 대표하는 정점들과 각 열들을 대표하는 정점들을 이어줄 때 \(capacity\)를 \(x\)로 설정해주면 된다. 그렇게 모든 \(x\)에 대해서 가능한지 확인한다면 답을 구할 수 있다.
하지만 모든 \(x\)에 대해서 네트워크 플로우를 돌리기에는 시간이 부족하다. 우리는 여기서 한가지 아이디어를 생각할 수 있다. \(x\)는 커지면 커질수록 원본 행렬을 만들 수 있는 확률이 높아진다는 것이다. 즉 단조증가 그래프가 그려지기 때문에 \(binary\ search\)를 적용할 수 있다.
\(binary\ search\)를 사용하면 원본 행렬에 존재하는 값의 최댓값중의 최솟값을 빠른 시간내로 찾을 수 있고 역추적 과정을 통해 원본 행렬을 복구할 수 있다.
소스 코드
#include <stdio.h>
#include <string.h>
#include <queue>
#include <algorithm>
using namespace std;
const int INF = 987654321;
int n, tot, N, src, sink, RS[60], CS[60], C[110][110], dist[110], iter[110];
void makeGraph(int x) {
memset(C, 0, sizeof(C));
for (int i = 0 ; i < n ; ++i) {
C[src][i] = RS[i];
C[i + n][sink] = CS[i];
for (int j = 0 ; j < n ; ++j) {
C[i][j + n] = x;
}
}
}
bool bfs() {
memset(dist, -1, sizeof(dist));
dist[src] = 0;
queue<int> q;
q.push(src);
while (!q.empty()) {
int here = q.front();
q.pop();
for (int there = 0 ; there < N ; ++there) {
if (C[here][there] && dist[there] == -1) {
dist[there] = dist[here] + 1;
q.push(there);
}
}
}
return dist[sink] != -1;
}
int dfs(int here, int flow) {
if (here == sink) {
return flow;
}
for (int& there = iter[here] ; there < N ; ++there) {
if (C[here][there] && dist[here] < dist[there]) {
int minFlow = dfs(there, min(flow, C[here][there]));
if (minFlow) {
C[here][there] -= minFlow;
C[there][here] += minFlow;
return minFlow;
}
}
}
return 0;
}
int maxFlow() {
int ret = 0;
while (bfs()) {
memset(iter, 0, sizeof(iter));
int flow;
while ((flow = dfs(src, INF))) {
ret += flow;
}
}
return ret;
}
bool solve(int x) {
makeGraph(x);
return maxFlow() == tot;
}
int main() {
scanf("%d", &n);
for (int i = 0 ; i < n ; ++i) {
scanf("%d", &RS[i]);
tot += RS[i];
}
for (int i = 0 ; i < n ; ++i) {
scanf("%d", &CS[i]);
}
N = n * 2 + 2;
src = n * 2;
sink = n * 2 + 1;
int lo = 0, hi = 1000000;
while (lo < hi) {
int mid = (lo + hi) >> 1;
if (solve(mid)) {
hi = mid;
} else {
lo = mid + 1;
}
}
solve(lo);
printf("%d\n", lo);
for (int i = 0 ; i < n ; ++i) {
for (int j = 0 ; j < n ; ++j) {
printf("%d ", lo - C[i][j + n]);
}
puts("");
}
}