小白简单介绍一下物品识别TFL的使用
1.小白简单介绍一下物品识别TFL的使用
调用系统手机的照片功能,实现分类器的主要片段。相机功能(通过CameraX)
package org.tensorflow.lite.examples.transfer;
//导入依赖库
import org.tensorflow.lite.examples.transfer.api.TransferLearningModel.Prediction;
import org.tensorflow.lite.examples.transfer.databinding.CameraFragmentBinding;
/**
*分类器的主要片段。相机功能(通过CameraX)
*/
//CameraFragment 继承Fragment的类
public class CameraFragment extends Fragment {
//定义低字节掩码
private static final int LOWER_BYTE_MASK = 0xFF;
//定义一个TAG
private static final String TAG = CameraFragment.class.getSimpleName();
private static final LensFacing LENS_FACING = LensFacing.BACK;
private TextureView viewFinder;
private Integer viewFinderRotation = null;
private Size bufferDimens = new Size(0, 0);
private Size viewFinderDimens = new Size(0, 0);
private CameraFragmentViewModel viewModel;
//定义机器学习的Model
private TransferLearningModelWrapper tlModel;
//当用户按下某个类的“添加示例”按钮时,
//该类将被添加到此队列中。稍后由
//推理线程和处理。
private final ConcurrentLinkedQueue<String> addSampleRequests = new ConcurrentLinkedQueue<>();
private final LoggingBenchmark inferenceBenchmark = new LoggingBenchmark("InferenceBench");
/**
*为取景器设置响应预览。
*/
private void startCamera() {
viewFinderRotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
if (viewFinderRotation == null) {
viewFinderRotation = 0;
}
DisplayMetrics metrics = new DisplayMetrics();
viewFinder.getDisplay().getRealMetrics(metrics);
Rational screenAspectRatio = new Rational(metrics.widthPixels, metrics.heightPixels);
PreviewConfig config = new PreviewConfig.Builder()
.setLensFacing(LENS_FACING)
.setTargetAspectRatio(screenAspectRatio)
.setTargetRotation(viewFinder.getDisplay().getRotation())
.build();
Preview preview = new Preview(config);
preview.setOnPreviewOutputUpdateListener(previewOutput -> {
ViewGroup parent = (ViewGroup) viewFinder.getParent();
parent.removeView(viewFinder);
parent.addView(viewFinder, 0);
viewFinder.setSurfaceTexture(previewOutput.getSurfaceTexture());
Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
updateTransform(rotation, previewOutput.getTextureSize(), viewFinderDimens);
});
viewFinder.addOnLayoutChangeListener((
view, left, top, right, bottom, oldLeft, oldTop, oldRight, oldBottom) -> {
Size newViewFinderDimens = new Size(right - left, bottom - top);
Integer rotation = getDisplaySurfaceRotation(viewFinder.getDisplay());
updateTransform(rotation, bufferDimens, newViewFinderDimens);
});
HandlerThread inferenceThread = new HandlerThread("InferenceThread");
inferenceThread.start();
ImageAnalysisConfig analysisConfig = new ImageAnalysisConfig.Builder()
.setLensFacing(LENS_FACING)
.setCallbackHandler(new Handler(inferenceThread.getLooper()))
.setImageReaderMode(ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.setTargetRotation(viewFinder.getDisplay().getRotation())
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis(analysisConfig);
imageAnalysis.setAnalyzer(inferenceAnalyzer);
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
//图片推理分析器
private final ImageAnalysis.Analyzer inferenceAnalyzer =
(imageProxy, rotationDegrees) -> {
final String imageId = UUID.randomUUID().toString();
inferenceBenchmark.startStage(imageId, "preprocess");
//rgbImage定义为float的数组
float[] rgbImage = prepareCameraImage(yuvCameraImageToBitmap(imageProxy), rotationDegrees);
inferenceBenchmark.endStage(imageId, "preprocess");
//添加示例也由推理线程/用例处理。
//我们不使用CameraX ImageCapture,因为它具有很高的延迟(像素2 XL上约650ms)
//即使使用.MIN_延迟。
String sampleClass = addSampleRequests.poll();
if (sampleClass != null) {
inferenceBenchmark.startStage(imageId, "addSample");
try {
tlModel.addSample(rgbImage, sampleClass).get();
} catch (ExecutionException e) {
throw new RuntimeException("Failed to add sample to model", e.getCause());
} catch (InterruptedException e) {
// no-op
}
viewModel.increaseNumSamples(sampleClass);
inferenceBenchmark.endStage(imageId, "addSample");
} else {
//我们在添加样本时不执行推断,因为我们应该处于捕获模式
//当时,所以推理结果实际上并没有显示出来。
inferenceBenchmark.startStage(imageId, "predict");
Prediction[] predictions = tlModel.predict(rgbImage);
if (predictions == null) {
return;
}
inferenceBenchmark.endStage(imageId, "predict");
for (Prediction prediction : predictions) {
viewModel.setConfidence(prediction.getClassName(), prediction.getConfidence());
}
}
inferenceBenchmark.finish(imageId);
};
//定义4个类名分别为1,2,3,4类
public final View.OnClickListener onAddSampleClickListener = view -> {
String className;
if (view.getId() == R.id.class_btn_1) {
className = "1";
} else if (view.getId() == R.id.class_btn_2) {
className = "2";
} else if (view.getId() == R.id.class_btn_3) {
className = "3";
} else if (view.getId() == R.id.class_btn_4) {
className = "4";
} else {
throw new RuntimeException("Listener called for unexpected view");
}
addSampleRequests.add(className);
};
/**
* 将相机预览调整为[viewFinder].
*
* @param rotation view finder rotation.
* @param newBufferDimens camera preview dimensions.
* @param newViewFinderDimens view finder dimensions.
*/
private void updateTransform(Integer rotation, Size newBufferDimens, Size newViewFinderDimens) {
if (Objects.equals(rotation, viewFinderRotation)
&& Objects.equals(newBufferDimens, bufferDimens)
&& Objects.equals(newViewFinderDimens, viewFinderDimens)) {
return;
}
if (rotation == null) {
return;
} else {
viewFinderRotation = rotation;
}
if (newBufferDimens.getWidth() == 0 || newBufferDimens.getHeight() == 0) {
return;
} else {
bufferDimens = newBufferDimens;
}
if (newViewFinderDimens.getWidth() == 0 || newViewFinderDimens.getHeight() == 0) {
return;
} else {
viewFinderDimens = newViewFinderDimens;
}
//输出日志格式化日志
/*
对数d(标记,字符串格式(“正在应用输出转换。\n”
+“取景器大小:%s。\n”
+“预览输出大小:%s\n”
+“取景器旋转:%s\n”,viewFinderDimens,bufferDimens,viewFinderRotation));
*/
Log.d(TAG, String.format("Applying output transformation.\n"
+ "View finder size: %s.\n"
+ "Preview output size: %s\n"
+ "View finder rotation: %s\n", viewFinderDimens, bufferDimens, viewFinderRotation));
Matrix matrix = new Matrix();
float centerX = viewFinderDimens.getWidth() / 2f;
float centerY = viewFinderDimens.getHeight() / 2f;
matrix.postRotate(-viewFinderRotation.floatValue(), centerX, centerY);
float bufferRatio = bufferDimens.getHeight() / (float) bufferDimens.getWidth();
int scaledWidth;
int scaledHeight;
if (viewFinderDimens.getWidth() > viewFinderDimens.getHeight()) {
scaledHeight = viewFinderDimens.getWidth();
scaledWidth = Math.round(viewFinderDimens.getWidth() * bufferRatio);
} else {
scaledHeight = viewFinderDimens.getHeight();
scaledWidth = Math.round(viewFinderDimens.getHeight() * bufferRatio);
}
float xScale = scaledWidth / (float) viewFinderDimens.getWidth();
float yScale = scaledHeight / (float) viewFinderDimens.getHeight();
matrix.preScale(xScale, yScale, centerX, centerY);
viewFinder.setTransform(matrix);
}
//创建,tlModel,viewModel
@Override
public void onCreate(Bundle bundle) {
super.onCreate(bundle);
tlModel = new TransferLearningModelWrapper(getActivity());
viewModel = ViewModelProviders.of(this).get(CameraFragmentViewModel.class);
viewModel.setTrainBatchSize(tlModel.getTrainBatchSize());
}
@Override
public View onCreateView(LayoutInflater inflater, ViewGroup container, Bundle bundle) {
CameraFragmentBinding dataBinding =
DataBindingUtil.inflate(inflater, R.layout.camera_fragment, container, false);
dataBinding.setLifecycleOwner(getViewLifecycleOwner());
dataBinding.setVm(viewModel);
View rootView = dataBinding.getRoot();
for (int buttonId : new int[] {
R.id.class_btn_1, R.id.class_btn_2, R.id.class_btn_3, R.id.class_btn_4}) {
rootView.findViewById(buttonId).setOnClickListener(onAddSampleClickListener);
}
ChipGroup chipGroup = (ChipGroup) rootView.findViewById(R.id.mode_chip_group);
if (viewModel.getCaptureMode().getValue()) {
((Chip) rootView.findViewById(R.id.capture_mode_chip)).setChecked(true);
} else {
((Chip) rootView.findViewById(R.id.inference_mode_chip)).setChecked(true);
}
chipGroup.setOnCheckedChangeListener((group, checkedId) -> {
if (checkedId == R.id.capture_mode_chip) {
viewModel.setCaptureMode(true);
} else if (checkedId == R.id.inference_mode_chip) {
viewModel.setCaptureMode(false);
}
});
return dataBinding.getRoot();
}
@Override
public void onViewCreated(View view, Bundle bundle) {
super.onViewCreated(view, bundle);
viewFinder = getActivity().findViewById(R.id.view_finder);
viewFinder.post(this::startCamera);
}
//重写已创建活动
@Override
public void onActivityCreated(Bundle bundle) {
super.onActivityCreated(bundle);
viewModel
.getTrainingState()
.observe(
getViewLifecycleOwner(),
//训练状态,开始和暂停
trainingState -> {
switch (trainingState) {
case STARTED:
tlModel.enableTraining((epoch, loss) -> viewModel.setLastLoss(loss));
if (!viewModel.getInferenceSnackbarWasDisplayed().getValue()) {
Snackbar.make(
getActivity().findViewById(R.id.classes_bar),
R.string.switch_to_inference_hint,
Snackbar.LENGTH_LONG)
.show();
viewModel.markInferenceSnackbarWasCalled();
}
break;
case PAUSED:
tlModel.disableTraining();
break;
case NOT_STARTED:
break;
}
});
}
//释放资源
@Override
public void onDestroy() {
super.onDestroy();
tlModel.close();
tlModel = null;
}
//获取显示面旋转
private static Integer getDisplaySurfaceRotation(Display display) {
if (display == null) {
return null;
}
switch (display.getRotation()) {
case Surface.ROTATION_0: return 0;
case Surface.ROTATION_90: return 90;
case Surface.ROTATION_180: return 180;
case Surface.ROTATION_270: return 270;
default: return null;
}
}
//拍摄的照片变为bitmap格式
private static Bitmap yuvCameraImageToBitmap(ImageProxy imageProxy) {
if (imageProxy.getFormat() != ImageFormat.YUV_420_888) {
throw new IllegalArgumentException(
"Expected a YUV420 image, but got " + imageProxy.getFormat());
}
PlaneProxy yPlane = imageProxy.getPlanes()[0];
PlaneProxy uPlane = imageProxy.getPlanes()[1];
int width = imageProxy.getWidth();
int height = imageProxy.getHeight();
byte[][] yuvBytes = new byte[3][];
int[] argbArray = new int[width * height];
for (int i = 0; i < imageProxy.getPlanes().length; i++) {
final ByteBuffer buffer = imageProxy.getPlanes()[i].getBuffer();
yuvBytes[i] = new byte[buffer.capacity()];
buffer.get(yuvBytes[i]);
}
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
width,
height,
yPlane.getRowStride(),
uPlane.getRowStride(),
uPlane.getPixelStride(),
argbArray);
return Bitmap.createBitmap(argbArray, width, height, Config.ARGB_8888);
}
/**
*将相机图像规格化为[0;1],将其剪切
*调整模型所需的大小并调整相机旋转。
*/
private static float[] prepareCameraImage(Bitmap bitmap, int rotationDegrees) {
int modelImageSize = TransferLearningModelWrapper.IMAGE_SIZE;
Bitmap paddedBitmap = padToSquare(bitmap);
Bitmap scaledBitmap = Bitmap.createScaledBitmap(
paddedBitmap, modelImageSize, modelImageSize, true);
Matrix rotationMatrix = new Matrix();
rotationMatrix.postRotate(rotationDegrees);
Bitmap rotatedBitmap = Bitmap.createBitmap(
scaledBitmap, 0, 0, modelImageSize, modelImageSize, rotationMatrix, false);
float[] normalizedRgb = new float[modelImageSize * modelImageSize * 3];
int nextIdx = 0;
for (int y = 0; y < modelImageSize; y++) {
for (int x = 0; x < modelImageSize; x++) {
int rgb = rotatedBitmap.getPixel(x, y);
float r = ((rgb >> 16) & LOWER_BYTE_MASK) * (1 / 255.f);
float g = ((rgb >> 8) & LOWER_BYTE_MASK) * (1 / 255.f);
float b = (rgb & LOWER_BYTE_MASK) * (1 / 255.f);
normalizedRgb[nextIdx++] = r;
normalizedRgb[nextIdx++] = g;
normalizedRgb[nextIdx++] = b;
}
}
return normalizedRgb;
}
//平铺到广角
private static Bitmap padToSquare(Bitmap source) {
int width = source.getWidth();
int height = source.getHeight();
int paddingX = width < height ? (height - width) / 2 : 0;
int paddingY = height < width ? (width - height) / 2 : 0;
Bitmap paddedBitmap = Bitmap.createBitmap(
width + 2 * paddingX, height + 2 * paddingY, Config.ARGB_8888);
Canvas canvas = new Canvas(paddedBitmap);
canvas.drawARGB(0xFF, 0xFF, 0xFF, 0xFF);
canvas.drawBitmap(source, paddingX, paddingY, null);
return paddedBitmap;
}
//绑定适配器:
@BindingAdapter({"captureMode", "inferenceText", "captureText"})
public static void setClassSubtitleText(
TextView view, boolean captureMode, Float inferenceText, Integer captureText) {
if (captureMode) {
view.setText(captureText != null ? Integer.toString(captureText) : "0");
} else {
view.setText(
String.format(Locale.getDefault(), "%.2f", inferenceText != null ? inferenceText : 0.f));
}
}
@BindingAdapter({"android:visibility"})
public static void setViewVisibility(View view, boolean visible) {
view.setVisibility(visible ? View.VISIBLE : View.GONE);
}
@BindingAdapter({"highlight"})
public static void setClassButtonHighlight(View view, boolean highlight) {
int drawableId;
if (highlight) {
drawableId = R.drawable.btn_default_highlight;
} else {
drawableId = R.drawable.btn_default;
}
view.setBackground(view.getContext().getDrawable(drawableId));
}
}
2.代码实现界面

3.小结
首先是定义了android系统本身的照相机,通过照相获取20个训练的例子,分别为4个类型的训练分类,三角,○,×和正方 4个类型。通过TFL识别到这4个类。
运行非常卡。看来需要扩充内存。