初级图神经网络。
代码示例:
import torch from torch_geometric.data import Data
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) # connections between nodes data = Data(x=x, edge_index=edge_index)
class Net(torch.nn.Module): def init(self): super(Net, self).init() self.conv1 = torch.nn.GraphConv(1, 16) # input and output feature dimensions self.conv2 = torch.nn.GraphConv(16, 1)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = torch.nn.functional.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) criterion = torch.nn.MSELoss()
for epoch in range(100): optimizer.zero_grad() out = model(data.to(device)) loss = criterion(out, data.x.to(device)) loss.backward() optimizer.step()
new_x = torch.tensor([[2], [-2]], dtype=torch.float) new_edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) new_data = Data(x=new_x, edge_index=new_edge_index) print(model(new_data.to(device)))