BERT模型的注意力矩阵是一个非常重要的输出,它可以告诉我们哪些token和哪些token之间的交互对模型最重要。我们可以使用Python中的transformers库来打印BERT模型的注意力矩阵。
首先,我们需要安装transformers库:
!pip install transformers
接下来,我们可以使用以下代码来打印在第一次序列中所有token和第二次序列中所有token之间的交互的注意力矩阵:
from transformers import BertTokenizer, BertModel
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
attention = output.attentions[0]
print(attention.shape)
这将输出注意力张量,即BERT的注意力矩阵:[batch_size, num_heads, sequence_length, sequence_length]。在BERT-base模型中,有12个attention head,我们可以使用以下代码来打印第一个attention head的注意力矩阵:
attention_head = attention[:, 0, :, :]
print(attention_head)
这将输出第一个注意力头的注意力矩阵。每一行代表一个token的注意力权重,每一列代表它与那些token的注意力权重。