AcWing 状态压缩DP相关问题 524. 愤怒的小鸟


'''
原点加上任意两个点可以确定一条抛物线,先把所有抛物线全部算出来
然后计算落到每条抛物线上的点信息,转换成一个区间覆盖问题求解
用状态压缩DP解决

'''

# 计算点是否在抛物线上
def is_match(x, y, a, b):
    val = a* x * x + b * x

    return abs(val - y) <= 1e-6


# 1的数量
def one_cnt(val):
    ans = 0
    while val:
        ans += 1
        val = val & (val - 1)
    return ans


T = int(input())
for _ in range(T):
    m, n = map(int, input().split())
    arr = []
    for i in range(m):
        a, b = map(float, input().split())
        arr.append((a, b))

    # 先计算所有可能的抛物线
    func_set = set()
    for i in range(m):
        for j in range(i + 1, m):
            x1, y1 = arr[i]
            x2, y2 = arr[j]
            if x1 == x2:
                continue

            a = (x2 * y1 - x1 * y2) / (x1 * x1 * x2 - x1 * x2 * x2)
            b = (y1 - a * x1 * x1) / x1
            if a < 0:
                func_set.add((a, b))
    func_list = list(func_set)

    # 提前先算好每个抛物线包含哪些点
    func_mask = []
    for func in func_list:
        mask = 0
        for node_idx in range(m):
            x, y = arr[node_idx]
            if is_match(x, y, func[0], func[1]):
                mask |= (1 << node_idx)
        func_mask.append(mask)

    if m <= 1:
        print(m)
    else:
        if len(func_list) == 0:
            print(m)
        else:
            # dp(i, stat)表示前i个抛物线中做选择,至少包含stat中所有点的选法中,最少的抛物线数量
            dp = [0] * (1 << 18)
            for i in range(len(func_mask)):
                for stat in range((1 << m) - 1, -1, -1):
                    if i == 0:
                        mask = func_mask[i]
                        match_cnt = 1 if mask & stat != 0 else 0
                        miss_cnt = one_cnt(stat & (~(mask & stat)))
                        dp[stat] = match_cnt + miss_cnt
                    else:
                        mask = func_mask[i]
                        if mask & stat != 0:
                            new_stat = stat & (~(stat & mask))
                            dp[stat] = min(dp[stat], 1 + dp[new_stat])

            print(dp[(1 << m) - 1])

    '''
    # 记忆化递归会超时
    # 前i个抛物线中做选择,至少包含stat中所有点的选法中,最少的抛物线数量    
    def dp(i, stat, memo):
        if (i, stat) in memo:
            return memo[(i, stat)]

        if stat == 0:
            memo[(i, stat)] = 0
            return 0

        if i == 0:
            mask = func_mask[i]
            match_cnt = 1 if mask & stat != 0 else 0
            miss_cnt = one_cnt(stat & (~(mask & stat)))
            ans = match_cnt + miss_cnt
            memo[(i, stat)] = ans
            return ans

        else:
            mask = func_mask[i]
            if mask & stat == 0:
                # 一个点都没有找到重合的
                ans = dp(i-1, stat, memo)
                memo[(i, stat)] = ans
                return ans
            else:
                new_stat = stat & (~(stat & mask))
                ans = min(dp(i-1, stat, memo), 1 + dp(i-1, new_stat, memo))
                memo[(i, stat)] = ans
                return ans

    if m <= 1:
        print(m)
    else:
        if len(func_list) == 0:
            print(m)
        else:
            print( dp(len(func_list)-1, (1<<m)-1, {}) )
    '''

猜你喜欢

转载自blog.csdn.net/xiaohaowudi/article/details/107762330