TfLite: mcu代码分析

2024-06-03 14:58
文章标签 分析 代码 mcu tflite

本文主要是介绍TfLite: mcu代码分析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

micro_framework和lite_framework的对比找到公共部分

cc_library(
    name = "micro_framework",
    srcs = [ 
        "micro_error_reporter.cc",
        "micro_interpreter.cc",
        "micro_mutable_op_resolver.cc",
        "simple_tensor_allocator.cc",
    ],  
    hdrs = [ 
        "compatibility.h",
        "micro_error_reporter.h",
        "micro_interpreter.h",
        "micro_mutable_op_resolver.h",
        "simple_tensor_allocator.h",
    ],  
    deps = [ 
        "//tensorflow/lite:schema_fbs_version",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/core/api",
        "//tensorflow/lite/schema:schema_fbs",

    ],  
)

cc_library(
    name = "framework",
    srcs = [ 
        "allocation.cc",
        "graph_info.cc",
        "interpreter.cc",
        "model.cc",
        "mutable_op_resolver.cc",
        "optional_debug_tools.cc",
        "stderr_reporter.cc",
    ] + select({
        "//tensorflow:android": [                                                                                                                                                                                  
            "nnapi_delegate.cc",
            "mmap_allocation.cc",
        ],  
        "//tensorflow:windows": [
            "nnapi_delegate_disabled.cc",
            "mmap_allocation_disabled.cc",
        ],  
        "//conditions:default": [
            "nnapi_delegate_disabled.cc",
            "mmap_allocation.cc",
        ],  
    }), 
    hdrs = [ 
        "allocation.h",
        "context.h",
        "context_util.h",
        "error_reporter.h",
        "graph_info.h",
        "interpreter.h",
        "model.h",
        "mutable_op_resolver.h",
        "nnapi_delegate.h",
        "op_resolver.h",
        "optional_debug_tools.h",
        "stderr_reporter.h",
    ], 

    deps = [
        ":arena_planner",
        ":graph_info",
        ":memory_planner",
        ":schema_fbs_version",
        ":simple_memory_arena",
        ":string",
        ":util",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/core/api",

        "//tensorflow/lite/kernels:eigen_support",
        "//tensorflow/lite/kernels:gemm_support",
        "//tensorflow/lite/nnapi:nnapi_lib",
        "//tensorflow/lite/profiling:profiler",
        "//tensorflow/lite/schema:schema_fbs",
    ] + select({
        ":with_tflite_flex": [
            "//tensorflow/lite/delegates/flex:delegate",
        ],
        "//conditions:default": [],
    }),
)

TfLite for mcu和TfLite reference之间的关系


1. tflite for mcu是和tflite是两个不同的 reference framework

2. tflite for mcu 相比tflire的framework 简单的多

3. 两者都依赖

 "//tensorflow/lite/core/api" and "//tensorflow/lite/c:c_api_internal"

 

 //tensorflow/lite/core/api

core/api提供的3个抽象接口类:error_report, op_resolver and BuiltinDataAllocator和使用这些类的接口

其中文件名flatbuffer_conversions.h和BuiltinDataAllocator类名有点歧义,BuiltinDataAllocator为保存从flatbuffer中得到 BuilitinData提供分配memory的方法。

具体的interpreter如TfLite 或者TfLite for mcu都有实现这三个基类的接口,这也许为什么叫core/api的原因

error_reporter.h

// A functor that reports error to supporting system. Invoked similar to
// printf.

// Subclass ErrorReporter to provide another reporting destination.
// For example, if you have a GUI program, you might redirect to a buffer
// that drives a GUI error log box.
class ErrorReporter {
 public:
  virtual ~ErrorReporter() {}
  virtual int Report(const char* format, va_list args) = 0;
  int Report(const char* format, ...);
  int ReportError(void*, const char* format, ...);
};

op_resolver.h

namespace tflite {

// Abstract interface that returns TfLiteRegistrations given op codes or custom
// op names
. This is the mechanism that ops being referenced in the flatbuffer
// model are mapped to executable function pointers (TfLiteRegistrations).
class OpResolver {
 public:
  // Finds the op registration for a builtin operator by enum code.
  virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                           int version) const = 0;
  // Finds the op registration of a custom operator by op name.
  virtual const TfLiteRegistration* FindOp(const char* op, 
                                           int version) const = 0;
  virtual ~OpResolver() {}
};

// Handles the logic for converting between an OperatorCode structure extracted
// from a flatbuffer and information about a registered operator implementation.
TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
                                       const OpResolver& op_resolver,
                                       ErrorReporter* error_reporter,
                                       const TfLiteRegistration** registration);

}  // namespace tflite

flatbuffer_conversions.h

  // These functions transform codes and data structures that are defined in the
  // flatbuffer serialization format into in-memory values that are used by the
  // runtime API and interpreter.

  // Interface class for builtin data allocations.
  class BuiltinDataAllocator {
   public:
    virtual void* Allocate(size_t size) = 0;
    virtual void Deallocate(void* data) = 0;
  
    // Allocate a structure, but make sure it is a POD structure that doesn't
    // require constructors to run. The reason we do this, is that Interpreter's C
    // extension part will take ownership so destructors  will not be run during
    // deallocation.
    template <typename T>
    T* AllocatePOD() {
      static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
      return static_cast<T*>(this->Allocate(sizeof(T)));
    }
  
    virtual ~BuiltinDataAllocator() {}
  };
  
  // Parse the appropriate data out of the op.
  // 如 卷积网络的stride_width、stride_height等
  // This handles builtin data explicitly as there are flatbuffer schemas.
  // If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
  // calling function has to pass in an allocator object, and this allocator
  // will be called to reserve space for the output data. If the calling
  // function's allocator reserves memory on the heap, then it's the calling
  // function's responsibility to free it.
  // If it returns kTfLiteError, `builtin_data` will be `nullptr`.
  TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
                           ErrorReporter* error_reporter,
                           BuiltinDataAllocator* allocator, void** builtin_data);
  
  // Converts the tensor data type used in the flat buffer to the representation
  // used by the runtime.
  TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
                                 ErrorReporter* error_reporter);

 

// Parse the appropriate data out of the op.
  //
  // This handles builtin data explicitly as there are flatbuffer schemas.
  // If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
  // need to be released by calling `free`.`
  // If it returns kTfLiteError, `builtin_data` will be `nullptr`.
  // 从flatbuffer格式的模型文件中parse出数据到 builtin_data
  // 该函数是个框架,依赖接口类的实现 BuiltinDataAllocator

  TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
                           ErrorReporter* error_reporter,
                           BuiltinDataAllocator* allocator, void** builtin_data) {

    *builtin_data = nullptr;
    switch (op_type) {
      case BuiltinOperator_CONV_2D: {
        TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
        if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
          params->padding = parse_padding(conv_params->padding());
          params->stride_width = conv_params->stride_w();
          params->stride_height = conv_params->stride_h();
          params->activation =
              parse_activation(conv_params->fused_activation_function());
  
          params->dilation_width_factor = conv_params->dilation_w_factor();
          params->dilation_height_factor = conv_params->dilation_h_factor();
        }
        *builtin_data = reinterpret_cast<void*>(params);
        break;
      }

 }

//tensorflow/lite/c:c_api_internal

cc_library(
    name = "c_api_internal",
    srcs = ["c_api_internal.c"],
    hdrs = [ 
        "builtin_op_data.h",
        "c_api_internal.h",
    ],  
    visibility = [ 
        "//tensorflow/contrib/lite:__subpackages__",
        "//tensorflow/lite:__subpackages__",
    ],  
)

c_api_internal.h

#ifdef __cplusplus
extern "C" {
#endif  // __cplusplus

#ifdef __cplusplus
}  // extern "C"
#endif  // __cplusplus

1]
c_api_internal.h 定义了tflite的主要数据结构
// This file defines a C API for implementing operations in tflite.
// These operations can be defined using c++ but the interface between
// the interpreter and the operations are C.
//
// Summary of abstractions
// TF_LITE_ENSURE - Self-sufficient error checking
// TfLiteStatus - Status reporting
// TfLiteIntArray - stores tensor shapes (dims),
// TfLiteContext - allows an op to access the tensors
// TfLiteTensor - tensor (a multidimensional array)
// TfLiteNode - a single node or operation
// TfLiteRegistration - the implementation of a conceptual operation.
//
// Some abstractions in this file are created and managed by Interpreter.

// Parameters for asymmetric quantization. Quantized values can be converted
// back to float using:
//    real_value = scale * (quantized_value - zero_point);
typedef struct {
  float scale;
  int32_t zero_point;
} TfLiteQuantizationParams;

c/c_api_internal.c
#ifndef TF_LITE_STATIC_MEMORY                                           
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#endif  // TF_LITE_STATIC_MEMORY
micro到底定义这个宏没?可不可以使用 malloc 是的 TF_LITE_STATIC_MEMORY关注下这个macro

builtin_op_data.h : op's builtin data

 

schema_fbs

# Generic schema for inference on device.
flatbuffer_cc_library(
    name = "schema_fbs",
    srcs = ["schema.fbs"],
)

schema_fbs_version


cc_library(
    name = "schema_fbs_version",
    hdrs = ["version.h"],
)
#define TFLITE_SCHEMA_VERSION (3)

 

TfLite for mcu 的core/api的实现

1] class MicroMutableOpResolver : public OpResolver

//增加了成员函数AddBuiltin and AddCustom和私有成员变量TfLiteRegistration数组和长度
class MicroMutableOpResolver : public OpResolver {                                                                                                                                                          
 public:
  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                   int version) const override;
  const TfLiteRegistration* FindOp(const char* op, int version) const override;
  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
                  int min_version = 1, int max_version = 1); 
  void AddCustom(const char* name, TfLiteRegistration* registration,
                 int min_version = 1, int max_version = 1); 

 private:
  TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX];
  int registrations_len_ = 0;

  TF_LITE_REMOVE_VIRTUAL_DELETE
};

Add/Find函数的实现比较简单,就是对数组TfLiteRegistration registrations_的操作
void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, 
                                        TfLiteRegistration* registration,
                                        int min_version, int max_version) {
  for (int version = min_version; version <= max_version; ++version) {
    if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
      // TODO(petewarden) - Add error reporting hooks so we can report this!
      return;
    }   
    TfLiteRegistration* new_registration = &registrations_[registrations_len_];
    registrations_len_ += 1;

    *new_registration = *registration;
    new_registration->builtin_code = op; 
    new_registration->version = version;
  }
}

2] micro_interpreter.cc class StackDataAllocator  public BuiltinDataAllocator

//StackDataAllocator就是返回静态数组的首地址,用于保存从flatbuffer中得到的builtin data
const int kStackDataAllocatorSize = 128;
class StackDataAllocator : public BuiltinDataAllocator {                                                                                                                                                  
   public:
    void* Allocate(size_t size) override {
      if (size > kStackDataAllocatorSize) {
        return nullptr;
      } else {
        return data_;
      }   
    }
    void Deallocate(void* data) override {
      // Do nothing.
    }
  
   private:
    uint8_t data_[kStackDataAllocatorSize];
  
    TF_LITE_REMOVE_VIRTUAL_DELETE
}; 

class SimpleTensorAllocator

class SimpleTensorAllocator 是另一个分配memory的class和BuiltinDataAllocator并列
// This allocator never frees up or reuses  any memory, even
// though we have enough information about lifetimes of the tensors to do so.
// This makes it pretty wasteful, so we should use a more intelligent method.
class SimpleTensorAllocator {
 public:
  SimpleTensorAllocator(uint8_t* buffer, int buffer_size)
      : data_size_(0), data_size_max_(buffer_size), data_(buffer) {}

  TfLiteStatus AllocateTensor(
      const tflite::Tensor& flatbuffer_tensor, int create_before,
      int destroy_after,
      const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
      ErrorReporter* error_reporter, TfLiteTensor* result);

  uint8_t* AllocateMemory(size_t size, size_t alignment); 

  int GetDataSize() const { return data_size_; }

 private:
  int data_size_;
  int data_size_max_;
  uint8_t* data_;
};

1] 构造函数提供了SimpleTensorAllocator(uint8_t* buffer, int buffer_size)memory的开始地址和大小
2] uint8_t* AllocateMemory(size_t size, size_t alignment);用于分配出一块内存并返回开始地址
3] AllocateTensor分配内存并赋值,从flatbuffer中取出,写入这里分配的 内存

AllocateTensor

TfLiteStatus SimpleTensorAllocator::AllocateTensor(
      const tflite::Tensor& flatbuffer_tensor, int create_before,
      int destroy_after,
      const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, 
      ErrorReporter* error_reporter, TfLiteTensor* result) {
  
  //tflite::Tensor(来着模型文件) 转换成TfLiteTensor
  //1.flatbuffer tensor type -> TfLiteTensor type  

    TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),
                                            &result->type, error_reporter));
  //2.flatbuffer tensor is_variable
    result->is_variable = flatbuffer_tensor.is_variable();
  
    result->data.raw = nullptr;
    result->bytes = 0;
  
  //3.flatbuffer tensor bufers(index) []
    if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) {
      //3.1 get the tensor data and size
      if (auto* array = buffer->data()) {
        if (size_t array_size = array->size()) {
          result->data.raw =
              const_cast<char*>(reinterpret_cast<const char*>(array->data()));
          size_t type_size;
          TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, array_size,
                                              &result->bytes, &type_size,
                                              error_reporter));
        }
      }
    }
    if (result->data.raw) {//TfLite type
      result->allocation_type = kTfLiteMmapRo;
    } else {
      int data_size = 1;
  //4.0 shape size 
      for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
        data_size *= flatbuffer_tensor.shape()->Get(n);
      }
      size_t type_size;
  //4.1 type size
      TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, data_size,
                                          &result->bytes, &type_size,
                                          error_reporter));
  //4.2 Allocate memory: based on shape and tpye
      result->data.raw =
          reinterpret_cast<char*>(AllocateMemory(result->bytes, type_size));
      if (result->data.raw == nullptr) {
        const char* tensor_name = flatbuffer_tensor.name()->c_str();
        if (tensor_name == nullptr) {
          tensor_name = "<None>";
        }
        error_reporter->Report(
            "Couldn't allocate memory for tensor '%s', wanted %d bytes but only "
            "%d were available",
            tensor_name, result->bytes, (data_size_max_ - data_size_));
        return kTfLiteError;
      }
      result->allocation_type = kTfLiteArenaRw;
    }
  //4.3 store tensorshape
    result->dims = reinterpret_cast<TfLiteIntArray*>(AllocateMemory(
        sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1), sizeof(int)));
    result->dims->size = flatbuffer_tensor.shape()->Length();
    for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
      result->dims->data[n] = flatbuffer_tensor.shape()->Get(n);
    }
  //4.4 tensor quantization
    if (flatbuffer_tensor.quantization()) {
      result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0);
      result->params.zero_point =
          flatbuffer_tensor.quantization()->zero_point()->Get(0);
    }
    result->allocation = nullptr;
  //4.5 tensor name
    if (flatbuffer_tensor.name()) {
      result->name = flatbuffer_tensor.name()->c_str();
    } else {
      result->name = "<No name>";
    }
    result->delegate = nullptr;
    result->buffer_handle = 0;
    result->data_is_stale = false;
    return kTfLiteOk;
  } 

TfLiteIntArray的内存分配

#ifndef TF_LITE_STATIC_MEMORY
  TfLiteIntArray* TfLiteIntArrayCreate(int size) {
    TfLiteIntArray* ret =
        (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size));
    ret->size = size;
    return ret;
  }
#endif
而在mcu中,就是原来是malloc的都用simple_tensor_allocator的AllocateMemory实现
reinterpret_cast<TfLiteIntArray*>(AllocateMemory(
        sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1), sizeof(int)));

如果没有quantization,都会crash吗?

//另外,遇到的问题是没有quantization的模型文件但是flatbuffer_tensor.quantization()非空
//但后面的访问无效导致crash

if (flatbuffer_tensor.quantization()) {  //没有 quantization这里的值为什么是true?                                                                                                                                                          
      result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0);
      result->params.zero_point = flatbuffer_tensor.quantization()->zero_point()->Get(0);

micro_interpreter.cc

class MicroInterpreter

   class MicroInterpreter {
   public:
    MicroInterpreter(const Model* model, const OpResolver& op_resolver,
                     SimpleTensorAllocator* tensor_allocator,
                     ErrorReporter* error_reporter);
  
    TfLiteStatus Invoke();
  
    size_t tensors_size() const { return context_.tensors_size; }
    TfLiteTensor* tensor(int tensor_index);
  
    TfLiteTensor* input(int index);
    size_t inputs_size() const { return subgraph_->inputs()->Length(); }
  
    TfLiteTensor* output(int index);
    size_t outputs_size() const { return subgraph_->outputs()->Length(); }
  
    TfLiteStatus initialization_status() const { return initialization_status_; }
  
    ErrorReporter* error_reporter() { return error_reporter_; }
  
   private:
    const Model* model_;
    const OpResolver& op_resolver_;
    SimpleTensorAllocator* tensor_allocator_;
    ErrorReporter* error_reporter_;

  
    TfLiteStatus initialization_status_;
    const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors_;
    const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators_;
    TfLiteContext context_;
  
    const SubGraph* subgraph_;           
  };

构造函数从模型文件中获得数据 赋值到 TfLite环境


MicroInterpreter::MicroInterpreter(const Model* model,
                                     const OpResolver& op_resolver,
                                   SimpleTensorAllocator* tensor_allocator,
                                     ErrorReporter* error_reporter)
      : model_(model),
        op_resolver_(op_resolver),
        tensor_allocator_(tensor_allocator),
        error_reporter_(error_reporter),

        initialization_status_(kTfLiteOk) {
  //1] get data from flatbuffers
    const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
        model->buffers();
    auto* subgraphs = model->subgraphs();
    if (subgraphs->size() != 1) {
      error_reporter->Report("Only 1 subgraph is currently supported.\n");
      initialization_status_ = kTfLiteError;
      return;
    }
    subgraph_ = (*subgraphs)[0];
    tensors_ = subgraph_->tensors();
    operators_ = subgraph_->operators();
    ---
}

Invoke 推断函数

TfLiteStatus MicroInterpreter::Invoke() {
  //1. get operator codes from flatbuffers
    auto opcodes = model_->operator_codes();
    for (int i = 0; i < operators_->Length(); ++i) {
          const auto* op = operators_->Get(i);
          int index = op->opcode_index();
          auto opcode = (*opcodes)[index];
          const TfLiteRegistration* registration = nullptr;
          //2.1 get registration based on opcode
          status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
                                         &registration);

          //2.2 from registration getting BuiltinOperator
          BuiltinOperator op_type =
          static_cast<BuiltinOperator>(registration->builtin_code);

          //2.3 get init_data  
          StackDataAllocator stack_data_allocator;
          unsigned char* builtin_data = nullptr;
          {// get builtin_ops data 
        TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
                                          &stack_data_allocator,
                                          (void**)(&builtin_data)));
          }
          //2.4 call registration init function and return user data
          void* user_data = nullptr;
          if (registration->init) {
        user_data = registration->init(&context_, init_data, init_data_size);
          }

          //2.5 op's inputs array: TfLiteIntArray :node的输入
          const int kMaxInputs = 16;
          int inputs_data[kMaxInputs + 1];                                                                                                             
          TfLiteIntArray* inputs_array =
          reinterpret_cast<TfLiteIntArray*>(inputs_data);
          inputs_array->size = op->inputs()->Length();
          for (int n = 0; n < op->inputs()->Length(); ++n) {
        inputs_array->data[n] = op->inputs()->Get(n);
          }
      
          //2.6 op's onputs array: TfLiteIntArray: node的输出
          const int kMaxOutputs = 16;
          int outputs_data[kMaxOutputs + 1];
          TfLiteIntArray* outputs_array =
          reinterpret_cast<TfLiteIntArray*>(outputs_data);
          outputs_array->size = op->outputs()->Length();
          for (int n = 0; n < op->outputs()->Length(); ++n) {
        outputs_array->data[n] = op->outputs()->Get(n);
          }

          //2.7 op's temp array: TfLiteIntArray   node的临时使用array
          const int kMaxTemporaries = 16;
          int temporaries_data[kMaxTemporaries + 1];
          TfLiteIntArray* temporaries_array =
          reinterpret_cast<TfLiteIntArray*>(temporaries_data);
          temporaries_array->size = 0;

          //2.8
          //2.8.1 上面所做的一切都是为了这里,value: TfLiteNode
          TfLiteNode node;
          node.inputs = inputs_array;
          node.outputs = outputs_array;
          node.temporaries = temporaries_array;
          node.user_data = user_data;
          node.builtin_data = reinterpret_cast<void*>(builtin_data);
          node.custom_initial_data = custom_data;
          node.custom_initial_data_size = custom_data_size;
          node.delegate = nullptr;                                                                                                   
          //2.8.2 prepare, invoke and free
          if (registration->prepare) {
        TfLiteStatus prepare_status = registration->prepare(&context_, &node);
          }
      
          if (registration->invoke) {
        TfLiteStatus invoke_status = registration->invoke(&context_, &node);
          }
      
          if (registration->free) {
        registration->free(&context_, user_data);
          }
  }
}

micro上实现的ops

cc_library(
    name = "all_ops_resolver", //对外的接口,具体实现是依赖: micro_ops
    srcs = [                                                                                                                          
        "all_ops_resolver.cc",
    ],
    hdrs = [
        "all_ops_resolver.h",
    ],
    copts = tflite_copts(),
    deps = [ 
        ":micro_ops",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/experimental/micro:micro_framework",
    ],  
)

 

class AllOpsResolver : public MicroMutableOpResolver

class AllOpsResolver : public MicroMutableOpResolver {
 public:
  AllOpsResolver();

 private:
  TF_LITE_REMOVE_VIRTUAL_DELETE
};

all_ops_resolver.cc

TfLiteRegistration* Register_SOFTMAX(); //没有包含对应头文件声明函数,这里直接声明为外部函数
TfLiteRegistration* Micro_Register_SOFTMAX() { return Register_SOFTMAX(); }

AllOpsResolver::AllOpsResolver() {//在构造函数中调用基类函数
  AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,                                                                                                                               Micro_Register_DEPTHWISE_CONV_2D());
  AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Micro_Register_FULLY_CONNECTED(),
             /* min_version */ 1,
             /* max_version */ 2); 
  AddBuiltin(BuiltinOperator_SOFTMAX, Micro_Register_SOFTMAX());
}


Register_SOFTMAX的具体实现

每个op都会返回标准的 TfLiteRegistration

softmax.cc

 {

  }  // namespace activations
  
  TfLiteRegistration* Register_SOFTMAX() {
    static TfLiteRegistration r = {activations::Init, activations::Free,
                                   activations::SoftmaxPrepare,
                                   activations::SoftmaxEval};
    return &r; 
  }
 

mico ops的实现

cc_library(
    name = "micro_ops",
    srcs = [
        "depthwise_conv.cc",
        "fully_connected.cc",
        "softmax.cc",
    ],
    hdrs = [
    ],
    copts = tflite_copts(),
    deps = [
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/experimental/micro:micro_framework",
        "//tensorflow/lite/kernels:kernel_util",
        "//tensorflow/lite/kernels:op_macros",
        "//tensorflow/lite/kernels:padding",
        "//tensorflow/lite/kernels/internal:quantization_util",
        "//tensorflow/lite/kernels/internal:reference_base",
        "//tensorflow/lite/kernels/internal:tensor",
    ],
)

对tfltie kernel的依赖

1. kernel_util
cc_library(
    name = "kernel_util",
    srcs = [
        "kernel_util.cc",                                                                                                                                                                                   
    ],   
    hdrs = [
        "kernel_util.h",
    ],   
    deps = [
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/kernels/internal:round",
        "//tensorflow/lite/kernels/internal:types",
    ],   
)

kernel_util: 基本操作如: getInput/Output of TfLiteNode

2. op_macros
cc_library(
    name = "op_macros",
    hdrs = [
        "op_macros.h",                                                                                                                                                                                      
    ],   
)

// If we're on a platform without standard IO functions, fall back to a
// non-portable function.
#ifdef TF_LITE_MCU_DEBUG_LOG

#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"

#define DEBUG_LOG(x) \
  do {               \
    DebugLog(x);     \
  } while (0)

inline void InfiniteLoop() {
  DEBUG_LOG("HALTED\n");
  while (1) {
  }
}
#define TFLITE_ASSERT_FALSE InfiniteLoop();
#define TFLITE_ABORT InfiniteLoop();

#else  // TF_LITE_MCU_DEBUG_LOG
#endif

3. padding
cc_library(
    name = "padding",
    srcs = [],
    hdrs = ["padding.h"],
    deps = [
        "//tensorflow/lite/c:c_api_internal",
    ],
)

4. internal:tensor
cc_library(
    name = "tensor",
    hdrs = [ 
        "tensor.h",
        "tensor_ctypes.h",
    ],  
    deps = [ 
        ":types",
        "//tensorflow/lite/c:c_api_internal",
    ],  
)

5. 
cc_library(
    name = "reference_base",
    srcs = [], 
    hdrs = [ 
        "common.h",
        "reference/depthwiseconv_float.h",
        "reference/depthwiseconv_uint8.h",
        "reference/fully_connected.h",
        "reference/reference_ops.h",
        "reference/softmax.h",
    ],  
    deps = [ 
        ":quantization_util",
        ":round",
        ":strided_slice_logic",
        ":types",
        "@gemmlowp",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/kernels:op_macros",
    ] + select({
        ":x86_64": tflite_deps_intel,
        "//conditions:default": [], 
    }), 
)

6.
cc_library(
    name = "quantization_util",
    srcs = ["quantization_util.cc"],                                                                                                                                                                        
    hdrs = [
        "compatibility.h",
        "quantization_util.h",
    ],
    deps = [
        ":round",
        ":types",
        "//tensorflow/lite/kernels:op_macros",
    ],
)
 

TfLite for mcu 和TfLite的op的实现哪里不同

这篇关于TfLite: mcu代码分析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1027317

相关文章

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java调用Python代码的几种方法小结

《Java调用Python代码的几种方法小结》Python语言有丰富的系统管理、数据处理、统计类软件包,因此从java应用中调用Python代码的需求很常见、实用,本文介绍几种方法从java调用Pyt... 目录引言Java core使用ProcessBuilder使用Java脚本引擎总结引言python

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

Redis主从复制实现原理分析

《Redis主从复制实现原理分析》Redis主从复制通过Sync和CommandPropagate阶段实现数据同步,2.8版本后引入Psync指令,根据复制偏移量进行全量或部分同步,优化了数据传输效率... 目录Redis主DodMIK从复制实现原理实现原理Psync: 2.8版本后总结Redis主从复制实

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

SpringBoot使用注解集成Redis缓存的示例代码

《SpringBoot使用注解集成Redis缓存的示例代码》:本文主要介绍在SpringBoot中使用注解集成Redis缓存的步骤,包括添加依赖、创建相关配置类、需要缓存数据的类(Tes... 目录一、创建 Caching 配置类二、创建需要缓存数据的类三、测试方法Spring Boot 熟悉后,集成一个外

锐捷和腾达哪个好? 两个品牌路由器对比分析

《锐捷和腾达哪个好?两个品牌路由器对比分析》在选择路由器时,Tenda和锐捷都是备受关注的品牌,各自有独特的产品特点和市场定位,选择哪个品牌的路由器更合适,实际上取决于你的具体需求和使用场景,我们从... 在选购路由器时,锐捷和腾达都是市场上备受关注的品牌,但它们的定位和特点却有所不同。锐捷更偏向企业级和专

轻松掌握python的dataclass让你的代码更简洁优雅

《轻松掌握python的dataclass让你的代码更简洁优雅》本文总结了几个我在使用Python的dataclass时常用的技巧,dataclass装饰器可以帮助我们简化数据类的定义过程,包括设置默... 目录1. 传统的类定义方式2. dataclass装饰器定义类2.1. 默认值2.2. 隐藏敏感信息

opencv实现像素统计的示例代码

《opencv实现像素统计的示例代码》本文介绍了OpenCV中统计图像像素信息的常用方法和函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 统计像素值的基本信息2. 统计像素值的直方图3. 统计像素值的总和4. 统计非零像素的数量

IDEA常用插件之代码扫描SonarLint详解

《IDEA常用插件之代码扫描SonarLint详解》SonarLint是一款用于代码扫描的插件,可以帮助查找隐藏的bug,下载并安装插件后,右键点击项目并选择“Analyze”、“Analyzewit... 目录SonajavascriptrLint 查找隐藏的bug下载安装插件扫描代码查看结果总结Sona