-
Notifications
You must be signed in to change notification settings - Fork 205
Open
Labels
bugSomething isn't workingSomething isn't workingneeds-triagedfor issues raised to be triagedfor issues raised to be triaged
Description
Is there an existing issue / discussion for this? | 是否已有关于该错误的issue或讨论?
- I have searched the existing issues / discussions | 我已经搜索过已有的issues和讨论
Is there an existing answer for this in tutorial? | 该问题是否在教程中有解答?
- I have searched tutorial | 我已经搜索过tutorial
Current Behavior | 当前行为
- 实际返回类型为
scipy.sparse._csc.csc_matrix - 签名为
np.matrix
以 calculate_symmetric_message_passing_adj 为例
def calculate_symmetric_message_passing_adj(adj: np.ndarray) -> np.matrix:
"""
Calculate the renormalized message-passing adjacency matrix as proposed in GCN.
The message-passing adjacency matrix is defined as A' = D^{-1/2} (A + I) D^{-1/2}.
Args:
adj (np.ndarray): Adjacency matrix A.
Returns:
np.matrix: Renormalized message-passing adjacency matrix.
"""
adj = adj + np.eye(adj.shape[0], dtype=np.float32)
adj = sp.coo_matrix(adj)
row_sum = np.array(adj.sum(1)).flatten()
d_inv_sqrt = np.power(row_sum, -0.5)
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
mp_adj = d_mat_inv_sqrt.dot(adj).transpose().dot(d_mat_inv_sqrt).astype(np.float32)
return mp_adjExpected Behavior | 期望行为
实际类型与签名类型保持一致
Environment | 运行环境
- BasicTS: 1.1.0BasicTS logs | BasicTS日志
No response
Steps To Reproduce | 复现方法
No response
Anything else? | 备注
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds-triagedfor issues raised to be triagedfor issues raised to be triaged