首页> 中国专利> 用于检测对话策略学习中模拟用户经验质量的方法和系统

用于检测对话策略学习中模拟用户经验质量的方法和系统

摘要

本发明提供了一种用于检测对话策略学习中模拟用户经验质量的方法和系统,其方法包括以下步骤:S1.由世界模型生成模拟经验;S2.通过基于KL散度的质量检测器对所述的模拟经验进行质量检测;S3.将质量检测合格的模拟经验进行保存以用于对话策略模型训练。本方案引入了基于KL散度的质量检测器,能够更轻松有效地评估模拟经验的质量,并在确保对话策略的鲁棒性和有效性的同时大大提高计算效率,实现有效控制模拟经验质量的目的。

著录项

  • 公开/公告号CN112989016A

    专利类型发明专利

  • 公开/公告日2021-06-18

    原文格式PDF

  • 申请/专利权人 南湖实验室;

    申请/专利号CN202110532470.2

  • 申请日2021-05-17

  • 分类号G06F16/332(20190101);G06F16/36(20190101);G06N3/00(20060101);G06N20/00(20190101);

  • 代理机构33233 浙江永鼎律师事务所;

  • 代理人陆永强;张晓英

  • 地址 314001 浙江省嘉兴市南湖区七星街道香湖别墅29幢

  • 入库时间 2023-06-19 11:29:13

说明书

技术领域

本发明属于机器学习技术领域,尤其是涉及一种用于检测对话策略学习中模拟用户经验质量的方法和系统。

背景技术

任务完成型对话策略学习旨在构建一个以完成任务为目标的对话系统,该系统可以通过几轮自然语言交互来帮助用户完成特定的单个任务或多域任务。它已广泛应用于聊天机器人和个人语音助手,例如苹果的Siri和微软的Cortana。

近年来,强化学习逐渐成为了对话策略学习的主流方法。基于强化学习,对话系统可以通过与用户进行自然语言交互来逐步调整、优化策略,以提高性能。但是,原始强化学习方法在获得可用的对话策略之前需要进行大量人机对话交互,这不仅增加了训练成本,而且还恶化了早期训练阶段的用户体验。

为了解决上述问题并加速对话策略的学习过程,研究者们在Dyna-Q框架的基础上,提出了Deep Dyna-Q(DDQ)框架。DDQ框架引入了世界模型,为了使其与真实用户更相似,该模型使用真实用户经验进行训练,用以在动态环境中生成模拟用户经验,以下简称模拟经验。在对话策略的学习过程中,使用从实际交互中收集的真实经验和从与世界模型交互中收集的模拟经验共同训练对话智能体。借助引进世界模型,只需要使用少量的真实用户交互,可以显著提升对话策略的学习效率,然而,DDQ在进一步优化基于有限对话交互的对话策略学习方面,还面临着一些难题,例如,世界模型产生的模拟经验并不一定会改善性能,低质量的模拟经验甚至会对性能造成严重的负面影响。近来的一些研究为了解决这个问题,尝试使用生成式对抗网络(GAN)来区分低质量经验以控制模拟经验的质量。但是,对GAN进行训练存在极大的不稳定性问题,其在很大概率上会导致对话策略学习不收敛,并且对超参数的选择高度敏感,使对话学习性能受到严重制约。因此,如何有效筛去除对话策略学习过程中的低质量经验,这个问题仍有待解决,且十分重要。

发明内容

本发明的目的是针对上述问题,提供一种用于检测对话策略学习中模拟用户经验质量的方法及其系统。

为达到上述目的,本发明采用了下列技术方案:

一种用于检测对话策略学习中模拟用户经验质量的方法,包括以下步骤:

S1.由世界模型生成模拟经验;

S2.通过基于KL散度的质量检测器对所述的模拟经验进行质量检测;

S3.将质量检测合格的模拟经验进行保存以用于对话策略模型训练。

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S2中,基于KL散度的质量检测器通过对比模拟经验与真实经验来进行模拟经验的质量检测。

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S3中,将质量检测合格的模拟经验存储至缓冲器以用于对话策略模型训练。

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S2中,根据世界模型生成的模拟经验更新词库world-dict,根据真实用户生成的真实经验更新词库real-dict,并通过KL散度衡量词库world-dict与词库real-dict的相似度以进行模拟经验的质量检测。

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,词库world-dict的主键为世界模型生成的用户动作,主键对应值为用户动作对应的频率;

词库real-dict的主键为真实用户生成的用户动作,主键对应值为用户动作对应的频率。

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S2中,通过事先定义的变量KL

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S2中,词库real-dict与词库world-dict的交集主键在两个词库中的频率值被存储在事先建立的词库same-dict中,并基于词库same-dict计算当前的KL散度,若当前KL散度小于或等于KL

在上述的用于检测对话策略学习中模拟用户经验质量的方法中,在步骤S2中,当词库same-dict的长度小于常量C时判断当前经验为合格经验。

一种用于检测对话策略学习中模拟用户经验质量的系统,包括连接于世界模型、真实用户经验库和对话策略模型的质量检测器,且所述的质量检测器包括KL散度检测器,KL散度检测器用于根据真实用户生成的真实经验检测世界模型生成的模拟经验的质量。

在上述的用于检测对话策略学习中模拟用户经验质量的系统中,所述的质量检测器包括用于存储真实经验的词库real-dict,用于存储模拟经验的词库world-dict和用于保存词库real-dict与词库world-dict的交集主键在两个词库中的频率值的词库same-dict。

本发明的优点在于:通过引入KL散度来检查经验的分布,不需要额外工作来设计和训练复杂的质量检测器,从而更轻松的评估模拟经验的质量,并在确保对话策略的鲁棒性和有效性的同时大大提高计算效率,能够有效控制模拟经验质量。

附图说明

图1为本发明对话学习方法的架构图;

图2为本发明对话学习方法中KL散度计算流程图;

图3为各类智能体在不同K参数下的学习曲线图,其中,

(a)为各类智能体在K=20时的学习曲线图;

(b)为各类智能体在K=30时的学习曲线图。

具体实施方式

下面结合附图和具体实施方式对本发明做进一步详细的说明。

实施例一

如图1所示,本方案提出一种用于检测对话策略学习中模拟用户经验质量的方法,其基本方法与现有技术一致,如使用人类会话数据来初始化对话策略模型和世界模型,并依此来启动对话策略学习。对话策略模型的对话策略学习主要包括直接强化学习和间接强化学习(也叫规划)两部分。直接强化学习,采用Deep Q-Network(DQN)根据真实经验改进对话策略,对话策略模型与用户User交互,在每一步中,对话策略模型根据观察到的对话状态s,通过最大化价值函数Q,选择要执行的动作a。然后,对话策略模型接收奖励r,真实用户的动作a

最大化价值函数Q(s,a;θ

间接强化学习期间,对话策略模型通过与世界模型进行交互来改善其对话策略,以减少训练成本,规划的频率由参数K控制,这意味着计划在直接强化学习的每一步中执行K步。当世界模型能够准确捕获真实环境的特征时,K的值往往会很大。在规划的每个步骤中,世界模型都会根据当前状态s来响应动作a

特别地,本方案在上述现有技术的基础上,采用了基于KL散度(KL divergence)的质量检测器对世界模型生成的模拟经验进行质量检测,并将质量检测合格的模拟经验保存至缓冲器以用于对话策略模型训练,从而保证模拟经验的质量,避免低质量模拟经验影响学习性能。

具体地,如图2所示,基于KL散度的质量检测器通过对比模拟经验与真实经验来进行模拟经验的质量检测,具体方法如下:

根据世界模型生成的模拟经验更新词库world-dict,根据真实用户生成的真实经验更新词库real-dict,词库world-dict和词库real-dict的主键分别为世界模型和真实用户生成的用户动作a

词库real-dict与词库world-dict的交集主键在两个词库中的频率值被存储在事先建立的词库same-dict中,并由KL散度衡量词库world-dict与词库real-dict的相似度以进行模拟经验的质量检测;

衡量相似度的方式为事先定义一个变量KL

为了展示本方案的有效性和优越性,通过实验组将本方法与其他算法进行比对,表1中,D3Q(10)*为基于GAN质量检测器的智能体,DDQ(M,K,N )为不使用质量检测器的智能体;GPDDQ(M,K,N)为使用GP世界模型且不使用质量检测器的智能体;UN-GPDDQ(5000,20,4)为使用GP世界模型且不使用质量检测器,同时考虑GP模型的不确定性的智能体;KL-GPDDQ(M,K,N)为在UN-GPDDQ基础上使用本方法KL质量检测器的智能体;其中M表示缓冲器大小,K表示规划步数、N表示批次大小:

表1:缓冲器大小为5000的不同智能体训练迭代{100,200,300}次,K=20,N=4时的实验结果;

上表中,Su(Success,成功率),Tu(Turns,对话回合),Re(Reward,奖励)。

从表1可以发现DDQ方法在全部5个当中仍是性能最差的。从GPDDQ,UN-GPDDQ,和KL-GPDDQ智能体的运行结果中,可以很明显地看出,本方案KL散度检查对性能提升很有帮助,并且其对于成功率和奖励都有显著提升,与DDQ对比,本方案方法能够在用户交互更少的情况下还能够提升成功率。

另外,由图3还可以看出,本方案提出的方法的学习速度远远高于DDQ和D3Q。需要注意的是D3Q的曲线起伏很大,很不稳定,尤其是当K=30时,D3Q甚至不能收敛到最优值,所以即使D3Q能够剔除低质量经验,其仍然很难在现实中实现,因为GAN太不稳定了。

由上述实验组,我们能够看到相对于现有技术基于DDQ框架的方法,本方案具有明显的优越性,并且相对于现有技术使用的GAN质量检测器同样具有明显的优势。本方案通过引入KL散度来检查经验的分布,不需要对质量检测器进行更多训练,从而可以更轻松地评估现实中模拟体验的质量,并在确保对话策略的鲁棒性和有效性的同时大大提高计算效率。

实施例二

本实施例与实施例一类似,不同之处在于,本实施例考虑到在初始阶段,词库world-dict中只有有限的动作(行为),因此词库same-dict长度也很小,为了预热世界模型,优选在词库same-dict长度小于常量C时,将模拟经验视作合格。常量C由本领域技术人员根据具体情况确定,这里不进行限制。

此时,只有当词库same-dict长度达到一定值时,即大于或等于常量C时,才通过事先定义的变量KL

实施例三

本实施例提供了一种用于检测对话策略学习中模拟用户经验质量的系统,用于执行实施例一或实施例二中的方法,包括连接于世界模型、真实用户经验库和对话策略模型的质量检测器,且所述的质量检测器包括KL散度检测器,KL散度检测器用于根据真实用户生成的真实经验检测世界模型生成的模拟经验的质量。

具体地,质量检测器包括用于存储真实经验的词库real-dict,用于存储模拟经验的词库world-dict和用于保存词库real-dict与词库world-dict的交集主键在两个词库中的频率值的词库same-dict。

本文中所描述的具体实施例仅仅是对本发明精神作举例说明。本发明所属技术领域的技术人员可以对所描述的具体实施例做各种各样的修改或补充或采用类似的方式替代,但并不会偏离本发明的精神或者超越所附权利要求书所定义的范围。

尽管本文较多地使用了模拟经验、真实经验、质量检测器、人类会话数据、世界模型、缓冲器、对话策略模型、真实用户经验库等术语,但并不排除使用其它术语的可能性。使用这些术语仅仅是为了更方便地描述和解释本发明的本质;把它们解释成任何一种附加的限制都是与本发明精神相违背的。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号