欢迎光临散文网 会员登陆 & 注册

MMNET 微表情识别(CASME2数据集)

2023-06-24 18:22 作者:感觉__站不如油管  | 我要投稿

原github地址:https://github.com/muse1998/MMNet

代码和数据集都存在一些问题,经过修改后方能够运行

main.py

CA_block.py

PC_module.py


main.py

# -*- coding: utf-8 -*-
import torch
import math
import numpy as np
import torchvision.models
import torch.utils.data as data
from torchvision import transforms
import CV2
import pandas as pd
import os, torch
import torch.nn as nn
#import image_utils
import argparse, random
from functools import partial

from MMNET.CA_block import resnet18_pos_attention

from PC_module import VisionTransformer_POS

from torchvision.transforms import Resize
torch.set_printoptions(precision=3, edgeitems=14, linewidth=350)



def parse_args():
   parser = argparse.ArgumentParser()
   parser.add_argument('--raf_path', type=str, default='D:/CASME2/', help='Raf-DB dataset path.')#default='D:/CASME2/'
   parser.add_argument('--checkpoint', type=str, default='D:/CASME2/',
                       help='Pytorch checkpoint file path')
   parser.add_argument('--pretrained', type=str, default=None,
                       help='Pretrained weights')
   parser.add_argument('--beta', type=float, default=0.7, help='Ratio of high importance group in one mini-batch.')
   parser.add_argument('--relabel_epoch', type=int, default=1000,
                       help='Relabeling samples on each mini-batch after 10(Default) epochs.')
   parser.add_argument('--batch_size', type=int, default=34, help='Batch size.')
   parser.add_argument('--optimizer', type=str, default="adam", help='Optimizer, adam or sgd.')
   parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate for sgd.')
   parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for sgd')
   parser.add_argument('--workers', default=0, type=int, help='Number of data loading workers (default: 4)')
   parser.add_argument('--epochs', type=int, default=1000, help='Total training epochs.')
   parser.add_argument('--drop_rate', type=float, default=0, help='Drop out rate.')
   return parser.parse_args()






class RafDataSet(data.Dataset):
   def __init__(self, raf_path, phase,num_loso, transform = None, basic_aug = False, transform_norm=None):
       self.phase = phase
       self.transform = transform
       self.raf_path = raf_path
       self.transform_norm = transform_norm
       SUBJECT_COLUMN =0
       NAME_COLUMN = 1
       ONSET_COLUMN = 2
       APEX_COLUMN = 3
       OFF_COLUMN = 4
       LABEL_AU_COLUMN = 5
       LABEL_ALL_COLUMN = 6


       df = pd.read_excel(os.path.join(self.raf_path, 'CASME2-coding-20140508.xlsx'),usecols=[0,1,3,4,5,7,8])
       df['Subject'] = df['Subject'].apply(str)

       if phase == 'train':
           dataset = df.loc[df['Subject']!=num_loso]
       else:
           dataset = df.loc[df['Subject'] == num_loso]

       Subject = dataset.iloc[:, SUBJECT_COLUMN].values
       File_names = dataset.iloc[:, NAME_COLUMN].values
       Label_all = dataset.iloc[:, LABEL_ALL_COLUMN].values  # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
       Onset_num = dataset.iloc[:, ONSET_COLUMN].values
       Apex_num = dataset.iloc[:, APEX_COLUMN].values
       Offset_num = dataset.iloc[:, OFF_COLUMN].values
       Label_au = dataset.iloc[:, LABEL_AU_COLUMN].values
       self.file_paths_on = []
       self.file_paths_off = []
       self.file_paths_apex = []
       self.label_all = []
       self.label_au = []
       self.sub= []
       self.file_names =[]
       a=0
       b=0
       c=0
       d=0
       e=0
       # use aligned images for training/testing
       for (f,sub,onset,apex,offset,label_all,label_au) in zip(File_names,Subject,Onset_num,Apex_num,Offset_num,Label_all,Label_au):


           if label_all == 'happiness' or label_all == 'repression' or label_all == 'disgust' or label_all == 'surprise' or label_all == 'fear' or label_all == 'sadness':

               self.file_paths_on.append(onset)
               self.file_paths_off.append(offset)
               self.file_paths_apex.append(apex)
               self.sub.append(sub)
               self.file_names.append(f)
               if label_all == 'happiness':
                   self.label_all.append(0)
                   a=a+1
               elif label_all == 'surprise':
                   self.label_all.append(1)
                   b=b+1
               else:
                   self.label_all.append(2)
                   c=c+1

           # label_au =label_au.split("+")
               if isinstance(label_au, int):
                   self.label_au.append([label_au])
               else:
                   label_au = label_au.split("+")
                   self.label_au.append(label_au)






           ##label

       self.basic_aug = basic_aug
       #self.aug_func = [image_utils.flip_image,image_utils.add_gaussian_noise]

   def __len__(self):
       return len(self.file_paths_on)

   def __getitem__(self, idx):
       ##sampling strategy for training set
       if self.phase == 'train':
           onset = self.file_paths_on[idx]
           #onset = onset.astype('int64')
           apex = self.file_paths_apex[idx]
           #apex = apex.astype('int64')
           offset =self.file_paths_off[idx]
           #offset = offset.astype('int64')

           on0 = str(random.randint(int(onset), int(onset + int(0.2* (int(apex) - int(onset)) / 4))))
           # on0 = str(int(onset))
           on1 = str(
               random.randint(int(onset + int(0.9 * (apex - onset) / 4)), int(onset + int(1.1 * (apex - onset) / 4))))
           on2 = str(
               random.randint(int(onset + int(1.8 * (apex - onset) / 4)), int(onset + int(2.2 * (apex - onset) / 4))))
           on3 = str(random.randint(int(onset + int(2.7 * (apex - onset) / 4)), onset + int(3.3 * (apex - onset) / 4)))
           # apex0 = str(apex)
           apex0 = str(
               random.randint(int(apex - int(0.15* (apex - onset) / 4)), apex + int(0.15 * (offset - apex) / 4)))
           off0 = str(
               random.randint(int(apex + int(0.9 * (offset - apex) / 4)), int(apex + int(1.1 * (offset - apex) / 4))))
           off1 = str(
               random.randint(int(apex + int(1.8 * (offset - apex) / 4)), int(apex + int(2.2 * (offset - apex) / 4))))
           off2 = str(
               random.randint(int(apex + int(2.9 * (offset - apex) / 4)), int(apex + int(3.1 * (offset - apex) / 4))))
           off3 = str(random.randint(int(apex + int(3.8 * (offset - apex) / 4)), offset))



           sub =str(self.sub[idx])
           f = str(self.file_names[idx])
       else:##sampling strategy for testing set
           onset = self.file_paths_on[idx]
           apex = self.file_paths_apex[idx]
           offset = self.file_paths_off[idx]

           on0 = str(onset)
           on1 = str(int(onset + int((apex - onset) / 4)))
           on2 = str(int(onset + int(2 * (apex - onset) / 4)))
           on3 = str(int(onset + int(3 * (apex - onset) / 4)))
           apex0 = str(apex)
           off0 = str(int(apex + int((offset - apex) / 4)))
           off1 = str(int(apex + int(2 * (offset - apex) / 4)))
           off2 = str(int(apex + int(3 * (offset - apex) / 4)))
           off3 = str(offset)

           sub = str(self.sub[idx])
           f = str(self.file_names[idx])


       on0 ='reg_img' + on0 + '.jpg'
       on1 = 'reg_img' + on1 + '.jpg'
       on2 = 'reg_img' + on2 + '.jpg'
       on3 = 'reg_img' + on3 + '.jpg'
       apex0 ='reg_img' + apex0 + '.jpg'
       off0 ='reg_img' + off0 + '.jpg'
       off1='reg_img' + off1 + '.jpg'
       off2 ='reg_img' + off2 + '.jpg'
       off3 = 'reg_img' + off3 + '.jpg'
       path_on0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on0).replace('\\', '/')
       path_on1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on1).replace('\\', '/')
       path_on2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on2).replace('\\', '/')
       path_on3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on3).replace('\\', '/')
       path_apex0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, apex0).replace('\\', '/')
       path_off0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off0).replace('\\', '/')
       path_off1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off1).replace('\\', '/')
       path_off2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off2).replace('\\', '/')
       path_off3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off3).replace('\\', '/')
       """
       print(path_on0)
       print(path_on1)
       print(path_on2)
       print(path_on3)
       print(path_apex0)
       print(path_off0)
       print(path_off1)
       print(path_off2)
       print(path_off3)
       """

       image_on0 = CV2.imread(path_on0)
       image_on1= CV2.imread(path_on1)
       image_on2 = CV2.imread(path_on2)
       image_on3 = CV2.imread(path_on3)
       image_apex0 = CV2.imread(path_apex0)
       image_off0 = CV2.imread(path_off0)
       image_off1 = CV2.imread(path_off1)
       image_off2 = CV2.imread(path_off2)
       image_off3 = CV2.imread(path_off3)

       image_on0 = image_on0[:, :, ::-1] # BGR to RGB
       image_on1 = image_on1[:, :, ::-1]
       image_on2 = image_on2[:, :, ::-1]
       image_on3 = image_on3[:, :, ::-1]
       image_off0 = image_off0[:, :, ::-1]
       image_off1 = image_off1[:, :, ::-1]
       image_off2 = image_off2[:, :, ::-1]
       image_off3 = image_off3[:, :, ::-1]
       image_apex0 = image_apex0[:, :, ::-1]

       label_all = self.label_all[idx]
       label_au = self.label_au[idx]

       # normalization for testing and training
       if self.transform is not None:
           image_on0 = self.transform(image_on0)
           image_on1 = self.transform(image_on1)
           image_on2 = self.transform(image_on2)
           image_on3 = self.transform(image_on3)
           image_off0 = self.transform(image_off0)
           image_off1 = self.transform(image_off1)
           image_off2 = self.transform(image_off2)
           image_off3 = self.transform(image_off3)
           image_apex0 = self.transform(image_apex0)
           ALL = torch.cat(
               (image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
                image_off3), dim=0)
           ## data augmentation for training only
           if self.transform_norm is not None and self.phase == 'train':
               ALL = self.transform_norm(ALL)
           image_on0 = ALL[0:3, :, :]
           image_on1 = ALL[3:6, :, :]
           image_on2 = ALL[6:9, :, :]
           image_on3 = ALL[9:12, :, :]
           image_apex0 = ALL[12:15, :, :]
           image_off0 = ALL[15:18, :, :]
           image_off1 = ALL[18:21, :, :]
           image_off2 = ALL[21:24, :, :]
           image_off3 = ALL[24:27, :, :]


           temp = torch.zeros(38)
           for i in label_au:
               #print(i)
               temp[int(i) - 1] = 1

           return image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, label_all, temp


def initialize_weight_goog(m, n=''):
   if isinstance(m, nn.Conv2d):
       fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
       m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
       if m.bias is not None:
           m.bias.data.zero_()
   elif isinstance(m, nn.BatchNorm2d):
       m.weight.data.fill_(1.0)
       m.bias.data.zero_()
   elif isinstance(m, nn.Linear):
       fan_out = m.weight.size(0)  # fan-out
       fan_in = 0
       if 'routing_fn' in n:
           fan_in = m.weight.size(1)
       init_range = 1.0 / math.sqrt(fan_in + fan_out)
       m.weight.data.uniform_(-init_range, init_range)
       m.bias.data.zero_()


def criterion2(y_pred, y_true):
   y_pred = (1 - 2 * y_true) * y_pred
   y_pred_neg = y_pred - y_true * 1e12
   y_pred_pos = y_pred - (1 - y_true) * 1e12
   zeros = torch.zeros_like(y_pred[..., :1])
   y_pred_neg = torch.cat((y_pred_neg, zeros), dim=-1)
   y_pred_pos = torch.cat((y_pred_pos, zeros), dim=-1)
   neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
   pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
   return torch.mean(neg_loss + pos_loss)


class MMNet(nn.Module):
   def __init__(self):
       super(MMNet, self).__init__()


       self.conv_act = nn.Sequential(
           nn.Conv2d(in_channels=3, out_channels=90*2, kernel_size=3, stride=2,padding=1, bias=False,groups=1),#group=2
           nn.BatchNorm2d(180),
           nn.ReLU(inplace=True),
           )
       self.pos =nn.Sequential(
           nn.Conv2d(in_channels=3, out_channels=512, kernel_size=1, stride=1, bias=False),
           nn.BatchNorm2d(512),
           nn.ReLU(inplace=True),

           )
       ##Position Calibration Module(subbranch)
       self.vit_pos=VisionTransformer_POS(img_size=14,
       patch_size=1, embed_dim=512, depth=3, num_heads=4, mlp_ratio=2, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6),drop_path_rate=0.3)
       self.resize=Resize([14,14])
       ##main branch consisting of CA blocks
       self.main_branch =resnet18_pos_attention()
       self.head1 = nn.Sequential(
           nn.Dropout(p=0.5),
           nn.Linear(1 * 112 *112, 38,bias=False),

       )

       self.timeembed = nn.Parameter(torch.zeros(1, 4, 111, 111))

       self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
   def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, if_shuffle):
       ##onset:x1 apex:x5
       B = x1.shape[0]

       #Position Calibration Module (subbranch)
       POS =self.vit_pos(self.resize(x1)).transpose(1,2).view(B,512,14,14)
       act = x5 -x1
       act=self.conv_act(act)
       #main branch and fusion
       out,_=self.main_branch(act,POS)

       return out





def run_training():

   args = parse_args()
   imagenet_pretrained = True #是否加载预训练模型

   if not imagenet_pretrained:
       for m in res18.modules():
           initialize_weight_goog(m)

   if args.pretrained:
       print("Loading pretrained weights...", args.pretrained)
       pretrained = torch.load(args.pretrained)
       pretrained_state_dict = pretrained['state_dict']
       model_state_dict = res18.state_dict()
       loaded_keys = 0
       total_keys = 0
       for key in pretrained_state_dict:
           if ((key == 'module.fc.weight') | (key == 'module.fc.bias')):
               pass
           else:
               model_state_dict[key] = pretrained_state_dict[key]
               total_keys += 1
               if key in model_state_dict:
                   loaded_keys += 1
       print("Loaded params num:", loaded_keys)
       print("Total params num:", total_keys)
       res18.load_state_dict(model_state_dict, strict=False)
   ### data normalization for both training set
   data_transforms = transforms.Compose([
       transforms.ToPILImage(),
       transforms.Resize((224, 224)),

       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),

   ])
   ### data augmentation for training set only
   data_transforms_norm = transforms.Compose([

       transforms.RandomHorizontalFlip(p=0.5),
       transforms.RandomRotation(4),
       transforms.RandomCrop(224, padding=4),


   ])


   ### data normalization for both teating set
   data_transforms_val = transforms.Compose([
       transforms.ToPILImage(),
       transforms.Resize((224, 224)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])])



   criterion = torch.nn.CrossEntropyLoss()
   #leave one subject out protocal
   LOSO = ['17', '26', '16', '9', '5', '24', '2', '13', '4', '23', '11', '12', '8', '14', '3', '19', '1', '10',
           '20', '21', '22', '15', '6', '25', '7']

   val_now = 0
   num_sum = 0
   pos_pred_ALL = torch.zeros(3)
   pos_label_ALL = torch.zeros(3)
   TP_ALL = torch.zeros(3)

   for subj in LOSO:
       train_dataset = RafDataSet(args.raf_path, phase='train', num_loso=subj, transform=data_transforms,
                                  basic_aug=True, transform_norm=data_transforms_norm)
       val_dataset = RafDataSet(args.raf_path, phase='test', num_loso=subj, transform=data_transforms_val)
       train_loader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=24,
                                                  num_workers=args.workers,
                                                  shuffle=True,
                                                  pin_memory=True)
       val_loader = torch.utils.data.DataLoader(val_dataset,
                                                batch_size=24,
                                                num_workers=args.workers,
                                                shuffle=False,
                                                pin_memory=True)
       print('num_sub', subj)
       print('Train set size:', train_dataset.__len__())
       print('Validation set size:', val_dataset.__len__())

       max_corr = 0
       max_f1 = 0
       max_pos_pred = torch.zeros(3)
       max_pos_label = torch.zeros(3)
       max_TP = torch.zeros(3)
       ##model initialization
       net_all = MMNet()

       params_all = net_all.parameters()

       if args.optimizer == 'adam':
           optimizer_all = torch.optim.AdamW(params_all, lr=0.0008, weight_decay=0.7)
           ##optimizer for MMNet

       elif args.optimizer == 'sgd':
           optimizer = torch.optim.SGD(params, args.lr,
                                       momentum=args.momentum,
                                       weight_decay=1e-4)
       else:
           raise ValueError("Optimizer not supported.")
       ##lr_decay
       scheduler_all = torch.optim.lr_scheduler.ExponentialLR(optimizer_all, gamma=0.987)

       net_all = net_all.cuda()

       for i in range(1, 100):
           running_loss = 0.0
           correct_sum = 0
           running_loss_MASK = 0.0
           correct_sum_MASK = 0
           iter_cnt = 0

           net_all.train()


           for batch_i, (
           image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3,
           label_all,
           label_au) in enumerate(train_loader):
               batch_sz = image_on0.size(0)
               b, c, h, w = image_on0.shape
               iter_cnt += 1

               image_on0 = image_on0.cuda()
               image_on1 = image_on1.cuda()
               image_on2 = image_on2.cuda()
               image_on3 = image_on3.cuda()
               image_apex0 = image_apex0.cuda()
               image_off0 = image_off0.cuda()
               image_off1 = image_off1.cuda()
               image_off2 = image_off2.cuda()
               image_off3 = image_off3.cuda()
               label_all = label_all.cuda()
               label_au = label_au.cuda()


               ##train MMNet
               ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1,
                                  image_off2, image_off3, False)

               loss_all = criterion(ALL, label_all)

               optimizer_all.zero_grad()

               loss_all.backward()

               optimizer_all.step()
               running_loss += loss_all
               _, predicts = torch.max(ALL, 1)
               correct_num = torch.eq(predicts, label_all).sum()
               correct_sum += correct_num






           ## lr decay
           if i <= 50:

               scheduler_all.step()
           if i>=0:
               acc = correct_sum.float() / float(train_dataset.__len__())

               running_loss = running_loss / iter_cnt

               print('[Epoch %d] Training accuracy: %.4f. Loss: %.3f' % (i, acc, running_loss))


           pos_label = torch.zeros(3)
           pos_pred = torch.zeros(3)
           TP = torch.zeros(3)
           ##test
           with torch.no_grad():
               running_loss = 0.0
               iter_cnt = 0
               bingo_cnt = 0
               sample_cnt = 0
               pre_lab_all = []
               Y_test_all = []
               net_all.eval()
               # net_au.eval()
               for batch_i, (
               image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
               image_off3, label_all,
               label_au) in enumerate(val_loader):
                   batch_sz = image_on0.size(0)
                   b, c, h, w = image_on0.shape

                   image_on0 = image_on0.cuda()
                   image_on1 = image_on1.cuda()
                   image_on2 = image_on2.cuda()
                   image_on3 = image_on3.cuda()
                   image_apex0 = image_apex0.cuda()
                   image_off0 = image_off0.cuda()
                   image_off1 = image_off1.cuda()
                   image_off2 = image_off2.cuda()
                   image_off3 = image_off3.cuda()
                   label_all = label_all.cuda()
                   label_au = label_au.cuda()

                   ##test
                   ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, False)


                   loss = criterion(ALL, label_all)
                   running_loss += loss
                   iter_cnt += 1
                   _, predicts = torch.max(ALL, 1)
                   correct_num = torch.eq(predicts, label_all)
                   bingo_cnt += correct_num.sum().cpu()
                   sample_cnt += ALL.size(0)

                   for cls in range(3):

                       for element in predicts:
                           if element == cls:
                               pos_label[cls] = pos_label[cls] + 1
                       for element in label_all:
                           if element == cls:
                               pos_pred[cls] = pos_pred[cls] + 1
                       for elementp, elementl in zip(predicts, label_all):
                           if elementp == elementl and elementp == cls:
                               TP[cls] = TP[cls] + 1

                   count = 0
                   SUM_F1 = 0
                   for index in range(3):
                       if pos_label[index] != 0 or pos_pred[index] != 0:
                           count = count + 1
                           SUM_F1 = SUM_F1 + 2 * TP[index] / (pos_pred[index] + pos_label[index])

                   AVG_F1 = SUM_F1 / count


               running_loss = running_loss / iter_cnt
               acc = bingo_cnt.float() / float(sample_cnt)
               acc = np.around(acc.numpy(), 4)
               if bingo_cnt > max_corr:
                   max_corr = bingo_cnt
               if AVG_F1 >= max_f1:
                   max_f1 = AVG_F1
                   max_pos_label = pos_label
                   max_pos_pred = pos_pred
                   max_TP = TP
               print("[Epoch %d] Validation accuracy:%.4f. Loss:%.3f, F1-score:%.3f" % (i, acc, running_loss, AVG_F1))
       num_sum = num_sum + max_corr
       pos_label_ALL = pos_label_ALL + max_pos_label
       pos_pred_ALL = pos_pred_ALL + max_pos_pred
       TP_ALL = TP_ALL + max_TP
       count = 0
       SUM_F1 = 0
       for index in range(3):
           if pos_label_ALL[index] != 0 or pos_pred_ALL[index] != 0:
               count = count + 1
               SUM_F1 = SUM_F1 + 2 * TP_ALL[index] / (pos_pred_ALL[index] + pos_label_ALL[index])

       F1_ALL = SUM_F1 / count
       val_now = val_now + val_dataset.__len__()
       print("[..........%s] correctnum:%d . zongshu:%d   " % (subj, max_corr, val_dataset.__len__()))
       print("[ALL_corr]: %d [ALL_val]: %d" % (num_sum, val_now))
       print("[F1_now]: %.4f [F1_ALL]: %.4f" % (max_f1, F1_ALL))


if __name__ == "__main__":
   run_training()

CA_block.py


# -*- coding: utf-8 -*-

#import torch
#import torch.nn as nn
import torch
import torch.nn as nn

torch.nn

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
          'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
          'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
   'resnet18': 'https://download.torch.org/models/resnet18-5c106cde.pth',
   'resnet34': 'https://download.torch.org/models/resnet34-333f7ec4.pth',
   'resnet50': 'https://download.torch.org/models/resnet50-19c8e357.pth',
   'resnet101': 'https://download.torch.org/models/resnet101-5d3b4d8f.pth',
   'resnet152': 'https://download.torch.org/models/resnet152-b121ed2d.pth',
   'resnext50_32x4d': 'https://download.torch.org/models/resnext50_32x4d-7cdf4587.pth',
   'resnext101_32x8d': 'https://download.torch.org/models/resnext101_32x8d-8ba56ff5.pth',
   'wide_resnet50_2': 'https://download.torch.org/models/wide_resnet50_2-95faca4d.pth',
   'wide_resnet101_2': 'https://download.torch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
   """3x3 convolution with padding"""
   return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                    padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1, groups=1):
   """1x1 convolution"""
   return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False,groups=groups)

##CA BLOCK
class CABlock(nn.Module):
   expansion = 1

   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None):
       super(CABlock, self).__init__()
       if norm_layer is None:
           norm_layer = nn.BatchNorm2d
       # if groups != 1 or base_width != 64:
       #     raise ValueError('BasicBlock only supports groups=1 and base_width=64')
       if dilation > 1:
           raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
       # Both self.conv1 and self.downsample layers downsample the input when stride != 1
       self.conv1 = conv3x3(inplanes, planes, stride,groups=groups)
       self.bn1 = norm_layer(planes)
       self.relu = nn.ReLU(inplace=True)
       self.conv2 = conv1x1(planes, planes,groups=groups)
       self.bn2 = norm_layer(planes)
       self.attn = nn.Sequential(
           nn.Conv2d(2, 1, kernel_size=1, stride=1,bias=False),  # 32*33*33
           nn.BatchNorm2d(1),
           nn.Sigmoid(),
       )
       self.downsample = downsample
       self.stride = stride
       self.planes=planes

   def forward(self, x):
       x, attn_last,if_attn =x##attn_last: downsampled attention maps from last layer as a prior knowledge
       identity = x

       out = self.conv1(x)
       out = self.bn1(out)

       out = self.relu(out)

       out = self.conv2(out)
       out = self.bn2(out)
       if self.downsample is not None:
           identity = self.downsample(identity)

       out = self.relu(out+identity)
       avg_out = torch.mean(out, dim=1, keepdim=True)
       max_out, _ = torch.max(out, dim=1, keepdim=True)
       attn = torch.cat((avg_out, max_out), dim=1)
       attn = self.attn(attn)
       if attn_last is not None:
           attn = attn_last * attn

       attn = attn.repeat(1, self.planes, 1, 1)
       if if_attn:
           out = out *attn


       return out,attn[:, 0, :, :].unsqueeze(1),True





class ResNet(nn.Module):

   def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                groups=4, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None):
       super(ResNet, self).__init__()
       if norm_layer is None:
           norm_layer = nn.BatchNorm2d
       self._norm_layer = norm_layer

       self.inplanes = 128
       self.dilation = 1
       if replace_stride_with_dilation is None:
           # each element in the tuple indicates if we should replace
           # the 2x2 stride with a dilated convolution instead
           replace_stride_with_dilation = [False, False, False]
       if len(replace_stride_with_dilation) != 3:
           raise ValueError("replace_stride_with_dilation should be None "
                            "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
       self.groups = groups
       self.base_width = width_per_group
       self.conv1 = nn.Conv2d(90*2, self.inplanes, kernel_size=3, stride=1,padding=1,
                              bias=False,groups=1)
       self.bn1 = norm_layer(self.inplanes)
       self.relu = nn.ReLU(inplace=True)
       self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
       self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
       self.layer1 = self._make_layer(block, 128, layers[0],groups=1)
       self.inplanes = int(self.inplanes*1)
       self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                      dilate=replace_stride_with_dilation[0],groups=1)
       self.inplanes = int(self.inplanes * 1)

       self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                      dilate=replace_stride_with_dilation[1],groups=1)
       self.inplanes = int(self.inplanes * 1)

       self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                      dilate=replace_stride_with_dilation[2],groups=1)
       self.inplanes = int(self.inplanes * 1)





       self.fc = nn.Linear(512* block.expansion*196, 5)
       self.drop = nn.Dropout(p=0.1)
       for m in self.modules():
           if isinstance(m, nn.Conv2d):
               nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
           elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
               nn.init.constant_(m.weight, 1)
               nn.init.constant_(m.bias, 0)

       # Zero-initialize the last BN in each residual branch,
       # so that the residual branch starts with zeros, and each residual block behaves like an identity.
       # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
       if zero_init_residual:
           for m in self.modules():
               if isinstance(m, Bottleneck):
                   nn.init.constant_(m.bn3.weight, 0)
               elif isinstance(m, BasicBlock):
                   nn.init.constant_(m.bn2.weight, 0)

   def _make_layer(self, block, planes, blocks, stride=1, dilate=False,groups=1):
       norm_layer = self._norm_layer
       downsample = None
       previous_dilation = self.dilation
       if dilate:
           self.dilation *= stride
           stride = 1
       if stride != 1 or self.inplanes != planes * block.expansion:
           downsample = nn.Sequential(
               conv1x1(self.inplanes, planes * block.expansion, stride),
               norm_layer(planes * block.expansion),
           )

       layers = []
       layers.append(block(self.inplanes, planes, stride, downsample, groups,
                           self.base_width, previous_dilation, norm_layer))
       self.inplanes = planes * block.expansion
       for _ in range(1, blocks):
           layers.append(block(self.inplanes, planes, groups=self.groups,
                               base_width=self.base_width, dilation=self.dilation,
                               norm_layer=norm_layer))

       return nn.Sequential(*layers)

   def _forward_impl(self, x,POS):##x->input of main branch; POS->position embeddings generated by sub branch

       x = self.conv1(x)
       x = self.bn1(x)
       x = self.relu(x)
       ##main branch
       x,attn1,_ = self.layer1((x,None,True))
       temp = attn1
       attn1 = self.maxpool(attn1)

       x ,attn2,_= self.layer2((x,attn1,True))


       attn2=self.maxpool(attn2)

       x ,attn3,_= self.layer3((x,attn2,True))
       #
       attn3 = self.maxpool(attn3)
       x,attn4,_ = self.layer4((x,attn3,True))

       x=x+POS#fusion of motion pattern feature and position embeddings

       x = torch.flatten(x, 1)

       x = self.fc(x)

       return x,temp.view(x.size(0),-1)

   def forward(self, x,POS):
       return self._forward_impl(x,POS)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
   model = ResNet(block, layers, **kwargs)
   if pretrained:
       state_dict = load_state_dict_from_url(model_urls[arch],
                                             progress=progress)
       model.load_state_dict(state_dict)
   return model

##main branch consisting of CA blocks
def resnet18_pos_attention(pretrained=False, progress=True, **kwargs):
   r"""ResNet-18 model from
   `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

   Args:
       pretrained (bool): If True, returns a model pre-trained on ImageNet
       progress (bool): If True, displays a progress bar of the download to stderr
   """
   return _resnet('resnet18', CABlock, [1, 1, 1, 1], pretrained, progress,
                  **kwargs)


PC_module.py

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# -*- coding: utf-8 -*-
#import torch
#import torch.nn as nn
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
import math
import logging
from functools import partial
from collections import OrderedDict

#import torch
#import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.functional as F
from itertools import repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import collections.abc
def drop_path(x, drop_prob: float = 0., training: bool = False):
   """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

   This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
   the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
   See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
   changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
   'survival rate' as the argument.

   """
   if drop_prob == 0. or not training:
       return x
   keep_prob = 1 - drop_prob
   shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
   random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
   random_tensor.floor_()  # binarize
   output = x.div(keep_prob) * random_tensor
   return output


class DropPath(nn.Module):
   """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
   """
   def __init__(self, drop_prob=None):
       super(DropPath, self).__init__()
       self.drop_prob = drop_prob

   def forward(self, x):
       return drop_path(x, self.drop_prob, self.training)
def _ntuple(n):
   def parse(x):
       if isinstance(x, collections.abc.Iterable):
           return x
       return tuple(repeat(x, n))
   return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
__all__ = [
   'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
   'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
   'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
   'deit_base_distilled_patch16_384',
]

class Mlp(nn.Module):
   def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
       super().__init__()
       out_features = out_features or in_features
       hidden_features = hidden_features or in_features
       self.fc1 = nn.Linear(in_features, hidden_features)
       self.act = act_layer()
       self.fc2 = nn.Linear(hidden_features, out_features)
       self.drop = nn.Dropout(drop)

   def forward(self, x):
       x = self.fc1(x)
       x = self.act(x)
       x = self.drop(x)
       x = self.fc2(x)
       x = self.drop(x)
       return x


class Attention(nn.Module):
   def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
       super().__init__()
       self.num_heads = num_heads
       head_dim = dim // num_heads
       # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
       self.scale = qk_scale or head_dim ** -0.5

       self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
       self.attn_drop = nn.Dropout(attn_drop)
       self.proj = nn.Linear(dim, dim)
       self.proj_drop = nn.Dropout(proj_drop)

   def forward(self, x):
       B, N, C = x.shape
       qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
       q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
       varq = torch.var(q, dim=2).sum(dim=2).sum()/B/N
       vark = torch.var(k, dim=2).sum(dim=2).sum()/B/N
       varv = torch.var(v, dim=2).sum(dim=2).sum()/B/N
       attn = (q @ k.transpose(-2, -1)) * self.scale
       attn = attn.softmax(dim=-1)
       attn = self.attn_drop(attn)

       x = (attn @ v).transpose(1, 2).reshape(B, N, C)
       x = self.proj(x)
       x = self.proj_drop(x)
       return x


class Block(nn.Module):

   def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
       super().__init__()
       self.norm1 = norm_layer(dim)
       self.attn = Attention(
           dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
       # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
       self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
       self.norm2 = norm_layer(dim)
       mlp_hidden_dim = int(dim * mlp_ratio)
       self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

   def forward(self, x):
       x = x + self.drop_path(self.attn(self.norm1(x)))
       x = x + self.drop_path(self.mlp(self.norm2(x)))
       return x


class PatchEmbed(nn.Module):
   """ Image to Patch Embedding
   """
   def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
       super().__init__()
       img_size = to_2tuple(img_size)
       patch_size = to_2tuple(patch_size)
       num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
       self.img_size = img_size
       self.patch_size = patch_size
       self.num_patches = num_patches

       self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

   def forward(self, x):
       B, C, H, W = x.shape
       # FIXME look at relaxing size constraints
       assert H == self.img_size[0] and W == self.img_size[1], \
           f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
       x = self.proj(x).flatten(2).transpose(1, 2)
       return x


class HybridEmbed(nn.Module):
   """ CNN Feature Map Embedding
   Extract feature map from CNN, flatten, project to embedding dim.
   """
   def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
       super().__init__()
       assert isinstance(backbone, nn.Module)
       img_size = to_2tuple(img_size)
       self.img_size = img_size
       self.backbone = backbone
       if feature_size is None:
           with torch.no_grad():
               # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
               # map for all networks, the feature metadata has reliable channel and stride info, but using
               # stride to calc feature dim requires info about padding of each stage that isn't captured.
               training = backbone.training
               if training:
                   backbone.eval()
               o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
               if isinstance(o, (list, tuple)):
                   o = o[-1]  # last feature if backbone outputs list/tuple of features
               feature_size = o.shape[-2:]
               feature_dim = o.shape[1]
               backbone.train(training)
       else:
           feature_size = to_2tuple(feature_size)
           if hasattr(self.backbone, 'feature_info'):
               feature_dim = self.backbone.feature_info.channels()[-1]
           else:
               feature_dim = self.backbone.num_features
       self.num_patches = feature_size[0] * feature_size[1]
       self.proj = nn.Conv2d(feature_dim, embed_dim, 1)

   def forward(self, x):
       x = self.backbone(x)
       if isinstance(x, (list, tuple)):
           x = x[-1]  # last feature if backbone outputs list/tuple of features
       x = self.proj(x).flatten(2).transpose(1, 2)
       return x


###Position Calibration Module
class VisionTransformer_POS(nn.Module):
   """ Vision Transformer

   A torch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
       https://arxiv.org/abs/2010.11929
   """
   def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                drop_rate=0., attn_drop_rate=0., drop_path_rate=0.15, hybrid_backbone=None, norm_layer=None):
       """
       Args:
           img_size (int, tuple): input image size
           patch_size (int, tuple): patch size
           in_chans (int): number of input channels
           num_classes (int): number of classes for classification head
           embed_dim (int): embedding dimension
           depth (int): depth of transformer
           num_heads (int): number of attention heads
           mlp_ratio (int): ratio of mlp hidden dim to embedding dim
           qkv_bias (bool): enable bias for qkv if True
           qk_scale (float): override default qk scale of head_dim ** -0.5 if set
           representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
           drop_rate (float): dropout rate
           attn_drop_rate (float): attention dropout rate
           drop_path_rate (float): stochastic depth rate
           hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
           norm_layer: (nn.Module): normalization layer
       """
       super().__init__()
       norm_layer=partial(nn.LayerNorm, eps=1e-6)
       self.num_classes = num_classes
       self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
       norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

       if hybrid_backbone is not None:
           self.patch_embed = HybridEmbed(
               hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
       else:
           self.patch_embed = PatchEmbed(
               img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
       num_patches = self.patch_embed.num_patches

       self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
       self.pos_embed = nn.Parameter(torch.zeros(1, 196, embed_dim))
       self.pos_drop = nn.Dropout(p=drop_rate)

       dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
       self.blocks = nn.ModuleList([
           Block(
               dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
               drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
           for i in range(depth)])
       self.norm = norm_layer(embed_dim)

       # Representation layer
       if representation_size:
           self.num_features = representation_size
           self.pre_logits = nn.Sequential(OrderedDict([
               ('fc', nn.Linear(embed_dim, representation_size)),
               ('act', nn.Tanh())
           ]))
       else:
           self.pre_logits = nn.Identity()

       # Classifier head
       self.head = nn.Linear(self.num_features, 5) if num_classes > 0 else nn.Identity()
       # self.to_Mask = nn.Sequential(nn.Conv2d(in_channels=self.num_features,out_channels=1,kernel_size=3,padding=1),
       #                              nn.Hardsigmoid(),
       #                              )
       # self.to_Mask = nn.Linear(self.num_features,1)
       self.to_Mask = nn.Sequential(nn.Linear(self.num_features,1),
                                    nn.Sigmoid(),
                                    )
       trunc_normal_(self.pos_embed, std=.02)
       trunc_normal_(self.cls_token, std=.02)
       self.apply(self._init_weights)

   def _init_weights(self, m):
       if isinstance(m, nn.Linear):
           trunc_normal_(m.weight, std=.02)
           if isinstance(m, nn.Linear) and m.bias is not None:
               nn.init.constant_(m.bias, 0)
       elif isinstance(m, nn.LayerNorm):
           nn.init.constant_(m.bias, 0)
           nn.init.constant_(m.weight, 1.0)

   @torch.jit.ignore
   def no_weight_decay(self):
       return {'pos_embed', 'cls_token'}

   def get_classifier(self):
       return self.head

   def reset_classifier(self, num_classes, global_pool=''):
       self.num_classes = num_classes
       self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

   def forward_features(self, x):
       B = x.shape[0]
       x = self.patch_embed(x)


       x = x + self.pos_embed
       x = self.pos_drop(x)


       for blk in self.blocks:
           x = blk(x)


       x = self.norm(x)
       x = self.pre_logits(x)
       return x

   def forward(self, x):
       x = self.forward_features(x)

       return x

CASME2数据集中的问题请自行修改



MMNET 微表情识别(CASME2数据集)的评论 (共 条)

分享到微博请遵守国家法律