欢迎光临散文网 会员登陆 & 注册

小白简单介绍一下物品识别TFL的使用

2020-06-02 22:00 作者:嗨大成  | 我要投稿

1.小白简单介绍一下物品识别TFL的使用

调用系统手机的照片功能,实现分类器的主要片段。相机功能(通过CameraX)

package org.tensorflow.lite.examples.transfer;

//导入依赖库

import org.tensorflow.lite.examples.transfer.api.TransferLearningModel.Prediction;

import org.tensorflow.lite.examples.transfer.databinding.CameraFragmentBinding; 

/**

*分类器的主要片段。相机功能(通过CameraX)

*/

 

//CameraFragment 继承Fragment的类

public class CameraFragment extends Fragment {

 

  //定义低字节掩码

  private static final int LOWER_BYTE_MASK = 0xFF;

  //定义一个TAG

  private static final String TAG = CameraFragment.class.getSimpleName();

  private static final LensFacing LENS_FACING = LensFacing.BACK;

  private TextureView viewFinder;

  private Integer viewFinderRotation = null;

  private Size bufferDimens = new Size(0, 0);

  private Size viewFinderDimens = new Size(0, 0);

  private CameraFragmentViewModel viewModel;

//定义机器学习的Model

  private TransferLearningModelWrapper tlModel;

 

//当用户按下某个类的“添加示例”按钮时,

//该类将被添加到此队列中。稍后由

//推理线程和处理。

  private final ConcurrentLinkedQueue<String> addSampleRequests = new ConcurrentLinkedQueue<>();

  private final LoggingBenchmark inferenceBenchmark = new LoggingBenchmark("InferenceBench");

 

/**

*为取景器设置响应预览。

*/

  private void startCamera() {

    viewFinderRotation = getDisplaySurfaceRotation(viewFinder.getDisplay());

    if (viewFinderRotation == null) {

      viewFinderRotation = 0;

    }

 

    DisplayMetrics metrics = new DisplayMetrics();

    viewFinder.getDisplay().getRealMetrics(metrics);

    Rational screenAspectRatio = new Rational(metrics.widthPixels, metrics.heightPixels);

 

    PreviewConfig config = new PreviewConfig.Builder()

        .setLensFacing(LENS_FACING)

        .setTargetAspectRatio(screenAspectRatio)

        .setTargetRotation(viewFinder.getDisplay().getRotation())

        .build();

 

    Preview preview = new Preview(config);

 

    preview.setOnPreviewOutputUpdateListener(previewOutput -> {

      ViewGroup parent = (ViewGroup) viewFinder.getParent();

      parent.removeView(viewFinder);

      parent.addView(viewFinder, 0);

 

      viewFinder.setSurfaceTexture(previewOutput.getSurfaceTexture());

 

      Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());

      updateTransform(rotation, previewOutput.getTextureSize(), viewFinderDimens);

    });

 

    viewFinder.addOnLayoutChangeListener((

        view, left, top, right, bottom, oldLeft, oldTop, oldRight, oldBottom) -> {

      Size newViewFinderDimens = new Size(right - left, bottom - top);

      Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());

      updateTransform(rotation, bufferDimens, newViewFinderDimens);

    });

 

    HandlerThread inferenceThread = new HandlerThread("InferenceThread");

    inferenceThread.start();

    ImageAnalysisConfig analysisConfig = new ImageAnalysisConfig.Builder()

        .setLensFacing(LENS_FACING)

        .setCallbackHandler(new Handler(inferenceThread.getLooper()))

        .setImageReaderMode(ImageReaderMode.ACQUIRE_LATEST_IMAGE)

        .setTargetRotation(viewFinder.getDisplay().getRotation())

        .build();

 

    ImageAnalysis imageAnalysis = new ImageAnalysis(analysisConfig);

    imageAnalysis.setAnalyzer(inferenceAnalyzer);

 

    CameraX.bindToLifecycle(this, preview, imageAnalysis);

  }

  //图片推理分析器

  private final ImageAnalysis.Analyzer inferenceAnalyzer =

      (imageProxy, rotationDegrees) -> {

        final String imageId = UUID.randomUUID().toString();

 

        inferenceBenchmark.startStage(imageId, "preprocess");

        //rgbImage定义为float的数组

        float[] rgbImage = prepareCameraImage(yuvCameraImageToBitmap(imageProxy), rotationDegrees);

        inferenceBenchmark.endStage(imageId, "preprocess");

 

//添加示例也由推理线程/用例处理。

//我们不使用CameraX ImageCapture,因为它具有很高的延迟(像素2 XL上约650ms)

//即使使用.MIN_延迟。

 

        String sampleClass = addSampleRequests.poll();

        if (sampleClass != null) {

          inferenceBenchmark.startStage(imageId, "addSample");

          try {

            tlModel.addSample(rgbImage, sampleClass).get();

          } catch (ExecutionException e) {

            throw new RuntimeException("Failed to add sample to model", e.getCause());

          } catch (InterruptedException e) {

            // no-op

          }

 

          viewModel.increaseNumSamples(sampleClass);

          inferenceBenchmark.endStage(imageId, "addSample");

 

        } else {

//我们在添加样本时不执行推断,因为我们应该处于捕获模式

//当时,所以推理结果实际上并没有显示出来。

          inferenceBenchmark.startStage(imageId, "predict");

          Prediction[] predictions = tlModel.predict(rgbImage);

          if (predictions == null) {

            return;

          }

          inferenceBenchmark.endStage(imageId, "predict");

 

          for (Prediction prediction : predictions) {

            viewModel.setConfidence(prediction.getClassName(), prediction.getConfidence());

          }

        }

        inferenceBenchmark.finish(imageId);

      };

 

   //定义4个类名分别为1,2,3,4类

  public final View.OnClickListener onAddSampleClickListener = view -> {

    String className;

    if (view.getId() == R.id.class_btn_1) {

      className = "1";

    } else if (view.getId() == R.id.class_btn_2) {

      className = "2";

    } else if (view.getId() == R.id.class_btn_3) {

      className = "3";

    } else if (view.getId() == R.id.class_btn_4) {

      className = "4";

    } else {

      throw new RuntimeException("Listener called for unexpected view");

    }

 

    addSampleRequests.add(className);

  };

 

  /**

   * 将相机预览调整为[viewFinder].

   *

   * @param rotation view finder rotation.

   * @param newBufferDimens camera preview dimensions.

   * @param newViewFinderDimens view finder dimensions.

   */

  private void updateTransform(Integer rotation, Size newBufferDimens, Size newViewFinderDimens) {

    if (Objects.equals(rotation, viewFinderRotation)

      && Objects.equals(newBufferDimens, bufferDimens)

      && Objects.equals(newViewFinderDimens, viewFinderDimens)) {

      return;

    }

    if (rotation == null) {

      return;

    } else {

      viewFinderRotation = rotation;

    }

    if (newBufferDimens.getWidth() == 0 || newBufferDimens.getHeight() == 0) {

      return;

    } else {

      bufferDimens = newBufferDimens;

    }

 

    if (newViewFinderDimens.getWidth() == 0 || newViewFinderDimens.getHeight() == 0) {

      return;

    } else {

      viewFinderDimens = newViewFinderDimens;

    }

    //输出日志格式化日志

/*

对数d(标记,字符串格式(“正在应用输出转换。\n”

+“取景器大小:%s。\n”

+“预览输出大小:%s\n”

+“取景器旋转:%s\n”,viewFinderDimens,bufferDimens,viewFinderRotation));

*/

    Log.d(TAG, String.format("Applying output transformation.\n"

        + "View finder size: %s.\n"

        + "Preview output size: %s\n"

        + "View finder rotation: %s\n", viewFinderDimens, bufferDimens, viewFinderRotation));

 

    Matrix matrix = new Matrix();

 

    float centerX = viewFinderDimens.getWidth() / 2f;

    float centerY = viewFinderDimens.getHeight() / 2f;

 

    matrix.postRotate(-viewFinderRotation.floatValue(), centerX, centerY);

 

    float bufferRatio = bufferDimens.getHeight() / (float) bufferDimens.getWidth();

 

    int scaledWidth;

    int scaledHeight;

    if (viewFinderDimens.getWidth() > viewFinderDimens.getHeight()) {

      scaledHeight = viewFinderDimens.getWidth();

      scaledWidth = Math.round(viewFinderDimens.getWidth() * bufferRatio);

    } else {

      scaledHeight = viewFinderDimens.getHeight();

      scaledWidth = Math.round(viewFinderDimens.getHeight() * bufferRatio);

    }

 

    float xScale = scaledWidth / (float) viewFinderDimens.getWidth();

    float yScale = scaledHeight / (float) viewFinderDimens.getHeight();

 

    matrix.preScale(xScale, yScale, centerX, centerY);

 

    viewFinder.setTransform(matrix);

  }

   

  //创建,tlModel,viewModel

  @Override

  public void onCreate(Bundle bundle) {

    super.onCreate(bundle);

 

    tlModel = new TransferLearningModelWrapper(getActivity());

    viewModel = ViewModelProviders.of(this).get(CameraFragmentViewModel.class);

    viewModel.setTrainBatchSize(tlModel.getTrainBatchSize());

  }

 

  @Override

  public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle bundle) {

    CameraFragmentBinding dataBinding =

        DataBindingUtil.inflate(inflater, R.layout.camera_fragment, container, false);

    dataBinding.setLifecycleOwner(getViewLifecycleOwner());

    dataBinding.setVm(viewModel);

    View rootView = dataBinding.getRoot();

 

    for (int buttonId : new int[] {

        R.id.class_btn_1, R.id.class_btn_2, R.id.class_btn_3, R.id.class_btn_4}) {

      rootView.findViewById(buttonId).setOnClickListener(onAddSampleClickListener);

    }

 

    ChipGroup chipGroup = (ChipGroup) rootView.findViewById(R.id.mode_chip_group);

    if (viewModel.getCaptureMode().getValue()) {

      ((Chip) rootView.findViewById(R.id.capture_mode_chip)).setChecked(true);

    } else {

      ((Chip) rootView.findViewById(R.id.inference_mode_chip)).setChecked(true);

    }

 

    chipGroup.setOnCheckedChangeListener((group, checkedId) -> {

      if (checkedId == R.id.capture_mode_chip) {

        viewModel.setCaptureMode(true);

      } else if (checkedId == R.id.inference_mode_chip) {

        viewModel.setCaptureMode(false);

      }

    });

 

    return dataBinding.getRoot();

  }

 

  @Override

  public void onViewCreated(View view, Bundle bundle) {

    super.onViewCreated(view, bundle);

 

    viewFinder = getActivity().findViewById(R.id.view_finder);

    viewFinder.post(this::startCamera);

  }

  //重写已创建活动

  @Override

  public void onActivityCreated(Bundle bundle) {

    super.onActivityCreated(bundle);

 

    viewModel

        .getTrainingState()

        .observe(

            getViewLifecycleOwner(),

            //训练状态,开始和暂停

            trainingState -> {

              switch (trainingState) {

                case STARTED:

                  tlModel.enableTraining((epoch, loss) -> viewModel.setLastLoss(loss));

                  if (!viewModel.getInferenceSnackbarWasDisplayed().getValue()) {

                    Snackbar.make(

                            getActivity().findViewById(R.id.classes_bar),

                            R.string.switch_to_inference_hint,

                            Snackbar.LENGTH_LONG)

                        .show();

                    viewModel.markInferenceSnackbarWasCalled();

                  }

                  break;

                case PAUSED:

                  tlModel.disableTraining();

                  break;

                case NOT_STARTED:

                  break;

              }

            });

  }

  //释放资源

 

  @Override

  public void onDestroy() {

    super.onDestroy();

    tlModel.close();

    tlModel = null;

  }

  //获取显示面旋转

  private static Integer getDisplaySurfaceRotation(Display display) {

    if (display == null) {

      return null;

    }

 

    switch (display.getRotation()) {

      case Surface.ROTATION_0: return 0;

      case Surface.ROTATION_90: return 90;

      case Surface.ROTATION_180: return 180;

      case Surface.ROTATION_270: return 270;

      default: return null;

    }

  }

  //拍摄的照片变为bitmap格式

  private static Bitmap yuvCameraImageToBitmap(ImageProxy imageProxy) {

    if (imageProxy.getFormat() != ImageFormat.YUV_420_888) {

      throw new IllegalArgumentException(

          "Expected a YUV420 image, but got " + imageProxy.getFormat());

    }

 

    PlaneProxy yPlane = imageProxy.getPlanes()[0];

    PlaneProxy uPlane = imageProxy.getPlanes()[1];

 

    int width = imageProxy.getWidth();

    int height = imageProxy.getHeight();

 

    byte[][] yuvBytes = new byte[3][];

    int[] argbArray = new int[width * height];

    for (int i = 0; i < imageProxy.getPlanes().length; i++) {

      final ByteBuffer buffer = imageProxy.getPlanes()[i].getBuffer();

      yuvBytes[i] = new byte[buffer.capacity()];

      buffer.get(yuvBytes[i]);

    }

 

    ImageUtils.convertYUV420ToARGB8888(

        yuvBytes[0],

        yuvBytes[1],

        yuvBytes[2],

        width,

        height,

        yPlane.getRowStride(),

        uPlane.getRowStride(),

        uPlane.getPixelStride(),

        argbArray);

 

    return Bitmap.createBitmap(argbArray, width, height, Config.ARGB_8888);

  }

 

/**

*将相机图像规格化为[0;1],将其剪切

*调整模型所需的大小并调整相机旋转。

*/

  private static float[] prepareCameraImage(Bitmap bitmap, int rotationDegrees)  {

    int modelImageSize = TransferLearningModelWrapper.IMAGE_SIZE;

 

    Bitmap paddedBitmap = padToSquare(bitmap);

    Bitmap scaledBitmap = Bitmap.createScaledBitmap(

        paddedBitmap, modelImageSize, modelImageSize, true);

 

    Matrix rotationMatrix = new Matrix();

    rotationMatrix.postRotate(rotationDegrees);

    Bitmap rotatedBitmap = Bitmap.createBitmap(

        scaledBitmap, 0, 0, modelImageSize, modelImageSize, rotationMatrix, false);

 

    float[] normalizedRgb = new float[modelImageSize * modelImageSize * 3];

    int nextIdx = 0;

    for (int y = 0; y < modelImageSize; y++) {

      for (int x = 0; x < modelImageSize; x++) {

        int rgb = rotatedBitmap.getPixel(x, y);

 

        float r = ((rgb >> 16) & LOWER_BYTE_MASK) * (1 / 255.f);

        float g = ((rgb >> 8) & LOWER_BYTE_MASK) * (1 / 255.f);

        float b = (rgb & LOWER_BYTE_MASK) * (1 / 255.f);

 

        normalizedRgb[nextIdx++] = r;

        normalizedRgb[nextIdx++] = g;

        normalizedRgb[nextIdx++] = b;

      }

    }

 

    return normalizedRgb;

  }

  //平铺到广角

  private static Bitmap padToSquare(Bitmap source) {

    int width = source.getWidth();

    int height = source.getHeight();

 

    int paddingX = width < height ? (height - width) / 2 : 0;

    int paddingY = height < width ? (width - height) / 2 : 0;

    Bitmap paddedBitmap = Bitmap.createBitmap(

        width + 2 * paddingX, height + 2 * paddingY, Config.ARGB_8888);

    Canvas canvas = new Canvas(paddedBitmap);

    canvas.drawARGB(0xFF, 0xFF, 0xFF, 0xFF);

    canvas.drawBitmap(source, paddingX, paddingY, null);

    return paddedBitmap;

  }

 

//绑定适配器:

 

  @BindingAdapter({"captureMode", "inferenceText", "captureText"})

  public static void setClassSubtitleText(

      TextView view, boolean captureMode, Float inferenceText, Integer captureText) {

    if (captureMode) {

      view.setText(captureText != null ? Integer.toString(captureText) : "0");

    } else {

      view.setText(

          String.format(Locale.getDefault(), "%.2f", inferenceText != null ? inferenceText : 0.f));

    }

  }

 

  @BindingAdapter({"android:visibility"})

  public static void setViewVisibility(View view, boolean visible) {

    view.setVisibility(visible ? View.VISIBLE : View.GONE);

  }

 

  @BindingAdapter({"highlight"})

  public static void setClassButtonHighlight(View view, boolean highlight) {

    int drawableId;

    if (highlight) {

      drawableId = R.drawable.btn_default_highlight;

    } else {

      drawableId = R.drawable.btn_default;

    }

    view.setBackground(view.getContext().getDrawable(drawableId));

  }

}

2.代码实现界面

TFL 20个训练


3.小结

首先是定义了android系统本身的照相机,通过照相获取20个训练的例子,分别为4个类型的训练分类,三角,○,×和正方 4个类型。通过TFL识别到这4个类。

运行非常卡。看来需要扩充内存。



小白简单介绍一下物品识别TFL的使用的评论 (共 条)

分享到微博请遵守国家法律