在PyTorch的BatchNorm2d中,每个通道都保存有一个running_mean和running_var,用于在训练过程中更新均值和方差的移动平均值和移动方差。运行均值和方差是整个训练过程中的累计均值和方差,通常会在每个epoch中更新一次。在测试阶段,使用这些running_mean和running_var来归一化输入数据。
以下是使用BatchNorm2d的示例代码:
import torch.nn as nn
class Net(nn.Module): def init(self): super(Net, self).init() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(128) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(128 * 8 * 8, 512) self.bn3 = nn.BatchNorm1d(512) self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 128 * 8 * 8)
x = self.fc1(x)
x = self.bn3(x)
x = self.relu(x)
x = self.fc2(x)
return x
net = Net() print(net)