本文主要是介绍PyG edge index 转换回 邻接矩阵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
PyG的edge index形式是 [ ( n o d e 1 , n o d e 2 ) , ( n o d e 1 , n o d e 3 ) . . . ] [(node_1,node_2), (node_1, node_3)...] [(node1,node2),(node1,node3)...]这种edge pair。
naive
直接for循环,吧edge index里面的位置填充1:
import torch def edge_index_to_adjacency_matrix(edge_index, num_nodes): # 创建大小为 (num_nodes, num_nodes) 的二维张量 adjacency_matrix = torch.zeros(num_nodes, num_nodes) # 根据边索引填充邻接矩阵的元素 for i, j in zip(*edge_index): adjacency_matrix[i, j] = 1 adjacency_matrix[j, i] = 1 return adjacency_matrix
效率很低
利用传播机制
用PyTorch的广播机制,通过将边索引直接作为索引,一次性将对应的邻接矩阵元素设置为1,避免了使用for循环进行逐个元素的填充。这种方法在大规模图形上具有更高的效率。
import torch def edge_index_to_adjacency_matrix(edge_index, num_nodes): # 构建一个大小为 (num_nodes, num_nodes) 的零矩阵 adjacency_matrix = torch.zeros(num_nodes, num_nodes, dtype=torch.uint8) # 使用索引广播机制,一次性将边索引映射到邻接矩阵的相应位置上 adjacency_matrix[edge_index[0], edge_index[1]] = 1 adjacency_matrix[edge_index[1], edge_index[0]] = 1 return adjacency_matrix
这篇关于PyG edge index 转换回 邻接矩阵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!