Android APP 集成 Unet进行图像语义分割【tensorflow】

版权声明:我是南七小僧,微信: to_my_love ,2020年硕士毕业,寻找 自然语言处理,图像处理,软件开发等相关工作,欢迎交流思想碰撞。 https://blog.csdn.net/qq_25439417/article/details/86564354

环境:

WIN7 64 + Android Studio3.3

Python 3.6

Keras 2.3

TF 1.9

概述:

1.先用Keras训练网络,保存为h5文件【model.save('xxx.h5')】

2.用Keras2pb.py 把h5文件转成Tf的pb文件

3.在Android src/main下新建Assets文件夹,把pb放到里面

4.Android gradle【app】里implementation一个Tf包

implementation 'org.tensorflow:tensorflow-android:+'

5.调用TF JAVA接口feed【input】 run fetch【output】

     

 Keras转PB:

# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a prunned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    prunned so subgraphs that are not neccesary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


#input_fld = sys.path[0]
weight_file = 'best_unet_mask_shishang.h5'
output_graph_name = 'tensor_model.pb'

output_fld = '.'
if not os.path.isdir(output_fld):
    os.mkdir(output_fld)
weight_file_path = osp.join('.', weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)


print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

PB测试预测,用TF Graph跑测试:

import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2


def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tensors = tf.import_graph_def(output_graph_def, name="")
#            print(tensors)

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            op = sess.graph.get_operations()

#          
#            for m in op:
#                print(m.values())

            input_x = sess.graph.get_tensor_by_name("input_10:0")  #具体名称看上一段代码的input.name
#            print input_x

            out_softmax = sess.graph.get_tensor_by_name("conv2d_215/Sigmoid:0") #具体名称看上一段代码的output.name

#            print out_softmax

            img = cv2.imread(jpg_path)
            img = cv2.resize(img,(256,144))
            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.array(img).reshape((-1,144, 256, 3)) / 255.0})

#            print "img_out_softmax:", img_out_softmax
#            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print(img_out_softmax)
            return img_out_softmax
#            print "label:", prediction_labels


pb_path = 'tensor_model.pb'
img = 'heikai.png'
a = recognize(img, pb_path)

a =a[0,:,:,0]
#a = a.flatten()
l = len(np.where(a>=0.5)[0])
a[a>0.5] = 255
a[a<=0.5] = 0
a = a.astype(np.uint8)
Image.fromarray(a).show()

 效果:(U-Net人眼标注)


Android上开发

图像处理库:【这是我写的一个适合我的算法库,不一定适合你们】

package com.keraseye.xkk.keraseye;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.os.Environment;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import org.json.*;

/**
 * Utility class for manipulating images.
 **/
public class ImageUtils {
    /**
     * Returns a transformation matrix from one reference frame into another.
     * Handles cropping (if maintaining aspect ratio is desired) and rotation.
     *
     * @param srcWidth Width of source frame.
     * @param srcHeight Height of source frame.
     * @param dstWidth Width of destination frame.
     * @param dstHeight Height of destination frame.
     * @param applyRotation Amount of rotation to apply from one frame to another.
     *  Must be a multiple of 90.
     * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
     * cropping the image if necessary.
     * @return The transformation fulfilling the desired requirements.
     */
    public static Matrix getTransformationMatrix(
            final int srcWidth,
            final int srcHeight,
            final int dstWidth,
            final int dstHeight,
            final int applyRotation,
            final boolean maintainAspectRatio) {
        final Matrix matrix = new Matrix();

        if (applyRotation != 0) {
            // Translate so center of image is at origin.
            matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);

            // Rotate around origin.
            matrix.postRotate(applyRotation);
        }

        // Account for the already applied rotation, if any, and then determine how
        // much scaling is needed for each axis.
        final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;

        final int inWidth = transpose ? srcHeight : srcWidth;
        final int inHeight = transpose ? srcWidth : srcHeight;

        // Apply scaling if necessary.
        if (inWidth != dstWidth || inHeight != dstHeight) {
            final float scaleFactorX = dstWidth / (float) inWidth;
            final float scaleFactorY = dstHeight / (float) inHeight;

            if (maintainAspectRatio) {
                // Scale by minimum factor so that dst is filled completely while
                // maintaining the aspect ratio. Some image may fall off the edge.
                final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
                matrix.postScale(scaleFactor, scaleFactor);
            } else {
                // Scale exactly to fill dst from src.
                matrix.postScale(scaleFactorX, scaleFactorY);
            }
        }

        if (applyRotation != 0) {
            // Translate back from origin centered reference to destination frame.
            matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
        }

        return matrix;
    }


    public static Bitmap processBitmap(Bitmap source,int wsize,int hsize){

        int image_height = source.getHeight();
        int image_width = source.getWidth();

        Bitmap croppedBitmap = Bitmap.createBitmap(wsize, hsize, Bitmap.Config.ARGB_8888);

        Matrix frameToCropTransformations = getTransformationMatrix(image_width,image_height,wsize, hsize,0,false);
        Matrix cropToFrameTransformations = new Matrix();
        frameToCropTransformations.invert(cropToFrameTransformations);

        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(source, frameToCropTransformations, null);

        return croppedBitmap;


    }

    public static float[] normalizeBitmap(Bitmap source,int wsize, int hsize,float mean,float std){

        float[] output = new float[wsize * hsize * 3];

        int[] intValues = new int[source.getHeight() * source.getWidth()];
//        System.out.println("hh"+source.getHeight());
//        System.out.println("ww"+source.getWidth());

        source.getPixels(intValues, 0, source.getWidth(), 0, 0, source.getWidth(), source.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            output[i * 3 + 2] = ((val >> 16) & 0xFF)/255f;
//            System.out.println(output[i * 3] );
            output[i * 3 + 1] = ((val >> 8) & 0xFF)/255f;
            output[i * 3 + 0] = (val & 0xFF)/255f;
        }

        return output;

    }

    public static Object[] argmax(float[] array){


        int best = -1;
        float best_confidence = 0.0f;

        for(int i = 0;i < array.length;i++){

            float value = array[i];

            if (value > best_confidence){

                best_confidence = value;
                best = i;
            }
        }



        return new Object[]{best,best_confidence};


    }


    public static String getLabel( InputStream jsonStream,int index){
        String label = "";
        try {

            byte[] jsonData = new byte[jsonStream.available()];
            jsonStream.read(jsonData);
            jsonStream.close();

            String jsonString = new String(jsonData,"utf-8");

            JSONObject object = new JSONObject(jsonString);

            label = object.getString(String.valueOf(index));



        }
        catch (Exception e){


        }
        return label;
    }

    public static Bitmap floatToBitmap(Bitmap bitmap,float[] img_float_array){

//        float[] output = new float[wsize * hsize * 3];

        int[] intValues = new int[bitmap.getHeight() * bitmap.getWidth()];

        int c=0;

        //这样写有很重的问题,因为TF的返回shape不是想象中的那样,需要自己构造
//        int c= 0;
//        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
//        for(int i = 0 ; i < bitmap.getHeight();i++){
//            for(int j =0 ;j<bitmap.getWidth();j++){
////                System.out.println(img_float_array.length);
//                if(img_float_array[j + i*bitmap.getHeight()]>=0.5){
//                    System.out.println(img_float_array[j + i*bitmap.getHeight()]);
//                    bitmap.setPixel(j,i,Color.WHITE);
//                c++;}
//                else{
//                    bitmap.setPixel(j,i,Color.BLACK);
//                }
//            }
//        }
//        System.out.println(c+"个");
        for (int i = 0; i < img_float_array.length; ++i) {

            if(img_float_array[i]>=0.5){c++;bitmap.setPixel(i%bitmap.getWidth(),i/bitmap.getWidth(),Color.WHITE);}
            else{bitmap.setPixel(i%bitmap.getWidth(),i/bitmap.getWidth(),Color.BLACK);}
        }
        System.out.println(c);

        return bitmap;

//
//
//        for (int i = 0; i < intValues.length; ++i) {
//            final int val = intValues[i];
//            output[i * 3] = ((val >> 16) & 0xFF)/255f;
////            System.out.println(output[i * 3] );
//            output[i * 3 + 1] = ((val >> 8) & 0xFF)/255f;
//            output[i * 3 + 2] = (val & 0xFF)/255f;
//        }
//
//        return output;
    }


}

MainActivity主控程序:【Activity可以参考我前面的Android开发来学习】

package com.keraseye.xkk.keraseye;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.AsyncTask;
import android.os.Bundle;
import android.renderscript.ScriptGroup;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.util.JsonReader;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import org.json.*;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.FileInputStream;
import java.io.InputStream;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.InputStream;
import java.nio.IntBuffer;

public class MainActivity extends AppCompatActivity {
    /*
     * 在需要调用TensoFlow的地方,加载so库“System.loadLibrary("tensorflow_inference");
     * 并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了
     * */
    //Load the tensorflow inference library
    //static{}(即static块),会在类被加载的时候执行且仅会被执行一次,一般用来初始化静态变量和调用静态方法。
    static {
        System.loadLibrary("tensorflow_inference");
    }

    //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
    //各节点名称
    private String MODEL_PATH = "file:///android_asset/tensor_model.pb";
    private String INPUT_NAME = "input_15:0";
    private String OUTPUT_NAME = "conv2d_310/Sigmoid:0";
    private TensorFlowInferenceInterface tf;

    //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
    //保存图片和图片尺寸的
    float[] PREDICTIONS = new float[36864];
    private float[] floatValues;
    private int[] INPUT_SIZE = {224,224,3};

    ImageView imageView;
    TextView resultView;
    Button buttonSub;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);

        imageView=(ImageView)findViewById(R.id.imageView1);
        resultView=(TextView)findViewById(R.id.text_show);
        buttonSub=(Button)findViewById(R.id.button1);

        buttonSub.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    System.out.println("哈哈哈");
                    InputStream imageStream = getAssets().open("222.jpg");
                    Bitmap bitmap = BitmapFactory.decodeStream(imageStream);
                    imageView.setImageBitmap(bitmap);

                    predict(bitmap);

                }catch(Exception e){

                }


            }
        });

    }

    //FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
    public Object[] argmax(float[] array){

        int best = -1;
        float best_confidence = 0.0f;
        for(int i = 0;i < array.length;i++){
            float value = array[i];
            if (value > best_confidence){
                best_confidence = value;
                best = i;
            }
        }
        return new Object[]{best,best_confidence};
    }



    public void predict(final Bitmap bitmap){

        //Runs inference in background thread
        new AsyncTask<Integer,Integer,Integer>(){

            @Override
            protected Integer doInBackground(Integer ...params){
                //Resize the image into 224 x 224
                Bitmap resized_image = ImageUtils.processBitmap(bitmap,256,144);
                final Bitmap a = resized_image;
                runOnUiThread(new Runnable() {
                    @Override
                    public void run() {

                        imageView.setImageBitmap(a);
                    }
                });
                long startTime=System.currentTimeMillis();   //获取开始时间
//                doSomeThing();  //测试的代码段
                //Normalize the pixels
                floatValues = ImageUtils.normalizeBitmap(resized_image,256, 144,127.5f,1.0f);

                //Pass input into the tensorflow
                tf.feed(INPUT_NAME,floatValues,1,144,256,3);

                //compute predictions
                tf.run(new String[]{OUTPUT_NAME});


                //copy the output into the PREDICTIONS array
                tf.fetch(OUTPUT_NAME,PREDICTIONS);

                long endTime=System.currentTimeMillis(); //获取结束时间
                System.out.println("程序运行时间: "+(endTime-startTime));
                System.out.println(PREDICTIONS);
//                final Bitmap Img = Bitmap.createBitmap(256, 144, Bitmap.Config.RGB_565);
//                Img.copyPixelsFromBuffer(IntBuffer.wrap(ImageUtils.floatToBitmap(PREDICTIONS)));
                final Bitmap bwimg = ImageUtils.floatToBitmap(resized_image,PREDICTIONS);
//                Img.setPixels(ia, 0, 256, 0, 0, 256, 144);

                runOnUiThread(new Runnable() {
                    @Override
                    public void run() {

//                        imageView.setImageBitmap(bwimg);
                    }
                });
                //Obtained highest prediction
//                Object[] results = argmax(PREDICTIONS);
//                System.out.println(ia[2]);
//                int class_index = (Integer) results[0];
//                float confidence = (Float) results[1];

                try{
//                    final String conf = String.valueOf(confidence * 100).substring(0,5);
                    //Convert predicted class index into actual label name
//                    final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index);
                    //Display result on UI
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            resultView.setText(PREDICTIONS[2] + " : " + PREDICTIONS[29232]);
                        }
                    });
                } catch (Exception e){
                }

                return 0;
            }

        }.execute(0);

    }
}

我对Unet做了大量优化,把网络模型降到了1.5MB,在电脑上跑能达到100FPS,在Android上5FPS左右,200MS左右一张语义分割。

速度仍然较慢,可能和TF底层的优化有关,接下来我回去研究下TF JAVA底层的实现。 

 坑点:

1.

注意网络Input和Output层名,这里我才Keras2Pb中做了输出,记得替换

2.BitMap和TF输出的形状不一样,要自己转换

 最终效果

 

因为设备资源有限,我就训练了几个Epoch就达到了不错的效果。

我自己写了一套图像增强算法,对数据集进行扩充。

猜你喜欢

转载自blog.csdn.net/qq_25439417/article/details/86564354