当前位置:   article > 正文

昇腾Ascend C算子开发_ascend算子开发

ascend算子开发

Ascend C的算子实现主要包含两个部分:
● Host侧Tiling实现
由于NPU中AI Core内部存储无法完全容纳算子输入输出的所有数据,需要每次搬
运一部分输入数据进行计算然后搬出,再搬运下一部分输入数据进行计算,这个
过程就称之为Tiling。切分数据的算法称为Tiling算法或者Tiling策略。根据算子的
shape等信息来确定数据切分算法相关参数(比如每次搬运的块大小,以及总共循
环多少次)的计算程序,称之为Tiling实现,也叫Tiling函数(Tiling
Function)。由于Tiling实现中完成的均为标量计算,AI Core并不擅长,所以我们
将其独立出来放在Host侧CPU上执行。
● Device侧Kernel实现
Kernel实现即算子核函数实现,在Kernel函数内部通过解析Host侧传入的Tiling结
构体获取Tiling信息,根据Tiling信息控制数据搬入搬出Local Memory的流程;通
过调用计算、数据搬运、内存管理、任务同步API,实现算子逻辑。其核心逻辑基
本上都为计算密集型任务,需要在NPU上执行。

#include <torch/extension.h>
#include "acl/acl.h"
#include <vector>

// Ascend forward declarations

std::vector<torch::Tensor> Kattention_ascend_forward(
    torch::Tensor input,
    torch::Tensor Kernel_Full_4DTensor,
    torch::Tensor output,
    int step);

std::vector<torch::Tensor> Kattention_ascend_backward(
    torch::Tensor grad_output,
    torch::Tensor input,
    torch::Tensor Kernel_Full_4DTensor,
    int step);

// C++ interface

#define CHECK_ASCEND(x) TORCH_CHECK(x.device().is_npu(), #x " must be an Ascend tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_ASCEND(x); CHECK_CONTIGUOUS(x)

std::vector<torch::Tensor> Kattention_forward(
    torch::Tensor input,
    torch::Tensor Kernel_Full_4DTensor,
    torch::Tensor output,
    int step
) {
  CHECK_INPUT(input);
  CHECK_INPUT(Kernel_Full_4DTensor);
  CHECK_INPUT(output);

  TORCH_CHECK(step > 0, "step " + std::to_string(step) + " must be positive");

  return Kattention_ascend_forward(input, Kernel_Full_4DTensor, output, step);
}

std::vector<torch::Tensor> Kattention_backward(
    torch::Tensor grad_output,
    torch::Tensor input,
    torch::Tensor Kernel_Full_4DTensor,
    int step) {
  CHECK_INPUT(grad_output);
  CHECK_INPUT(input);
  CHECK_INPUT(Kernel_Full_4DTensor);

  TORCH_CHECK(step > 0, "step " + std::to_string(step) + " must be positive");
  return Kattention_ascend_backward(
      grad_output,
      input,
      Kernel_Full_4DTensor,
      step);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &Kattention_forward, "Kattention forward (Ascend)");
  m.def("backward", &Kattention_backward, "Kattention backward (Ascend)");
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
#include <torch/extension.h>
#include "acl/acl.h"
#include <vector>
#include <iostream> // for debug
#include <chrono> // for time record
#include <ctime> // for performance test
#define T 1024  // threads
#define B 1073741824 // 65535 // max number of each dim in block
#define FORWARD_NAME "kattention_forward_ascend"
#define BACKWARD_INPUT_NAME "kattention_backward_grad_input_ascend"
#define BACKWARD_KERNEL_NAME "kattention_backward_grad_kernel_ascend"
#include <cmath>

// function header
template <typename scalar_t>
void kattention_forward_kernel (
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> output,
    int *parser,
    int kernel_len,
    int c_size,
    int step
);

template <typename scalar_t>
void kattention_backward_grad_kernel(
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_kernel,
    int b_size, int iseq_pos, int *parser, int step
);

template <typename scalar_t>
void kattention_backward_grad_input(
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> grad_input,
    size_t b_size,
    size_t c_size,
    size_t iseq_len,
    size_t kernel_len,
    size_t kernel_num,
    int step
);

// utils functions
int get_global_index() {
  // Replace CUDA-specific indexing logic with Ascend-compatible logic if needed
  return 0; // Placeholder implementation
}

template <typename scalar_t>
void kattention_forward_kernel (
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> output,
    int *parser,
    int kernel_len,
    int c_size,
    int step) {
  // Replace with Ascend-compatible kernel logic
}

template <typename scalar_t>
void kattention_backward_grad_kernel(
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_kernel,
    int b_size, int iseq_pos, int *parser, int step) {
  // Replace with Ascend-compatible kernel logic
}

template <typename scalar_t>
void kattention_backward_grad_input(
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> grad_input,
    size_t b_size,
    size_t c_size,
    size_t iseq_len,
    size_t kernel_len,
    size_t kernel_num,
    int step) {
  // Replace with Ascend-compatible kernel logic
}

void launch_kattention_forward_kernel(torch::Tensor input, torch::Tensor kernel, torch::Tensor output) {
  AT_DISPATCH_FLOATING_TYPES(input.type(), "kattention_forward", ([&] {
    kattention_forward_kernel<scalar_t>(
        input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        nullptr, kernel.size(1), input.size(1), 1);
  }));
}

void launch_kattention_backward_kernel(torch::Tensor grad_output, torch::Tensor input, torch::Tensor kernel, torch::Tensor grad_input, torch::Tensor grad_kernel) {
  AT_DISPATCH_FLOATING_TYPES(input.type(), "kattention_backward", ([&] {
    kattention_backward_grad_kernel<scalar_t>(
        grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        grad_kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        grad_output.size(0), 0, nullptr, 1);
    kattention_backward_grad_input<scalar_t>(
        grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
        grad_input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
        grad_output.size(0), input.size(1), input.size(2), kernel.size(1), kernel.size(0), 1, 1);
  }));
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/962150
推荐阅读
相关标签
  

闽ICP备14008679号