博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
GCN代码分析 2019.03.12 22:34:54字数 560阅读 5714 本文主要对GCN源码进行分析。
阅读量:5288 次
发布时间:2019-06-14

本文共 5536 字,大约阅读时间需要 18 分钟。

GCN代码分析

 

1 代码结构

.├── data      // 图数据├── inits    // 初始化的一些公用函数├── layers     // GCN层的定义├── metrics    // 评测指标的计算├── models     // 模型结构定义├── train    // 训练└── utils    //  工具函数的定义

utils.py

def parse_index_file(filename) # 处理index文件并返回index矩阵

def sample_mask(idx, l) #创建 mask 并返回mask矩阵

def load_data(dataset_str) # 读取数据

  • 从gcn/data文件夹下读取数据,文件包括有:

  • ind.dataset_str.x => 训练实例的特征向量,如scipy.sparse.csr.csr_matrix类的实例

  • ind.dataset_str.tx => 测试实例的特征向量,如scipy.sparse.csr.csr_matrix类的实例

  • ind.dataset_str.allx => 有标签的+无无标签训练实例的特征向量,是ind.dataset_str.x的超集

  • ind.dataset_str.y => 训练实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.ty => 测试实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.ally => 有标签的+无无标签训练实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.graph => 图数据,collections.defaultdict类的实例,格式为 {index:[index_of_neighbor_nodes]}

  • ind.dataset_str.test.index => 测试实例的id

​ 上述文件必须都用python的pickle模块存储

  • 返回: adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask

def sparse_to_tuple(sparse_mx) # 将矩阵转换成tuple格式并返回

def preprocess_features(features) # 处理特征:将特征进行归一化并返回tuple (coords, values, shape)

def normalize_adj(adj) # 图归一化并返回

def preprocess_adj(adj) # 处理得到GCN中的归一化矩阵并返回

def construct_feed_dict(features, support, labels, labels_mask, placeholders) # 构建输入字典并返回

def chebyshev_polynomials(adj, k) # 切比雪夫多项式近似:计算K阶的切比雪夫近似矩阵

def chebyshev_polynomials(adj, k): """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation).""" print("Calculating Chebyshev polynomials up to order {}...".format(k)) adj_normalized = normalize_adj(adj) # D^{-1/2}AD^{1/2} laplacian = sp.eye(adj.shape[0]) - adj_normalized # L = I_N - D^{-1/2}AD^{1/2} largest_eigval, _ = eigsh(laplacian, 1, which='LM') # \lambda_{max} scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0]) # 2/\lambda_{max}L-I_N # 将切比雪夫多项式的 T_0(x) = 1和 T_1(x) = x 项加入到t_k中 t_k = list() t_k.append(sp.eye(adj.shape[0])) t_k.append(scaled_laplacian) # 依据公式 T_n(x) = 2xT_n(x) - T_{n-1}(x) 构造递归程序,计算T_2 -> T_k项目 def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap): s_lap = sp.csr_matrix(scaled_lap, copy=True) return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two for i in range(2, k+1): t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian)) return sparse_to_tuple(t_k)

layers.py

  • 各层定义的方式与keras类似

  • 定义基类 Layer

    属性:name (String) => 定义了变量范围;logging (Boolean) => 打开或关闭TensorFlow直方图日志记录

    方法:__init__()(初始化),_call()(定义计算),__call__()(调用_call()函数),_log_vars()

  • 定义Dense Layer类,继承自Layer类

  • 定义GraphConvolution类,继承自Layer类。重点来看一下这个类的实现。

class GraphConvolution(Layer): """Graph convolution layer.""" def __init__(self, input_dim, output_dim, placeholders, dropout=0., sparse_inputs=False, act=tf.nn.relu, bias=False, featureless=False, **kwargs): super(GraphConvolution, self).__init__(**kwargs) if dropout: self.dropout = placeholders['dropout'] else: self.dropout = 0. self.act = act self.support = placeholders['support'] self.sparse_inputs = sparse_inputs self.featureless = featureless self.bias = bias # helper variable for sparse dropout self.num_features_nonzero = placeholders['num_features_nonzero'] # 下面是定义变量,主要是通过调用utils.py中的glorot函数实现 with tf.variable_scope(self.name + '_vars'): for i in range(len(self.support)): self.vars['weights_' + str(i)] = glorot([input_dim, output_dim], name='weights_' + str(i)) if self.bias: self.vars['bias'] = zeros([output_dim], name='bias') if self.logging: self._log_vars() def _call(self, inputs): x = inputs # dropout 设置dropout if self.sparse_inputs: x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero) else: x = tf.nn.dropout(x, 1-self.dropout) # convolve 卷积的实现。主要是根据论文中公式Z = \tilde{D}^{-1/2}\tilde{A}^{-1/2}X\theta实现 supports = list() for i in range(len(self.support)): if not self.featureless: pre_sup = dot(x, self.vars['weights_' + str(i)], sparse=self.sparse_inputs) else: pre_sup = self.vars['weights_' + str(i)] support = dot(self.support[i], pre_sup, sparse=True) supports.append(support) output = tf.add_n(supports) # bias if self.bias: output += self.vars['bias'] return self.act(output)

model.py

定义了一个model基类,以及两个继承自model类的MLP、GCN类。重点来看看GCN类的定义

class GCN(Model): def __init__(self, placeholders, input_dim, **kwargs): super(GCN, self).__init__(**kwargs) self.inputs = placeholders['features'] self.input_dim = input_dim # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions self.output_dim = placeholders['labels'].get_shape().as_list()[1] self.placeholders = placeholders self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) self.build() # 损失计算 def _loss(self): # Weight decay loss # 正则化项 for var in self.layers[0].vars.values(): self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) # Cross entropy error # 交叉熵损失函数 self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'], self.placeholders['labels_mask']) # 计算模型准确度 def _accuracy(self): self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'], self.placeholders['labels_mask']) # 构建模型:两层GCN def _build(self): self.layers.append(GraphConvolution(input_dim=self.input_dim, output_dim=FLAGS.hidden1, placeholders=self.placeholders, act=tf.nn.relu, dropout=True, sparse_inputs=True, logging=self.logging)) self.layers.append(GraphConvolution(input_dim=FLAGS.hidden1, output_dim=self.output_dim, placeholders=self.placeholders, act=lambda x: x, dropout=True, logging=self.logging)) # 模型预测 def predict(self): return tf.nn.softmax(self.outputs)

2 实践

更新中...

转载于:https://www.cnblogs.com/think90/p/11502647.html

你可能感兴趣的文章
Struts2 注释类型
查看>>
JSP中EL表达式语言不能使用的解决方法
查看>>
做XH2.54杜邦线材料-导线
查看>>
如何刻录cd音乐
查看>>
Codeforces Round #318(Div 1) 573A, 573B,573C
查看>>
51Nod 1091 线段重叠 贪心 区间重叠
查看>>
[翻译] NimbusKit
查看>>
POJ 2196
查看>>
熟悉下 mysql 的数据库导入导出
查看>>
5个数组Array方法: indexOf、filter、forEach、map、reduce使用实例(转)
查看>>
Machine Learning for hackers读书笔记(七)优化:密码破译
查看>>
Python基础第24天
查看>>
使用NPOI 做Excel导出
查看>>
L0/L1/L2范数(转载)
查看>>
[deviceone开发]-数据绑定示例
查看>>
CSU - 1770 按钮控制彩灯实验
查看>>
使用函数处理数据
查看>>
C语言函数返回数组
查看>>
动态对象(dynamic)的用法
查看>>
第九周软件工程作业-每周例行报告
查看>>