diff --git a/graphistry/tests/test_umap_utils.py b/graphistry/tests/test_umap_utils.py index dd764d084..a550a41d0 100644 --- a/graphistry/tests/test_umap_utils.py +++ b/graphistry/tests/test_umap_utils.py @@ -74,6 +74,24 @@ node_numeric = node_ints + node_floats node_target = triangleNodes[["y"]] +node_graph_with_index = pd.DataFrame( + { + "index": range(1, 13), + "a": ["a", "b", "c", "d"] * 3, + "b": ["w", "x", "y", "z"] * 3, + } +) + +edge_graph_with_index = pd.DataFrame( + { + "index": range(1, 13), + "a": ["a", "b", "c", "d"] * 3, + "b": ["w", "x", "y", "z"] * 3, + "src": [1, 2, 3, 4] * 3, + "dst": [4, 3, 1, 2] * 3, + } +) + def _eq(df1, df2): try: df1 = df1.to_pandas() @@ -150,6 +168,15 @@ def setUp(self): ) self.g2e = g2 + # graph with index + self.g_index_nodes = graphistry.nodes(node_graph_with_index) + self.g_index_nodes_umaped = self.g_index_nodes.umap(engine="umap_learn") + assert "_n" == self.g_index_nodes_umaped._node + + self.g_index_edges = graphistry.nodes(edge_graph_with_index) + self.g_index_edges_umaped = self.g_index_edges.umap(engine="umap_learn") + assert "_n" == self.g_index_edges_umaped._node + @pytest.mark.skipif(not has_umap, reason="requires umap feature dependencies") def test_columns_match(self): @@ -816,6 +843,15 @@ def test_base(self): graphistry.nodes(self.df).umap('auto')._node_embedding.shape == (self.samples, 2) graphistry.nodes(self.df).umap('engine')._node_embedding.shape == (self.samples, 2) + # graph with index + self.g_index_nodes = graphistry.nodes(node_graph_with_index) + self.g_index_nodes_umaped = self.g_index_nodes.umap(engine="cuml") + assert "_n" == self.g_index_nodes_umaped._node + + self.g_index_edges = graphistry.nodes(edge_graph_with_index) + self.g_index_edges_umaped = self.g_index_edges.umap(engine="cuml") + assert "_n" == self.g_index_edges_umaped._node + if __name__ == "__main__": unittest.main() diff --git a/graphistry/umap_utils.py b/graphistry/umap_utils.py index d2561739d..fc66c88df 100644 --- a/graphistry/umap_utils.py +++ b/graphistry/umap_utils.py @@ -588,12 +588,13 @@ def umap( if kind == "nodes": index = res._nodes.index + if res._node is None: logger.debug("-Writing new node name") + res._nodes[config.IMPLICIT_NODE_ID] = range(len(res._nodes)) + res = res.nodes( # type: ignore - res._nodes.reset_index(drop=True) - .reset_index() - .rename(columns={"index": config.IMPLICIT_NODE_ID}), + res._nodes, config.IMPLICIT_NODE_ID, ) res._nodes.index = index