TFLite: interprenter run

  /** Classifies a frame from the preview stream. */
  private void classifyFrame() {
    Bitmap bitmap = //getBitmap
                              textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
    String textToShow = classifier.classifyFrame(bitmap);
    ----
  }

imgData =                                                                                                                                                     ByteBuffer.allocateDirect(     
            DIM_BATCH_SIZE//1     
                * getImageSizeX()     
                * getImageSizeY()     
                * DIM_PIXEL_SIZE//3     
                * getNumBytesPerChannel()//1);
 

/** Writes Image data into a {@code ByteBuffer}. */
  private void convertBitmapToByteBuffer(Bitmap bitmap) {
    imgData.rewind();
    //保存到 initValues
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    // Convert the image to floating point.
    int pixel = 0;
    long startTime = SystemClock.uptimeMillis();
    for (int i = 0; i < getImageSizeX(); ++i) {
      for (int j = 0; j < getImageSizeY(); ++j) {
        final int val = intValues[pixel++];
        addPixelValue(val);
      }
    }
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
  }

  pixels    int: The array to receive the bitmap’s colors
offset      int: The first index to write into pixels[]
stride     int: The number of entries in pixels[] to skip between rows (must be >= bitmap’s width). Can be negative.
x            int: The x coordinate of the first pixel to read from the bitmap
y            int: The y coordinate of the first pixel to read from the bitmap
width     int: The number of pixels to read from each row
height    int: The number of rows to read

  //赋值imageData
  protected void addPixelValue(int pixelValue) {
    imgData.put((byte) ((pixelValue >> 16) & 0xFF));
    imgData.put((byte) ((pixelValue >> 8) & 0xFF));
    imgData.put((byte) (pixelValue & 0xFF));
  }

/** Dimensions of inputs. */
  private static final int DIM_BATCH_SIZE = 1;
  private static final int DIM_PIXEL_SIZE = 3;

  /* Preallocated buffers for storing image data in. */
  private int[] intValues = new int[getImageSizeX() * getImageSizeY()];


  /** An instance of the driver class to run model inference with Tensorflow Lite. */
  protected Interpreter tflite;

  /** Labels corresponding to the output of the vision model. */
  private List<String> labelList;

  /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
  protected ByteBuffer imgData = null;

  @Override
  protected int getImageSizeX() {
    return 224;
  }

  @Override
  protected int getImageSizeY() {
    return 224;
  }

tflitecamerademo/Camera2BasicFragment.java -> classifyFrame()

  /** Classifies a frame from the preview stream. */
  private void classifyFrame() {
    Bitmap bitmap = //getBitmap
        textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
    String textToShow = classifier.classifyFrame(bitmap);
    ----
  }

  /** Classifies a frame from the preview stream. */
  String classifyFrame(Bitmap bitmap) {

    convertBitmapToByteBuffer(bitmap);
    // Here's where the magic happens!!!
    long startTime = SystemClock.uptimeMillis();
    runInference();
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));

  }

  /** Writes Image data into a {@code ByteBuffer}. */
  private void convertBitmapToByteBuffer(Bitmap bitmap) {
    imgData.rewind();
    //保存到 initValues
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    // Convert the image to floating point.
    int pixel = 0;
    long startTime = SystemClock.uptimeMillis();
    for (int i = 0; i < getImageSizeX(); ++i) {
      for (int j = 0; j < getImageSizeY(); ++j) {
        final int val = intValues[pixel++];
        addPixelValue(val);
      }
    }
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
  }


pixels    int: The array to receive the bitmap’s colors
offset    int: The first index to write into pixels[]
stride    int: The number of entries in pixels[] to skip between rows (must be >= bitmap’s width). Can be negative.
x    int: The x coordinate of the first pixel to read from the bitmap
y    int: The y coordinate of the first pixel to read from the bitmap
width    int: The number of pixels to read from each row
height    int: The number of rows to read

  //赋值imageData
  protected void addPixelValue(int pixelValue) {
    imgData.put((byte) ((pixelValue >> 16) & 0xFF));
    imgData.put((byte) ((pixelValue >> 8) & 0xFF));
    imgData.put((byte) (pixelValue & 0xFF));
  }

  @Override
  protected void runInference() {
    tflite.run(imgData, labelProbArray);
  }

  /**
   * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
   * This isn't part of the super class, because we need a primitive array here.
   */
  private byte[][] labelProbArray = null;


tflite = new Interpreter(loadModelFile(activity));

imgData =                                                                                                                                                          ByteBuffer.allocateDirect(     
            DIM_BATCH_SIZE//1     
                * getImageSizeX()     
                * getImageSizeY()     
                * DIM_PIXEL_SIZE//3     
                * getNumBytesPerChannel()//1);

//tflite.run(imgData, labelProbArray);
  public void run(@NotNull Object input, @NotNull Object output) {
    //input是数组,output是Map
    Object[] inputs = {input};
    Map<Integer, Object> outputs = new HashMap<>();
    outputs.put(0, output);
    runForMultipleInputsOutputs(inputs, outputs);
  }

  public void runForMultipleInputsOutputs(
      @NotNull Object[] inputs, @NotNull Map<Integer, Object> outputs) {
    //返回的是个数组,copy to map中,为什么绕一圈?output直接使用 array?
    Tensor[] tensors = wrapper.run(inputs);
    final int size = tensors.length;
    for (Integer idx : outputs.keySet()) {
      tensors[idx].copyTo(outputs.get(idx));
    }
  }

  /** Sets inputs, runs model inference and returns outputs. */
  Tensor[] run(Object[] inputs) {

    int[] dataTypes = new int[inputs.length];
    Object[] sizes = new Object[inputs.length];
    int[] numsOfBytes = new int[inputs.length];

    for (int i = 0; i < inputs.length; ++i) {
      DataType dataType = dataTypeOf(inputs[i]);
      dataTypes[i] = dataType.getNumber();
      if (dataType == DataType.BYTEBUFFER) {
        ByteBuffer buffer = (ByteBuffer) inputs[i];
        numsOfBytes[i] = buffer.limit();
        sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
      } 
    }
    long[] outputsHandles =
        run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs);

    Tensor[] outputs = new Tensor[outputsHandles.length];
    for (int i = 0; i < outputsHandles.length; ++i) {
      outputs[i] = Tensor.fromHandle(outputsHandles[i]);
    }
    return outputs;
  }

  /** Returns the type of the data. */
  static DataType dataTypeOf(Object o) {
    if (o != null) {
      Class<?> c = o.getClass();
      while (c.isArray()) {
        c = c.getComponentType();
      }
      if (float.class.equals(c)) {
        return DataType.FLOAT32;
      } else if (int.class.equals(c)) {
        return DataType.INT32;
      } else if (byte.class.equals(c)) {
        return DataType.UINT8;
      } else if (long.class.equals(c)) {
        return DataType.INT64;
      } else if (ByteBuffer.class.isInstance(o)) {
        return DataType.BYTEBUFFER;
      }
    }
  }


  /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
  int getNumber() {
    return value;
  }

private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);

//通过
JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
    JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {

  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);

  const int idx = static_cast<int>(input_idx);
  // interpreter inputs的size??是怎么来的?
  if (input_idx >= interpreter->inputs().size()) {
    throwException(env, kIllegalArgumentException,
                   "Out of range: Failed to get %d-th input out of %d inputs",
                   input_idx, interpreter->inputs().size());
    return nullptr;
  }
  //根据index找到具体的TfLiteTensor
  TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
  int size = target->dims->size;
  int expected_num_bytes = elementByteSize(target->type);
  for (int i = 0; i < size; ++i) {
    expected_num_bytes *= target->dims->data[i];
  }

  if (num_bytes != expected_num_bytes) {//
    throwException(env, kIllegalArgumentException,
                   "Failed to get input dimensions. %d-th input should have"
                   " %d bytes, but found %d bytes.",
                   idx, expected_num_bytes, num_bytes);
    return nullptr;
  }

  jintArray outputs = env->NewIntArray(size);
  env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
  return outputs;
}

// An tensor in the interpreter system which is a wrapper around a buffer of
// data including a dimensionality (or NULL if not currently defined).
typedef struct {
  // The data type specification for data stored in `data`. This affects
  // what member of `data` union should be used.
  TfLiteType type;
  // A union of data pointers. The appropriate type should be used for a typed
  // tensor based on `type`.
  TfLitePtrUnion data;
  // A pointer to a structure representing the dimensionality interpretation
  // that the buffer should have. NOTE: the product of elements of `dims`
  // and the element datatype size should be equal to `bytes` below.
  TfLiteIntArray* dims;
  // Quantization information.
  TfLiteQuantizationParams params;
  // How memory is mapped
  //  kTfLiteMmapRo: Memory mapped read only.
  //  i.e. weights
  //  kTfLiteArenaRw: Arena allocated read write memory
  //  (i.e. temporaries, outputs).
  TfLiteAllocationType allocation_type;
  // The number of bytes required to store the data of this Tensor. I.e.
  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
  // type is kTfLiteFloat32 and dims = {3, 2} then
  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
  size_t bytes;

  // An opaque pointer to a tflite::MMapAllocation
  const void* allocation;

  // Null-terminated name of this tensor.pixels    int: The array to receive the bitmap’s colors
offset    int: The first index to write into pixels[]
stride    int: The number of entries in pixels[] to skip between rows (must be >= bitmap’s width). Can be negative.
x    int: The x coordinate of the first pixel to read from the bitmap
y    int: The y coordinate of the first pixel to read from the bitmap
width    int: The number of pixels to read from each row
height    int: The number of rows to readpixels    int: The array to receive the bitmap’s colors
offset    int: The first index to write into pixels[]
stride    int: The number of entries in pixels[] to skip between rows (must be >= bitmap’s width). Can be negative.
x    int: The x coordinate of the first pixel to read from the bitmap
y    int: The y coordinate of the first pixel to read from the bitmap
width    int: The number of pixels to read from each row
height    int: The number of rows to read
  const char* name;
} TfLiteTensor;

typedef struct {
  int size;
  int data[0];
} TfLiteIntArray;
 

猜你喜欢

转载自blog.csdn.net/u011279649/article/details/83752387
run