欢迎光临散文网 会员登陆 & 注册

大成带你用TFL数字分类识别

2020-05-31 13:42 作者:嗨大成  | 我要投稿

1.TFL数字分类识别

通过Android在手机屏幕上写数字0,1,3,--9的数字,进行数值判断,通过机器学习的mnist.tflite进行数字识别,数字识别根据输入的手写生成图片。进行识别,识别率有的可以到99%,model加载mnist.tflite的库

2.源码

package org.tensorflow.lite.examples.digitclassifier

import android.content.Context
import android.content.res.AssetManager
import android.graphics.Bitmap
import android.util.Log
import com.google.android.gms.tasks.Task
import com.google.android.gms.tasks.Tasks.call
import java.io.FileInputStream
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.FileChannel
import java.util.concurrent.Callable
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import org.tensorflow.lite.Interpreter

class DigitClassifier(private val context: Context) {
 private var interpreter: Interpreter? = null
 var isInitialized = false
   private set

 /** Executor to run inference task in the background */
 private val executorService: ExecutorService = Executors.newCachedThreadPool()

 private var inputImageWidth: Int = 0 // will be inferred from TF Lite model
 private var inputImageHeight: Int = 0 // will be inferred from TF Lite model
 private var modelInputSize: Int = 0 // will be inferred from TF Lite model

 fun initialize(): Task<Void> {
   return call(
     executorService,
     Callable<Void> {
       initializeInterpreter()
       null
     }
   )
 }

 @Throws(IOException::class)
 private fun initializeInterpreter() {
   // Load the TF Lite model
   val assetManager = context.assets
   val model = loadModelFile(assetManager)

   // Initialize TF Lite Interpreter with NNAPI enabled
   val options = Interpreter.Options()
   options.setUseNNAPI(true)
   val interpreter = Interpreter(model, options)

   // Read input shape from model file
   val inputShape = interpreter.getInputTensor(0).shape()
   inputImageWidth = inputShape[1]
   inputImageHeight = inputShape[2]
   modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE

   // Finish interpreter initialization
   this.interpreter = interpreter
   isInitialized = true
   Log.d(TAG, "初始化TFLite推断人.")
 }

 @Throws(IOException::class)
 private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
   val fileDescriptor = assetManager.openFd(MODEL_FILE)
   val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
   val fileChannel = inputStream.channel
   val startOffset = fileDescriptor.startOffset
   val declaredLength = fileDescriptor.declaredLength
   return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
 }

 private fun classify(bitmap: Bitmap): String {
   if (!isInitialized) {
     throw IllegalStateException("TF Lite Interpreter is not initialized yet.")
   }

   var startTime: Long
   var elapsedTime: Long

   // Preprocessing: resize the input
   startTime = System.nanoTime()
   val resizedImage = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true)
   val byteBuffer = convertBitmapToByteBuffer(resizedImage)
   elapsedTime = (System.nanoTime() - startTime) / 1000000
   Log.d(TAG, "预处理时间 = " + elapsedTime + "ms")

   startTime = System.nanoTime()
   val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
   interpreter?.run(byteBuffer, result)
   elapsedTime = (System.nanoTime() - startTime) / 1000000
   Log.d(TAG, "推断的时间 = " + elapsedTime + "ms")

   return getOutputString(result[0])
 }

 fun classifyAsync(bitmap: Bitmap): Task<String> {
   return call(executorService, Callable<String> { classify(bitmap) })
 }

 fun close() {
   call(
     executorService,
     Callable<String> {
       interpreter?.close()
       Log.d(TAG, "关闭TF推断人.")
       null
     }
   )
 }

 private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
   val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
   byteBuffer.order(ByteOrder.nativeOrder())

   val pixels = IntArray(inputImageWidth * inputImageHeight)
   bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

   for (pixelValue in pixels) {
     val r = (pixelValue shr 16 and 0xFF)
     val g = (pixelValue shr 8 and 0xFF)
     val b = (pixelValue and 0xFF)

     // Convert RGB to grayscale and normalize pixel value to [0..1]
     val normalizedPixelValue = (r + g + b) / 3.0f / 255.0f
     byteBuffer.putFloat(normalizedPixelValue)
   }

   return byteBuffer
 }

 private fun getOutputString(output: FloatArray): String {
   val maxIndex = output.indices.maxBy { output[it] } ?: -1
   return "推断结果: %d\nConfidence: %2f".format(maxIndex, output[maxIndex])
 }

 companion object {
   private const val TAG = "DigitClassifier"
   private const val MODEL_FILE = "mnist.tflite"
   private const val FLOAT_TYPE_SIZE = 4
   private const val PIXEL_SIZE = 1
   private const val OUTPUT_CLASSES_COUNT = 10
 }
}

主要任务是实现数字分类model的新建,加载,通过调用mnist.tflite的模型,进行数字的预测,预测默认给出10个预测值,获取最高的预测结果。

3.手机画个手机输入的面板,输入0-9的数字,进行判断

package org.tensorflow.lite.examples.digitclassifier

import android.annotation.SuppressLint
import android.graphics.Color
import android.os.Bundle
import androidx.appcompat.app.AppCompatActivity
import android.util.Log
import android.view.MotionEvent
import android.widget.Button
import android.widget.TextView
import com.divyanshu.draw.widget.DrawView

class MainActivity : AppCompatActivity() {

 private var drawView: DrawView? = null
 private var clearButton: Button? = null
 private var predictedTextView: TextView? = null
 private var digitClassifier = DigitClassifier(this)

 @SuppressLint("ClickableViewAccessibility")
 override fun onCreate(savedInstanceState: Bundle?) {
   super.onCreate(savedInstanceState)
   setContentView(R.layout.tfe_dc_activity_main)

   // Setup view instances
   drawView = findViewById(R.id.draw_view)
   drawView?.setStrokeWidth(50.0f)
   drawView?.setColor(Color.WHITE)
   drawView?.setBackgroundColor(Color.BLACK)
   clearButton = findViewById(R.id.clear_button)
   predictedTextView = findViewById(R.id.predicted_text)

   // Setup clear drawing button
   clearButton?.setOnClickListener {
     drawView?.clearCanvas()
     predictedTextView?.text = getString(R.string.tfe_dc_prediction_text_placeholder)
   }

   // Setup classification trigger so that it classify after every stroke drew
   drawView?.setOnTouchListener { _, event ->
     // As we have interrupted DrawView's touch event,
     // we first need to pass touch events through to the instance for the drawing to show up
     drawView?.onTouchEvent(event)

     // Then if user finished a touch event, run classification
     if (event.action == MotionEvent.ACTION_UP) {
       classifyDrawing()
     }
     true
   }

   // Setup digit classifier
   digitClassifier
     .initialize()
     .addOnFailureListener { e -> Log.e(TAG, "Error to setting up digit classifier.", e) }
 }

 override fun onDestroy() {
   digitClassifier.close()
   super.onDestroy()
 }

 private fun classifyDrawing() {
   val bitmap = drawView?.getBitmap()
   if ((bitmap != null) && (digitClassifier.isInitialized)) {
     digitClassifier
       .classifyAsync(bitmap)
       .addOnSuccessListener { resultText -> predictedTextView?.text = resultText }
       .addOnFailureListener { e ->
         predictedTextView?.text = getString(
           R.string.tfe_dc_classification_error_message,
           e.localizedMessage
         )
         Log.e(TAG, "Error classifying drawing.", e)
       }
   }
 }

 companion object {
   private const val TAG = "MainActivity"
 }
}

4.运行效果:




大成带你用TFL数字分类识别的评论 (共 条)

分享到微博请遵守国家法律