MMNET 微表情识别(CASME2数据集)
原github地址:https://github.com/muse1998/MMNet
代码和数据集都存在一些问题,经过修改后方能够运行
main.py
CA_block.py
PC_module.py

main.py
# -*- coding: utf-8 -*-
import torch
import math
import numpy as np
import torchvision.models
import torch.utils.data as data
from torchvision import transforms
import CV2
import pandas as pd
import os, torch
import torch.nn as nn
#import image_utils
import argparse, random
from functools import partial
from MMNET.CA_block import resnet18_pos_attention
from PC_module import VisionTransformer_POS
from torchvision.transforms import Resize
torch.set_printoptions(precision=3, edgeitems=14, linewidth=350)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--raf_path', type=str, default='D:/CASME2/', help='Raf-DB dataset path.')#default='D:/CASME2/'
parser.add_argument('--checkpoint', type=str, default='D:/CASME2/',
help='Pytorch checkpoint file path')
parser.add_argument('--pretrained', type=str, default=None,
help='Pretrained weights')
parser.add_argument('--beta', type=float, default=0.7, help='Ratio of high importance group in one mini-batch.')
parser.add_argument('--relabel_epoch', type=int, default=1000,
help='Relabeling samples on each mini-batch after 10(Default) epochs.')
parser.add_argument('--batch_size', type=int, default=34, help='Batch size.')
parser.add_argument('--optimizer', type=str, default="adam", help='Optimizer, adam or sgd.')
parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate for sgd.')
parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for sgd')
parser.add_argument('--workers', default=0, type=int, help='Number of data loading workers (default: 4)')
parser.add_argument('--epochs', type=int, default=1000, help='Total training epochs.')
parser.add_argument('--drop_rate', type=float, default=0, help='Drop out rate.')
return parser.parse_args()
class RafDataSet(data.Dataset):
def __init__(self, raf_path, phase,num_loso, transform = None, basic_aug = False, transform_norm=None):
self.phase = phase
self.transform = transform
self.raf_path = raf_path
self.transform_norm = transform_norm
SUBJECT_COLUMN =0
NAME_COLUMN = 1
ONSET_COLUMN = 2
APEX_COLUMN = 3
OFF_COLUMN = 4
LABEL_AU_COLUMN = 5
LABEL_ALL_COLUMN = 6
df = pd.read_excel(os.path.join(self.raf_path, 'CASME2-coding-20140508.xlsx'),usecols=[0,1,3,4,5,7,8])
df['Subject'] = df['Subject'].apply(str)
if phase == 'train':
dataset = df.loc[df['Subject']!=num_loso]
else:
dataset = df.loc[df['Subject'] == num_loso]
Subject = dataset.iloc[:, SUBJECT_COLUMN].values
File_names = dataset.iloc[:, NAME_COLUMN].values
Label_all = dataset.iloc[:, LABEL_ALL_COLUMN].values # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
Onset_num = dataset.iloc[:, ONSET_COLUMN].values
Apex_num = dataset.iloc[:, APEX_COLUMN].values
Offset_num = dataset.iloc[:, OFF_COLUMN].values
Label_au = dataset.iloc[:, LABEL_AU_COLUMN].values
self.file_paths_on = []
self.file_paths_off = []
self.file_paths_apex = []
self.label_all = []
self.label_au = []
self.sub= []
self.file_names =[]
a=0
b=0
c=0
d=0
e=0
# use aligned images for training/testing
for (f,sub,onset,apex,offset,label_all,label_au) in zip(File_names,Subject,Onset_num,Apex_num,Offset_num,Label_all,Label_au):
if label_all == 'happiness' or label_all == 'repression' or label_all == 'disgust' or label_all == 'surprise' or label_all == 'fear' or label_all == 'sadness':
self.file_paths_on.append(onset)
self.file_paths_off.append(offset)
self.file_paths_apex.append(apex)
self.sub.append(sub)
self.file_names.append(f)
if label_all == 'happiness':
self.label_all.append(0)
a=a+1
elif label_all == 'surprise':
self.label_all.append(1)
b=b+1
else:
self.label_all.append(2)
c=c+1
# label_au =label_au.split("+")
if isinstance(label_au, int):
self.label_au.append([label_au])
else:
label_au = label_au.split("+")
self.label_au.append(label_au)
##label
self.basic_aug = basic_aug
#self.aug_func = [image_utils.flip_image,image_utils.add_gaussian_noise]
def __len__(self):
return len(self.file_paths_on)
def __getitem__(self, idx):
##sampling strategy for training set
if self.phase == 'train':
onset = self.file_paths_on[idx]
#onset = onset.astype('int64')
apex = self.file_paths_apex[idx]
#apex = apex.astype('int64')
offset =self.file_paths_off[idx]
#offset = offset.astype('int64')
on0 = str(random.randint(int(onset), int(onset + int(0.2* (int(apex) - int(onset)) / 4))))
# on0 = str(int(onset))
on1 = str(
random.randint(int(onset + int(0.9 * (apex - onset) / 4)), int(onset + int(1.1 * (apex - onset) / 4))))
on2 = str(
random.randint(int(onset + int(1.8 * (apex - onset) / 4)), int(onset + int(2.2 * (apex - onset) / 4))))
on3 = str(random.randint(int(onset + int(2.7 * (apex - onset) / 4)), onset + int(3.3 * (apex - onset) / 4)))
# apex0 = str(apex)
apex0 = str(
random.randint(int(apex - int(0.15* (apex - onset) / 4)), apex + int(0.15 * (offset - apex) / 4)))
off0 = str(
random.randint(int(apex + int(0.9 * (offset - apex) / 4)), int(apex + int(1.1 * (offset - apex) / 4))))
off1 = str(
random.randint(int(apex + int(1.8 * (offset - apex) / 4)), int(apex + int(2.2 * (offset - apex) / 4))))
off2 = str(
random.randint(int(apex + int(2.9 * (offset - apex) / 4)), int(apex + int(3.1 * (offset - apex) / 4))))
off3 = str(random.randint(int(apex + int(3.8 * (offset - apex) / 4)), offset))
sub =str(self.sub[idx])
f = str(self.file_names[idx])
else:##sampling strategy for testing set
onset = self.file_paths_on[idx]
apex = self.file_paths_apex[idx]
offset = self.file_paths_off[idx]
on0 = str(onset)
on1 = str(int(onset + int((apex - onset) / 4)))
on2 = str(int(onset + int(2 * (apex - onset) / 4)))
on3 = str(int(onset + int(3 * (apex - onset) / 4)))
apex0 = str(apex)
off0 = str(int(apex + int((offset - apex) / 4)))
off1 = str(int(apex + int(2 * (offset - apex) / 4)))
off2 = str(int(apex + int(3 * (offset - apex) / 4)))
off3 = str(offset)
sub = str(self.sub[idx])
f = str(self.file_names[idx])
on0 ='reg_img' + on0 + '.jpg'
on1 = 'reg_img' + on1 + '.jpg'
on2 = 'reg_img' + on2 + '.jpg'
on3 = 'reg_img' + on3 + '.jpg'
apex0 ='reg_img' + apex0 + '.jpg'
off0 ='reg_img' + off0 + '.jpg'
off1='reg_img' + off1 + '.jpg'
off2 ='reg_img' + off2 + '.jpg'
off3 = 'reg_img' + off3 + '.jpg'
path_on0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on0).replace('\\', '/')
path_on1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on1).replace('\\', '/')
path_on2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on2).replace('\\', '/')
path_on3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on3).replace('\\', '/')
path_apex0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, apex0).replace('\\', '/')
path_off0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off0).replace('\\', '/')
path_off1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off1).replace('\\', '/')
path_off2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off2).replace('\\', '/')
path_off3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off3).replace('\\', '/')
"""
print(path_on0)
print(path_on1)
print(path_on2)
print(path_on3)
print(path_apex0)
print(path_off0)
print(path_off1)
print(path_off2)
print(path_off3)
"""
image_on0 = CV2.imread(path_on0)
image_on1= CV2.imread(path_on1)
image_on2 = CV2.imread(path_on2)
image_on3 = CV2.imread(path_on3)
image_apex0 = CV2.imread(path_apex0)
image_off0 = CV2.imread(path_off0)
image_off1 = CV2.imread(path_off1)
image_off2 = CV2.imread(path_off2)
image_off3 = CV2.imread(path_off3)
image_on0 = image_on0[:, :, ::-1] # BGR to RGB
image_on1 = image_on1[:, :, ::-1]
image_on2 = image_on2[:, :, ::-1]
image_on3 = image_on3[:, :, ::-1]
image_off0 = image_off0[:, :, ::-1]
image_off1 = image_off1[:, :, ::-1]
image_off2 = image_off2[:, :, ::-1]
image_off3 = image_off3[:, :, ::-1]
image_apex0 = image_apex0[:, :, ::-1]
label_all = self.label_all[idx]
label_au = self.label_au[idx]
# normalization for testing and training
if self.transform is not None:
image_on0 = self.transform(image_on0)
image_on1 = self.transform(image_on1)
image_on2 = self.transform(image_on2)
image_on3 = self.transform(image_on3)
image_off0 = self.transform(image_off0)
image_off1 = self.transform(image_off1)
image_off2 = self.transform(image_off2)
image_off3 = self.transform(image_off3)
image_apex0 = self.transform(image_apex0)
ALL = torch.cat(
(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
image_off3), dim=0)
## data augmentation for training only
if self.transform_norm is not None and self.phase == 'train':
ALL = self.transform_norm(ALL)
image_on0 = ALL[0:3, :, :]
image_on1 = ALL[3:6, :, :]
image_on2 = ALL[6:9, :, :]
image_on3 = ALL[9:12, :, :]
image_apex0 = ALL[12:15, :, :]
image_off0 = ALL[15:18, :, :]
image_off1 = ALL[18:21, :, :]
image_off2 = ALL[21:24, :, :]
image_off3 = ALL[24:27, :, :]
temp = torch.zeros(38)
for i in label_au:
#print(i)
temp[int(i) - 1] = 1
return image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, label_all, temp
def initialize_weight_goog(m, n=''):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def criterion2(y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat((y_pred_neg, zeros), dim=-1)
y_pred_pos = torch.cat((y_pred_pos, zeros), dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return torch.mean(neg_loss + pos_loss)
class MMNet(nn.Module):
def __init__(self):
super(MMNet, self).__init__()
self.conv_act = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=90*2, kernel_size=3, stride=2,padding=1, bias=False,groups=1),#group=2
nn.BatchNorm2d(180),
nn.ReLU(inplace=True),
)
self.pos =nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=512, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
##Position Calibration Module(subbranch)
self.vit_pos=VisionTransformer_POS(img_size=14,
patch_size=1, embed_dim=512, depth=3, num_heads=4, mlp_ratio=2, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6),drop_path_rate=0.3)
self.resize=Resize([14,14])
##main branch consisting of CA blocks
self.main_branch =resnet18_pos_attention()
self.head1 = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(1 * 112 *112, 38,bias=False),
)
self.timeembed = nn.Parameter(torch.zeros(1, 4, 111, 111))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, if_shuffle):
##onset:x1 apex:x5
B = x1.shape[0]
#Position Calibration Module (subbranch)
POS =self.vit_pos(self.resize(x1)).transpose(1,2).view(B,512,14,14)
act = x5 -x1
act=self.conv_act(act)
#main branch and fusion
out,_=self.main_branch(act,POS)
return out
def run_training():
args = parse_args()
imagenet_pretrained = True #是否加载预训练模型
if not imagenet_pretrained:
for m in res18.modules():
initialize_weight_goog(m)
if args.pretrained:
print("Loading pretrained weights...", args.pretrained)
pretrained = torch.load(args.pretrained)
pretrained_state_dict = pretrained['state_dict']
model_state_dict = res18.state_dict()
loaded_keys = 0
total_keys = 0
for key in pretrained_state_dict:
if ((key == 'module.fc.weight') | (key == 'module.fc.bias')):
pass
else:
model_state_dict[key] = pretrained_state_dict[key]
total_keys += 1
if key in model_state_dict:
loaded_keys += 1
print("Loaded params num:", loaded_keys)
print("Total params num:", total_keys)
res18.load_state_dict(model_state_dict, strict=False)
### data normalization for both training set
data_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
### data augmentation for training set only
data_transforms_norm = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(4),
transforms.RandomCrop(224, padding=4),
])
### data normalization for both teating set
data_transforms_val = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
criterion = torch.nn.CrossEntropyLoss()
#leave one subject out protocal
LOSO = ['17', '26', '16', '9', '5', '24', '2', '13', '4', '23', '11', '12', '8', '14', '3', '19', '1', '10',
'20', '21', '22', '15', '6', '25', '7']
val_now = 0
num_sum = 0
pos_pred_ALL = torch.zeros(3)
pos_label_ALL = torch.zeros(3)
TP_ALL = torch.zeros(3)
for subj in LOSO:
train_dataset = RafDataSet(args.raf_path, phase='train', num_loso=subj, transform=data_transforms,
basic_aug=True, transform_norm=data_transforms_norm)
val_dataset = RafDataSet(args.raf_path, phase='test', num_loso=subj, transform=data_transforms_val)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=24,
num_workers=args.workers,
shuffle=True,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=24,
num_workers=args.workers,
shuffle=False,
pin_memory=True)
print('num_sub', subj)
print('Train set size:', train_dataset.__len__())
print('Validation set size:', val_dataset.__len__())
max_corr = 0
max_f1 = 0
max_pos_pred = torch.zeros(3)
max_pos_label = torch.zeros(3)
max_TP = torch.zeros(3)
##model initialization
net_all = MMNet()
params_all = net_all.parameters()
if args.optimizer == 'adam':
optimizer_all = torch.optim.AdamW(params_all, lr=0.0008, weight_decay=0.7)
##optimizer for MMNet
elif args.optimizer == 'sgd':
optimizer = torch.optim.SGD(params, args.lr,
momentum=args.momentum,
weight_decay=1e-4)
else:
raise ValueError("Optimizer not supported.")
##lr_decay
scheduler_all = torch.optim.lr_scheduler.ExponentialLR(optimizer_all, gamma=0.987)
net_all = net_all.cuda()
for i in range(1, 100):
running_loss = 0.0
correct_sum = 0
running_loss_MASK = 0.0
correct_sum_MASK = 0
iter_cnt = 0
net_all.train()
for batch_i, (
image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3,
label_all,
label_au) in enumerate(train_loader):
batch_sz = image_on0.size(0)
b, c, h, w = image_on0.shape
iter_cnt += 1
image_on0 = image_on0.cuda()
image_on1 = image_on1.cuda()
image_on2 = image_on2.cuda()
image_on3 = image_on3.cuda()
image_apex0 = image_apex0.cuda()
image_off0 = image_off0.cuda()
image_off1 = image_off1.cuda()
image_off2 = image_off2.cuda()
image_off3 = image_off3.cuda()
label_all = label_all.cuda()
label_au = label_au.cuda()
##train MMNet
ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1,
image_off2, image_off3, False)
loss_all = criterion(ALL, label_all)
optimizer_all.zero_grad()
loss_all.backward()
optimizer_all.step()
running_loss += loss_all
_, predicts = torch.max(ALL, 1)
correct_num = torch.eq(predicts, label_all).sum()
correct_sum += correct_num
## lr decay
if i <= 50:
scheduler_all.step()
if i>=0:
acc = correct_sum.float() / float(train_dataset.__len__())
running_loss = running_loss / iter_cnt
print('[Epoch %d] Training accuracy: %.4f. Loss: %.3f' % (i, acc, running_loss))
pos_label = torch.zeros(3)
pos_pred = torch.zeros(3)
TP = torch.zeros(3)
##test
with torch.no_grad():
running_loss = 0.0
iter_cnt = 0
bingo_cnt = 0
sample_cnt = 0
pre_lab_all = []
Y_test_all = []
net_all.eval()
# net_au.eval()
for batch_i, (
image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
image_off3, label_all,
label_au) in enumerate(val_loader):
batch_sz = image_on0.size(0)
b, c, h, w = image_on0.shape
image_on0 = image_on0.cuda()
image_on1 = image_on1.cuda()
image_on2 = image_on2.cuda()
image_on3 = image_on3.cuda()
image_apex0 = image_apex0.cuda()
image_off0 = image_off0.cuda()
image_off1 = image_off1.cuda()
image_off2 = image_off2.cuda()
image_off3 = image_off3.cuda()
label_all = label_all.cuda()
label_au = label_au.cuda()
##test
ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, False)
loss = criterion(ALL, label_all)
running_loss += loss
iter_cnt += 1
_, predicts = torch.max(ALL, 1)
correct_num = torch.eq(predicts, label_all)
bingo_cnt += correct_num.sum().cpu()
sample_cnt += ALL.size(0)
for cls in range(3):
for element in predicts:
if element == cls:
pos_label[cls] = pos_label[cls] + 1
for element in label_all:
if element == cls:
pos_pred[cls] = pos_pred[cls] + 1
for elementp, elementl in zip(predicts, label_all):
if elementp == elementl and elementp == cls:
TP[cls] = TP[cls] + 1
count = 0
SUM_F1 = 0
for index in range(3):
if pos_label[index] != 0 or pos_pred[index] != 0:
count = count + 1
SUM_F1 = SUM_F1 + 2 * TP[index] / (pos_pred[index] + pos_label[index])
AVG_F1 = SUM_F1 / count
running_loss = running_loss / iter_cnt
acc = bingo_cnt.float() / float(sample_cnt)
acc = np.around(acc.numpy(), 4)
if bingo_cnt > max_corr:
max_corr = bingo_cnt
if AVG_F1 >= max_f1:
max_f1 = AVG_F1
max_pos_label = pos_label
max_pos_pred = pos_pred
max_TP = TP
print("[Epoch %d] Validation accuracy:%.4f. Loss:%.3f, F1-score:%.3f" % (i, acc, running_loss, AVG_F1))
num_sum = num_sum + max_corr
pos_label_ALL = pos_label_ALL + max_pos_label
pos_pred_ALL = pos_pred_ALL + max_pos_pred
TP_ALL = TP_ALL + max_TP
count = 0
SUM_F1 = 0
for index in range(3):
if pos_label_ALL[index] != 0 or pos_pred_ALL[index] != 0:
count = count + 1
SUM_F1 = SUM_F1 + 2 * TP_ALL[index] / (pos_pred_ALL[index] + pos_label_ALL[index])
F1_ALL = SUM_F1 / count
val_now = val_now + val_dataset.__len__()
print("[..........%s] correctnum:%d . zongshu:%d " % (subj, max_corr, val_dataset.__len__()))
print("[ALL_corr]: %d [ALL_val]: %d" % (num_sum, val_now))
print("[F1_now]: %.4f [F1_ALL]: %.4f" % (max_f1, F1_ALL))
if __name__ == "__main__":
run_training()

CA_block.py
# -*- coding: utf-8 -*-
#import torch
#import torch.nn as nn
import torch
import torch.nn as nn
torch.nn
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.torch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.torch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.torch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.torch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.torch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.torch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.torch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.torch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.torch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1, groups=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False,groups=groups)
##CA BLOCK
class CABlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(CABlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# if groups != 1 or base_width != 64:
# raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride,groups=groups)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv1x1(planes, planes,groups=groups)
self.bn2 = norm_layer(planes)
self.attn = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=1, stride=1,bias=False), # 32*33*33
nn.BatchNorm2d(1),
nn.Sigmoid(),
)
self.downsample = downsample
self.stride = stride
self.planes=planes
def forward(self, x):
x, attn_last,if_attn =x##attn_last: downsampled attention maps from last layer as a prior knowledge
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(identity)
out = self.relu(out+identity)
avg_out = torch.mean(out, dim=1, keepdim=True)
max_out, _ = torch.max(out, dim=1, keepdim=True)
attn = torch.cat((avg_out, max_out), dim=1)
attn = self.attn(attn)
if attn_last is not None:
attn = attn_last * attn
attn = attn.repeat(1, self.planes, 1, 1)
if if_attn:
out = out *attn
return out,attn[:, 0, :, :].unsqueeze(1),True
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=4, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 128
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(90*2, self.inplanes, kernel_size=3, stride=1,padding=1,
bias=False,groups=1)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 128, layers[0],groups=1)
self.inplanes = int(self.inplanes*1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0],groups=1)
self.inplanes = int(self.inplanes * 1)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1],groups=1)
self.inplanes = int(self.inplanes * 1)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2],groups=1)
self.inplanes = int(self.inplanes * 1)
self.fc = nn.Linear(512* block.expansion*196, 5)
self.drop = nn.Dropout(p=0.1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False,groups=1):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x,POS):##x->input of main branch; POS->position embeddings generated by sub branch
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
##main branch
x,attn1,_ = self.layer1((x,None,True))
temp = attn1
attn1 = self.maxpool(attn1)
x ,attn2,_= self.layer2((x,attn1,True))
attn2=self.maxpool(attn2)
x ,attn3,_= self.layer3((x,attn2,True))
#
attn3 = self.maxpool(attn3)
x,attn4,_ = self.layer4((x,attn3,True))
x=x+POS#fusion of motion pattern feature and position embeddings
x = torch.flatten(x, 1)
x = self.fc(x)
return x,temp.view(x.size(0),-1)
def forward(self, x,POS):
return self._forward_impl(x,POS)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
##main branch consisting of CA blocks
def resnet18_pos_attention(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', CABlock, [1, 1, 1, 1], pretrained, progress,
**kwargs)

PC_module.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# -*- coding: utf-8 -*-
#import torch
#import torch.nn as nn
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
import math
import logging
from functools import partial
from collections import OrderedDict
#import torch
#import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.functional as F
from itertools import repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import collections.abc
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
__all__ = [
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
'deit_base_distilled_patch16_384',
]
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
varq = torch.var(q, dim=2).sum(dim=2).sum()/B/N
vark = torch.var(k, dim=2).sum(dim=2).sum()/B/N
varv = torch.var(v, dim=2).sum(dim=2).sum()/B/N
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x
###Position Calibration Module
class VisionTransformer_POS(nn.Module):
""" Vision Transformer
A torch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.15, hybrid_backbone=None, norm_layer=None):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
norm_layer=partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, 196, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head
self.head = nn.Linear(self.num_features, 5) if num_classes > 0 else nn.Identity()
# self.to_Mask = nn.Sequential(nn.Conv2d(in_channels=self.num_features,out_channels=1,kernel_size=3,padding=1),
# nn.Hardsigmoid(),
# )
# self.to_Mask = nn.Linear(self.num_features,1)
self.to_Mask = nn.Sequential(nn.Linear(self.num_features,1),
nn.Sigmoid(),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
x = self.pre_logits(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x

CASME2数据集中的问题请自行修改