添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
  • TensorFlow Datasets 数据集载入
  • Swift for TensorFlow (S4TF) (Huan)
  • TensorFlow Quantum: 混合量子-经典机器学习 *
  • 强化学习简介
  • 使用Docker部署TensorFlow环境
  • 在云端使用TensorFlow
  • 部署自己的交互式Python开发环境JupyterLab
  • 参考资料与推荐阅读
  • 术语中英对照表
  • TensorFlow概述
  • TensorFlow 安裝與環境配置
  • TensorFlow 基礎
  • TensorFlow 模型建立與訓練
  • TensorFlow常用模組
  • TensorFlow模型匯出
  • TensorFlow Serving
  • TensorFlow Lite(Jinpeng)
  • TensorFlow in JavaScript(Huan)
  • 大規模訓練與加速

  • TensorFlow分布式訓練
  • 使用TPU訓練TensorFlow模型(Huan)
  • TensorFlow Hub 模型複用(Jinpeng)
  • TensorFlow Datasets 資料集載入
  • Swift for TensorFlow (S4TF) (Huan)
  • TensorFlow Quantum: 混合量子-經典機器學習 *
  • 強化學習簡介
  • 使用Docker部署TensorFlow環境
  • 在雲端使用TensorFlow
  • 部署自己的互動式 Python 開發環境 JupyterLab
  • 參考資料與推薦閱讀
  • 專有名詞中英對照表
  • Preface

  • Preface
  • TensorFlow Overview
  • Basic

  • Installation and Environment Configuration
  • TensorFlow Basic
  • Model Construction and Training
  • Common Modules in TensorFlow
  • Deployment

  • TensorFlow Model Export
  • TensorFlow Serving
  • Large-scale Training

  • Distributed training with TensorFlow
  • Extensions

  • TensorFlow Datasets: Ready-to-use Datasets
  • TensorFlow Quantum: Hybrid Quantum-classical Machine Learning *
  • Converting models

    由于移动设备空间和计算能力受限,使用TensorFlow训练好的模型,模型太大、运行效率比较低,不能直接在移动端部署。

    故在移动端部署的时候,需要使用 tflight_convert 转化格式,其在通过pip安装TensorFlow时一起安装。 tflight_convert 会把原模型转换为FlatBuffer格式。

    在终端执行如下命令:

    tflight_convert -h
    

    输出结果如下,即该命令的使用方法:

    usage: tflite_convert [-h] --output_file OUTPUT_FILE
                          (--graph_def_file GRAPH_DEF_FILE | --saved_model_dir SAVED_MODEL_DIR | --keras_model_file KERAS_MODEL_FILE)
                          [--output_format {TFLITE,GRAPHVIZ_DOT}]
                          [--inference_type {FLOAT,QUANTIZED_UINT8}]
                          [--inference_input_type {FLOAT,QUANTIZED_UINT8}]
                          [--input_arrays INPUT_ARRAYS]
                          [--input_shapes INPUT_SHAPES]
                          [--output_arrays OUTPUT_ARRAYS]
                          [--saved_model_tag_set SAVED_MODEL_TAG_SET]
                          [--saved_model_signature_key SAVED_MODEL_SIGNATURE_KEY]
                          [--std_dev_values STD_DEV_VALUES]
                          [--mean_values MEAN_VALUES]
                          [--default_ranges_min DEFAULT_RANGES_MIN]
                          [--default_ranges_max DEFAULT_RANGES_MAX]
                          [--post_training_quantize] [--drop_control_dependency]
                          [--reorder_across_fake_quant]
                          [--change_concat_input_ranges {TRUE,FALSE}]
                          [--allow_custom_ops] [--target_ops TARGET_OPS]
                          [--dump_graphviz_dir DUMP_GRAPHVIZ_DIR]
                          [--dump_graphviz_video]
    

    模型的导出:Keras Sequential save方法中产生的模型文件,可以使用如下命令处理:

    tflite_convert --keras_model_file=./mnist_cnn.h5 --output_file=./mnist_cnn.tflite
    

    到此,我们已经得到一个可以运行的TensorFlow Lite模型了,即 mnist_cnn.tflite

    这里只介绍了keras HDF5格式模型的转换,其他模型转换建议参考:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/convert/cmdline_examples.md

    Converting models with Quantization

    还有一种quantization的转化方法,这种转化命令如下:

    tflite_convert \
      --output_file=keras_mnist_quantized_uint8.tflite \
      --keras_model_file=mnist_cnn.h5 \
      --inference_type=QUANTIZED_UINT8 \
      --mean_values=128 \
      --std_dev_values=127 \
      --default_ranges_min=0 \
      --default_ranges_max=255 \
      --input_arrays=conv2d_1_input \
      --output_arrays=dense_2/Softmax
    

    细心的读者肯定会问,上图中有很多参数是怎么来的呢?我们可以使用 tflite_convert 获得模型具体结构,命令如下:

    tflite_convert \
      --output_file=keras_mnist.dot \
      --output_format=GRAPHVIZ_DOT \
      --keras_model_file=mnist_cnn.h5
    

    dot是一种graph description language,可以用graphz的dot命令转化为pdf或png等可视化图。

    dot -Tpng -O keras_mnist.dot
    

    这样就转化为一张图了,如下:

    很明显的可以看到如下信息:

    conv2d_1_input
    Type: Float [1×28×28×1]
    MinMax: [0, 255]
    
    dense_2/Softmax
    Type: Float [1×10]
    

    因此,可以知道

    --input_arrays 就是 conv2d_1_input

    --output_arrays 就是 dense_2/Softmax

    --default_ranges_min 就是 0

    --default_ranges_max 就是 255

    关于 --mean_values--std_dev_values 的用途:

    QUANTIZED_UINT8的quantized模型期望的输入是[0,255], 需要有个跟原始的float类型输入有个对应关系。
    mean_values和std_dev_values就是为了实现这个对应关系
    mean_values对应float的float_min
    std_dev_values对应255 / (float_max - float_min)
    

    因此,可以知道

    --mean_values 就是 0

    --std_dev_values 就是 1

    Deployment on Android

    现在开始在Android环境部署,对于国内的读者,需要先给Android Studio配置proxy,因为gradle编译环境需要获取相应的资源,请大家自行解决,这里不再赘述。

    配置app/build.gradle

    新建一个Android Project,打开 app/build.gradle 添加如下信息:

    android {
        aaptOptions {
            noCompress "tflite"
    repositories {
        maven {
            url 'https://google.bintray.com/tensorflow'
    dependencies {
        implementation 'org.tensorflow:tensorflow-lite:+'
    
  • aaptOptions 设置tflite文件不压缩,确保后面tflite文件可以被Interpreter正确加载。

  • org.tensorflow:tensorflow-lite 的最新版本号可以在这里查询 https://bintray.com/google/tensorflow/tensorflow-lite

  • 设置好后,sync和build整个工程,如果build成功说明,配置成功。

    添加tflite文件到assets文件夹

    在app目录先新建assets目录,并将 mnist_cnn.tflite 文件保存到assets目录。重新编译apk,检查新编译出来的apk的assets文件夹是否有 mnist_cnn.tflite 文件。

    使用apk analyzer查看新编译出来的apk,存在如下目录即编译打包成功:

    assets
         |__mnist_cnn.tflite
    

    使用如下函数将 mnist_cnn.tflite 文件加载到memory-map中,作为Interpreter实例化的输入

    private static final String MODEL_PATH = "mnist_cnn.tflite";
    /** Memory-map the model file in Assets. */
    private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    

    实例化Interpreter,其中this为当前acitivity

    tflite = new Interpreter(loadModelFile(this));
    

    我们使用mnist test测试集中的某张图片作为输入,mnist图像大小28*28,单像素。这样我们输入的数据需要设置成如下格式

    /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
    private ByteBuffer imgData = null;
    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_PIXEL_SIZE = 1;
    private static final int DIM_IMG_WIDTH = 28;
    private static final int DIM_IMG_HEIGHT = 28;
    protected void onCreate() {
        imgData = ByteBuffer.allocateDirect(
            4 * DIM_BATCH_SIZE * DIM_IMG_WIDTH * DIM_IMG_HEIGHT * DIM_PIXEL_SIZE);
        imgData.order(ByteOrder.nativeOrder());
    

    将mnist图片转化成 ByteBuffer ,并保持到 imgData

    /** Preallocated buffers for storing image data in. */
    private int[] intValues = new int[DIM_IMG_WIDTH * DIM_IMG_HEIGHT];
    /** Writes Image data into a {@code ByteBuffer}. */
    private void convertBitmapToByteBuffer(Bitmap bitmap) {
        if (imgData == null) {
            return;
        // Rewinds this buffer. The position is set to zero and the mark is discarded.
        imgData.rewind();
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        // Convert the image to floating point.
        int pixel = 0;
        for (int i = 0; i < DIM_IMG_WIDTH; ++i) {
            for (int j = 0; j < DIM_IMG_HEIGHT; ++j) {
                final int val = intValues[pixel++];
                imgData.putFloat(val);
    

    convertBitmapToByteBuffer 的输出即为模型运行的输入。

    定义一个1*10的多维数组,因为我们只有1个batch和10个label(TODO:need double check),具体代码如下

    private float[][] labelProbArray = new float[1][10];
    

    运行结束后,每个二级元素都是一个label的概率。

    运行及结果处理

    开始运行模型,具体代码如下

    tflite.run(imgData, labelProbArray);
    

    针对某个图片,运行后 labelProbArray 的内容如下,也就是各个label识别的概率

    index 0 prob is 0.0
    index 1 prob is 0.0
    index 2 prob is 0.0
    index 3 prob is 1.0
    index 4 prob is 0.0
    index 6 prob is 0.0
    index 7 prob is 0.0
    index 8 prob is 0.0
    index 9 prob is 0.0
    

    接下来,我们要做的就是根据对这些概率进行排序,找出Top的label并界面呈现给用户.