【CANN全新升级】CANN创新MLAPO算子,DeepSeek模型推理效率倍增

news/2025/6/19 17:59:05

MoE模型中的MLA架构

DeepSeek系列模型凭借其创新性的MLA(Multi-Head Latent Attention)架构,替代了传统的MHA(Multi Head Attention),显著降低了推理时的KV Cache开销,大幅提升了推理效率,使其能够更好地适应长上下文任务并提高推理准确性。MLA的成功应用不仅推动了DeepSeek系列模型自身的技术突破,其低成本和高效率的特点也为AI行业的普及和转型提供了重要支持。

图片

创新MLAPO算子,加速MLA前处理,提升DeepSeek系列模型性能

早在2024年5月DeepSeekV2发布时,昇腾CANN针对MLA架构进行了深度适配优化,经过2个月的开发,率先完成PagedAttention算子对DeepSeek系列模型的适配,实现了高效支持。随着DeepSeek系列模型的持续演进,昇腾也在不断探索推理预处理阶段中MLA的计算加速技术,通过VV融合(多个Vector算子融合),进一步提升MLA预处理阶段的计算效率。

MLA的预处理阶段,以DeepSeekV3-671B为例,其模型结构如下图所示:

图片

初始token的HiddenSize为7K,首先Q和KV会经由两个降维矩阵分别完成降维,降维后Q的HiddenSize为1536,KV为576。Q在经过RmsNorm后,进入Q升维矩阵做矩阵乘,升维后每个token变为128个Head,每个Head的HeadDim为192。

接下来,Q与KV会分别将每个Head切分成64+128和64+512,其中64均进入Rope,K的另一半进入RmsNorm,Q的另一半则进入K升维矩阵做矩阵乘。最后,Q和KV分别把各自的Head合并,输出结果给MLA使用。

在融合算子技术设计中,VV融合是最为高效快捷的融合开发方式。如上图红框所示,通过将MLA预处理两部分计算流分别融合成2个融合算子,可以实现融合算子性能直接翻倍。将这两个融合预处理小算子实现后,当前在DeepSeekV3整网中已取得了5%+的计算性能提升。

而为了针对DeepSeekV3模型场景进一步提升性能,昇腾CANN选择将前处理过程中的13个小算子直接融合成一个超级大算子MLAPO(MlaPreprocessOperation)。

图片

MLAPO算子的完整流程可以分为以下几个步骤:

1. RmsNorm/Preload并行
2. Q+KV的降维Matmul
3. Q的RmsNorm
4. Q的升维Matmul/KV Rope&RmsNorm&ReshapeandCache并行
5. K的升维Matmul/Q Rope并行

在计算时,通过对Vector和Cube计算单元的并行处理及流水优化,基本可以将用时较短的Vector耗时完全掩盖,进一步缩短MLA前处理的时延。实现MLA预处理算子MLAPO融合后,小算子的头开销和下发开销基本可以消除。这种大融合算子能够在VV融合的基础上,实现算子性能的再次翻倍。当前在大参数DeepSeekV3模型的量化场景下,MLAPO算子的实现将计算耗时从109us缩减为45us,带来整网性能提升20%+。

图片

DeepSeekV3火爆全球的同时,针对DeepSeek系列模型的计算优化思路也在不断探索泛化中,从小融合到大融合,多流水并行以及未来更高自由度的量化方式,昇腾也将持续探索更多可能,以工程创新释放更强算力。

MLAPO算子使能指南

以上优化特性已在昇腾CANN最新版本中实现,CANN包安装过程可参考社区文档:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha001/softwareinst/instg/instg_0000.html?

Mode=PmIns&OS=Ubuntu&Software=cannToolKit

./Ascend-cann-toolkit_<version>_linux-<arch>.run --install

CANN包安装并通过环境变量使能后,可以通过调用MlaPreprocessOperation算子接口使能

./Ascend-cann-toolkit_<version>_linux-<arch>.run --install

CANN包安装并通过环境变量使能后,可以通过调用MlaPreprocessOperation算子接口使能MLAPO算子,参考示例见下。

int main(int argc, char **argv){    std::string dtypeStr;    int tokenNum = 4;    int headNum = 128;    aclDataType dtype = ACL_FLOAT16;    if (argc == 4) {        dtypeStr = argv[1];        tokenNum = std::stoi(argv[2]);        headNum = std::stoi(argv[3]);    }    if (dtypeStr == "bf16") {        dtype = ACL_BF16;    }    // 设置卡号、创建context、设置stream    atb::Context *context = nullptr;    void *stream = nullptr;    CHECK_STATUS(aclInit(nullptr));    CHECK_STATUS(aclrtSetDevice(DEVICE_ID));    CHECK_STATUS(atb::CreateContext(&context));    CHECK_STATUS(aclrtCreateStream(&stream));    context->SetExecuteStream(stream);    // 创建op    atb::Operation *mlaPreprocessOp = CreateMlaPreprocessOperation();    // 准备输入tensor    atb::VariantPack variantPack;    variantPack.inTensors = PrepareInTensor(context, stream, dtype, tokenNum, headNum);  // 放入输入tensor    // 准备输出tensor    atb::Tensor qOut0 = CreateTensor(ACL_INT8, aclFormat::ACL_FORMAT_ND, {tokenNum,headNum,512});    atb::Tensor &kvCacheOut0 = variantPack.inTensors.at(19);    atb::Tensor qOut1 = CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum,headNum,64});    atb::Tensor &kvCacheOut1 = variantPack.inTensors.at(20);    variantPack.outTensors = {qOut0, kvCacheOut0, qOut1, kvCacheOut1};  // 放入输出tensor    uint64_t workspaceSize = 0;    // 计算workspaceSize大小    CHECK_STATUS(mlaPreprocessOp->Setup(variantPack, workspaceSize, context));    uint8_t *workspacePtr = nullptr;    if (workspaceSize > 0) {        CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));    }    for (size_t i = 0; i < 10; i++){        std::cout << "tokenNum: " << tokenNum << " headNum: " << headNum << " loop: " << i << std::endl;        // mlaPreprocess执行        mlaPreprocessOp->Execute(variantPack, workspacePtr, workspaceSize, context);        CHECK_STATUS(aclrtSynchronizeStream(stream));  // 流同步,等待device侧任务计算完成    }    // 释放资源    for (atb::Tensor &inTensor : variantPack.inTensors) {        CHECK_STATUS(aclrtFree(inTensor.deviceData));        for (atb::Tensor &outTensor : variantPack.outTensors) {            if (outTensor.deviceData == inTensor.deviceData) {                outTensor.deviceData = nullptr;            }        }        inTensor.deviceData = nullptr;    }    for (atb::Tensor &outTensor : variantPack.outTensors) {        if (outTensor.deviceData == nullptr) continue;        CHECK_STATUS(aclrtFree(outTensor.deviceData));    }    if (workspaceSize > 0) {        CHECK_STATUS(aclrtFree(workspacePtr));    }    CHECK_STATUS(atb::DestroyOperation(mlaPreprocessOp));  // operation,对象概念,先释放    CHECK_STATUS(aclrtDestroyStream(stream));    CHECK_STATUS(DestroyContext(context));  // context,全局资源,后释放    CHECK_STATUS(aclFinalize());    std::cout << "MlaPreprocess demo success!" << std::endl;    return 0;}
更多学习内容,可参考ATB算子代码开源仓:
https://gitee.com/ascend/ascend-transformer-boost


https://dhexx.cn/news/show-5537652.html

相关文章

SDC命令详解:使用set_wire_load_model命令进行约束

相关阅读 SDC命令详解https://blog.csdn.net/weixin_45791458/category_12931432.html?spm1001.2014.3001.5482 目录 指定线负载模型名 指定搜索库 指定最大、最小条件 指定对象列表 set_wire_load_model命令用于显式指定一个线负载模型&#xff08;设置了对象的wire_loa…

SQL进阶之旅 Day 22:批处理与游标优化

【SQL进阶之旅 Day 22】批处理与游标优化 文章简述&#xff08;300字左右&#xff09; 在数据库开发中&#xff0c;面对大量数据的处理任务时&#xff0c;单条SQL语句往往无法满足性能需求。本篇文章聚焦“批处理与游标优化”&#xff0c;深入探讨如何通过批量操作和游标技术提…

pymilvus

一.pymilvus介绍 &#x1f680; pymilvus 是什么&#xff1f; pymilvus 是连接和操作 Milvus 向量数据库的 Python SDK&#xff0c;用于处理大规模向量数据的存储、索引和搜索。 &#x1f3d7;️ Milvus 向量数据库 什么是 Milvus&#xff1f; &#x1f50d; 专业向量数据…

C/C++ 面试复习笔记(5)

1.用户态和内核态切换的开销来自哪里&#xff1f;如何减少这种开销&#xff1f; 主要开销&#xff1a; 上下文保存与恢复&#xff1a;需保存/恢复寄存器、堆栈等状态&#xff08;约数百CPU周期&#xff09;。 CPU 模式切换&#xff1a;从用户态到内核态的权限检查及模式切换…

CppCon 2015 学习:Time Programming Fundamentals

Civil Time 公历时间 特点&#xff1a; 共 6 个字段&#xff1a; Year&#xff08;年&#xff09;Month&#xff08;月&#xff09;Day&#xff08;日&#xff09;Hour&#xff08;小时&#xff09;Minute&#xff08;分钟&#xff09;Second&#xff08;秒&#xff09; 表示…

对比一下blender快捷键:p和alt+p

在 Blender 中&#xff0c;P 和 Alt P 虽然看起来相似&#xff0c;但它们作用在不同的上下文&#xff08;Mode&#xff09;下&#xff0c;并完成完全不同的操作&#xff1a; 何时使用哪一个&#xff1f; 想要把模型的一部分从当前网格里拆分出来**&#xff0c;就进入 Edit Mod…

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…

打开GitHub网站因为网络原因导致加载失败问题解决方案

Date: 2025.06.09 20:34:22 author: lijianzhan 在Windows系统中&#xff0c;打开GitHub网站因为网络原因导致加载失败问题解决方案 打开Windows系统下方搜索框&#xff0c;搜索Microsoft Store&#xff0c;并且双击打开 在应用里面搜索Watt Toolkit&#xff0c;并下载安装 …

Linux文件管理和输入输出重定向

文件管理 Bash执行命令 passwd passwd普通用户修改密码 passwd robinkoolroot用户管理账户密码 passwd -d robinkoolroot用户删除普通用户密码 file file /bin/filecat cat option 文件 cat -A /etc/hosts #-A选项等于-VETcat /etc/hosts /etc/fstab一次性查看多个文件…

Linux线程互斥与竞态条件解析

Linux线程互斥及相关概念解析 1. 临界资源&#xff08;Critical Resource&#xff09; 定义&#xff1a;被多个线程共享的资源&#xff08;如变量、文件、内存区域等&#xff09;&#xff0c;需通过互斥访问确保数据一致性。特点&#xff1a; 共享性&#xff1a;多个线程可能…