欢迎光临散文网 会员登陆 & 注册

90天学会GAN--Day2--从MNIST数据集开始

2023-06-02 02:55 作者:弱弱的小汤汤  | 我要投稿

2.2 将数据输入模型

书接上回,我们将二进制文件转换为了 .png 图片以及 label.txt, 现在我们要把这些图片输入到模型中。 

因此我们构造了一个 class customDataset 来导入数据(其实就是放到 label[] 和 img[] 里)。

首先是初始化,在初始化的部分会导入图片的路径和标签

其次是将图片通过 CV2 导入

最后是后面的函数需要的 len 函数

3. 构建模型 

3.1. 生成器 generator

此处使用的是 Multi-LayerPerception 的全连接层来链接不同的层。

由于图片是 28×28 的格式,所以最后应该是一个1×784 的层。

由此可以得出整个模型的形式 : latent dim -> 1024 -> 784

 而对于每一层,形式是:nn.linear → normalize → 激活函数

(此处为Leaky ReLU)

于是可以写出初始化的代码:

之后就是 forward 函数:

3.2. 鉴别器 discriminator 

鉴别器和生成器结构非常相似,只是反过来而已。

至此,我们已经完成了GAN的大部分内容,包括数据输入、生成器和鉴别器,下一步就是训练模型了。

90天学会GAN--Day2--从MNIST数据集开始的评论 (共 条)

分享到微博请遵守国家法律