首页> 中国专利> 元协作训练范式

元协作训练范式

摘要

生成对抗模型有很多益处;然而,由于模式瓦解,这些生成器面临质量‑多样性的折衷(即,生成器模型为了提高生成质量而牺牲了生成多样性)。本文提出的是通过减速模式瓦解来提高对抗性内容生成的性能的实施例。在一个或多个实施例中,采用协作训练范式,其中第二模型与生成器协作训练,并且帮助有效地塑造生成器的数据分布以防止模式瓦解。此外,可以使用元学习机制的实施例,其中对生成器的协作更新用作高级元任务,并且其有助于确保在对抗性更新后生成器参数保持对模式瓦解的抵抗力。在实验中,经测试的使用证明了对抗文本生成器的模式瓦解的有效减慢。总体而言,实施例在生成质量和多样性两者方面以明显的优势胜过基准方法。

著录项

  • 公开/公告号CN113222105A

    专利类型发明专利

  • 公开/公告日2021-08-06

    原文格式PDF

  • 申请/专利权人 百度(美国)有限责任公司;

    申请/专利号CN202110162379.6

  • 发明设计人 李定成;尹海燕;李旭;李平;

    申请日2021-02-05

  • 分类号G06N3/04(20060101);G06N3/08(20060101);

  • 代理机构11204 北京英赛嘉华知识产权代理有限责任公司;

  • 代理人王达佐;王艳春

  • 地址 美国加利福尼亚州

  • 入库时间 2023-06-19 12:07:15

说明书

技术领域

本公开总体上涉及用于计算机学习的系统和方法,该系统和方法可以提供改进的计算机性能、特征和用途。更具体地,本公开涉及用于生成模型的对抗训练的系统和方法。

背景技术

神经网络在许多领域都取得了巨大的成功,诸如计算机视觉、自然语言处理、推荐系统等。神经网络模型的一种类型是生成模型,该生成模型用于生成内容,诸如文本和图像。训练生成模型以从训练集中学习真实的数据分布,并且能够在训练完成时生成新的数据点。近年来,它们已成功应用于广泛的应用,包括图像生成、风格化、半监督分类和自然语言生成。应用的一个领域是文本生成的新兴任务,通常将其建模为顺序的离散数据生成过程。此类任务在许多现实世界应用中扮演着关键角色,诸如机器翻译、文本摘要和对话系统。

顺序文本生成模型的训练在很大程度上依赖于在自动回归模型上应用强制教学(teacher forcing),即,以最大似然估计(MLE)进行优化。然而,用强制教学训练生成模型将遭受曝光偏差(exposure bias),即,模型在推理时间被馈送到其预测数据而不是地面真实数据,并且因此由于积累的误差而导致生成不良样本。为了解决曝光偏差问题,针对文本生成的正在进行的主要研究集中在利用对抗训练技术来推导更好的文本生成模型。通常,这种尝试可以分为以下两个方面:第一种方法将生成对抗网络(GAN)与强化学习(RL)进行组合,表示为基于RL;第二种方法仅玩双人对抗式游戏,而无需使用RL,表示为无RL。

基于RL和无RL的文本生成方法两者都遭受模式瓦解(mode collapse),这对于训练基于GAN的模型是众所周知的挑战。也就是说,随着对抗训练的进行,所生成的分布倾向于与生成用于数据的模式子集形成对比。因此,生成器输出重复的语句,并且因此不再表达性地表示数据生成分布。在最近的研究中,已经对这种效果进行了定量评估,结果表明,当从MLE训练移动到对抗训练阶段时,生成器的输出分布的熵将经历明显的下降。为了使用基于GAN的技术推导更好的文本生成模型,一项关键任务是通过有效地减慢对抗性生成器的模式瓦解来实现更好的质量-多样性折衷,即,让生成器从对抗性更新中获取丰富的梯度信息以使其输出更真实(即提高质量),同时容忍较小的模式瓦解效果(即降低多样性)。然而,有限数量的现有基于RL或无RL的方法明确考虑处理GAN训练的模式瓦解。

因此,需要明确地解决对抗训练的模式瓦解的挑战的方法,从而产生改进的文本生成模型。

发明内容

在第一方面,本公开提供了一种用于训练生成器的计算机实现的方法,其包括:

响应于尚未达到停止条件,执行步骤,所述步骤包括:

从训练数据中采样一组数据点;

使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;

使用对抗训练损失函数来计算所述生成器模型的对抗损失;

使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;

使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;

使用所述协作训练损失来确定元梯度;

使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;

使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及

使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及

响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。

在第二方面,本公开提供了一种系统,其包括:

一个或多个处理器;以及

非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:

响应于尚未达到停止条件,执行步骤,所述步骤包括:

从具有第一分布的训练数据中采样一组数据点;

使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;

使用对抗训练损失函数来计算所述生成器模型的对抗损失;

使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;

使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;

使用所述生成器模型的所述协作训练损失来确定元梯度;

使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;

使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及

使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及

响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。

在第三方面,本公开提供了一种用于训练生成器的计算机实现的方法,其包括:

响应于尚未达到停止条件,执行步骤,所述步骤包括:

使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;

使用对抗训练损失函数来计算所述生成器模型的对抗损失;

使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;

使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;

使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及

使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及

响应于已达到所述停止条件,输出所述生成器模型。

在第四方面,本公开提供了一种系统,其包括:

一个或多个处理器;以及

非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:

响应于尚未达到停止条件,执行步骤,所述步骤包括:

使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;

使用对抗训练损失函数来计算所述生成器模型的对抗损失;

使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;

使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;

使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及

使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及

响应于已达到所述停止条件,输出所述生成器模型。

在第四方面,本公开提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第一方面所述的方法。

在第五方面,本公开提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第三方面所述的方法。

附图说明

将参考本公开的实施例,它们的示例可示于附图中。这些附图旨在是说明性的而非限制性的。虽然本公开大体上在这些实施例的上下文中描述,但应理解,本公开的范围并不旨在限于这些特定实施例。附图中的项目可能未按比例绘制。

附图(“图”)1描绘了根据本公开的实施例的协作训练过程的高级概述。

图2描绘了根据本公开的实施例的示例性生成系统。

图3描绘了根据本公开的实施例的示例性鉴别器系统。

图4描绘了根据本公开的实施例的GAN系统和Meta-CoTGAN数据流方法的概述。

图5描绘了根据本公开的实施例的用于训练生成器模型的Meta-CoTGAN方法。

图6描绘了根据本公开的实施例的用于使用已经使用Meta-CoTGAN方法训练的生成器模型的方法。

图7描绘了根据本公开的实施例的在关于NLL

图8包含表2,其根据本公开的实施例呈现了在数据集上的评估结果。结果经6次运行取平均(随机种子),并且对于NLL

图9描绘了根据本公开的实施例的RelGAN和Meta-CoTGAN实施例的NLL

图10包含表3,其根据本公开的实施例呈现了在数据集2上的评估结果。结果经6次运行取平均,并且对于NLL

图11包含表4,其根据本公开的实施例呈现了在数据集1上的消融研究结果。当协作训练部分和元优化分别关闭时,评估包括Meta-CoTGAN实施例。报告的评分源自6个随机种子。

图12描绘根据本公开的实施例的计算设备/信息处理系统的简化框图。

具体实施方式

在以下描述中,出于解释目的,阐明具体细节以便提供对本公开的理解。然而,将对本领域的技术人员显而易见的是,可在没有这些细节的情况下实践本公开。此外,本领域的技术人员将认识到,下文描述的本公开的实施例可以以各种方式(例如过程、装置、系统、设备或方法)在有形的计算机可读介质上实施。

附图中示出的组件或模块是本公开实施例的示例性说明,并且意图避免使本公开不清楚。还应理解,在本论述的全文中,组件可描述为单独的功能单元(可包括子单元),但是本领域的技术人员将认识到,各种组件或其部分可划分成单独组件,或者可整合在一起(包括例如位于单个的系统或组件内)。应注意,本文论述的功能或操作可实施为组件。组件可以以软件、硬件、或它们的组合实施。

此外,附图内的组件或系统之间的连接并不旨在限于直接连接。相反,在这些组件之间的数据可由中间组件修改、重格式化、或以其它方式改变。另外,可使用另外或更少的连接。还应注意,术语“联接”、“连接”、“通信地联接”、“接合”、“接口”或其派生词中的任一个,应理解为包括直接连接、通过一个或多个中间设备来进行的间接连接、和无线连接。还应注意,任何通信(诸如信号、响应、答复、确认、消息、查询等)可包括一个或多个信息交换。

在本说明书中对“一个或多个实施例”、“优选实施例”、“实施例”、“多个实施例”等的提及表示结合实施例所描述的具体特征、结构、特性或功能包括在本公开的至少一个实施例中,以及可包括在多于一个的实施例中。另外,在本说明书的各个地方出现以上所提到的短语并不一定全都是指相同的实施例或多个相同实施例。

在本说明书的各个地方使用某些术语目的在于说明,并且不应被理解为限制。服务、功能或资源并不限于单个服务、单个功能或单个资源;这些术语的使用可指代相关服务、功能或资源的可分布或聚合的分组。术语“包括”、“包括有”、“包含”和“包含有”应理解为开放性的术语,并且其后任何列出内容都是实例,而不旨在限于所列项目。“层”可包括一个或多个操作。词“最佳”、“优化”、“最优化”等是指对结果或过程的改进,并非要求指定的结果或过程已达到“最佳”或峰值状态。存储器、数据库、信息库、数据存储、表、硬件、高速缓存等在本文中的使用,可用来指代可输入信息或以其它方式记录信息的一个或多个系统组件。

在一个或多个实施例中,停止条件可包括:(1)已执行了设定次数的迭代;(2)已达到一定量的处理时间;(3)收敛(例如,连续迭代之间的差小于第一阈值);(4)发散(例如,性能劣化);(5)已达到可接受的结果。

本领域技术人员应当认识到:(1)可以可选地执行某些步骤;(2)步骤可以不限于本文所述的特定顺序;(3)某些步骤可以以不同的顺序执行;并且(4)某些步骤可以同时进行。

本文所使用的任何标题仅是为了组织目的,并且不应被用于限制说明书或权利要求的范围。本专利文献中提到的每个参考文献/文件以其整体通过引用并入本文。

应注意,本文提供的任何实验和结果均以说明性的方式提供,并且是在特定条件下使用特定实施例进行的;因此,这些实验及其结果均不得用于限制当前专利文件的公开范围。

还应注意,尽管本文描述的实施例可能在文本生成的情景内,但是本公开的各方面不限于此。因此,本公开的各方面可应用或适用于其它情景并用于生成其它内容。

A.概述

可以生成具有足够多样性的高质量文本的训练生成模型是自然语言生成(NLG)社区的重要开放问题。最近,生成性对抗模型已广泛应用于文本生成任务,其中经对抗性训练的生成器减轻了常规最大似然方法所经历的曝光偏差,并且产生了有前途的生成质量。然而,由于对抗训练的模式瓦解的臭名昭著的缺陷,经对抗训练的生成器面临质量-多样性的折衷,即,生成器模型倾向于为了提高生成质量而严重牺牲生成多样性。

本文提出的是新颖方法的实施例,其经由有效地减速对抗训练的模式瓦解来提高对抗内容生成的性能。为此,提出了协作训练范式的实施例,其中与生成器协作地训练语言模型,并且在一个或多个实施例中,语言模型被用来有效地塑造生成器的数据分布以防止模式瓦解。此外,在一个或多个实施例中,代替原则上进行针对生成器的协作更新,制定了元学习机制,其中对生成器的协作更新充当高级元任务,凭直觉确保对抗性更新后的生成器的参数将保持一致以防止模式瓦解。在实验中,证明了实施例可以有效地减慢对抗文本生成器的模式瓦解的速度。总体而言,实施例能够在验证域中的生成质量和多样性两者方面以明显的优势胜过基准方法。

除了具有强制教学的训练语言模型的常规方法之外,当前用于文本生成的方法通常可以被分类为基于RL的方法或无RL的方法。大多数基于RL的方法将文本生成公式化为马尔可夫决策过程(MDP)。通常,生成器会通过策略梯度算法或其变型使用从GAN的鉴别器得出的奖励信号进行更新。此类方法的突出示例包括SeqGAN、RankGAN、LeakGAN和MaskGAN。从鉴别器模型得出的有噪声的奖励信号倾向于使这种基于RL的模型遭受高方差梯度,以更新生成器的参数。除了梯度的高方差之外,基于RL的方法还面临由部分序列评估、学习缓慢和敏感超参数带来的困难。考虑到针对基于RL的方法的此类挑战,可以认为实施例属于但不限于无RL方法的类别。无RL方法的突出示例包括TextGAN、FM-GAN、GSGAN和Rel-GAN。此类方法为生成器提供低方差梯度,并且经常导致更稳定的训练。

大多数对抗文本生成模型首先通过MLE进行预训练,然后在基于RL或无RL的机制下通过对抗训练不断进行优化。当从MLE训练切换到对抗训练阶段时,基于RL和无RL方法两者的生成器模型将遭受模式瓦解问题。本文的一个或多个实施例的核心直觉是利用协作训练的语言模型来减速对抗训练的模式瓦解。尽管利用语言模型以促进对抗性文本生成的类似直觉与其他著作吻合,但是还是存在明显差异。在J.Xu、X.Ren、J.Lin和X.Sun的“DP-GAN:用于生成信息性文本和多样化文本的促进多样性的生成性对抗网络(DP-GAN:Diversity-Promoting Generative Adversarial Network for Generating Informative andDiversified Text)”(获自arXiv预印本arXiv:1802.01345(2018))中,用于对抗训练的鉴别器被建模为语言模型,该模型使真实数据的概率最大化,并且使生成数据的概率最小化。此外,在基于RL的设置下,将从语言模型得出的输出用作奖励信号,以促进生成多样性。由Sidi Lu、Lantao Yu、Siyuan Feng、Yaoming Zhu和Weinan Zhang在第36届机器学习国际会议论文集的PMLR97:4164-4172(2019)(以下称为“Lu等人2019”)中的“CoT:用于离散数据的生成建模的协作训练(CoT:Cooperative training for generative modeling ofdiscrete data)”中,其中在线训练语言模型,以提供目标分布,以便使实际数据分布与生成的分布之间的詹森-香农(Jensen-Shannon)散度最小化。相比之下,一个或多个实施例可以被认为采用相似的策略来训练语言模型,但是针对生成器模型的协作训练在其他差异中大为不同。例如,实施例包括不同的元学习设置,以优化生成器的协作训练损失。

总体而言,该专利文件中提出了至少三个贡献。首先,提出了新颖的协作训练方法的实施例,其中使用语言模型来有效地塑造对抗文本生成器的输出分布。该方法的实施例有效地减慢了对抗文本生成器的模式瓦解,并且因此导致文本生成朝向更好的质量-多样性折衷。其次,为了优化生成器的协作训练损失,本文提出了新颖的元学习机制的实施例。在一个或多个实施例中,协作训练任务用作元任务,并且对抗训练用作基本任务。因此,实施例确保在对抗性更新之后的生成器参数对模式瓦解有抵抗力。第三,在合成和真实世界的数据集上进行的大量实验表明,实施例能够在质量和多样性方面产生更好的文本生成模型。

B.序言

文本生成的任务通常被建模为顺序的离散数据生成过程。让

其中每个序列x的概率以自回归方式表示:

其中y<i表示先前令牌的序列y

利用GAN进行文本生成的方法尝试在生成器G

其中,生成器G

在自动回归生成过程中,第i个令牌y

相对于生成器的参数θ变成不可微分,因为

其中,

其中,

C.方法实施例

用对抗训练机制(基于RL和无RL的方法两者)训练的语言生成器在从强制教学切换到对抗训练阶段时会遭受模式瓦解。在这一部分中,新颖的元协作训练方法的实施例用于克服此类挑战。总体而言,目标是经由降低其对抗训练的模式瓦解来为语言生成器实现更好的质量-多样性折衷。即,该方法的实施例允许生成器从对抗训练中获得丰富的梯度信息以便提高生成质量,同时在生成多样性方面牺牲很少。总体而言,在一个或多个实施例中,采用语言模型来减速生成器的输出分布的模式瓦解。在一个或多个实施例中,在对抗训练期间,语言模型与生成器G

1.协作训练制定实施例

在本节中给出了协作训练范式的实施例,该协作训练范式参与了对抗生成器G

图1描绘了根据本公开的实施例的协作训练过程的高级概述。用对抗训练训练的生成器G

在协作训练过程中,可以通过MLE损失来一致地优化语言模型。为了为生成器提供平稳变化的目标分布,在一个或多个实施例中,语言模式用来自混合分布的数据与来自真实数据和生成的数据的平衡样本进行训练,例如

用来自真实数据的样本一致地更新语言模型M

其中,y

因此,可以认为在生成器上应用协作训练的效果等同于以加权方式增加真实数据的密度。

2.元协作优化实施例

在这一部分中,提出了用于对生成器模型参数的对抗训练损失

为此,在一个或多个实施例中,将优化对抗损失

正式地,在一个或多个实施例中,可以首先通过优化基本任务损失来对生成器参数θ进行一个梯度更新:

然后,在一个或多个实施例中,从真实数据分布中获得新样本:x~p

在以下方法1中给出了用于元协作训练的示例性完整方法的实施例。

方法1—元协作训练实施例

图2描绘了根据本公开的实施例的具有相关存储器的示例性生成系统。在并入新的观察x

图3描绘了根据本公开的实施例的示例性鉴别器系统。在一个或多个实施例中,鉴别器300包括嵌入层、一个或多个卷积层、自我关注层、一个或多个卷积层、线性层和分对数输出。

图4描绘了根据本公开的实施例的GAN系统的概述,并且图5描绘了根据本公开的实施例的用于训练生成器模型的Meta-CoTGAN方法。在一个或多个实施例中,用于训练生成器的计算机实现的方法可以包括以下步骤。可以对来自训练数据405的一组数据点410进行采样(505),并且使用包括一组生成器参数值的生成器模型415,可以生成(510)一组生成的数据点(例如,伪数据点)。使用鉴别器420,该鉴别器接收真实数据点和伪数据点并且试图在两者之间进行区分,可以使用对抗训练损失函数445来计算生成器模型的对抗损失。鉴别器模型的对抗损失和生成器模型的对抗损失可以通过使用最小-最大损失函数来获得。

在一个或多个实施例中,然后可以使用(515)对抗损失和梯度下降来确定生成器模型的一组中间生成器参数值。

在一个或多个实施例中,将从训练数据中采样的一组数据点用作以下各项的输入:(1)第二神经网络模型(例如语言模型425),其包括第二神经网络模型组的参数值;以及(2)使用一组中间生成器参数值的生成器模型415,计算(520)生成器模型的协作训练损失。在一个或多个实施例中,该协作训练损失可以然后用于确定(525)元梯度。

在一个或多个实施例中,使用对抗梯度来更新(530)一组生成器参数值,该对抗梯度是使用生成器模型的对抗损失和元梯度获得的。还可以使用第二神经网络模型的协作训练损失来更新(540)第二神经网络模型的第二神经网络模型组的参数值;并且可以使用鉴别器模型的对抗损失来更新(535)鉴别器模型的一组鉴别器参数值。

在一个或多个实施例中,该处理可以重复直到达到(545)停止条件为止;否则,如果已经达到停止条件,则输出具有其生成器参数值的最终更新集合的生成器模型(550),并且可以将其用于生成。接下来参考图6(如下)讨论经训练的生成器的示例性部署。

在一个或多个实施例中,图5的过程还可以包括初始化步骤。例如,可以初始化至少生成器模型的生成器参数值的集合和鉴别器模型的鉴别器参数值的集合,并且可以使用训练数据、生成器模型和鉴别器模型对生成器模型进行预训练。在一个或多个实施例中,可以使用最小-最大对抗训练来完成预训练。

在一个或多个实施例中,如前所述,第二神经网络模型和生成器模型可以共用相同的神经网络结构。因此,在一个或多个实施例中,来自预训练生成器模型的一组生成器参数值中的至少一些可以用作第二神经网络模型的参数值。还应当注意的是,第二神经网络模型首先用不同的值进行初始化。例如,可以首先使用随机值来初始化所有模型。

在一个或多个实施例中,使用协作训练损失来更新第二神经网络模型的第二神经网络模型组的参数值的步骤可以包括使用最大似然估计(MLE)损失函数。换句话讲,使用协作训练损失来更新第二神经网络模型的第二神经网络模型组的参数值的步骤包括最小化使用从训练数据采样的一组数据点的第二神经网络模型与使用从训练数据采样的数据点和从由生成器模型生成的数据点采样的数据点的混合的第二神经网络模型之间的Kullback-Leibler散度。在一个或多个实施例中,该混合可以是来自训练数据的相等数量或近似相等数量的数据点以及由生成器模型生成的数据点。

图6描绘了根据本公开的实施例的用于使用已经使用Meta-CoTGAN方法训练的生成器模型的方法。给定已经使用Meta-CoTGAN方法实施例训练的生成器模型,可以部署(605)生成器模型以便生成内容。因此,已训练和部署的Meta-CoTGAN生成器模型可以用于(610)生成输出。

D.实验结果

为了方便起见,通常可以将

应当注意,这些实验和结果仅通过说明的方式提供并且使用一个或多个具体实施例在具体条件下执行;因此,这些实验和它们的结果都不应被用来限制本专利文献的公开的范围。

1.实施细节

实施例是在RelGAN(由Weili Nie、Nina Narodytska和Ankit Patel在2019年的国际学习表征会议(ICLR)中的“RelGAN:用于文本生成的关系生成对抗网络(RelGAN:Relational Generative Adversarial Networks For Text Generation)”中提出,该文献通过援引以其全部内容并入本文)、是最先进的方法之一的无RL的对抗文本生成模型之上实现的。应当注意,可以使用其他生成对抗网络。具体地,Rel-GAN采用关系记忆来对输入令牌之间的长距离依赖性建模,并且采用gumbel-softmax松弛来克服生成器训练中的不可微分问题。关系存储器采用1个存储器时隙、带有2头的多头注意力,并且注意密钥大小设置为512。用于协作训练的语言模型采用与生成器相同的网络体系结构,并且在进行预训练后将生成器的参数的权重分配给语言模型。鉴别器采用大小为64的多种表示。在测试实施例中,Adam被用作用于更新所有模型参数的优化算法。

2.评估指标

为了比较,同时根据样本质量和样本多样性来评估各种模型。在当今大多数文本生成工作之后,通过在数据集上进行测试时的BLEU评分指标来评估样本质量,并且在合成数据集上进行测试时通过NLL

其中,在生成器模型上评估真实数据的密度。因此,具有更好样本多样性的模型将在实际数据空间上具有更广泛的覆盖范围,并且导致更低的NLL

3.基准模型

为了评估测试实施例的效率,考虑了MLE以及基于RL的基准,包括SeqGAN、RankGAN和LeakGAN。另外,还与最相关的无RL基准RelGAN进行了比较。在评估过程中,遵循在Rel-GAN中提出的温度设置,并且此处提出了用100和1000的温度评估时所测试方法实施例的结果。

4.合成数据集

第一评估域是合成oracle数据集。该实验采用随机初始化的长短期(LSTM)模型作为目标模型以模拟真实世界的序列,并且从真实数据分布中生成数据。进行的合成实验的序列长度被设置为20。在该域中进行实验的目的是将被测试的实施例与其最接近的协作训练对应物CoT进行比较。虽然可以认为这两种模型采用相似的方式来训练语言模型,但是调查了在这两种方法中提出的在生成器模型上采用相应的协作训练损失的效率。

在图7中证明了NLL

表1:序列长度为20的合成oracle的评估结果。对于CoT,呈现它们对于NLL

5.数据集1

第二评估域使用真实世界数据集(数据集1),其涉及图像字幕。在朱耀明(YaomingZhu)、陆思迪(Sidi Lu)、郑磊(Lei Zheng)、郭佳贤(Jiaxian Guo)、张卫南(WeinanZhang)、王军(Jun Wang)和余勇(Yong Yu)在2018年6月的SIGIR‘18:第41届国际ACM SIGIR信息检索研究与发展会议(第1097 1100页)中的“Texygen:用于文本生成模型的基准测试平台(Texygen:A Benchmarking Platform for Text Generation Models)”(该文献通过援引以其全部内容并入本文)中提出预处理方法。训练和测试集分别包含大约10,000个句子。句子的最小长度为7,并且最大长度为37。词汇量大约为4,700。

在表2(在图8中)中呈现用于测量样本质量的BLEU-2至BLEU-5的评分以及用于测量样本多样性的NLL

为了进一步验证这一点,在图9中呈现用于样本多样性指标和作为代表性样本质量指标的BLEU-5的学习曲线。图9展示了测试方法实施例以及数据集1上的基准RelGAN的质量-多样性折衷。与RelGAN相比,Meta-CoTGAN实施例逐步获得了更好的BLEU-5评分,其中模式瓦解的进程明显较慢。用于RelGAN的BLEU-5被绘制到对应的NLL

已经观察到,用于RelGAN的NLL

6.数据集2

第三评估域是另一个数据集(数据集2),其大小比数据集1大得多。数据集2包含270,000个句子的训练集和10,000个句子的测试集。句子的最大长度为51,并且词汇量为大约5,250。使用数据集2的结果在表3(在图10中)中呈现。

可以看出,就所有的BLEU指标和NLL

E.消融研究

1.协作训练语言模型的影响

在该部分中,展示了使用在线更新的语言模型进行协作训练过程的实施例的影响。为此,直接比较是使用未用协作训练更新的预训练的语言模型。我们将这种基准表示为Meta-CoTGAN

2.元优化的影响

还评估了元优化设置的影响。为此,将实施例与采用协同训练损失来优化生成器参数的原理方法进行了比较,该方法以通过加权方式对对抗性损失和协同训练损失进行线性求和的形式来提出,即

F.一些结论

本文提出的是用于促进对抗生成模型的训练的元协作训练方法的实施例。实施例利用协作训练的第二模型(例如,语言模型)来经由将实际数据上的第二模型的预测输出分布提炼到对抗生成器模型来有效地减速对抗训练的模式瓦解。使用合成数据集和两个现实世界数据集(具有的序列长度在7到51的范围内)两者来评估所提出方法的实施例。因此,经测试的方法同时在样本质量指标和样本多样性指标上始终优于基准算法。该方法的实施例是通用的,并且可以与面临模式瓦解问题的基于RL或无RL的不同的对抗文本生成算法一起应用。元协作训练的实施例也可以应用于或适应于更多新兴的基于RL的/无GAN的模型。

G.计算系统实施例

在一个或多个实施例中,本专利文献的方面可涉及、可包括一个或多个信息处理系统/计算系统,或者可在一个或多个信息处理系统(或计算系统)上实现。信息处理系统/计算系统可包括可操作来计算、运算、确定、分类、处理、传输、接收、检索、发起、路由、交换、存储、显示、通信、显现、检测、记录、再现、处理或利用任何形式信息、智能或数据的任何手段或手段的组合。例如,计算系统可以是或可包括个人计算机(例如,膝上型计算机)、平板电脑、移动设备(例如,个人数字助理(PDA)、智能手机、平板手机、平板等)、智能手表、服务器(例如,刀片式服务器或机架式服务器)、网络存储设备、摄像机或任何其它合适设备,并且可在大小、形状、性能、功能和价格方面改变。计算系统可包括随机存取存储器(RAM)、一个或多个处理资源(诸如中央处理器(CPU)或硬件或软件控制逻辑)、只读存储器(ROM)和/或其它类型的存储器。计算系统的另外组件可包括一个或多个驱动器(例如,硬盘驱动器、固态驱动器或两者)、用于与外部设备通信的一个或多个网络端口、以及各种输入和输出(I/O)设备(例如键盘、鼠标、手写笔、触摸屏和/或视频显示器)。计算系统还可包括可操作为在各种硬件组件之间传输通信的一个或多个总线。

图12描绘了根据本公开的实施例的信息处理系统(或计算系统)的简化框图。应理解,计算系统可不同地配置并且包括不同组件,包括如图12中所示的更少或更多的部件,但应理解,针对系统1200所示出的功能可操作为支持计算系统的各种实施例。

如图12所示,计算系统1200包括一个或多个中央处理器(CPU)1201,CPU 1201提供计算资源并控制计算机。CPU 1201可用微处理器等实现,并且还可包括一个或多个图处理单元(GPU)1202和/或用于数学计算的浮点协处理器。在一个或多个实施例中,一个或多个GPU 1202可并入显示控制器1209内,诸如一个或多个图卡的一部分。系统1200还可包括系统存储器1219,系统存储器1219可包括随机存取存储器(RAM)、只读存储器(ROM)或两者。

如图12中所示,还可提供多个控制器和外围设备。输入控制器1203表示至各种输入设备1204的接口,例如键盘、鼠标、触摸屏和/或触笔。计算系统1200还可包括存储控制器1207,该存储控制器1207用于与一个或多个存储设备1208对接,存储设备中的每个包括存储介质(诸如磁带或盘)或光学介质(其可用于记录用于操作系统、实用工具和应用程序的指令的程序,它们可包括实施本公开的各方面的程序的实施例)。存储设备1208还可用于存储经处理的数据或是将要根据本公开处理的数据。系统1200还可包括显示控制器1209,该显示控制器1209用于为显示设备1211提供接口,显示设备1211可为阴极射线管(CRT)显示器、薄膜晶体管(TFT)显示器、有机发光二极管、电致发光面板、等离子面板或任何其它类型的显示器。计算系统1200还可包括用于一个或多个外围设备1206的一个或多个外围设备控制器或接口1205。外围设备的示例可包括一个或多个打印机、扫描仪、输入设备、输出设备、传感器等。通信控制器1214可与一个或多个通信设备1215对接,这使系统1200能够通过各种网络(包括互联网、云资源(例如以太云、经以太网的光纤通道(FCoE)/数据中心桥接(DCB)云等)、局域网(LAN)、广域网(WAN)、存储区域网络(SAN))中的任一网络,或通过任何合适电磁载波信号(包括红外信号)来连接至远程设备。如描绘的实施例中所示,计算系统1200包括一个或多个风扇或风扇托盘1218以及一个或多个冷却子系统控制器1217,其监视系统1200(或其组件)的热温度并操作风扇/风扇托盘1218以助于调节温度。

在示出的系统中,所有主要系统组件可连接至总线1216,总线1216可表示多于一个的物理总线。然而,各种系统组件可在物理上彼此接近或可不在物理上彼此接近。例如,输入数据和/或输出数据可远程地从一个物理位置传输到另一物理位置。另外,实现本公开的各方面的程序可经由网络从远程位置(例如,服务器)访问。此类数据和/或程序可通过各种机器可读介质中的任一机器可读介质来传送,机器可读介质包括例如:诸如硬盘、软盘和磁带的磁性介质;诸如光盘(CD)和全息设备的光学介质;磁光介质;以及专门配置成存储或存储并执行程序代码的硬件设备,诸如专用集合成电路(ASIC)、可编程逻辑器件(PLD)、闪存设备、其它非易失性存储器(NVM)设备(诸如基于XPoint的3D设备)、以及ROM和RAM设备。

本公开的方面可利用用于一个或多个处理器或处理单元以使步骤执行的指令在一个或多个非暂态计算机可读介质上编码。应注意,一个或多个非暂态计算机可读介质应包括易失性存储器和/或非易失性存储器。应注意,替代实现方式是可能的,其包括硬件实现方式或软件/硬件实现方式。硬件实施的功能可使用ASIC、可编程的阵列、数字信号处理电路等来实现。因此,任何权利要求中的术语“手段”旨在涵盖软件实现方式和硬件实现方式两者。类似地,如本文使用的术语“计算机可读媒介或介质”包括具有实施在其上的指令程序的软件和/或硬件或它们的组合。利用所构想的这些替代实现方式,应理解,附图以及随附描述提供本领域的技术人员编写程序代码(即,软件)和/或制造电路(即,硬件)以执行所需处理所要求的功能信息。

应注意,本公开的实施例还可涉及具有其上具有用于执行各种计算机实施的操作的计算机代码的非暂态有形计算机可读介质的计算机产品。介质和计算机代码可为出于本公开的目的而专门设计和构造的介质和计算机代码,或者它们可为相关领域中的技术人员已知或可用的。有形计算机可读介质的示例包括例如:诸如硬盘、软盘和磁带的磁性介质;诸如CD和全息设备的光学介质;磁光介质;以及专门配置成存储或存储并执行程序代码的硬件设备,诸如ASIC、可编程逻辑器件(PLD)、闪存设备、其它非易失性存储器(NVM)设备(诸如基于XPoint的3D设备)、以及ROM和RAM设备。计算机代码的示例包括机器代码(例如,编译器产生的代码)以及包含可由计算机使用解释器来执行的更高级代码的文件。本公开的实施例可整体地或部分地实施为可在由处理设备执行的程序模块中的机器可执行指令。程序模块的示例包括库、程序、例程、对象、组件和数据结构。在分布的计算环境中,程序模块可物理上定位在本地、远程或两者的设定中。

本领域的技术人员将认识到,计算系统或编程语言对本公开的实践来说均不重要。本领域的技术人员将还将认识到,多个上述元件可物理地和/或在功能上划分成模块和/或子模块或组合在一起。

本领域技术人员将理解,前文的示例和实施例是示例性的,并且不限制本公开的范围。旨在说明的是,在本领域的技术人员阅读本说明书并研究附图后将对本领域的技术人员显而易见的本公开的所有、置换、增强、等同、组合或改进包括在本公开的真实精神和范围内。还应注意,任何权利要求书的元素可不同地布置,包括具有多个从属、配置和组合。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号