小白简单介绍一下物品识别TFL的使用

1.小白简单介绍一下物品识别TFL的使用

调用系统手机的照片功能,实现分类器的主要片段。相机功能(通过CameraX)

package org.tensorflow.lite.examples.transfer;
import org.tensorflow.lite.examples.transfer.api.TransferLearningModel.Prediction;
import org.tensorflow.lite.examples.transfer.databinding.CameraFragmentBinding;



/**
*分类器的主要片段。
*
*相机功能(通过CameraX)
*/

//CameraFragment 继承Fragment的类
public class CameraFragment extends Fragment {

  //定义低字节掩码
  private static final int LOWER_BYTE_MASK = 0xFF;
  //定义一个TAG
  private static final String TAG = CameraFragment.class.getSimpleName();

  private static final LensFacing LENS_FACING = LensFacing.BACK;

  private TextureView viewFinder;

  private Integer viewFinderRotation = null;

  private Size bufferDimens = new Size(0, 0);
  private Size viewFinderDimens = new Size(0, 0);

  private CameraFragmentViewModel viewModel;
  private TransferLearningModelWrapper tlModel;

//当用户按下某个类的“添加示例”按钮时,
//该类将被添加到此队列中。稍后由
//推理线程和处理。
  private final ConcurrentLinkedQueue<String> addSampleRequests = new ConcurrentLinkedQueue<>();

  private final LoggingBenchmark inferenceBenchmark = new LoggingBenchmark("InferenceBench");

/**
*为取景器设置响应预览。
*/
  private void startCamera() {
    viewFinderRotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
    if (viewFinderRotation == null) {
      viewFinderRotation = 0;
    }

    DisplayMetrics metrics = new DisplayMetrics();
    viewFinder.getDisplay().getRealMetrics(metrics);
    Rational screenAspectRatio = new Rational(metrics.widthPixels, metrics.heightPixels);

    PreviewConfig config = new PreviewConfig.Builder()
        .setLensFacing(LENS_FACING)
        .setTargetAspectRatio(screenAspectRatio)
        .setTargetRotation(viewFinder.getDisplay().getRotation())
        .build();

    Preview preview = new Preview(config);

    preview.setOnPreviewOutputUpdateListener(previewOutput -> {
      ViewGroup parent = (ViewGroup) viewFinder.getParent();
      parent.removeView(viewFinder);
      parent.addView(viewFinder, 0);

      viewFinder.setSurfaceTexture(previewOutput.getSurfaceTexture());

      Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
      updateTransform(rotation, previewOutput.getTextureSize(), viewFinderDimens);
    });

    viewFinder.addOnLayoutChangeListener((
        view, left, top, right, bottom, oldLeft, oldTop, oldRight, oldBottom) -> {
      Size newViewFinderDimens = new Size(right - left, bottom - top);
      Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
      updateTransform(rotation, bufferDimens, newViewFinderDimens);
    });

    HandlerThread inferenceThread = new HandlerThread("InferenceThread");
    inferenceThread.start();
    ImageAnalysisConfig analysisConfig = new ImageAnalysisConfig.Builder()
        .setLensFacing(LENS_FACING)
        .setCallbackHandler(new Handler(inferenceThread.getLooper()))
        .setImageReaderMode(ImageReaderMode.ACQUIRE_LATEST_IMAGE)
        .setTargetRotation(viewFinder.getDisplay().getRotation())
        .build();

    ImageAnalysis imageAnalysis = new ImageAnalysis(analysisConfig);
    imageAnalysis.setAnalyzer(inferenceAnalyzer);

    CameraX.bindToLifecycle(this, preview, imageAnalysis);
  }
  //图片推理分析器
  private final ImageAnalysis.Analyzer inferenceAnalyzer =
      (imageProxy, rotationDegrees) -> {
        final String imageId = UUID.randomUUID().toString();

        inferenceBenchmark.startStage(imageId, "preprocess");
        //rgbImage定义为float的数组
        float[] rgbImage = prepareCameraImage(yuvCameraImageToBitmap(imageProxy), rotationDegrees);
        inferenceBenchmark.endStage(imageId, "preprocess");

//添加示例也由推理线程/用例处理。
//我们不使用CameraX ImageCapture,因为它具有很高的延迟(像素2 XL上约650ms)
//即使使用.MIN_延迟。

        String sampleClass = addSampleRequests.poll();
        if (sampleClass != null) {
          inferenceBenchmark.startStage(imageId, "addSample");
          try {
            tlModel.addSample(rgbImage, sampleClass).get();
          } catch (ExecutionException e) {
            throw new RuntimeException("Failed to add sample to model", e.getCause());
          } catch (InterruptedException e) {
            // no-op
          }

          viewModel.increaseNumSamples(sampleClass);
          inferenceBenchmark.endStage(imageId, "addSample");

        } else {
//我们在添加样本时不执行推断,因为我们应该处于捕获模式
//当时,所以推理结果实际上并没有显示出来。
          inferenceBenchmark.startStage(imageId, "predict");
          Prediction[] predictions = tlModel.predict(rgbImage);
          if (predictions == null) {
            return;
          }
          inferenceBenchmark.endStage(imageId, "predict");

          for (Prediction prediction : predictions) {
            viewModel.setConfidence(prediction.getClassName(), prediction.getConfidence());
          }
        }

        inferenceBenchmark.finish(imageId);
      };

   //定义4个类名分别为1,2,3,4类
  public final View.OnClickListener onAddSampleClickListener = view -> {
    String className;
    if (view.getId() == R.id.class_btn_1) {
      className = "1";
    } else if (view.getId() == R.id.class_btn_2) {
      className = "2";
    } else if (view.getId() == R.id.class_btn_3) {
      className = "3";
    } else if (view.getId() == R.id.class_btn_4) {
      className = "4";
    } else {
      throw new RuntimeException("Listener called for unexpected view");
    }

    addSampleRequests.add(className);
  };

  /**
   * 将相机预览调整为[viewFinder].
   *
   * @param rotation view finder rotation.
   * @param newBufferDimens camera preview dimensions.
   * @param newViewFinderDimens view finder dimensions.
   */
  private void updateTransform(Integer rotation, Size newBufferDimens, Size newViewFinderDimens) {
    if (Objects.equals(rotation, viewFinderRotation)
      && Objects.equals(newBufferDimens, bufferDimens)
      && Objects.equals(newViewFinderDimens, viewFinderDimens)) {
      return;
    }

    if (rotation == null) {
      return;
    } else {
      viewFinderRotation = rotation;
    }

    if (newBufferDimens.getWidth() == 0 || newBufferDimens.getHeight() == 0) {
      return;
    } else {
      bufferDimens = newBufferDimens;
    }

    if (newViewFinderDimens.getWidth() == 0 || newViewFinderDimens.getHeight() == 0) {
      return;
    } else {
      viewFinderDimens = newViewFinderDimens;
    }
    //输出日志格式化日志
/*


对数d(标记,字符串格式(“正在应用输出转换。\n”
+“取景器大小:%s。\n”
+“预览输出大小:%s\n”
+“取景器旋转:%s\n”,viewFinderDimens,bufferDimens,viewFinderRotation));
*/

    Log.d(TAG, String.format("Applying output transformation.\n"
        + "View finder size: %s.\n"
        + "Preview output size: %s\n"
        + "View finder rotation: %s\n", viewFinderDimens, bufferDimens, viewFinderRotation));

    Matrix matrix = new Matrix();

    float centerX = viewFinderDimens.getWidth() / 2f;
    float centerY = viewFinderDimens.getHeight() / 2f;

    matrix.postRotate(-viewFinderRotation.floatValue(), centerX, centerY);

    float bufferRatio = bufferDimens.getHeight() / (float) bufferDimens.getWidth();

    int scaledWidth;
    int scaledHeight;
    if (viewFinderDimens.getWidth() > viewFinderDimens.getHeight()) {
      scaledHeight = viewFinderDimens.getWidth();
      scaledWidth = Math.round(viewFinderDimens.getWidth() * bufferRatio);
    } else {
      scaledHeight = viewFinderDimens.getHeight();
      scaledWidth = Math.round(viewFinderDimens.getHeight() * bufferRatio);
    }

    float xScale = scaledWidth / (float) viewFinderDimens.getWidth();
    float yScale = scaledHeight / (float) viewFinderDimens.getHeight();

    matrix.preScale(xScale, yScale, centerX, centerY);

    viewFinder.setTransform(matrix);
  }
   
  //创建,tlModel,viewModel
  @Override
  public void onCreate(Bundle bundle) {
    super.onCreate(bundle);

    tlModel = new TransferLearningModelWrapper(getActivity());
    viewModel = ViewModelProviders.of(this).get(CameraFragmentViewModel.class);
    viewModel.setTrainBatchSize(tlModel.getTrainBatchSize());
  }

  @Override
  public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle bundle) {
    CameraFragmentBinding dataBinding =
        DataBindingUtil.inflate(inflater, R.layout.camera_fragment, container, false);
    dataBinding.setLifecycleOwner(getViewLifecycleOwner());
    dataBinding.setVm(viewModel);
    View rootView = dataBinding.getRoot();

    for (int buttonId : new int[] {
        R.id.class_btn_1, R.id.class_btn_2, R.id.class_btn_3, R.id.class_btn_4}) {
      rootView.findViewById(buttonId).setOnClickListener(onAddSampleClickListener);
    }

    ChipGroup chipGroup = (ChipGroup) rootView.findViewById(R.id.mode_chip_group);
    if (viewModel.getCaptureMode().getValue()) {
      ((Chip) rootView.findViewById(R.id.capture_mode_chip)).setChecked(true);
    } else {
      ((Chip) rootView.findViewById(R.id.inference_mode_chip)).setChecked(true);
    }

    chipGroup.setOnCheckedChangeListener((group, checkedId) -> {
      if (checkedId == R.id.capture_mode_chip) {
        viewModel.setCaptureMode(true);
      } else if (checkedId == R.id.inference_mode_chip) {
        viewModel.setCaptureMode(false);
      }
    });

    return dataBinding.getRoot();
  }

  @Override
  public void onViewCreated(View view, Bundle bundle) {
    super.onViewCreated(view, bundle);

    viewFinder = getActivity().findViewById(R.id.view_finder);
    viewFinder.post(this::startCamera);
  }
  //重写已创建活动
  @Override
  public void onActivityCreated(Bundle bundle) {
    super.onActivityCreated(bundle);

    viewModel
        .getTrainingState()
        .observe(
            getViewLifecycleOwner(),
            //训练状态,开始和暂停
            trainingState -> {
              switch (trainingState) {
                case STARTED:
                  tlModel.enableTraining((epoch, loss) -> viewModel.setLastLoss(loss));
                  if (!viewModel.getInferenceSnackbarWasDisplayed().getValue()) {
                    Snackbar.make(
                            getActivity().findViewById(R.id.classes_bar),
                            R.string.switch_to_inference_hint,
                            Snackbar.LENGTH_LONG)
                        .show();
                    viewModel.markInferenceSnackbarWasCalled();
                  }
                  break;
                case PAUSED:
                  tlModel.disableTraining();
                  break;
                case NOT_STARTED:
                  break;
              }
            });
  }
  //释放资源

  @Override
  public void onDestroy() {
    super.onDestroy();
    tlModel.close();
    tlModel = null;
  }
  //获取显示面旋转
  private static Integer getDisplaySurfaceRotation(Display display) {
    if (display == null) {
      return null;
    }

    switch (display.getRotation()) {
      case Surface.ROTATION_0: return 0;
      case Surface.ROTATION_90: return 90;
      case Surface.ROTATION_180: return 180;
      case Surface.ROTATION_270: return 270;
      default: return null;
    }
  }
  //拍摄的照片变为bitmap格式
  private static Bitmap yuvCameraImageToBitmap(ImageProxy imageProxy) {
    if (imageProxy.getFormat() != ImageFormat.YUV_420_888) {
      throw new IllegalArgumentException(
          "Expected a YUV420 image, but got " + imageProxy.getFormat());
    }

    PlaneProxy yPlane = imageProxy.getPlanes()[0];
    PlaneProxy uPlane = imageProxy.getPlanes()[1];

    int width = imageProxy.getWidth();
    int height = imageProxy.getHeight();

    byte[][] yuvBytes = new byte[3][];
    int[] argbArray = new int[width * height];
    for (int i = 0; i < imageProxy.getPlanes().length; i++) {
      final ByteBuffer buffer = imageProxy.getPlanes()[i].getBuffer();
      yuvBytes[i] = new byte[buffer.capacity()];
      buffer.get(yuvBytes[i]);
    }

    ImageUtils.convertYUV420ToARGB8888(
        yuvBytes[0],
        yuvBytes[1],
        yuvBytes[2],
        width,
        height,
        yPlane.getRowStride(),
        uPlane.getRowStride(),
        uPlane.getPixelStride(),
        argbArray);

    return Bitmap.createBitmap(argbArray, width, height, Config.ARGB_8888);
  }

/**
*将相机图像规格化为[0;1],将其剪切
*调整模型所需的大小并调整相机旋转。
*/
  private static float[] prepareCameraImage(Bitmap bitmap, int rotationDegrees)  {
    int modelImageSize = TransferLearningModelWrapper.IMAGE_SIZE;

    Bitmap paddedBitmap = padToSquare(bitmap);
    Bitmap scaledBitmap = Bitmap.createScaledBitmap(
        paddedBitmap, modelImageSize, modelImageSize, true);

    Matrix rotationMatrix = new Matrix();
    rotationMatrix.postRotate(rotationDegrees);
    Bitmap rotatedBitmap = Bitmap.createBitmap(
        scaledBitmap, 0, 0, modelImageSize, modelImageSize, rotationMatrix, false);

    float[] normalizedRgb = new float[modelImageSize * modelImageSize * 3];
    int nextIdx = 0;
    for (int y = 0; y < modelImageSize; y++) {
      for (int x = 0; x < modelImageSize; x++) {
        int rgb = rotatedBitmap.getPixel(x, y);

        float r = ((rgb >> 16) & LOWER_BYTE_MASK) * (1 / 255.f);
        float g = ((rgb >> 8) & LOWER_BYTE_MASK) * (1 / 255.f);
        float b = (rgb & LOWER_BYTE_MASK) * (1 / 255.f);

        normalizedRgb[nextIdx++] = r;
        normalizedRgb[nextIdx++] = g;
        normalizedRgb[nextIdx++] = b;
      }
    }

    return normalizedRgb;
  }
  //平铺到广角
  private static Bitmap padToSquare(Bitmap source) {
    int width = source.getWidth();
    int height = source.getHeight();

    int paddingX = width < height ? (height - width) / 2 : 0;
    int paddingY = height < width ? (width - height) / 2 : 0;
    Bitmap paddedBitmap = Bitmap.createBitmap(
        width + 2 * paddingX, height + 2 * paddingY, Config.ARGB_8888);
    Canvas canvas = new Canvas(paddedBitmap);
    canvas.drawARGB(0xFF, 0xFF, 0xFF, 0xFF);
    canvas.drawBitmap(source, paddingX, paddingY, null);
    return paddedBitmap;
  }

//绑定适配器:

  @BindingAdapter({"captureMode", "inferenceText", "captureText"})
  public static void setClassSubtitleText(
      TextView view, boolean captureMode, Float inferenceText, Integer captureText) {
    if (captureMode) {
      view.setText(captureText != null ? Integer.toString(captureText) : "0");
    } else {
      view.setText(
          String.format(Locale.getDefault(), "%.2f", inferenceText != null ? inferenceText : 0.f));
    }
  }

  @BindingAdapter({"android:visibility"})
  public static void setViewVisibility(View view, boolean visible) {
    view.setVisibility(visible ? View.VISIBLE : View.GONE);
  }

  @BindingAdapter({"highlight"})
  public static void setClassButtonHighlight(View view, boolean highlight) {
    int drawableId;
    if (highlight) {
      drawableId = R.drawable.btn_default_highlight;
    } else {
      drawableId = R.drawable.btn_default;
    }
    view.setBackground(view.getContext().getDrawable(drawableId));
  }
}

2.代码实现界面

猜你喜欢

转载自blog.csdn.net/keny88888/article/details/106505943