注意力机制(四)(多头注意力机制)

​🌈 个人主页十二月的猫-CSDN博客
🔥 系列专栏 🏀《深度学习基础知识》

      相关专栏: 《机器学习基础知识》

                         🏐《机器学习项目实战》
                         🥎《深度学习项目实战(pytorch)》

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

回顾

注意力机制与RNN、LSTM的对比 

总论

RNN

LSTM

注意力机制

多头注意力机制 

核心思想介绍

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

3、将多组输出拼接后乘以矩阵W0降低维度

 多头流程图

试图解释

Pytorch代码实现

代码解释 

两种实现思想对比

总结


回顾

在上一篇注意力机制(三)(不同注意力机制对比)-CSDN博客,重点讲了针对QKV来源不同制造的注意力机制的一些变体,包括交叉注意力、自注意力等。这里再对注意力机制理解中的核心要点进行归纳整理

1、注意力机制规定的是对QKV的处理,并不指定QKV的来源

2、注意力机制和RNN、LSTM本身是同级的,都可以用于独立解决时间序列的问题

3、注意力机制相比于RNN、LSTM来说能够学习句子内部的句法特征和语义特征

4、注意力机制能够进行并行运算,解决了RNN、LSTM串行运算的问题。但是也存在注意力机制运算量非常大的问题

5、注意力机制解决了RNN以及LSTM中有由于梯度消失梯度爆炸造成的长期依赖问题

注意力机制与RNN、LSTM的对比 

总论

RNN(递归神经网络):它能够处理序列数据,因为它具有循环的结构,可以保留之前的信息。这种结构模拟了人类思考时的持续性,即我们对当前事物的理解是建立在之前信息的基础上的。然而,传统的RNN在处理长距离依赖时会遇到困难,因为随着时间的推移,梯度可能会消失或爆炸,导致网络难以学习和记住长期的信息。

LSTM(长短时记忆网络):它是RNN的一种改进型,专门设计来解决长期依赖问题。LSTM通过引入三个门(遗忘门、输入门、输出门)的结构来控制信息的流动,从而有效地保存长期的信息。这种结构使得LSTM能够更好地查询较长一段时间内的信息,因为它可以通过“门”结构来决定哪些信息需要被记住或遗忘。

注意力机制:它是一种允许模型在处理序列时动态地关注不同部分信息的方法。注意力机制可以与RNN和LSTM结合使用,也可以独立使用。它的优点是可以提高模型对序列中重要信息的敏感度,从而提高模型的性能。注意力机制通过计算注意力权重来分配对不同时间步的关注,这使得模型在每个时间步都能够考虑到整个序列的信息。

RNN

        由于RNN在反向传播过程中涉及到矩阵的连乘,这可能导致梯度指数级减小(梯度消失)或增大(梯度爆炸)。梯度消失会使得网络难以学习和传递长期的依赖关系,因为梯度变得太小,以至于反向传播时权重更新几乎停滞(遗忘,不再考虑这个信息)。相反,梯度爆炸会导致梯度过大,使得网络训练不稳定甚至发散

LSTM

        LSTM就是应对RNN长期依赖问题而产生的。LSTM利用遗忘门、输出输入门以及特殊结构——记忆单元使得长期的信息能够被模型记忆并传递下来从而一定程度上解决了RNN因梯度消失、爆炸出现的长期依赖问题

        但是LSTM在序列很长时仍然会出现梯度消失梯度爆炸的问题,并且LSTM并不能实现并行运算。同时,LSTM在应对并不能很好的捕捉一个序列后面的信息(因为因果卷积的问题),所有导致LSTM并不能很好的理解句子的语法和句法特征

注意力机制

        注意力机制相比于前两个模型的特点在于:1、能够进行并行运算;2、能够完美解决长期依赖问题;3、由于对句子有全面的相似度计算,能够更好理解句子的句法特征和语义特征

上图体现了注意力机制对句子句法特征的理解

 上图体现了注意力机制对句子语法特征的理解

多头注意力机制 

核心思想介绍

前文,我们介绍了自注意力机制:自注意力的QKV是同源的。同源的好处就是更容易发现序列内部的信息,但是也存在一些可以改进的地方。

例如:对于一个待分析的序列矩阵,它存在许多方面的特征。此时我们要用一个参数矩阵Wq、Wk去分析并学习出序列中的这么多特征。由于参数矩阵的维度是有限的,所以一次性学习多特征的信息必然会造成信息学习的模糊性,所以作者又提出了多头注意力机制

下图为多头注意力机制模型图:

多头注意力机制在以下两个方面提升注意力机制的性能:

  • 它为注意力机制提供了多个投射子空间的可能。它利用多头机制提供了多组的参数矩阵,每组参数矩阵能够通过线性变化将词向量X放入不同的向量空间,从而反映出词向量X的不同特征。多组参数矩阵映射不同向量空间,再将不同向量空间的结果进行整合,如此比单组参数矩阵表示向量特征更加准确
  • 它拓展了模型关注不同位置的能力。多头机制不仅在词向量维度上能够挖掘更多词向量的特征,在词数上也能够同时关注更多的词信息

通过多头注意力机制,我们会为每一个都单独配置QKV的权重矩阵,从而在模型训练中产生不同的QKV矩阵。对于每个头来说,其核心思想和训练方式和自注意力机制是相似的

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

多头注意力机制的每一头的处理方式和自注意力机制是相同的,也就是利用输入向量X分别乘上从而得到对应的q、k、v。然后每个头的参数矩阵会在不同初始值的情况下,各自训练自己的参数,最后分别生成不同的Q、K、V(初始值不同最后学习的结果也不同,可以参考梯度下降中的局部最优理解)

下图举了两个头的注意力机制的示意图:

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

 将多组训练得到的参数与V经过mulmat融合得到多组Z。此时的上下文向量Z不仅包含原始的信息也包括对文本上下文注意力的信息,并且这个注意力信息利用多头考虑了多个维度的特征信息

3、将多组输出拼接后乘以矩阵W0降低维度

通过第二步得到多组的Z,为了全面的利用所有Z的信息,我们将Zconcat(拼接在一起),这将得到一个非常长的向量矩阵。由于输出结果Z和输入结果X的向量矩阵维度应该要相同,所以我们利用矩阵W0对结果进行变化降维,得到最终结果

 多头流程图

借用其他大佬翻译的流程图版本

试图解释

深度学习模型都是黑盒子模型,所以没有一个很严谨的解释。这里,我也只能给一个非常模糊且不透彻的解释,希望能帮助大家的理解

假设我们有一句话“The animal didn’t cross the street because it was too tired”

  • 图中绿线和橙线表示两个不同的头
  • 可以看到绿线重点关注的是tired单词,橙线重点关注animal单词。这表明it在高维度上某些特征和animal相似,另外一些特征和tired相似
  • 经过注意力机制调整,it中将包含tire和animal两个单词的信息。模型在分析时,对于it单词也将重点关心tired和animal两个单词

说明上面这句翻译的英语句子中it和animal以及tired的关系度相对较大。我们自己分析这个句子时结果也是这样,因为it就指代animal,同时全句子的重点也就是在animal很tired上

一旦注意力头更多之后,整个模型的解释会变得更难,因此我们不再展开

Pytorch代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
import math


def self_attention(query, key, value, dropout=None, mask=None):
    """
    前置参数(自注意力):
    输入矩阵X形状为(batch_size, seq_len, d_model)
    Q = torch.matmul(X, W_Q)
    K = torch.matmul(X, W_K)
    V = torch.matmul(X, W_V)
    自注意力计算:
    :param query: Q
    :param key: K
    :param value: V
    :param dropout: drop比率
    :param mask: 是否mask
    :return: 经自注意力机制计算后的值
    """
    # d_k指降维后待查询的词向量维度
    d_k = query.size(-1)  # 防止softmax未来求梯度消失时的d_k
    # Q,K相似度计算公式:\frac{Q^TK}{\sqrt{d_k}},score的维度就是词数*词数(每两个词语间的相似度)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # Q,K相似度计算
    # 判断是否要mask,注:mask的操作在QK之后,softmax之前
    if mask is not None:
        """
        scores.masked_fill默认是按照传入的mask条件中为1的元素所在的索引,
        在scores中相同的的索引处替换为value,替换值为-1e9,即-(10^9)
        """
        mask.cuda()  # 将mask放入GPU运算
        #score此时就是一个torch.tensor对象,可以直接用masked_fill函数
        scores = scores.masked_fill(mask == 0, -1e9)

    self_attn_softmax = F.softmax(scores, dim=-1)  # 进行softmax
    # 判断是否要对相似概率分布进行dropout操作
    if dropout is not None:
        self_attn_softmax = dropout(self_attn_softmax)

    # 注意:返回经自注意力计算后的值,以及进行softmax后的相似度(即相似概率分布)
    # 词数*词数 * 词数*词向量维度 = 全新的 词数*词向量维度。如果是多头的:头数*词数*词向量维度
    return torch.matmul(self_attn_softmax, value), self_attn_softmax


class MultiHeadAttention(nn.Module):  # 继承nn.module类
    """
    多头注意力计算
    """

    def __init__(self, head, d_model, dropout=0.1):
        """
        :param head: 头数
        :param d_model: 词向量的维度,必须是head的整数倍
        :param dropout: drop比率
        """
        super(MultiHeadAttention, self).__init__() # 先初始化父类的属性(子类要用父类的属性和方法)
        assert (d_model % head == 0)  # 确保词向量维度是头数的整数倍
        self.d_k = d_model // head  # 被拆分为多头后的某一头词向量的维度,和自注意力降维后维度是相同的
        self.head = head
        self.d_model = d_model

        """
        由于多头注意力机制是针对多组Q、K、V,因此有了下面这四行代码,具体作用是,
        针对未来每一次输入的Q、K、V,都给予参数进行构建
        其中linear_out是针对多头汇总时给予的参数
        """
        self.linear_query = nn.Linear(d_model, d_model)  # 进行一个普通的全连接层变化,但不修改维度
        self.linear_key = nn.Linear(d_model, d_model)
        self.linear_value = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.attn_softmax = None  # attn_softmax是能量分数, 即句子中某一个词与所有词的相关性分数, softmax(QK^T)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            """
            多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k]
            再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第二维(head维)添加一维,与后面的self_attention计算维度一样
            具体点将,就是:
            因为mask的作用是未来传入self_attention这个函数的时候,作为masked_fill需要mask哪些信息的依据
            针对多head的数据,Q、K、V的形状维度中,只有head是通过view计算出来的,是多余的,为了保证mask和
            view变换之后的Q、K、V的形状一直,mask就得在head这个维度添加一个维度出来,进而做到对正确信息的mask
            """
            mask = mask.unsqueeze(1)

        n_batch = query.size(0)  # batch_size大小,假设query的维度是:[10, 32, 512],其中10是batch_size的大小

        """
        下列三行代码都在做类似的事情,对Q、K、V三个矩阵做处理
        其中view函数是对Linear层的输出做一个形状的重构,其中-1是自适应(自主计算)
        这里本质用的是对词向量的维度进行了拆分,将不同维度放入不同自注意力模型去训练
        transopose(1,2)是对前形状的两个维度(索引从0开始)做一个交换,这里处理后的quary(batch,head,词数,词维度)
        假设Linear成的输出维度是:[10, 32, 512],其中10是batch_size的大小
        注:这里解释了为什么d_model // head == d_k,如若不是,则view函数做形状重构的时候会出现异常
        """
        query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8
        key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)   # [b, 8, 32, 64],head=8
        value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8

        # x是通过自注意力机制计算出来的值, self.attn_softmax是相似概率分布
        x, self.attn_softmax = self_attention(query, key, value, dropout=self.dropout, mask=mask)

        """
        首先,交换“head数”和“词数”,这两个维度,结果为(batch, 词数, head数, d_model/head数)
        对应代码为:`x.transpose(1, 2).contiguous()`
        然后将“head数”和“d_model/head数”这两个维度合并,结果为(batch, 词数,d_model)
        contiguous()是重新开辟一块内存后存储x,然后才可以使用.view方法,否则直接使用.view方法会报错
        """
        x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k)
        return self.linear_out(x)

代码解释 

代码带有注释细节方面就不在这里解释了,这里重点来看看整体的思想

代码实现的思想和上面说的多头注意力机制存在一点不同。两者核心的思想是相同的,但是具体实现的方式存在区别

两种实现思想对比

第一种思路:上文说QKV是同时将X映射到不同的维度好几份,将X中的词向量维度通过映射在QKV中得到降维目的,降低计算难度。并且让不同份的QKV参数矩阵能够关注X不同的特征,从而其让X的特征值的寻找能够更加细致彻底(如下图)

这里介绍另外一种思路(代码实现采用的,两者核心思路是一样的) :

第二种思路:这里我们不再将X利用参数矩阵直接映射成的QKV,从而实现降维以及多头的效果。而是将词向量维度分为头数*新词向量维度(即 :词向量维度=头数*新词向量维度),此时也就实现了降维、多头两个效果

如此分离之后,有一种更好的理解方式:将词向量特征分为不同的组,将不同的组分给不同的注意力机制模型学习,从而让模型专注学习每个组对应的词向量特征,从而使得模型学习效果更好

(下图将词向量在特征维度上分为四头)

在代码具体实现时,考虑到两者最终的效果差不多,但是上面的这个算法实现起来效率会差很多(参数计算量更大了),所以我们采用的策略会是第二种思路

总结

撰写文章不易,如果文章能帮助到大家,大家可以点点赞、收收藏呀~

十二月的猫在这里祝大家学业有成、事业顺利、情到财来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/583301.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python | Leetcode Python题解之第55题跳跃游戏

题目&#xff1a; 题解&#xff1a; class Solution:def canJump(self, nums: List[int]) -> bool:n, rightmost len(nums), 0for i in range(n):if i < rightmost:rightmost max(rightmost, i nums[i])if rightmost > n - 1:return Truereturn False

闲话 Asp.Net Core 数据校验(三)EF Core 集成 FluentValidation 校验数据例子

前言 一个在实际应用中 EF Core 集成 FluentValidation 进行数据校验的例子。 Step By Step 步骤 创建一个 Asp.Net Core WebApi 项目 引用以下 Nuget 包 FluentValidation.AspNetCore Microsoft.AspNetCore.Identity.EntityFrameworkCore Microsoft.EntityFrameworkCore.Re…

Unity 合并子物体获得简化Mesh

合并子物体获得简化Mesh &#x1f959;环境&#x1f96a;Demo &#x1f959;环境 PackageManager安装Editor Coroutines 导入插件&#x1f448; &#x1f96a;Demo 生成参数微调&#xff1a;Assets/EasyColliderEditor/Scripts/VHACDSettings/VHACDSettings.asset

TDengine高可用架构之TDengine+Keepalived

之前在《TDengine高可用探讨》提到过&#xff0c;TDengine通过多副本和多节点能够保证数据库集群的高可用。单对于应用端来说&#xff0c;如果使用原生连接方式&#xff08;taosc&#xff09;还好&#xff0c;当一个节点下线&#xff0c;应用不会受到影响&#xff1b;但如果使用…

Kafka 3.x.x 入门到精通(03)——Kafka基础生产消息

Kafka 3.x.x 入门到精通&#xff08;03&#xff09;——对标尚硅谷Kafka教程 2. Kafka基础2.1 集群部署2.2 集群启动2.3 创建主题2.4 生产消息2.4.1 生产消息的基本步骤2.4.2 生产消息的基本代码2.4.3 发送消息2.4.3.1 拦截器2.4.3.1.1 增加拦截器类2.4.3.1.2 配置拦截器 2.4.3…

Mysql事务—隔离级别—脏读、不可重复读、幻读-遥遥领先版

事务的基本概念 事务就是一组原子性的操作&#xff0c;这些操作要么全部发生&#xff0c;要么全部不发生。事务把数据库从一种一致性状态转换成另一种一致性状态。 事务最经典也经常被拿出来说例子就是转账了。 假如小明要给小红转账1000元&#xff0c;这个转账会涉及到两个…

Linux进程——进程的概念(PCB的理解)

前言&#xff1a;在了解完冯诺依曼体系结构和操作系统之后&#xff0c;我们进入了Linux的下一篇章Linux进程&#xff0c;但在学习Linux进程之前&#xff0c;一定要阅读理解上一篇内容&#xff0c;理解“先描述&#xff0c;再组织”才能更好的理解进程的含义。 Linux进程学习基…

【中级软件设计师】上午题12-软件工程(3):项目活动图、软件风险、软件评审、软件项目估算

【中级软件设计师】上午题12-软件工程&#xff08;3&#xff09; 1 软件项目估算1.1 COCOMO估算模型1.2 COCOMOⅡ模型 2 进度管理2.1 gantt甘特图2.2 pert图2.3 项目活动图2.3.1 画项目图 3 软件配置管理4 软件风险4.1 风险管理4.2 风险识别4.3 风险预测4.4 风险评估4.5 风险控…

二叉树遍历递归法迭代法实现

一.递归法实现二叉树遍历 前序遍历 创建一个节点类 属性是val,左节点&#xff0c;右节点 public class TreeNode { int val; TreeNode left; TreeNode right; TreeNode(int x) { val x; } } 前序遍历 class Solution {public List<Integer> preorderTraversa…

微服务启动慢,看我如何消灭这些憨憨怪!

Hello&#xff0c;我是大都督周瑜&#xff0c;最近在公司做微服务启动速度的优化&#xff0c;我们有些微服务启动要花5-6分钟&#xff08;就问你夸不夸张&#xff09;&#xff0c;直接导致打工人们有了更多的划水时间&#xff0c;领导表示不开心&#xff0c;要求我将微服务的启…

python监听html click教程

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 Python实现监听HTML点击事件 在Web开发中&#xff0c;经常需要在用户与页面交互时执行一些…

乐观锁悲观锁

视频&#xff1a;什么是乐观锁&#xff1f;什么是悲观锁&#xff1f;_哔哩哔哩_bilibili

如何在电脑桌面上显示每天的待办事项?

对于上班族来说&#xff0c;每天面临的任务繁杂&#xff0c;很容易遗漏或忘记某些重要事项。因此&#xff0c;在电脑桌面上直接显示每天的待办事项显得尤为重要。例如&#xff0c;当你忙于处理邮件或编写报告时&#xff0c;桌面的待办事项提醒能够让你一目了然地掌握接下来的工…

C语言进阶|链表经典OJ题

✈移除链表元素 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 方法一&#xff1a; 遍历链表找到所有等于val的节点&#xff0c;再执行删除操作删除这些节点。 方法二&#xff1a; …

Flask 数据库前后端交互案例-1

Flask 数据库前后端交互案例 目录结构templates目录base.htmlheader.htmlleft.html首页职员管理页面添加员工界面员工编辑页面员工详情界面 后台main.pyapp.pymodels.pyviews.py 数据库数据position.sqlperson.sqlpermission.sqldepartment.sql 目录结构 静态文件链接&#xff…

Linux工具篇 之 vim概念 操作 及基础指令讲解

学校不大 创造神话 讲桌两旁 陨落的王 临时抱佛脚 佛踹我一脚 书山有路勤为径 游戏玩的很起劲 想要计算机学的好&#xff0c;我的博客列表是个宝 –❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀-正文开始-❀–❀–❀–❀–❀–❀–❀–❀…

OceanBase开发者大会实录-杨传辉:携手开发者打造一体化数据库

本文来自2024 OceanBase开发者大会&#xff0c;OceanBase CTO 杨传辉的演讲实录—《携手开发者打造一体化数据库》。完整视频回看&#xff0c;请点击这里&#xff1e;> 各位 OceanBase 的开发者&#xff0c;大家上午好&#xff01;今天非常高兴能够在上海与大家再次相聚&…

Springboot+Vue项目-基于Java+MySQL的校园外卖服务系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

自动驾驶中的深度学习和计算机视觉

书籍&#xff1a;Applied Deep Learning and Computer Vision for Self-Driving Cars: Build autonomous vehicles using deep neural networks and behavior-cloning techniques 作者&#xff1a;Sumit Ranjan&#xff0c;Dr. S. Senthamilarasu 出版&#xff1a;Packt 书籍…

【GitHub】如何在github上提交PR(Pull Request) + 多个pr同时提交、互不干扰

【GitHub】如何在github上提交PR(Pull Request 写在最前面1. 准备工作1.1 注册 GitHub 账号1.2 了解 Git 基础1.3 找到一个项目 2. 创建你的 PR2.1 Fork 和克隆仓库2.2 创建一个新的分支2.3 进行更改2.4 推送更改到 GitHub2.5 创建 Pull Request 3. 优化你的 PR3.1 保持提交清晰…
最新文章