欢迎光临散文网 会员登陆 & 注册

90天学会GAN--Day3--从MNIST数据集开始

2023-06-02 03:12 作者:弱弱的小汤汤  | 我要投稿

4. 模型训练

4.1 损失函数

GAN的训练优化目标其实就是如下函数:

可以看到,这里有两个loss:一个是训练鉴别器时使用的 D_loss, 另一个是训练生成器时使用的 G_loss。

而这个模型的目标是要最小化 G_loss, 以及最大化 D_loss。 

这里我们使用了Adam优化策略和BCE loss 来优化这两个。 于是可以写出:

4.2 模型迭代 

在模型迭代的过程中,我们会做如下步骤: 

  1. 我们会读取图像和标签(暂时没用)

  2. 然后生成一个随机的噪声z 并放入生成器生成一张假的图片,称为fake_img

  3. 之后将fake _ image 放入鉴别器得出 fake _ image 的评分

  4. 将这个评分与 1 比较得到 G_loss

  5. 再将输入的图像和fake_image 加上真假标签后放入鉴别器中得到D _ loss

  6. 循环以上过程 opt.epoch 次

由此,我们可以得到这部分的代码:

至此,模型已经训练完毕。

5.保存图片以及模型 

这里我们使用 torchvision.utils 库中的 save_image函数来存储图片,用法如下:

注:path为你想要存图片的路径

我们使用torch.save来保存模型即其中的参数,实际上需要保存的其实就是 generator 和 discriminator 这两个东西,用法如下: 

注:path为你想要存图片的路径

然后使用的时候就只需要load一下就行了:

注:path为你想要存图片的路径

之后就像之前一样使用generator和discriminator就可以了。

这样做的好处是:validate的时候就不需要重新跑一次所有的程序了,只需要把之前的模型 load 出来用就行了

90天学会GAN--Day3--从MNIST数据集开始的评论 (共 条)

分享到微博请遵守国家法律