算法导论(二)之分治策略

版权声明:欢迎大家转载,指正。 https://blog.csdn.net/yin__ren/article/details/83213895

一、最大子数组

1. 伪代码

FIND-MAXIMUM-SUBARRAY(A,low,high){
	if high == low
		return (low,high,A[low])
	else 
		mid = (low + high) / 2
		(left-low,left-high,left-sum) = FIND-MAXIMUM-SUBARRAY(A,low,mid)
		(right-low,right-high,right-sum) = FIND-MAXIMUM-SUBARRAY(A,mid + 1,high)
		(cross-low,cross-high,cross-sum) = FIND-MAX-CROSSING-SUBARRAY(A,low,mid,high)
		if left-sum >= right-sum and left-sum >= cross-sum
			return (left-low,left-high,left-sum)
		elseif right-sum >= left-sum and right-sum >= cross-sum
			return (right-low,right-high,right-sum)
		else 
			return (cross-low,cross-high,cross-sum)
}

FIND-MAX-CROSSING-SUBARRAY(A,low,mid,high){
	left-sum = 负无穷
	sum = 0
	for i = mid downto low
		sum = sum + A[i]
		if sum > left-sum
			left-sum = sum
			max-left = i
	right-sum = 负无穷
	sum = 0
	for j = mid + 1 to high
		sum = sum + A[j]
		if sum > right-sum
			right-sum = sum
			max-right = j
	return (max-left,max-right,left-sum + right-sum)
}

2. 时间复杂度

O ( n l o g n ) O(n * logn)

3. Java 代码

public static int[] findMaxSubArray(int[] array, int low, int high) {
        if (high == low) {
            int[] result = {low, high, array[low]};
            return result;
        } else {
            int mid = (int) Math.floor((low + high) / 2);
            int[] left = findMaxSubArray(array, low, mid);
            int[] right = findMaxSubArray(array, mid + 1, high);
            int[] cross = findMaxCrossingSubArray(array, low, mid, high);

            if (left[2] >= right[2] && left[2] >= cross[2]) {
                return left;
            } else if (right[2] >= left[2] && right[2] >= cross[2]) {
                return right;
            } else {
                return cross;
            }
        }
    }

public static int[] findMaxCrossingSubArray(int[] array, int low, int mid, int high) {

        int leftsum = Integer.MIN_VALUE;
        int sum1 = 0;
        int maxleft = 0;

        for (int i = mid; i >= 0; i--) {
            sum1 = sum1 + array[i];

            if (sum1 > leftsum) {
                leftsum = sum1;
                maxleft = i;
            }
        }

        int rightsum = Integer.MIN_VALUE;
        int sum2 = 0;
        int maxright = 0;

        for (int j = mid + 1; j <= high; j++) {
            sum2 = sum2 +array[j];

            if (sum2 > rightsum) {
                rightsum = sum2;
                maxright = j;
            }
        }
        int[] result = new int[3];

        result[0] = maxleft;
        result[1] = maxright;
        result[2] = leftsum + rightsum;

        return result;
    }

4. 非递归,线性时间实现

public static int findMaxByLine(int[] array) {
        int max = Integer.MIN_VALUE;
            int tmp = Integer.MIN_VALUE;
        for (int i = 0; i < array.length; i++) {
            if (tmp + array[i] >= array[i]){
                tmp += array[i];
            }else {
                tmp = array[i];
            }
            if (tmp > max){
                max = tmp;
            }
        }
        return max;
    }

二、矩阵乘法的 Strassen

Strassen算法详解

1. 数学方法

1. 伪代码

SQUARE-MATRIX-MULTIPLY(A,B){
	m = A.rows
	let C be a new nXn matrix
	for i = 1 to n
		for j = 1 to n
			C[i][j] = 0
			for k = 1 to n
				C[i][j] = C[i][j] + A[i][k] * B[k][j]
	return C
 }

2. 时间复杂度

O ( n 3 ) O(n ^ 3)

3. java 代码

public int[][] mathMethod() {
        for (int i = 0; i < NUM; i++) {
            for (int j = 0; j < NUM; j++) {
                matrixResult[i][j] = 0;
                for (int k = 0; k < NUM; k++) {
                    matrixResult[i][j] += matrix1[i][k] * matrix2[k][j];
                }
            }
        }
		return matrixResult;
    }

2. 分治法

1. 伪代码

在这里插入图片描述

2. 时间复杂度

T ( n ) = 7 T ( n / 2 ) + θ ( n 2 ) T(n) = 7 * T(n / 2) + \theta(n ^ 2)

3. Java 代码实现

public int[][] strassenMethod(int[][] strMatrix1, int[][] strMatrix2, int num) {
        int[][] result = new int[num][num];
        for (int i = 0;i < num;i++){
            for (int j = 0;j < num;j++){
                result[i][j] = 0;
            }
        }
        if (num == 1) {
            result[0][0] = strMatrix1[0][0] * strMatrix2[0][0];
        } else {
            int[][][][] halfMatrix1 = new int[2][2][num / 2][num / 2];
            int[][][][] halfMatrix2 = new int[2][2][num / 2][num / 2];
            int[][][][] C = new int[2][2][num / 2][num / 2];
            //先进行 n/2 划分
            for (int i = 0; i < num / 2; i++) {
                for (int j = 0; j < num / 2; j++) {
                    halfMatrix1[0][0][i][j] = strMatrix1[i][j];
                    halfMatrix1[0][1][i][j] = strMatrix1[i][j + num / 2];
                    halfMatrix1[1][0][i][j] = strMatrix1[i + num / 2][j];
                    halfMatrix1[1][1][i][j] = strMatrix1[i + num / 2][j + num / 2];

                    halfMatrix2[0][0][i][j] = strMatrix2[i][j];
                    halfMatrix2[0][1][i][j] = strMatrix2[i][j + num / 2];
                    halfMatrix2[1][0][i][j] = strMatrix2[i + num / 2][j];
                    halfMatrix2[1][1][i][j] = strMatrix2[i + num / 2][j + num / 2];
                }
            }
            num /= 2;

            int[][][] S = new int[10][num][num];
            S[0] = subtractMatrix(halfMatrix2[0][1], halfMatrix2[1][1], num);
            S[1] = addMatrix(halfMatrix1[0][0], halfMatrix1[0][1], num);
            S[2] = addMatrix(halfMatrix1[1][0], halfMatrix1[1][1], num);
            S[3] = subtractMatrix(halfMatrix2[1][0], halfMatrix2[0][0], num);
            S[4] = addMatrix(halfMatrix1[0][0], halfMatrix1[1][1], num);
            S[5] = addMatrix(halfMatrix2[0][0], halfMatrix2[1][1], num);
            S[6] = subtractMatrix(halfMatrix1[0][1], halfMatrix1[1][1], num);
            S[7] = addMatrix(halfMatrix2[1][0], halfMatrix2[1][1], num);
            S[8] = subtractMatrix(halfMatrix1[0][0], halfMatrix1[1][0], num);
            S[9] = addMatrix(halfMatrix2[0][0], halfMatrix2[0][1], num);

            int[][][] P = new int[7][num][num];
            P[0] = strassenMethod(halfMatrix1[0][0], S[0], num);
            P[1] = strassenMethod(S[1], halfMatrix2[1][1], num);
            P[2] = strassenMethod(S[2], halfMatrix2[0][0], num);
            P[3] = strassenMethod(halfMatrix1[1][1], S[3], num);
            P[4] = strassenMethod(S[4], S[5], num);
            P[5] = strassenMethod(S[6], S[7], num);
            P[6] = strassenMethod(S[9], S[9], num);

            C[0][0] = addMatrix(subtractMatrix(addMatrix(P[4], P[3], num), P[1], num), P[5], num);
            C[0][1] = addMatrix(P[0], P[1], num);
            C[1][0] = addMatrix(P[2], P[3], num);
            C[1][1] = subtractMatrix(subtractMatrix(addMatrix(P[4], P[0], num), P[2], num), P[6], num);

            num *= 2;

            for (int i = 0; i < num / 2; i++) {
                for (int j = 0; j < num / 2; j++) {
                    result[i][j] = C[0][0][i][j];
                    result[i][j + num / 2] = C[0][1][i][j];
                    result[i + num / 2][j] = C[1][0][i][j];
                    result[i + num / 2][j + num / 2] = C[1][1][i][j];
                }
            }
        }
        return result;
    }

//矩阵加法
    private int[][] addMatrix(int[][] addMatrix1, int[][] addMatrix2, int num) {
        int[][] result = new int[num][num];

        for (int i = 0; i < num; i++) {
            for (int j = 0; j < num; j++) {
                result[i][j] = addMatrix1[i][j] + addMatrix2[i][j];
            }
        }
        return result;
    }

    //矩阵减法
    private int[][] subtractMatrix(int[][] addMatrix1, int[][] addMatrix2, int num) {
        int[][] result = new int[num][num];

        for (int i = 0; i < num; i++) {
            for (int j = 0; j < num; j++) {
                result[i][j] = addMatrix1[i][j] - addMatrix2[i][j];
            }
        }
        return result;
    }

三、大整数相乘

大数相乘问题推荐参考

1. 适用场景

  • 密码学和网络安全
  • 加密和解密

2. 分治法

1. 伪代码

MULT(X,Y){
	if |X| = |Y| = 1 
		then do return XY
	else 
		return MULT(a, c)*2^n + (MULT(a, d) + MULT(b,c))2^(n / 2) + MULT(b, d)
}

2. 时间复杂度

T ( n ) = { 1 i f n = 1 4 T ( n / 2 ) + θ ( n ) i f n &gt; 1 T(n) = \begin{cases} 1 &amp; \text if &amp;n = 1 \\ 4 * T(n/2) + \theta(n) &amp; \text if &amp; n &gt; 1 \end{cases}
T ( n ) = θ ( n 2 ) T(n) = \theta(n^2)

3. Java 代码

//规模只要在这个范围内就可以直接计算了
    private final static int SIZE = 4;

    // 此方法要保证入参len为X、Y的长度最大值
    private static String bigIntMultiply(String X, String Y, int len) {
        // 最终返回结果
        String str = "";
        // 补齐X、Y,使之长度相同
        X = formatNumber(X, len);
        Y = formatNumber(Y, len);
        // 少于4位数,可直接计算
        if (len <= SIZE) {
            return "" + (Integer.parseInt(X) * Integer.parseInt(Y));
        }
        // 将X、Y分别对半分成两部分
        int len1 = len / 2;
        int len2 = len - len1;
        String A = X.substring(0, len1);
        String B = X.substring(len1);
        String C = Y.substring(0, len1);
        String D = Y.substring(len1);

        // 乘法法则,分块处理
        int lenM = Math.max(len1, len2);
        String AC = bigIntMultiply(A, C, len1);
        String AD = bigIntMultiply(A, D, lenM);
        String BC = bigIntMultiply(B, C, lenM);
        String BD = bigIntMultiply(B, D, len2);
        // 处理BD,得到原位及进位
        String[] sBD = dealString(BD, len2);
        // 处理AD+BC的和
        String ADBC = addition(AD, BC);
        // 加上BD的进位
        if (!"0".equals(sBD[1])) {
            ADBC = addition(ADBC, sBD[1]);
        }
        // 得到ADBC的进位
        String[] sADBC = dealString(ADBC, lenM);
        // AC加上ADBC的进位
        AC = addition(AC, sADBC[1]);
        // 最终结果
        str = AC + sADBC[0] + sBD[0];
        return str;
    }

    // 两个数字串按位加
    private static String addition(String ad, String bc) {
        // 返回的结果
        String str = "";
        // 两字符串长度要相同
        int lenM = Math.max(ad.length(), bc.length());
        ad = formatNumber(ad, lenM);
        bc = formatNumber(bc, lenM);
        // 按位加,进位存储在temp中
        int flag = 0;
        // 从后往前按位求和
        for (int i = lenM - 1; i >= 0; i--) {
            int t = flag + Integer.parseInt(ad.substring(i, i + 1)) + Integer.parseInt(bc.substring(i, i + 1));
            // 如果结果超过9,则进位当前位只保留个位数
            if (t > 9) {
                flag = 1;
                t = t - 10;
            } else {
                flag = 0;
            }
            // 拼接结果字符串
            str = "" + t + str;
        }
        if (flag != 0) {
            str = "" + flag + str;
        }
        return str;
    }

    // 处理数字串,分离出进位;
    // String数组第一个为原位数字,第二个为进位
    private static String[] dealString(String ac, int len1) {
        String[] str = {ac, "0"};
        if (len1 < ac.length()) {
            int t = ac.length() - len1;
            str[0] = ac.substring(t);
            str[1] = ac.substring(0, t);
        } else {
            // 要保证结果的length与入参的len一致,少于则高位补0
            String result = str[0];
            for (int i = result.length(); i < len1; i++) {
                result = "0" + result;
            }
            str[0] = result;
        }
        return str;
    }

    // 乘数、被乘数位数对齐
    private static String formatNumber(String x, int len) {
        while (len > x.length()) {
            x = "0" + x;
        }
        return x;
    }

3. Karatsuba 算法

1. 伪代码

MULT(X,Y){
	if |X| = |Y| = 1 
		then do return X * Y
	else
		A1 = MULT(a,c);
		A2 = MULT(b,d);
		A3 = MULT((a+b)(c+d));
}

2. 时间复杂度

T ( n ) = { 1 i f n = 1 3 T ( n / 2 ) + θ ( n ) i f n &gt; 1 T(n) = \begin{cases} 1 &amp; \text if &amp;n = 1 \\ 3 * T(n/2) + \theta(n) &amp; \text if &amp; n &gt; 1 \end{cases}
T ( n ) = θ ( n l o g 2 3 ) = θ ( n 1.58 ) T(n) = \theta(n ^ {log_23}) = \theta(n ^ {1.58})

3. Java 代码

public static long karatsuba(long num1, long num2){
        //递归终止条件
        if(num1 < 10 || num2 < 10) {
            return num1 * num2;
        }

        // 计算拆分长度
        int size1 = String.valueOf(num1).length();
        int size2 = String.valueOf(num2).length();
        int halfN = Math.max(size1, size2) / 2;

    /* 拆分为a, b, c, d */
        long a = Long.valueOf(String.valueOf(num1).substring(0, size1 - halfN));
        long b = Long.valueOf(String.valueOf(num1).substring(size1 - halfN));
        long c = Long.valueOf(String.valueOf(num2).substring(0, size2 - halfN));
        long d = Long.valueOf(String.valueOf(num2).substring(size2 - halfN));

        // 计算z2, z0, z1, 此处的乘法使用递归
        long z2 = karatsuba(a, c);
        long z0 = karatsuba(b, d);
        long z1 = karatsuba((a + b), (c + d)) - z0 - z2;

        return (long)(z2 * Math.pow(10, (2*halfN)) + z1 * Math.pow(10, halfN) + z0);
    }

四、最近点对问题

1. 算法

Closest-pair§

  • Preprocessing: (预处理,递增排列)
    • Construct Px and Py as sorted-list by x- and y-coordinates
  • Divide (分割)
    • Construct L, Lx , Ly and R, Rx , Ry
  • Conquer (治理)
    • Let A1= Closest-Pair(L, Lx , Ly )
    • Let A2= Closest-Pair(R, Rx , Ry )
  • Combination (整合)
    • Let A = min(A1 , A2 )
    • Construct S and Sy (按y坐标排序)
    • For each point in Sy, check each of its next 7points down the list(检查与后续7个点之间的距离)
    • If the distance is less than A , update the A as this smaller distance(如果存在小于A的距离,则为距离最短点对)

2. 时间复杂度

T ( n ) = O ( n l o g n ) T(n) = O(n * logn)

3. Java 代码实现

public class NPointPair {
    /**
     * 最近点问题
     *
     * @param S
     */
    public static DcPoint[] closestPoint(DcPoint[] S) {
        DcPoint[] result = new DcPoint[2];
        //0.首先,解决该问题的边界,当数组长度在一定范围内时直接求出最近点,蛮力求解
        double dmin = Double.POSITIVE_INFINITY;
        double tmpmin = 0;
        if (S.length <= 20) {
            for (int i = 0; i < S.length; i++) {
                for (int j = i + 1; j < S.length; j++) {
                    tmpmin = Math.sqrt(Math.pow(S[i].getX() - S[j].getX(), 2)) + Math.pow(S[i].getY() - S[j].getY(), 2);
                    if (tmpmin < dmin) {
                        dmin = tmpmin;
                        result[0] = S[i];
                        result[1] = S[j];
                    }
                }
            }
            return result;
        }
        //1.求所有点在X坐标的中位数
        //保证假设的初始最小值足够大
        int minX = (int) Double.POSITIVE_INFINITY;
        //保证假设的初始最大值足够小
        int maxX = (int) Double.NEGATIVE_INFINITY;
        //对数组对象根据 X 排序
        Arrays.sort(S, new Comparator<DcPoint>() {
            @Override
            public int compare(DcPoint o1, DcPoint o2) {
                if (o1.getX() > o2.getX()){
                    return 1;
                }
                if (o1.getX() < o2.getX()){
                    return -1;
                }
                return 0;
            }
        });
        for (int i = 0; i < S.length; i++) {
            if (S[i].getX() < minX) {
                minX = S[i].getX();
            }
            if (S[i].getX() > maxX) {
                maxX = S[i].getX();
            }
        }
        int midX = (minX + maxX) / 2;
        //2.以midX为界将所有点分成两组分别存放在两个表中
        ArrayList T1 = new ArrayList();
        ArrayList T2 = new ArrayList();
        for (int i = 0; i < S.length; i++) {
            if (S[i].getX() <= midX) {
                T1.add(S[i]);
            }
            if (S[i].getX() > midX) {
                T2.add(S[i]);
            }
        }
        //3.将两张表转化为数组类型,并分别按X坐标升序排列
        DcPoint[] S1 = new DcPoint[T1.size()];
        DcPoint[] S2 = new DcPoint[T2.size()];
        T1.toArray(S1);
        T2.toArray(S2);
        //按X坐标升序排列
        mergeSort(S1, "x");
        mergeSort(S2, "x");
        //4.求S1中的最近距离的两个点
        DcPoint[] result1 = closestPoint(S1);
        //5.求S2中的最近距离的两个点
        DcPoint[] result2 = closestPoint(S2);
        //6.求两最近距离的最小值
        double d1 = Math.sqrt(Math.min(Math.pow(result1[0].getX() - result1[1].getX(), 2) + Math.pow(result1[0].getY() - result1[1].getY(), 2), Math.pow(result2[0].getX() - result2[1].getX(), 2) + Math.pow(result2[0].getY() - result2[1].getY(), 2)));
        if (Math.pow(result1[0].getX() - result1[1].getX(), 2) + Math.pow(result1[0].getY() - result1[1].getY(), 2) < Math.pow(result2[0].getX() - result2[1].getX(), 2) + Math.pow(result2[0].getY() - result2[1].getY(), 2)) {
            result = result1;
        } else {
            result = result2;
        }
        //7.在S1、S2中收集距离中线小于d1的点,分别存放于两个表中
        ArrayList T3 = new ArrayList();
        ArrayList T4 = new ArrayList();
        for (int i = 0; i < S1.length; i++) {
            if (midX - S1[i].getX() < d1) {
                T3.add(S1[i]);
            }
        }
        for (int i = 0; i < S2.length; i++) {
            if (S2[i].getX() - midX < d1) {
                T4.add(S2[i]);
            }
        }
        //8.分别将表T3、T4转换为数组类型的S3、S4,并将其分别按Y坐标升序排列
        DcPoint[] S3 = new DcPoint[T3.size()];
        DcPoint[] S4 = new DcPoint[T4.size()];
        T3.toArray(S3);
        T4.toArray(S4);
        mergeSort(S3, "y");
        mergeSort(S4, "y");
        //求解S3、S4两者之间可能的更近(相比于d1)距离 , 以及构成该距离的点
        double d = Double.POSITIVE_INFINITY;
        for (int i = 0; i < S3.length; i++) {
            for (int j = 0; j < S4.length; j++) {
                if (Math.abs(S3[i].getY() - S4[j].getY()) < d1) {
                    double tmp = Math.sqrt(Math.pow(S3[i].getX() - S4[j].getX(), 2) + Math.pow(S3[i].getY() - S4[j].getY(), 2));
                    if (tmp < d) {
                        d = tmp;
                        result[0] = S3[i];
                        result[1] = S4[j];
                    }
                }
            }
        }
        return result;
    }

    private static void mergeSort(DcPoint[] a, String property) {
        DcPoint[] tempArray = new DcPoint[a.length];
        mergeSort(a, tempArray, 0, a.length - 1, property);
    }

    private static void mergeSort(DcPoint[] a, DcPoint[] tempArray, int left, int right, String property) {
        if (left < right) {
            int center = (left + right) >> 1;
            //分治
            mergeSort(a, tempArray, left, center, property);
            mergeSort(a, tempArray, center + 1, right, property);
            //合并
            merge(a, tempArray, left, center + 1, right, property);
        }
    }

    private static void merge(DcPoint[] a, DcPoint[] tempArray, int leftPos, int rightPos, int rightEnd, String property) {
        int leftEnd = rightPos - 1;
        int numOfElements = rightEnd - leftPos + 1;
        int tmpPos = leftPos;
        //游标变量, 另两个游标变量分别是leftPos 和 rightPos
        while (leftPos <= leftEnd && rightPos <= rightEnd) {
            if (property.equals("x")) {
                if (a[leftPos].getX() <= a[rightPos].getX()) {
                    tempArray[tmpPos++] = a[leftPos++];
                } else {
                    tempArray[tmpPos++] = a[rightPos++];
                }
            } else if (property.equals("y")) {
                if (a[leftPos].getY() <= a[rightPos].getY()) {
                    tempArray[tmpPos++] = a[leftPos++];
                } else {
                    tempArray[tmpPos++] = a[rightPos++];
                }
            } else {
                throw new RuntimeException();
            }
        }
        while (leftPos <= leftEnd) {
            tempArray[tmpPos++] = a[leftPos++];
        }
        while (rightPos <= rightEnd) {
            tempArray[tmpPos++] = a[rightPos++];
            //将排好序的段落拷贝到原数组中
            System.arraycopy(tempArray, rightEnd - numOfElements + 1, a, rightEnd - numOfElements + 1, numOfElements);
        }
    }

    public static void main(String[] args) {
        Set<DcPoint> testData = new TreeSet<DcPoint>();
        Random random = new Random();
        int x = 0;
        int y = 0;
        for (int i = 0; i < 100; i++) {
            x = random.nextInt(500);
            y = random.nextInt(500);
            testData.add(new DcPoint(x, y));
        }
        DcPoint[] S = new DcPoint[testData.size()];
        S = (DcPoint[]) testData.toArray(S);
        for (int i = 0; i < S.length; i++) {
            System.out.println("(" + S[i].getX() + ", " + S[i].getY() + ")");
        }
        System.out.println(testData.size());
        DcPoint[] result = new DcPoint[2];
        result = closestPoint(S);
        System.out.println("最近的两点分别是(" + result[0].getX() + ", " + result[0].getY() + ") 和 (" + result[1].getX() + ", " + result[1].getY() + "), 最近距离为:" + Math.sqrt(Math.pow(result[0].getX() - result[1].getX(), 2) + Math.pow(result[0].getY() - result[1].getY(), 2)));
    }
} 

DcPoint 类:

public class DcPoint implements Cloneable, Comparable<DcPoint> {
    public DcPoint() {
        x = 0;
        y = 0;
    }

    public DcPoint(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public void setX(int x) {
        this.x = x;
    }

    public void setY(int y) {
        this.y = y;
    }

    public int getX() {
        return x;
    }

    public int getY() {
        return y;
    }

    private int x;
    private int y;

    @Override
    public int compareTo(DcPoint o) {
        if (x == o.getX() && y == o.getY()) {
            return 0;
        } else {
            return 1;
        }
    }
}

猜你喜欢

转载自blog.csdn.net/yin__ren/article/details/83213895
今日推荐