手把手构建图像分类baseline
这篇也是学姐粉丝的投稿,最近几周推文写了几篇粉丝投稿,发现大家对学习知识复盘做的是真的好。今天这篇是深度之眼学员粉丝写的这篇《图像分类baseline》需要的同学可以来做参考。
投稿还在进行中,学习心得、个人经历、学习建议、学习工具都可以投稿来。投稿就有奖品赠送,感兴趣的就找学姐吧!
图像分类baseline
定义超类
图像维度的匹配
正向传播
定义损失函数
定义训练部分
从dataset定义data_loader,从data_loader得到数据集
1.将数据放入设备
2.传入模型
3.模型结果传入损失函数
4.损失函数反向传播
device的读取
梯度清空
pytorch的梯度不清空就会累积
定义epoch,batch_size
定义data_loader()以及数据集划分
传入数据集以及数据集的标签
train_lables 训练集
train_labels['target'] 训练集的标签
train_index 划分的训练集标签
test_index 划分的验证集标签
train_images,valid_images = 数据集['图像路径或图像'].iloc[训练集标签],数据集['图像路径或图像'].iloc[测试集标签]
train_targets,valid_targets= 数据集['图像的标签'].iloc[训练集标签],数据集['图像的标签'].iloc[测试集标签]
定义dataset
传入dataset就是dataloader
定义优化器
模型训练
一个for循环就是一折
定义验证函数
定义验证集精度计算
(预测结果,预测标签)
模型保存
常见错误:out of memory
重启notebook
不要直接复制代码使用,编辑器有差异可能会因为空格和换行运行时报错。粉丝投稿,有错误理性交流不要上来就喷。我们都是高材生,有素质。
