首页> 中国专利> 一种跨领域层次关系的知识蒸馏方法和系统

一种跨领域层次关系的知识蒸馏方法和系统

摘要

本发明公开了一种跨领域层次关系的知识蒸馏方法和系统,为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

著录项

  • 公开/公告号CN113849641A

    专利类型发明专利

  • 公开/公告日2021-12-28

    原文格式PDF

  • 申请/专利权人 中山大学;

    申请/专利号CN202111131585.7

  • 发明设计人 董晨鹤;沈颖;

    申请日2021-09-26

  • 分类号G06F16/35(20190101);G06K9/62(20060101);

  • 代理机构11227 北京集佳知识产权代理有限公司;

  • 代理人刘思言

  • 地址 510275 广东省广州市新港西路135号

  • 入库时间 2023-06-19 13:26:15

说明书

技术领域

本发明涉及知识蒸馏技术领域,尤其涉及一种跨领域层次关系的知识蒸馏方法和系统。

背景技术

语言是人类重要的沟通和表达方式。随着互联网的发展,当今时代的信息量不断增加,信息的增长速度已经远远超越了人类的理解速度。如果计算机能够处理、理解语言,不仅为海量信息的处理提供了可能,也有助于深化对语言能力和人类智能的认识。

自然语言处理,旨在将人的语言形式转化为机器可理解的、结构化的、完整的语义表示,目的是让计算机能够理解和生成人类语言。大型预训练语言模型在大量的自然语言处理任务上取得了显著的效果,例如机器翻译、文本摘要、对话生成等任务。然而,过大的模型尺寸和过低的推理时间阻碍了其实际应用的脚步,很难在资源受限的设备上进行部署。因而,涌现了许多针对预训练语言模型的压缩技术,例如量化、权重裁剪、知识蒸馏等技术。由于知识蒸馏即插即用的特性,其在实际中得到了广泛的应用。

知识蒸馏的目的在于将知识从更大尺寸的教师模型迁移到更小尺寸的学生模型中。传统的知识蒸馏局限于单领域知识蒸馏,然而,对于人类而言,常常迁移不同领域的相关知识,例如学过钢琴的人学小提琴比别人学得快,会骑自行车的人更容易学会骑摩托车。来自不同领域的文本数据在文本、句式术语上有显著的差异,但是自然语言又具备跨领域的共性知识,如词汇、句法等,这为跨领域的知识迁移提供了可能。因而,现有的知识蒸馏技术已经从传统的单领域知识蒸馏扩展到了跨领域知识蒸馏。然而,现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力。

发明内容

本发明实施例提供了一种跨领域层次关系的知识蒸馏方法和系统,用于解决现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

有鉴于此,本发明第一方面提供了一种跨领域层次关系的知识蒸馏方法方法,所述方法包括:

获取不同领域的训练样本;

对各领域的训练样本分别计算学生层的原型特征;

对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;

将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;

将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;

根据蒸馏损失函数对学生模型进行迭代训练。

可选地,在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。

可选地,学生层的原型特征计算公式为:

其中,h

可选地,还包括:

基于自注意力机制建立每个领域的参考原型特征;

将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;

将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。

可选地,参考原型特征为:

其中,

可选地,聚合原型特征为:

其中,

可选地,蒸馏损失函数为:

其中,r

本发明第二方面提供一种跨领域层次关系的知识蒸馏系统,所述系统包括:

训练样本获取模块,用于获取不同领域的训练样本;

原型特征生成模块,用于对各领域的训练样本分别计算学生层的原型特征;

领域关系网络构建模块,用于对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;

领域关系系数获取模块,用于将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;

蒸馏损失函数生成模块,用于将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;

模型训练模块,用于根据蒸馏损失函数对学生模型进行迭代训练。

可选地,领域关系网络构建模块建立的两层领域关系图网络的网络结构包括:

在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。

可选地,还包括:

参考原型特征生成模块,用于基于自注意力机制建立每个领域的参考原型特征;

对比聚合模块,用于将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;

领域关系系数更新模块,用于将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。

从以上技术方案可以看出,本发明实施例具有以下优点:

由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏方法中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

同时,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,进一步提升压缩语言模型的表达能力。

附图说明

图1为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的一个流程示意图;

图2为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的一个模型结构原理图;

图3为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的令一个流程示意图;

图4为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的另一个模型结构原理图;

图5为本发明实施例中提供的跨领域层次关系的知识蒸馏系统的结构示意图。

具体实施方式

为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。

基础的跨领域知识蒸馏方法中,对教师模型和学生模型的嵌入层、注意力矩阵、前馈网络层、预测概率分布进行联合蒸馏,并且运用多任务学习的训练策略进行跨领域知识蒸馏。具体地,将所有领域的嵌入层和前馈网络层的权重进行共享,但为不同的预测层赋予不同的权重。第d个领域的嵌入层损失

其中,MSE和CE分别代表均方差损失和交叉熵损失,E

第d个领域中的第m个学生层的注意力层损失

其中,h为注意力头数,

采用均匀策略对学生模型和教师模型间的模型层进行匹配。最后,总体的蒸馏损失函数可以表示为:

其中,D为总领域数,M是学生模型中的层数,γ为用来控制预测损失

上述基础跨领域知识蒸馏方法虽然可以将不同领域的学生模型进行蒸馏,但不同领域间的关系信息被忽略了,也就降低了模型的泛化能力。为了提高模型捕捉不同领域之间的关系信息方面能力,增强模型的泛化能力,本发明提供了一种新的跨领域层次关系的知识蒸馏方法。

为了便于理解,请参阅图1和图2,图1为本发明实施例中跨领域层次关系的知识蒸馏方法的一个流程示意图,如图1所示,本发明实施例中跨领域层次关系的知识蒸馏方法包括:

步骤101、获取不同领域的训练样本。

步骤102、对各领域的训练样本分别计算学生层的原型特征。

本发明中,对所有学生模型层都执行知识蒸馏。使用原型特征来反映各个领域数据的特点,对不同的学生层计算不同的原型特征。在实际中,为不同批量的训练样本计算不同的原型特征。在d领域的第m个学生层的原型特征h

其中,h

步骤103、对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络。

计算的这些领域原型特征被用来挖掘不同领域间的关系,为一次性地同时找到跨领域关系,本发明中用基于图注意力网络的领域关系网络同时处理所有领域的原型特征。

步骤104、将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数。

在图注意力网络中,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度。这样,不同领域的关系便可以同时被捕捉。如图2所示,除了预测层外,为每个学生层建立一个两层领域关系图网络,第m层图网络的输入h

在第m个学生层的第一层领域关系网络中,一个共享参数矩阵

其中,

之后,节点i最终的输出

其中,k表示头序号。

在第m个学生层的第二层领域关系图网络中,为了得到领域关系系数,将第一层图网络中用到的参数W

步骤105、将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数。

根据每个学生层生成的领域关系系数,对每个领域进行权重重分配,因而可以得到各层的损失,从而确定总体损失,得到蒸馏损失函数。

步骤106、根据蒸馏损失函数对学生模型进行迭代训练。

根据蒸馏损失函数损失函数对学生模型进行训练,即,根据蒸馏损失函数更新学生模型的参数。

由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏方法中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

同时,为进一步提升压缩语言模型的表达能力,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,为每个领域建立了一系列参考原型特征,并根据与对应参考原型特征的相似度层次化地聚合当前层和其之前层的原型特征,从而得到各个领域更具有代表性的聚合原型特征。

请参阅图3-图4,在一个实施例中,还包括:

步骤107、基于自注意力机制建立每个领域的参考原型特征;

步骤108、将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;

步骤109、将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。

对于每个学生层,当前层和其之前层原型特征的参考原型特征可以简单地设定为当前层的原始领域原型特征。然而,这种方法没有考虑到其他领域的信息,而该信息对于提高模型在不同领域的泛化性能起着很重要的作用。为此,本发明实施例中为同一层的所有领域原型特征引入了一个自注意力机制以加入不同领域的信息。具体地,第m个学生层发参考原型特征

其中,

在得到参考原型特征之后,使用一个对比-聚合机制来动态地聚合层原型特征,该过程通过将其与对应的参考原型特征进行对比完成,可以使得模型注意到每个领域中更具代表性的层原型特征。具体地,第m层第d个领域的聚合原型特征

其中,

聚合后的原型特征

最后,得到的总体损失(即最终确定的蒸馏损失函数)可以表示为:

其中,r

为了便于理解,请参阅图5,本发明中提供了一种跨领域层次关系的知识蒸馏系统的实施例,包括:

训练样本获取模块501,用于获取不同领域的训练样本;

原型特征生成模块502,用于对各领域的训练样本分别计算学生层的原型特征;

领域关系网络构建模块503,用于对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;

领域关系系数获取模块504,用于将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;

蒸馏损失函数生成模块505,用于将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;

模型训练模块506,用于根据蒸馏损失函数对学生模型进行迭代训练。

领域关系网络构建模块503建立的两层领域关系图网络的网络结构包括:

在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。

还包括:

参考原型特征生成模块507,用于基于自注意力机制建立每个领域的参考原型特征;

对比聚合模块508,用于将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;

领域关系系数更新模块509,用于将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。

由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏系统中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

同时,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,进一步提升压缩语言模型的表达能力。

本发明实施例中提供的跨领域层次关系的知识蒸馏系统用于执行前述跨领域层次关系的知识蒸馏方法实施例中的跨领域层次关系的知识蒸馏方法,可取得与前述跨领域层次关系的知识蒸馏方法实施例相同的技术效果,在此不再进行赘述。

以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号