MapReduce 基础算法【矩阵乘法】

1. 矩阵乘法原理和实现思路

对于任意矩阵 M N ,若矩阵 M 的列数等于矩阵 N 的行数,则记 M N 的乘积 P = M N 。其中 m i j 记作矩阵 M 的第 i 行第 j 列的元素, n j k 记作矩阵 N 的第 j 行第 k 列的元素,则其乘积矩阵 P 的元素可由下式求得:

p i k = ( M N ) i k = j m i j n j k

可以得出,决定最后 p i k 位置的是 i , k ,所以可以将其作为 R e d u c e r 的输入 k e y 值。而为了求出 m i j n j k ,我们需要分别知道 m i j n j k 。对于 m i j ,其所需要的属性有矩阵名称 M ,所在行数 i ,所在列数 j ,和其本身的数值大小 m i j ;同样对于 n j k ,其所需要的属性有矩阵名称 N ,所在行数 j ,所在列数 k ,和其本身的数值大小 n j k 。这些属性值由 M a p p e r 处理得到,基本处理思路下:
Map函数:对于矩阵 M 中的每个元素 m i j ,产生一系列的 k e y v a l u e < ( i , k ) , ( M , j , m i j ) > ,其中 k = 1 , 2.... 直到矩阵 N 的总列数;对于矩阵 N 中的每个元素 n j k ,产生一系列的 k e y v a l u e < ( i , k ) , ( N , j , n j k ) > ,其中 i = 1 , 2... 直到矩阵 M 的总行数。
Reduce函数:对于每个键 ( i , k ) 相关联的值 ( M , j , m i j ) ( N , j , n j k ) ,根据相同的 j 值将 m i j n j k 分别存入不同数组中,然后将两者的第 j 个元素抽取出来分别相乘,最后相加,即可得到 p i k 的值。
示例:

设矩阵 M [ 1 2 ] ,矩阵 N [ 2 1 3 0 2 4 ] ,其中, i = 1 , j = 1 , 2 , k = 1 , 2 , 3. 经过 m a p ( ) 函数之后得到如下的输出。

< ( 1 , 1 ) , ( M , 1 , m 11 ) >< ( 1 , 1 ) , ( N , 1 , n 11 ) > < ( 1 , 1 ) , ( M , 2 , m 12 ) >< ( 1 , 1 ) , ( N , 2 , n 21 ) > < ( 1 , 2 ) , ( M , 1 , m 11 ) >< ( 1 , 2 ) , ( N , 1 , n 12 ) > < ( 1 , 2 ) , ( M , 2 , m 12 ) >< ( 1 , 2 ) , ( N , 2 , n 22 ) > < ( 1 , 3 ) , ( M , 1 , m 11 ) >< ( 1 , 3 ) , ( N , 1 , n 13 ) > < ( 1 , 3 ) , ( M , 1 , m 12 ) >< ( 1 , 3 ) , ( N , 1 , n 23 ) >

Reduce 函数对于输入的每个 k e y ( i , k ) ,根据 j 值进行抽取出对应的元素 m i j n j k 相乘,然后再累加。
对于 k e y 值为(1,1)的输入: m 11 × n 11 + m 12 × n 21 = 1 × 2 + 2 × 0 = 2 ,

对于 k e y 值为(1,2)的输入: m 11 × n 12 + m 12 × n 22 = 1 × 1 + 2 × 2 = 5 ,

对于 k e y 值为(1,1)的输入: m 11 × n 13 + m 12 × n 23 = 1 × 3 + 2 × 4 = 11.

2.矩阵乘法的MapReduce程序实现

思路:一共有两个输入文本文件,分别存放矩阵 M N 的元素,文件内容每一行的形式是”行坐标,列坐标\t元素数值”,在map端,输出形式为 < k e y ( i , k ) , ( M , j , m i j ) > ,同理,N矩阵也是如此。在reduce端,使用两个数组分别存来自 R e d u c e 端的相同键的不同值 V a l u e 的M矩阵的行值,N矩阵的列值。

Sell生成实验数据:

#!/bin/bash
if [ $# -ne 3 ]
then
  echo "there must be 3 arguments to generate the two matries file!"
  exit 1
fi
cat /dev/null > M_$1_$2
cat /dev/null > N_$2_$3
for i in `seq 1 $1`
do
    for j in `seq 1 $2`
    do
        s=$((RANDOM%100))
        echo –e "$i,$j\t$s" >>M_$1_$2
    done
done
echo "we have built the matrix file M"

for i in `seq 1 $2`
    do
    for j in ` seq 1 $3`
    do
        s=$((RANDOM%100))
        echo -e "$i,$j\t$s" >>N_$2_$3 
    done
done
echo "we have built the matrix file N"

MapReduce 代码:

package cn.zzuli.zcs0;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;

import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.IOException;

/**
 * Created by 张超帅 on 2018/8/15.
 */
public class MatrixMultiply {
    public static int rowM = 0;
    public static int columnM = 0;
    public static int columnN = 0;
    public static class MatrixMapper extends Mapper<Object, Text, Text, Text> {
        private Text map_key = new Text();
        private Text map_value = new Text();

        public void setup(Context context) throws IOException {
            Configuration conf = context.getConfiguration();
            columnM = Integer.parseInt(conf.get("columnN"));
            rowM = Integer.parseInt(conf.get("rowM"));
        }

        @Override
        protected void map(Object key, Text value, Context context) throws IOException, InterruptedException {
            FileSplit fileSplit = (FileSplit) context.getInputSplit();
            String filename = fileSplit.getPath().getName();

            if(filename.contains("M")) {
                String[] tuple = value.toString().split(",");
                int i = Integer.parseInt(tuple[0]);
                String[] tuples = tuple[1].split("\t");
                int j = Integer.parseInt(tuples[0]);
                int Mij = Integer.parseInt(tuples[1]);

                for(int k = 1; k < columnN + 1; k ++) {
                    map_key.set(i + "," + k);
                    map_value.set("M" + "," + j + "," + Mij);
                    context.write(map_key,map_value);
                }
            }
            else if(filename.contains("N")) {
                String[] tuple = value.toString().split(",");
                int j = Integer.parseInt(tuple[0]);
                String[] tuples = tuple[1].split("\t");
                int k = Integer.parseInt(tuples[0]);
                int Njk = Integer.parseInt(tuples[1]);

                for(int i = 1; i < rowM + 1; i ++) {
                    map_key.set(i + "," + k);
                    map_value.set("N" + "," + j + "," + Njk);
                    context.write(map_key, map_value);
                }
            }
        }
    }

    public  static class MatrixReducer extends Reducer<Text, Text, Text, Text> {
        private int sum = 0;

        @Override
        protected void setup(Context context) throws IOException, InterruptedException {
            Configuration conf = context.getConfiguration();
            columnM = Integer.parseInt(conf.get("columnM"));

        }

        @Override
        protected void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
            int[] M = new int[columnM + 1];
            int[] N = new int[columnM + 1];
            for(Text val : values) {
                String[] tuple = val.toString().split(",");
                if(tuple[0].equals("M")) {
                    M[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
                } else
                    N[Integer.parseInt(tuple[1])] = Integer.parseInt(tuple[2]);
            }
            for(int j = 1; j < columnM + 1; j ++) {
                sum += M[j] * N[j];
            }
            context.write(key, new Text(Integer.toString(sum)));
            sum = 0;
        }
    }

    public static void main(String[] args) throws  Exception{
        if(args.length != 3) {
            System.err.println("Usage: MatrixMultiply <inputPathM> <inputPathN> <outputPath>");
            System.exit(2);
        } else {
            String[] infoTupleM = args[0].split("_");
            rowM = Integer.parseInt(infoTupleM[1]);
            columnM = Integer.parseInt(infoTupleM[2]);
            String[] infoTupleN = args[1].split("_");
            columnN = Integer.parseInt(infoTupleN[2]);
        }

        Configuration conf = new Configuration();
        conf.setInt("rowM",rowM);
        conf.setInt("columnM", columnM);
        conf.setInt("columnN", columnN);

        Job job = Job.getInstance(conf, "MatrixMultiply");
        job.setJarByClass(MatrixMultiply.class);
        job.setMapperClass(MatrixMapper.class);
        job.setReducerClass(MatrixReducer.class);
        job.setOutputKeyClass(Text.class);
        job.setOutputValueClass(Text.class);
        FileInputFormat.setInputPaths(job,new Path(args[0]), new Path(args[1]));
        FileOutputFormat.setOutputPath(job,new Path(args[2]));

        System.exit(job.waitForCompletion(true)? 0: 1);


    }
}

猜你喜欢

转载自blog.csdn.net/qq_38386316/article/details/81701286