欢迎光临散文网 会员登陆 & 注册

register_buffer() 与 register_parameter() 的区别

2023-07-25 14:46 作者:Enzo_Mi  | 我要投稿


import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1,  bias=False)

        self.weight = torch.ones(10, 10)
        self.bias = torch.zeros(10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x + self.weight - self.bias
        return x


1、register_parameter()

register_parameter()是 torch.nn.Module 类中的一个方法

1.1、 作用

  • 可将 self.weightself.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习

  • self.weightself.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

1.2、用法

  register_parameter(name,param)

  • name:参数名称

  • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,否则报错如下

TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
        self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)


2、register_buffer()

register_buffer()torch.nn.Module() 类中的一个方法

2.1 、作用

  • self.weightself.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

  • self.weightself.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

  • 参数:可以被优化器更新  (requires_grad=False / True)

  • buffer 中的数据 : 不会被优化器更新

2.2、函数

  register_buffer(name,tensor)

  • name:参数名称

  • tensor:张量

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_buffer('weight', torch.ones(10, 10))
        self.register_buffer('bias', torch.zeros(10))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)



register_buffer() 与 register_parameter() 的区别的评论 (共 条)

分享到微博请遵守国家法律