赞
踩
Ascend C的算子实现主要包含两个部分:
● Host侧Tiling实现
由于NPU中AI Core内部存储无法完全容纳算子输入输出的所有数据,需要每次搬
运一部分输入数据进行计算然后搬出,再搬运下一部分输入数据进行计算,这个
过程就称之为Tiling。切分数据的算法称为Tiling算法或者Tiling策略。根据算子的
shape等信息来确定数据切分算法相关参数(比如每次搬运的块大小,以及总共循
环多少次)的计算程序,称之为Tiling实现,也叫Tiling函数(Tiling
Function)。由于Tiling实现中完成的均为标量计算,AI Core并不擅长,所以我们
将其独立出来放在Host侧CPU上执行。
● Device侧Kernel实现
Kernel实现即算子核函数实现,在Kernel函数内部通过解析Host侧传入的Tiling结
构体获取Tiling信息,根据Tiling信息控制数据搬入搬出Local Memory的流程;通
过调用计算、数据搬运、内存管理、任务同步API,实现算子逻辑。其核心逻辑基
本上都为计算密集型任务,需要在NPU上执行。
#include <torch/extension.h> #include "acl/acl.h" #include <vector> // Ascend forward declarations std::vector<torch::Tensor> Kattention_ascend_forward( torch::Tensor input, torch::Tensor Kernel_Full_4DTensor, torch::Tensor output, int step); std::vector<torch::Tensor> Kattention_ascend_backward( torch::Tensor grad_output, torch::Tensor input, torch::Tensor Kernel_Full_4DTensor, int step); // C++ interface #define CHECK_ASCEND(x) TORCH_CHECK(x.device().is_npu(), #x " must be an Ascend tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_ASCEND(x); CHECK_CONTIGUOUS(x) std::vector<torch::Tensor> Kattention_forward( torch::Tensor input, torch::Tensor Kernel_Full_4DTensor, torch::Tensor output, int step ) { CHECK_INPUT(input); CHECK_INPUT(Kernel_Full_4DTensor); CHECK_INPUT(output); TORCH_CHECK(step > 0, "step " + std::to_string(step) + " must be positive"); return Kattention_ascend_forward(input, Kernel_Full_4DTensor, output, step); } std::vector<torch::Tensor> Kattention_backward( torch::Tensor grad_output, torch::Tensor input, torch::Tensor Kernel_Full_4DTensor, int step) { CHECK_INPUT(grad_output); CHECK_INPUT(input); CHECK_INPUT(Kernel_Full_4DTensor); TORCH_CHECK(step > 0, "step " + std::to_string(step) + " must be positive"); return Kattention_ascend_backward( grad_output, input, Kernel_Full_4DTensor, step); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &Kattention_forward, "Kattention forward (Ascend)"); m.def("backward", &Kattention_backward, "Kattention backward (Ascend)"); }
#include <torch/extension.h> #include "acl/acl.h" #include <vector> #include <iostream> // for debug #include <chrono> // for time record #include <ctime> // for performance test #define T 1024 // threads #define B 1073741824 // 65535 // max number of each dim in block #define FORWARD_NAME "kattention_forward_ascend" #define BACKWARD_INPUT_NAME "kattention_backward_grad_input_ascend" #define BACKWARD_KERNEL_NAME "kattention_backward_grad_kernel_ascend" #include <cmath> // function header template <typename scalar_t> void kattention_forward_kernel ( torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> output, int *parser, int kernel_len, int c_size, int step ); template <typename scalar_t> void kattention_backward_grad_kernel( torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_kernel, int b_size, int iseq_pos, int *parser, int step ); template <typename scalar_t> void kattention_backward_grad_input( torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> grad_input, size_t b_size, size_t c_size, size_t iseq_len, size_t kernel_len, size_t kernel_num, int step ); // utils functions int get_global_index() { // Replace CUDA-specific indexing logic with Ascend-compatible logic if needed return 0; // Placeholder implementation } template <typename scalar_t> void kattention_forward_kernel ( torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> output, int *parser, int kernel_len, int c_size, int step) { // Replace with Ascend-compatible kernel logic } template <typename scalar_t> void kattention_backward_grad_kernel( torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_kernel, int b_size, int iseq_pos, int *parser, int step) { // Replace with Ascend-compatible kernel logic } template <typename scalar_t> void kattention_backward_grad_input( torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> grad_output, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> input, torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> kernel, torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> grad_input, size_t b_size, size_t c_size, size_t iseq_len, size_t kernel_len, size_t kernel_num, int step) { // Replace with Ascend-compatible kernel logic } void launch_kattention_forward_kernel(torch::Tensor input, torch::Tensor kernel, torch::Tensor output) { AT_DISPATCH_FLOATING_TYPES(input.type(), "kattention_forward", ([&] { kattention_forward_kernel<scalar_t>( input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(), kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), nullptr, kernel.size(1), input.size(1), 1); })); } void launch_kattention_backward_kernel(torch::Tensor grad_output, torch::Tensor input, torch::Tensor kernel, torch::Tensor grad_input, torch::Tensor grad_kernel) { AT_DISPATCH_FLOATING_TYPES(input.type(), "kattention_backward", ([&] { kattention_backward_grad_kernel<scalar_t>( grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(), grad_kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), grad_output.size(0), 0, nullptr, 1); kattention_backward_grad_input<scalar_t>( grad_output.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(), kernel.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), grad_input.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(), grad_output.size(0), input.size(1), input.size(2), kernel.size(1), kernel.size(0), 1, 1); })); }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。