tensorflow:添加新OP和新设备

13 分钟阅读

  • REGISTER_OP,注册一个OP,其实也是声明一个OP
  • REGISTER_KERNEL_BUILDER,注册一个Kernel,其实就是对OP的实现
  • REGISTER_LOCAL_DEVICE_FACTORY,添加设备工厂

一、注册OP

REGISTER_OP("Concat")
    .Input("concat_dim: int32")
    .Input("values: N * T")
    .Output("output: T")
    .Attr("N: int >= 2")
    .Attr("T: type")
    .SetShapeFn([](InferenceContext* c) {
      return shape_inference::ConcatShape(c, c->num_inputs() - 1);
    });

以Concat为例,它描述维度合并算法,比如[1,2,3]与[4,2,3]在0维度合并,则输出为[5,2,3]。

输入输出

.Input.Output 代表输入输出,形式如<name> : <io-type-expr>

  • <type>, 基本类型比如float,int32,string等等

  • <attr-type>, 属性类型,比如:

    REGISTER_OP("PolymorphicSingleInput")
          .Attr("T: type")
          .Input("in: T);
    REGISTER_OP("RestrictedPolymorphicSingleInput")
          .Attr("T: {int32, int64}")
          .Input("in: T);
    REGISTER_OP("ArbitraryTensorSequenceExample")
          .Attr("T: list(type)")
          .Input("in: T")
          .Output("out: T");
    REGISTER_OP("RestrictedTensorSequenceExample")
          .Attr("T: list({int32, int64})")
          .Input("in: T")
          .Output("out: T");
    
  • <number> * <type>,一组相同类型的tensor

属性

.Attr 描述属性,形式如<name>: <attr-type-expr>,比如上面的N与T,可以在Kernel构造函数中这样获得:

OP_REQUIRES_OK(context, context->GetAttr("N", &N_));
OP_REQUIRES_OK(context, context->GetAttr("T", &T_));

属性类型:

  • string, int, float, bool
  • type : DataType中的非引用类型
  • shape : TensorShapeProto
  • tensor :TensorProto
  • list(<type>) :以上类型的列表

约束条件:

  • {<type1>, <type2>},必须是type1或type2中的一种类型

  • {'<string1>', '<string2>'},必须是字符串,且是string1或者string2中的一个

  • numbertype,数字类型

  • realnumbertype, 不支持复杂类型的numbertype

  • quantizedtype, 只支持量化数值类型

  • int >= 2,必须是int类型,取值大于等于2

  • list(<type>) >= 2, 列表长度必须大于等于2,比如:

    REGISTER_OP("TypeListExample")
          .Attr("a: list({int32, float}) >= 3");
    
  • = <default>,设置默认值,比如:

    REGISTER_OP("AttrDefaultExampleForAllTypes")
       .Attr("s: string = 'foo'")
       .Attr("i: int = 0")
       .Attr("f: float = 1.0")
       .Attr("b: bool = true")
       .Attr("ty: type = DT_INT32")
       .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
       .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
       .Attr("l_empty: list(int) = []")
       .Attr("l_int: list(int) = [2, 3, 5, 7]");
    

多态:

  • 比如输入和输出必须是同类型,且是float或者int32,如下:

    REGISTER_OP("ZeroOut")
        .Attr("T: {float, int32} = DT_INT32")
        .Input("to_zero: T")
        .Output("zeroed: T");
    
  • 输出类型自动推断,如下:

    REGISTER_OP("StringToNumber")
         .Input("string_tensor: string")
         .Output("output: out_type")
         .Attr("out_type: {float, int32}");
    

Shape回调函数

.SetShapeFn 用来配置shape的回调函数,也就是输出的shape与输入之间的算法关系。比如ConcatShape函数如下:

Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
  return ConcatShapeHelper(c, 1 /* start_value_index */,
                           1 + num_inputs_to_concat /* end_value_index */,
                           0 /* dim_index */);
}

二、注册Kernel

// for cpu
#define REGISTER_CONCAT(type)                            \
  REGISTER_KERNEL_BUILDER(Name("Concat")                 \
                              .Device(DEVICE_CPU)        \
                              .TypeConstraint<type>("T") \
                              .HostMemory("concat_dim"), \
                               ConcatOp<CPUDevice, type>)
REGISTER_CONCAT(quint8);
REGISTER_CONCAT(qint8);
REGISTER_CONCAT(quint16);
REGISTER_CONCAT(qint16);
REGISTER_CONCAT(qint32);

// for GPU CUDA
#define REGISTER_GPU(type)                           \
  REGISTER_KERNEL_BUILDER(Name("Concat")             \
                          .Device(DEVICE_GPU)        \
                          .TypeConstraint<type>("T") \
                          .HostMemory("concat_dim"), \
                          ConcatOp<GPUDevice, type>)
REGISTER_GPU(bfloat16);

以上注册了各种类型的基于CPU的kernel,以及基于CUDE的float16的kernel。

  • Name("Concat"),对应OP的名称

  • .Device(DEVICE_GPU),对应执行OP的设备

  • .TypeContraint<bfloat16>("T"),对应OP的数据类型

  • .HostMemory("concat_dim"),将concat_dim标记为系统内存输入

  • ConcatOp<GPUDevice, bfloat16>,是一个类,继承OpKernel,定义类似如下:

    template <typename Device, typename T, AxisArgumentName AxisArgName>
    class ConcatBaseOp : public OpKernel {
     public:
      typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
          ConstMatrixVector;
    
      explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
    
      void Compute(OpKernelContext* c) override {
        const Tensor* concat_dim_tensor;
        const char* axis_attribute_name = "concat_dim";
        OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
        ...
      }
    };
    template <typename Device, typename T>
    using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>;
    

    这个类需要实现构造函数,和Compute方法

三、注册新设备

定义设备名称

namespace tensorflow {
const char* const DEVICE_GPU = "GPU";
}

创建设备类,继承LocalDevice

class BaseGPUDevice : public LocalDevice {
 public:
  BaseGPUDevice(const SessionOptions& options, const string& name,
                Bytes memory_limit, const DeviceLocality& locality,
                TfGpuId tf_gpu_id, const string& physical_device_desc,
                Allocator* gpu_allocator, Allocator* cpu_allocator,
                bool sync_every_op, int32 max_streams);
  ~BaseGPUDevice() override;
  void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
  Status Sync() override;
  void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
                    AsyncOpKernel::DoneCallback done) override;
  Status FillContextMap(const Graph* graph,
                        DeviceContextMap* device_context_map) override;
  ......
 private:
  std::vector<GPUDeviceContext*> device_contexts_;
};

class GPUDevice : public BaseGPUDevice {
 public:
  GPUDevice(const SessionOptions& options, const string& name,
            Bytes memory_limit, const DeviceLocality& locality,
            TfGpuId tf_gpu_id, const string& physical_device_desc,
            Allocator* gpu_allocator, Allocator* cpu_allocator);

  Allocator* GetAllocator(AllocatorAttributes attr) override;
};

创建设备工厂类,继承DeviceFactory

class BaseGPUDeviceFactory : public DeviceFactory {
 public:
  Status CreateDevices(const SessionOptions& options, const string& name_prefix,
                       std::vector<std::unique_ptr<Device>>* devices) override;
};

class GPUDeviceFactory : public BaseGPUDeviceFactory {
 private:
  std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
      const SessionOptions& options, const string& name, Bytes memory_limit,
      const DeviceLocality& locality, TfGpuId tf_gpu_id,
      const string& physical_device_desc, Allocator* gpu_allocator,
      Allocator* cpu_allocator) override {
    return absl::make_unique<GPUDevice>(options, name, memory_limit, locality,
                                        tf_gpu_id, physical_device_desc,
                                        gpu_allocator, cpu_allocator);
  }
};

REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory, 210);

创建上下文,继承DeviceContext

class GPUDeviceContext : public DeviceContext {
 public:
  // Does not take ownership of streams.
  GPUDeviceContext(int stream_id, se::Stream* stream,
                   se::Stream* host_to_device_stream,
                   se::Stream* device_to_host_stream,
                   gtl::InlinedVector<se::Stream*, 4> device_to_device_stream);
  ~GPUDeviceContext() override {}
  void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                             Tensor* device_tensor,
                             StatusCallback done) const override;

  void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
                             Device* device, Tensor* cpu_tensor,
                             StatusCallback done) override;

  void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
                              Tensor* output_tensor,
                              StatusCallback done) const override;
  ......
};

该context可以在OpKernelContext中得到,如下:

void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
                                  OpKernelContext* context) {
  GPUDeviceContext* gpu_device_context = device_contexts_[0];
  if (context->op_device_context() != nullptr) {
    gpu_device_context =
        static_cast<GPUDeviceContext*>(context->op_device_context());
  }
  ......
}

通常在LocalDevice::FillContextMap中填入,如下:

Status BaseGPUDevice::FillContextMap(const Graph* graph,
                                     DeviceContextMap* device_context_map) {
  ......
  for (Node* n : graph->nodes()) {
    auto mapped_stream = node_to_stream_id[n->id()];
    CHECK_LE(mapped_stream, num_streams);
    auto ctx = device_contexts_[mapped_stream];
    VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()]
            << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id()
            << " " << n->type_string() << " " << n->name();
    ctx->Ref();
    (*device_context_map)[n->id()] = ctx;
  }
  ......
}

评论