公开/公告号CN112766489A
专利类型发明专利
公开/公告日2021-05-07
原文格式PDF
申请/专利权人 合肥黎曼信息科技有限公司;
申请/专利号CN202110036718.6
申请日2021-01-12
分类号G06N3/08(20060101);G06N3/04(20060101);
代理机构34113 安徽省蚌埠博源专利商标事务所(普通合伙);
代理人朱恒兰
地址 230001 安徽省合肥市高新区创新大道2800号创新产业园二期J1栋A座1027室
入库时间 2023-06-19 10:54:12
技术领域
本发明涉及对抗学习领域,具体涉及一种基于对偶距离损失的生成对抗网络训练方法。
背景技术
生成对抗网络是一类神经网络,通过轮流训练判别器和生成器,令其相互对抗,来从复杂概率分布中采样,例如生成图片、文字、语音等。
如果原始的生成器和判别器是随机的,则很难确定生成器和判别器是否可以通过给定数据的训练收敛到理想结论。虽然可以证明在某些强假设下,生成器和判别器可以收敛到局部纳什均衡,但许多生成对抗网络算法不能全局收敛。
发明内容
本发明的目的是提供一种基于对偶距离损失的生成对抗网络训练方法,实现通过给定数据的训练收敛到理想结论。
为了实现以上目的,本发明采用的技术方案为:一种基于对偶距离损失的生成对抗网络训练方法,包括以下步骤:
S1:获取目标分布的数据集,以及对数据集进行预处理;
S2:设置生成器和判别器神经网络的结构和参数,以及训练过程中的学习率;
S3:根据步骤S2中的神经网络的参数计算出对偶距离损失函数;
S4:基于步骤S3中的对偶距离损失函数采用随机梯度下降法,训练生成器,使之可以生成真实分布。
进一步的,所述的对偶距离损失函数为
其中
进一步的,计算摄动点,然后由所述的摄动点确定对偶距离损失函数,以及优化方向,包括以下步骤:
步骤1:初始化,即对目标数据集进行处理,给定初始状态下的判别器f
步骤2:数据点的随机选取,即在目标数据集中选取m个点,记作{x
步骤3:摄动点的计算,即对于给定的生成器g
步骤4:优化方向计算,即考虑函数
步骤5:更新步骤,即计算
进一步的,所述的方法包括很好的泛化性,即至少以1-3δ的概率满足
本发明的技术效果在于:对偶距离损失可以作为传统度量F-distance的一个上界,并且建立了关于对偶距离损失的泛化误差界,在经过预训练之后,收敛到纳什均衡点附近,基于对偶距离损失的生成对抗网络训练方法可以通过传统的鞍点算法达到全局收敛。
附图说明
图1为本发明实施例提供的生成对抗网络训练方法的流程示意图。
图2为本发明实施例提供的生成对抗网络训练方法在CIFAR10数据集上经过20000次迭代后与WGAN-GP方法的生成结果对比图。
图3为本发明实施例提供的生成对抗网络训练方法在CIFAR10数据集上与WGAN-GP方法得到的Inception Score对比图。
具体实施方式
参照附图1-3,一种基于对偶距离损失的生成对抗网络训练方法,包括以下步骤:
S1:获取目标分布的数据集,以及对数据集进行预处理;
S2:设置生成器和判别器神经网络的结构和参数,以及训练过程中的学习率;
S3:根据步骤S2中的神经网络的参数计算出对偶距离损失函数;
S4:基于步骤S3中的对偶距离损失函数采用随机梯度下降法,训练生成器,使之可以生成真实分布。
本发明实施基于对偶距离损失的生成对抗网络训练方法,有较好的泛化性能:
具体的,给定2个域
如果真实样本X以及高斯分布样本Z有界,并且边界由B
此时,给出得到公式(1)的具体过程,可以包括:
将公式(1)等式左边化简为:
McDiarmid不等式条件为:
其中X={x
使用McDiarmid不等式可知,有至少
再次使用McDiarmid不等式可知,有至少
这里∈=(∈
所以,至少有1-δ的概率有公式(4):
由此,至少有1-3δ的概率有公式(7):
由于判别器f和生成器g均为神经网络,所以可以写成公式(8)和公式(9)的形式:
f=a
g=b
其中,a
根据上述假设,有公式(10):
令
由于
根据拉德马赫复杂度与覆盖数的关系,得到公式(13):
同理,得到公式(14)和公式(15):
假设m>>n,并结合公式(7)、公式(13)、公式(14)、公式(15),得到基于对偶损失距离的泛化误差界,即公式(1)。
下面具体说明,本发明实施例提供的基于对偶距离损失的生成对抗训练方法中,在设置好神经网络的结构和参数以及给出符合目标分布的数据集之后,对对偶距离损失函数采用梯度下降法进行求解,可以包括:
初始化步骤:对目标数据集进行处理。给定初始状态下的判别器f
数据点的随机选取:在目标数据集中选取m个点,记作{x
摄动点的计算:对于给定的生成器g
对偶距离计算:计算出对偶距离损失函数
优化方向计算:考虑函数
更新步骤:计算
本发明实施例提供的基于对偶距离损失函数的生成对抗训练方法已经在MNIST和CIFAR10上面成功实现了数据分布生成。结果表明,相对于传统的生成对抗训练方法,本发明实施例提供的基于对偶距离损失函数的生成对抗训练方法可以在相同的迭代步数情况下得到更加准确的结果,能够提高训练的质量,具有更好的成本收益。
由于在训练的过程中,每一次训练都需要生成高斯噪声,因此本实施例提供的基于对偶距离损失函数的生成对抗训练方法虽然在损失函数的设置上看起来更加复杂,但是并没有增加泛化误差,因此和传统的生成对抗训练方法有着相同的泛化性能。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。
机译: 基于三元组损失的序数分类深度学习模型的训练方法及其训练装置
机译: 一种基于探针的PCR检测方法来测量循环的脱甲基化的贝塔细胞衍生DNA的水平作为糖尿病中贝塔细胞损失的一种测量方法
机译: 一种基于弱监督的字符检测器训练方法及装置,系统和介质