欢迎光临散文网 会员登陆 & 注册

5.6 Dropout

2023-02-13 10:39 作者:梗直哥丶  | 我要投稿

上一节我们着重讲了L2正则化,这一节介绍另一种防止过拟合的常用方法Dropout。

5.6.1 什么是Dropout

Dropout是一种常见的正则化技术,用于减少神经网络中的过拟合。它是由 Geoffrey Hinton 和他的团队在 2012 年提出的,并在他们的论文 《Improving neural networks by preventing co-adaptation of feature detectors》 中进行了描述。

其核心思想是在训练过程中随机“删除”(即将其权重设为零)一些神经元,从而使模型不能完全依赖于某些特定的特征。这样可以防止神经网络对训练集过于依赖,从而使模型更加泛化,也就是更好地适用于新的、未见过的数据。

在训练期间随机将输入特征的一部分“删除”,而在测试期间则使用全部特征。通常在测试期间会使用所有可用的信息来得到最好的结果。因此,在测试期间一般不使用 Dropout。

Dropout 在神经网络训练中取得了很大的成功,并且现在已经成为了一种标准的正则化技术。它在解决过拟合问题方面非常有效,因此经常用于解决深度学习中的过拟合问题。

5.6.2 Dropout的工作原理

如上图所示,Dropout的工作流程主要可以分为以下几步:


5.6.3 Dropout的代码实现

为了更直观的理解Dropout的作用,我们用一个例子演示在有无Dropout情况下,模型的效果对比。

首先导入相关库,定义超参数。然后生成两组线性数据集,分别是训练集和测试集,并引入部分噪声。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 随机数种子
torch.manual_seed(2333)

# 定义超参数
num_samples = 20 # 样本数
hidden_size = 200 # 隐藏层大小
num_epochs = 500  # 训练轮数

# 生成训练集和测试集
x_train = torch.unsqueeze(torch.linspace(-1, 1, num_samples), 1)
y_train = x_train + 0.3 * torch.randn(num_samples, 1)
x_test = torch.unsqueeze(torch.linspace(-1, 1, num_samples), 1)
y_test = x_test + 0.3 *  torch.randn(num_samples, 1)

接下来定义两个模型,一个是可能会过拟合的网络,其中包含3个隐藏层,激活函数使用ReLU函数。另一个网络结构一模一样,区别在于加入了Dropout层,其中概率p设置为0.5。

# 定义一个可能会过拟合的网络
net_overfitting = torch.nn.Sequential(
    torch.nn.Linear(1, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
)

# 定义一个包含 Dropout 的网络
net_dropout = torch.nn.Sequential(
    torch.nn.Linear(1, hidden_size),
    torch.nn.Dropout(0.5),  # p=0.5
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.Dropout(0.5),  # p=0.5
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, 1),
)

接下来,分别训练这两个模型,这里使用均方误差(MSE)作为损失函数。

# 定义优化器和损失函数
optimizer_overfitting = torch.optim.Adam(net_overfitting.parameters(), lr=0.01)
optimizer_dropout = torch.optim.Adam(net_dropout.parameters(), lr=0.01)
criterion = nn.MSELoss()

# 训练网络
for i in range(num_epochs):
    pred_overfitting = net_overfitting(x_train)
    pred_dropout = net_dropout(x_train)
    
    loss_overfitting = criterion(pred_overfitting, y_train)
    loss_dropout = criterion(pred_dropout, y_train)

    optimizer_overfitting.zero_grad()
    optimizer_dropout.zero_grad()
    
    loss_overfitting.backward()
    loss_dropout.backward()
    
    optimizer_overfitting.step()
    optimizer_dropout.step()

最后,使用 Matplotlib 绘制训练集和测试集,以及两个模型的拟合曲线。

# 在测试过程中不使用 Dropout
net_overfitting.eval()
net_dropout.eval()

# 预测
test_pred_overfitting = net_overfitting(x_test)
test_pred_dropout = net_dropout(x_test)

# 绘制拟合效果
plt.scatter(x_train.data.numpy(), y_train.data.numpy(), c='r', alpha=0.3, label='train')
plt.scatter(x_test.data.numpy(), y_test.data.numpy(), c='b', alpha=0.3, label='test')
plt.plot(x_test.data.numpy(), test_pred_overfitting.data.numpy(), 'r-', lw=2, label='overfitting')
plt.plot(x_test.data.numpy(), test_pred_dropout.data.numpy(), 'b--', lw=2, label='dropout')
plt.legend(loc='upper left')
plt.ylim((-2, 2))
plt.show()


在这个例子中,我们定义了两个网络,一个是可能会产生过拟合现象的网络,另一个是一模一样的结构,唯一的区别就是增加了Dropout层。然后分别进行训练,用Matplotlib绘制出数据集和模型拟合的效果。从上图中可以看到,第一个网络在训练集上明显产生了过拟合现象,而增加了Dropout层的网络更为稳定。由此可以看出,Dropout可以有效降低过拟合的风险。

5.6.4 Dropout的另一种理解



Dropout还可以被类比成集成大量深层神经网络的一种Bagging方法。每做一次丢弃,相当于从原始的网络中采样得到一个子网络,那么,最终的网络可以近似看作集成了指数级个不同网络的集成学习模型。

如上图所示,Dropout训练是由所有子网络组成的集合,其中子网络通过从基本网络中删除一部分神经元得到。我们从具有两个可见单元和两个隐藏单元的基本网络开始。这四个单元有十六个可能的子集。右图展示了从原始网络中丢弃不同的神经元子集而形成的所有十六个子网络。在这个小例子中,所得到的大部分网络已经失去从输入连接到输出的路径。但当层较宽时,丢弃所有从输入到输出的可能路径的概率变小,所以这个问题不太可能出现在层较宽的网络中。

5.6.5 Dropout的优缺点

优点

可以有效地减少过拟合的风险。Dropout 在训练过程中随机地清零一定比例的输入特征,从而使得模型不能依赖于任何一个特定的输入特征。这样可以使得模型更加稳健,并且在新的数据上的性能也更好。

Dropout 相对来说比较简单,并且在实际应用中也比较方便,适用于多种模型。

有研究显示,Dropout相比其他计算开销小的正则化方法更有效。Dropout还可以与其他形式的正则化合并,得到进一步的提升。

缺点

在训练过程中会降低训练效率。因为引入Dropout之后相当于每次只是训练的原先网络的一个子网络,为了达到同样的精度需要的训练次数会增多。

损失函数无法被明确定义。因为每次迭代都会随机消除一些神经元的影响,因此无法确保损失函数单调递减。

梗直哥提示:其实介绍了这么多概念,只要记得Dropout能够有效降低过拟合的风险,同时适用多种模型即可。如果你想了解更多内容,欢迎入群学习(加V: gengzhige99)

更多视频讲解,欢迎点击收看机器学习必修课

最新上线《深度学习必修课》带你探索ChatGPT等前沿技术。




5.6 Dropout的评论 (共 条)

分享到微博请遵守国家法律