欢迎光临散文网 会员登陆 & 注册

在Pytorch上用MNIST数据集训练和实现变分自动编码器

2021-09-02 19:16 作者:深度之眼官方账号  | 我要投稿

历时一个多月Pytorch构建深度学习模型系列指南终于最后一篇啦!完结撒花!虽然战线拉得有点长,但是有利于大家消化吸收里面的知识。



前几期教程传送门:


1.Pytorch初级教程

2.认识深度学习模型中的张量维度

3.CNN和特征可视化

4.使用Optuna调整超参数

5.K折交叉验证

6.学习自编码器

7.去噪自动编码器


本篇是本系列教程的最后一篇,本篇要讲解的是变分自动编码器的训练和实现,这块以前没理解的同学可以试试这篇能否帮助你掌握这个知识点。



0变分自动编码器(下面称变分自编码器)


标准自动编码器可能存在潜在空间可能不规则的问题 [1]。这意味着潜在空间的闭合点可以在可见单元上产生不同的无意义的图案。


作为自动编码器,变分自动编码器由编码器和解码器两个神经网络架构组成。这个问题的一个解决方案是引入变分自动编码器;但是编码-解码过程有一个修改,简单讲解一下步骤:

  • 我们将输入编码为潜在空间上的分布,而不是将其视为单个点。该编码分布被选择为正态分布,以便可以训练编码器返回均值矩阵和协方差矩阵。

  • 第二步,我们从编码分布中采样一个点。

  • 之后,我们可以解码采样点并计算重构误差。

  • 我们通过网络反向传播重建误差。由于采样过程是一个离散过程,所以它不是连续的,我们需要应用重新参数化技巧来实现反向传播工作:

02 VAE 损失函数


VAE 的损失由两项组成:

  • 第一项是重构项,它是通过比较输入及其对应的重构得到的。

  • 另一个术语是正则化项,也称为编码器返回的分布与标准正态分布之间的Kullback-Leibler 散度[3]。这个术语在潜在空间中起着正则化的作用,可能会使编码器返回的分布接近标准正态分布。


03 使用Pytorch实现


与前面的教程一样,变分自动编码器是在MNIST数据集上实现和训练的。


导入库和数据集



定义一个 VariationalAutoencoder 类,它结合Encoder类和Decoder类 [3]。


编码器和解码器网络包含三个卷积层和两个全链接层。添加了一些批处理法线层,使其在潜在空间中具有更强的特性。与标准自动编码器不同,编码器返回均值和方差矩阵,我们使用它们来获得采样的潜在向量。在VariationalEncoder类中获得Kullback-Leibler项



定义Decoder类之后,它与本系列第五篇教程中展示的那个类保持一致。



下面定义合并编码器和解码器的类


我们在代码中初始化VariationalAutoencoder类、优化器和使用GPU的设备



定义用于训练和评估变分自动编码器的函数:



上述理论中所描述的是损失由两个项组成;重构项是输入与其重构之间的差的平方和。其他一些版本使用BCE损失而不是MSE损失,但我更喜欢这种方式,因为它更有意义。


想要在VAE模型的训练过程中看到每个epoch中的输入及其相应的重构,定义一个函数来实现这些可视化:



最终训练VAE并在验证集中进行评估:



这些是50个epoch后获得的结果,我们看到输入图像与其重构之间有很高的相似性,即便仍然存在一些缺陷。


为了估计变分自动编码器的学习能力,我们还可以从潜在代码生成新图像:




大多数生成的样本看起来像数字,所以变分自编码器似乎已经从潜在空间中学到了稳健的模式。


我们可以将变分解码器学习的潜在代码可视化,并按十类数字进行着色:



潜在空间的取值范围更小、更集中,整体分布似乎接近高斯分布。


应用t-SNE(一种降维方法)可以获得更好的可视化效果,使用两个组件,可视化潜在代码:




由此产生的潜在代码似乎将不同组中的数字聚集在一起,不同数字之间也有轻微的重叠。


总结:

本篇文章用实战代码讲解了使用Pytorch实现和训练变分自动编码器。它是自动编码器的扩展,唯一的区别是它将输入编码为潜在空间上的分布。文末参考文献可以帮助大家更深入地研究 VAE。


本文代码:

https://github.com/eugeniaring/Pytorch-tutorial/blob/main/VAE_mnist.ipynb


参考文献:


[1]

https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73


[2]

https://atcold.github.io/pytorch-Deep-Learning/en/week08/08-3/


[3]

https://avandekleut.github.io/vae/


原文链接:

https://medium.com/dataseries/variational-autoencoder-with-pytorch-2d359cbf027b


每天18:30分更新

关注学姐公众号+星标+在看

不迷路看好文



在Pytorch上用MNIST数据集训练和实现变分自动编码器的评论 (共 条)

分享到微博请遵守国家法律