\(\quad\) 想了挺久没想出来,一看题解恍然大悟。一个数对应一行和一列,二分答案,凡是小于等于答案的就连边。如果满足能够取出 \(n - k + 1\) 个不比二分中点 \(mid\) 大的数,那么r = mid
,不然l = mid + 1
。
\(\quad\) 为什么是要求 \(n - k + 1\) 个不大于答案的数,而不是 \(k\) 个不小于答案的数呢?因为后者无法保证能够取出 \(n\) 个数。比方说:
样例
3 4 2
1 5 6 6
8 3 4 3
6 8 6 3
\(\quad\) \(1\) 满足能够取出 \(k\) 个不小于它的数,但如果把它作为第 \(k\) 大的数,就无法找到 \(n - k\) 个比它小的数一并取出。
\(\quad\) 换言之,如果能够取出 \(n - k - 1\) 个不大于它的数,就一定能够在它取最小值时再取出 \(k\) 个大于等于它的数;如果能够取出 \(k\) 个大于等于它的数,却不一定能够在它取最小值时再取出 \(n - k - 1\) 个小于等于它的数。
#include <cstdio>
#include <cstring>
#include <queue>
inline int min(const int& a, const int& b){
return a < b ? a : b;
}
inline int max(const int& a, const int& b){
return a > b ? a : b;
}
const int MAXN = 5e2 + 19, MAXM = 2.5e5 + 19;
struct Edge{
int to, next, c;
}edge[MAXM << 1];
int cnt = -1, head[MAXN];
inline void add(int from, int to, int c){
edge[++cnt].to = to;
edge[cnt].c = c;
edge[cnt].next = head[from];
head[from] = cnt;
}
int n, m, k;
int a[MAXN][MAXN];
int dep[MAXN];
int bfs(void){
std::queue<int>q; q.push(0);
std::memset(dep, 0, sizeof dep); dep[0] = 1;
while(!q.empty()){
int node = q.front(); q.pop();
for(int i = head[node]; i != -1; i = edge[i].next)
if(!dep[edge[i].to] && edge[i].c)
dep[edge[i].to] = dep[node] + 1, q.push(edge[i].to);
}
return dep[n + m + 1];
}
int dfs(int node, int flow){
if(node == n + m + 1 || !flow)
return flow;
int stream = 0, f;
for(int i = head[node]; i != -1; i = edge[i].next)
if(dep[edge[i].to] == dep[node] + 1 && (f = dfs(edge[i].to, min(flow, edge[i].c)))){
flow -= f, stream += f;
edge[i].c -= f, edge[i ^ 1].c += f;
if(!flow)
break;
}
return stream;
}
int dinic(void){
int flow = 0;
while(bfs())
flow += dfs(0, 0x3f3f3f3f);
return flow;
}
int main(){
int l = 0x3f3f3f3f, r = 0;
std::scanf("%d%d%d", &n, &m, &k);
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j){
std::scanf("%d", &a[i][j]);
l = min(l, a[i][j]), r = max(r, a[i][j]);
}
while(l < r){
int mid = (l + r) >> 1;
std::memset(head, -1, sizeof head), cnt = -1;
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j)
if(a[i][j] <= mid)
add(i, j + n, 1), add(j + n, i, 0);
for(int i = 1; i <= n; ++i)
add(0, i, 1), add(i, 0, 0);
for(int j = 1; j <= m; ++j)
add(j + n, n + m + 1, 1), add(n + m + 1, j + n, 0);
if(dinic() >= n - k + 1)
r = mid;
else
l = mid + 1;
}
std::printf("%d\n", l);
return 0;
}