欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

【stgcn】代码pytorch解读

时间:2023-04-21

解读 utils.py

import osimport zipfileimport numpy as npimport torch

一、加载矩阵数据

def load_metr_la_data(): if (not os.path.isfile("../PeMSD7(M)/adj_mat.npy") or not os.path.isfile("../PeMSD7(M)/node_values.npy")): with zipfile.ZipFile("../PeMSD7(M)/METR-LA.zip", 'r') as zip_ref: zip_ref.extractall("data/") # 如果文件路径不存在,则打开zip文件 A = np.load("../PeMSD7(M)/adj_mat.npy") X = np.load("../PeMSD7(M)/node_values.npy").transpose((1, 2, 0)) X = X.astype(np.float32) # Normalization using Z-score method means = np.mean(X, axis=(0, 2)) # 均值 X = X - means.reshape(1, -1, 1) stds = np.std(X, axis=(0, 2)) # 方差 X = X / stds.reshape(1, -1, 1) # 标准化 return A, X, means, stds

注释
1、np.transpose():转轴,(0,1,2)–》(1,2,0)

二、拉普拉斯矩阵归一化

def get_normalized_adj(A): """ Returns the degree normalized adjacency matrix. """ A = A + np.diag(np.ones(A.shape[0], dtype=np.float32))# A=A+E 邻接矩阵 D = np.array(np.sum(A, axis=1)).reshape((-1,)) # D 度矩阵 D[D <= 10e-5] = 10e-5 # Prevent infs diag = np.reciprocal(np.sqrt(D)) A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A), diag.reshape((1, -1))) return A_wave

注释

np.sqrt(D):返回数组的平方根np.reciprocal():数返回参数逐元素的倒数。 三、生成迭代器

def generate_dataset(X, num_timesteps_input, num_timesteps_output): """ Takes node features for the graph and divides them into multiple samples along the time-axis by sliding a window of size (num_timesteps_input+ num_timesteps_output) across it in steps of 1. 获取图的节点特征,并将其划分为窗口大小为(输入时间步长+输出时间步长)的多维样本每隔一步。 :param X: Node features of shape (num_vertices, num_features, num_timesteps) :return: - Node features divided into multiple samples、Shape is (num_samples, num_vertices, num_features, num_timesteps_input).=(样本案例数,顶点,特征,输入时间步长) - Node targets for the samples、Shape is (num_samples, num_vertices, num_features, num_timesteps_output).=(样本案例数,顶点,特征,输出时间步长) """ # Generate the beginning index and the ending index of a sample, 生成样本的开始和结束索引 # which contains (num_points_for_training + num_points_for_predicting) points共包含(训练点+特征点) indices = [(i, i + (num_timesteps_input + num_timesteps_output)) for i in range(X.shape[2] - ( num_timesteps_input + num_timesteps_output) + 1) ] # Save samples features, target = [], [] for i, j in indices: features.append( X[:, :, i: i + num_timesteps_input].transpose( (0, 2, 1))) target.append(X[:, 0, i + num_timesteps_input: j]) return torch.from_numpy(np.array(features)), torch.from_numpy(np.array(target))

注释

node_values.shape=(34272, 207, 2)X.transpose((1, 2, 0))X为X_train共(207,2,20563), X_test共(207,2,6854), X_val共(207,2,6854)indices的范围【(0,总数-(时间输入步长+时间输出步长)】,每个索引为(i,i+(时间输入步长+时间输出步长))每个切片的特征维度为【(207,2,时间输入步长)】.transpose((0,2,1))->【(207,时间输入步长,2)】每个标签维度为【(207,1,时间输出步长)】

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。