求解四阶幻方

由于 16! = 20922789888000 规模在10^13次方级别,需要进行大量剪枝操作

利用四阶幻方的特殊性质可以进行剪枝

幻方的一般性质为:幻方每一行之和、每一列之和、两条对角线之和都相等,都等于幻和(四阶幻和为34)。

四阶幻方还有一些性质:中心四个数之和、四个顶角数字之和、边上两对行中间四个数之和、边上两对列中间四个数之和也都等于幻和(34)。

 0  1  2  3
 4  5  6  7
 8  9 10 11
12 13 14 15

12 13  3  6
 1  4 14 15
16  9  7  2
 5  8 10 11


相比三阶幻方,四阶幻方多了下面几个约束条件
1+2+13+14
4+7+8+11
5+6+9+10
0+3+12+15

利用上面性质对排列局面剪枝,可以极大减少搜索量。

另外,每搜索到一个幻方,就可以得到与其旋转反射相同的8个幻方,搜索量可以缩减为1/8

C语言跑了十几秒

#include <stdio.h>
int a[17], b[17], m;
void s(int i){/*四阶幻方全解搜索程序,C代码,运行时间7秒*/
    int n = 0, j = 0;
    while (++j < 17)
        if (!a[j]){
            a[b[i] = j] = 1;
            switch (i){
            case 1: case 2: case 3: case 5: case 6: case 7: case 9: case 10: s(i + 1); break;
            case 11: if (b[6] + b[7] + b[10] + b[11] == 34) s(12); break;
            case 4: case 8: case 12: if (b[i - 3] + b[i - 2] + b[i - 1] + b[i] == 34) s(i + 1); break;
            case 13: if (b[1] + b[5] + b[9] + b[13] == 34 && b[4] + b[7] + b[10] + b[13] == 34) s(14); break;
            case 14: case 15: if (b[i - 12] + b[i - 8] + b[i - 4] + b[i] == 34) s(i + 1); break;
            case 16: for (printf("\n"), ++m; ++n < 17; n % 4 ? 0 : printf("\n") ) printf("%2d ", b[n]);
            }
            a[j] = 0;
        }
}
int main(void){
    s(1);
    printf("四阶幻方总数:%d个(含旋转反射相同)", m);
    return 0;
}

四级幻方共7040种

Python跑了30分钟。。。

'''
三阶幻方8个
四阶幻方7040个
'''

# 幻方阶数
n = 4
# 幻和,即每行,每列和对角线的和
s = n * n * (n * n + 1) // 2 // n
# 数组
a = []
# 使用过的数字
used = set()
# 答案数组
ans = []


# 检查是否满足幻方条件
def check():
    flag = True
    for i in range(n):
        #  行的幻和
        r = sum([a[i * n + j] for j in range(n)])

        # 列的幻和
        c = sum([a[i + j * n] for j in range(n)])
        flag = flag and r == c == s
    # 主对角线和次对角线的幻和
    s1 = sum([a[i * n + i] for i in range(n)])
    s2 = sum([a[n * i - i] for i in range(1, n + 1)])
    flag = flag and s1 == s2 == s
    return flag


cnt = 0


# 搜索算法,参数表示准备放入第m个数字
def dfs(m):
    global cnt
    # 表示已经放入足够的数字了,进行检验
    if m == n * n:
        if check():
            cnt += 1
            print(show(a), cnt)
            ans.append(a.copy())
        # 弹出最后一位,继续递归,并且移除used集合中的数
        used.remove(a.pop())
        return

    # 行剪枝
    if m > 0 and m % n == 0 and sum([a[int(m / n - 1) * n + i] for i in range(n)]) != s:
        used.remove(a.pop())
        return

    # 对角线剪枝和边行剪枝
    if m >= n * (n - 1) + 1:
        if sum([a[n * i - i] for i in range(1, n + 1)]) != s or sum([a[4], a[8], a[7], a[11]]) != s:
            used.remove(a.pop())
            return

    # 中心剪枝,准备放入第11个数字时
    if m > 10:
        if sum([a[5], a[6], a[9], a[10]]) != s:
            used.remove(a.pop())
            return

    if m > n * (n - 1):
        c = m - n * (n - 1) - 1
        if sum([a[c + i * n] for i in range(n)]) != s:
            used.remove(a.pop())
            return

            # 数组长度不够时
    for i in range(1, n * n + 1):
        if i not in used:
            a.append(i)
            used.add(i)
            dfs(m + 1)
    # print(a, m)

    if len(a) > 0:
        used.remove(a.pop())


def show(arr):
    s = ''
    for index, i in enumerate(arr):
        s += str(i)
        if index % n == n - 1:
            s += '\n'

    return s


dfs(0)
print(len(ans))
for i in ans:
    print(show(i))

优化下。。。发现没有少多少。。。主要在于dfs执行次数太多了。。。

'''

三阶幻方8个
四阶幻方7040个
'''

# 幻方阶数
n = 4
# 幻和,即每行,每列和对角线的和
s = n * n * (n * n + 1) // 2 // n

# 深度搜索的数组,先分配内存,减少pop和push
a = [0 for i in range(n * n)]

# 使用过的数字
used = set()
all_num = set([i + 1 for i in range(n * n)])
# 保存答案的数组
ans = []


# 搜索算法,参数表示准备放入第m个数字
def dfs(m):
    # 表示已经放入十五个数字了,最后一位可以计算出来
    if m == n * n - 1:
        # 最后一个数字必须满足行和列以及对角线都为幻和
        t1 = s - (a[12] + a[13] + a[14])  # 行
        t2 = s - (a[3] + a[7] + a[11])  # 列
        t3 = s - sum([a[0], a[5], a[10]])
        if t1 == t2 == t3 and t1 not in used:
            a[m] = t1
            # show(a)
            ans.append(a.copy())
            print(len(ans))

        # 弹出最后一位,继续递归,并且移除used集合中的数
        used.remove(a[m - 1])
        return

    # 行剪枝
    if m in [4, 8, 12] and sum([a[int(m / n - 1) * n + i] for i in range(n)]) != s:
        used.remove(a[m - 1])
        return

    # 对角线剪枝和边行剪枝
    if m == 13:
        if sum([a[3] + a[6] + a[9] + a[12]]) != s or sum([a[4], a[8], a[7], a[11]]) != s:
            used.remove(a[m - 1])
            return

    # 中心剪枝,准备放入第11个数字时
    if m == 11:
        if sum([a[5], a[6], a[9], a[10]]) != s:
            used.remove(a[m - 1])
            return

    # 列剪枝
    if m in [13, 14]:
        if sum([a[m - 13 + i * n] for i in range(n)]) != s:
            used.remove(a[m - 1])
            return

    # 数组长度不够时
    for i in all_num - used:
        a[m] = i
        used.add(i)
        dfs(m + 1)

    if m > 0:
        used.remove(a[m - 1])


def show(arr):
    ss = ''
    for index, i in enumerate(arr):
        ss += str(i).ljust(4)
        if index % n == n - 1:
            ss += '\n'
    print(ss)


dfs(0)
print(len(ans))
# for i in ans:
#     print(show(i))

同样的逻辑c++用了17秒。。。。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <string>

using namespace std;
int n = 4;
int s = 34;
int a[17];
//  0表示没有使用
int used[17];
int cnt = 0;
int ans = 0;


void dfs(int m){
    cnt++;

    if (m == n * n - 1){
        int t1 = s - (a[12] + a[13] + a[14]);
        int t2 = s - (a[3] + a[7] + a[11]);// 列
        int t3 = s - (a[0] + a[5] + a[10]);
        if (t1 == t2 && t2 == t3 && !used[t1]){
            a[m] = t1;
            ans++;
            printf("%d\n", ans);
        }
        used[a[m - 1]] = 0;
        return;
    }

    // 行剪枝
    if (m > 0 && m % 4 == 0 && a[int(m / n - 1) * n] +
        a[int(m / n - 1) * n + 1] +
        a[int(m / n - 1) * n + 2] +
        a[int(m / n - 1) * n + 3] != s){
        used[a[m - 1]] = 0;
        return;
    }

    if (m == 11 && a[5] + a[6] + a[9] + a[10] != s){
        used[a[m - 1]] = 0;
        return;
    }

    if (m == 13 && (
            a[3] + a[6] + a[9] + a[12] != s ||
            a[4] + a[8] + a[7] + a[11] != s
            ) ){
        used[a[m - 1]] = 0;
        return;
    }

    if ( (m == 13 or m == 14) && a[m - 13] + a[m - 13 + 4] + a[m - 13 + 8] + a[m - 13 + 12] != s){
        used[a[m - 1]] = 0;
        return;
    }

    for (int i = 1; i <= 16; i++){
        if (!used[i]){
            a[m] = i;
            used[i] = 1;
            dfs(m + 1);
        }
    }

    if (m > 0){
        used[a[m - 1]] = 0;
    }

}

int main(){
    dfs(0);
    printf("%d\n", cnt);
    return 0;
}

取消输出用了17秒

java 跑了15秒

取消输出用了13s

package com.company;

public class Main {
    static int n = 4;
    static int s = 34;
    static int[] a = new int[17];
    //  0表示没有使用
    static int[] used = new int[17];
    static int cnt = 0;
    static int ans = 0;


    public static void dfs(int m) {
        cnt++;
        if (m == n * n - 1) {
            int t1 = s - (a[12] + a[13] + a[14]);
            int t2 = s - (a[3] + a[7] + a[11]);// 列
            int t3 = s - (a[0] + a[5] + a[10]);
            if (t1 == t2 && t2 == t3 && used[t1] == 0) {
                a[m] = t1;
                ans++;
                System.out.println(ans);
            }
            used[a[m - 1]] = 0;
            return;
        }

        // 行剪枝
        if (m > 0 && m % 4 == 0 && a[(int) (m / n - 1) * n] +
                a[(int) (m / n - 1) * n + 1] +
                a[(int) (m / n - 1) * n + 2] +
                a[(int) (m / n - 1) * n + 3] != s) {
            used[a[m - 1]] = 0;
            return;
        }

        if (m == 11 && a[5] + a[6] + a[9] + a[10] != s) {
            used[a[m - 1]] = 0;
            return;
        }

        if (m == 13 && (
                a[3] + a[6] + a[9] + a[12] != s ||
                        a[4] + a[8] + a[7] + a[11] != s
        )) {
            used[a[m - 1]] = 0;

            return;
        }

        if ((m == 13 || m == 14) && a[m - 13] + a[m - 13 + 4] + a[m - 13 + 8] + a[m - 13 + 12] != s) {
            used[a[m - 1]] = 0;

            return;
        }

        for (int i = 1; i <= 16; i++) {
            if (used[i] == 0) {
                a[m] = i;
                used[i] = 1;
                Main.dfs(m + 1);
            }
        }

        if (m > 0) {
            used[a[m - 1]] = 0;
        }

    }


    public static void main(String[] args) {
        Main.dfs(0);
        System.out.println(Main.cnt);
    }
}

综上Java真NB

g++ 开O3 。。。 6秒!!!

js 用了30s,比java和c慢,但至少是在一个级别的,比Python快很多了。。。

let n = 4;
let s = 34;
let a = Array(17)
//  0表示没有使用
let used = Array(17);
let cnt = 0;
let ans = 0;


function dfs(m) {
  cnt++;

  if (m == n * n - 1) {
    let t1 = s - (a[12] + a[13] + a[14]);
    let t2 = s - (a[3] + a[7] + a[11]);// 列
    let t3 = s - (a[0] + a[5] + a[10]);
    if (t1 == t2 && t2 == t3 && !used[t1]) {
      a[m] = t1;
      ans++;
      // console.log(ans)
    }
    used[a[m - 1]] = 0;
    return;
  }

  // 行剪枝
  if (m > 0 && m % 4 == 0 && a[Math.floor(m / n - 1) * n] +
    a[Math.floor(m / n - 1) * n + 1] +
    a[Math.floor(m / n - 1) * n + 2] +
    a[Math.floor(m / n - 1) * n + 3] != s) {
    used[a[m - 1]] = 0;
    return;
  }

  if (m == 11 && a[5] + a[6] + a[9] + a[10] != s) {
    used[a[m - 1]] = 0;
    return;
  }

  if (m == 13 && (
      a[3] + a[6] + a[9] + a[12] != s ||
      a[4] + a[8] + a[7] + a[11] != s
    )) {
    used[a[m - 1]] = 0;
    return;
  }

  if ((m == 13 || m == 14) && a[m - 13] + a[m - 13 + 4] + a[m - 13 + 8] + a[m - 13 + 12] != s) {
    used[a[m - 1]] = 0;
    return;
  }

  for (let i = 1; i <= 16; i++) {
    if (!used[i]) {
      a[m] = i;
      used[i] = 1;
      dfs(m + 1);
    }
  }

  if (m > 0) {
    used[a[m - 1]] = 0;
  }

}

let start = new Date().getTime();
dfs(0);
console.log(cnt);
console.log(new Date().getTime() - start)

猜你喜欢

转载自my.oschina.net/ahaoboy/blog/1800612