在进行BERT模型的输入处理和输出处理过程中,需要根据每个文本输入的长度,对输入进行填充,以保证整个输入序列具有相同的长度。这样做会在输入序列的末尾填充一定数量的特殊标记"[PAD]",但是这些填充部分并不应该对文本嵌入向量的计算产生影响,因此需要在计算均值池化时进行排除。
以下是使用PyTorch实现排除填充的部分代码示例:
import torch
# 定义输入
input_tensors = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 0, 0] # 填充为[7, [PAD], [PAD]]
])
# 定义掩码
mask = input_tensors.gt(0).to(torch.float32)
# 通过掩码排除填充的部分进行均值池化
masked_sum = torch.sum(input_tensors, dim=1) # 取得每个样本的向量和
masked_mean = masked_sum / torch.sum(mask, dim=1) # 相应地计算出实际有效部分的均值
print(masked_mean)
执行上述代码将会输出以下均值向量:
tensor([ 2., 5., 7.])
其中,gt(0)会在所有等于0的位置返回False,否则返回True。由于我们在掩码中需要排除填充的部分,因此使用比0大的值。 最后,通过torch.sum计算在每个样本中实际部分的向量总和。在本例中,第一个样本的向量和为1+2+3=6。 接着将向量总和除以属于每个样本的有效部分的总数,即可得到对第一个样本的均值池化结果。