BERT分类器出现ValueError错误:目标大小(torch.Size([4,1]))必须与输入大小(torch.Size([4,2]))相同
创始人
2024-11-30 21:30:06
0

这种错误通常是由于目标大小与模型的输出不匹配导致的。可以在处理标签时对其进行one-hot编码,以确保匹配模型的输出大小。下面是对标签进行one-hot编码以解决此问题的示例代码:

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer

class MyDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, index):
        text = self.texts[index]
        label = self.labels[index]
        
        # 对标签进行one-hot编码
        label = torch.eye(2)[label].squeeze(0)
        
        encoding = self.tokenizer.encode_plus(
            text, 
            add_special_tokens=True, 
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return { 
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': label
        }

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

texts = ['This is the first sentence', 'This is the second sentence']
labels = [0, 1]

dataset = MyDataset(texts, labels, tokenizer, max_len=10)
dataloader = DataLoader(dataset, batch_size=2)

for batch in dataloader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    label = batch['label']
    
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_state = outputs.last_hidden_state
    
    # 注:这里假设模型输出的是二分类问题的结果,因此使用sigmoid函数处理
    logits = torch.sigmoid(last_hidden_state[:, 0, :])
    
    loss_fn = torch.nn.BCELoss()
    loss = loss_fn(logits, label)
    
    print(loss)

在上面的示例中,我们将标签进行one-hot编码,并使用BCELoss作为损失函数计算总

相关内容

热门资讯

两分钟辅助!开心泉州小程序开挂... 两分钟辅助!开心泉州小程序开挂有什么技巧,原来真的是有辅助插件(有挂教学)开心泉州小程序开挂有什么技...
七分钟辅助!奇迹脚本辅助,真是... 七分钟辅助!奇迹脚本辅助,真是有辅助软件(确实有挂)1、超多福利:超高返利,海量正版游戏,奇迹脚本辅...
一分钟辅助!天天贵阳智能辅助器... 一分钟辅助!天天贵阳智能辅助器,原来是有辅助脚本(真的有挂)亲,关键说明,天天贵阳智能辅助器透视脚本...
3分钟辅助!一起宁德钓蟹黑科技... 3分钟辅助!一起宁德钓蟹黑科技辅助软件推荐,其实真的有辅助挂(有挂存在)1、玩家可以在一起宁德钓蟹黑...
第二分钟辅助!大菠萝789辅助... 第二分钟辅助!大菠萝789辅助器下载,原来存在有辅助挂(存在有挂)运大菠萝789辅助器下载辅助工具,...
3分钟辅助!科乐填坑辅助,原来... 3分钟辅助!科乐填坑辅助,原来真的是有辅助器(有挂方略)1、下载好科乐填坑辅助透视辅助下载之后点击打...
3分钟辅助!潮友会透视辅助教程... 3分钟辅助!潮友会透视辅助教程,果然存在有辅助器(有挂辅助)亲,关键说明,潮友会透视辅助教程透视脚本...
4分钟辅助!福建兄弟十三冰修改... 4分钟辅助!福建兄弟十三冰修改器,本来真的是有辅助app(有挂讲解)1、游戏颠覆性的策略玩法,独创攻...
第二分钟辅助!wepoker插... 第二分钟辅助!wepoker插件程序,真是是真的有辅助技巧(有挂细节)1、不需要AI权限,帮助你快速...
1分钟辅助!悠悠互娱辅助,真是... 1分钟辅助!悠悠互娱辅助,真是是有辅助神器(有挂解密)悠悠互娱辅助透视方法中分为三种模型:悠悠互娱辅...