-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsnn_test.py
More file actions
67 lines (58 loc) · 2.05 KB
/
snn_test.py
File metadata and controls
67 lines (58 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import sys
import numpy as np
from snnpy import *
import time
dim = 2
def read_points(fname):
t = -time.perf_counter()
with open(fname, "rb") as f:
n = int.from_bytes(f.read(8), byteorder="little")
data = np.zeros((n,dim), dtype=np.float32)
for i in range(n):
data[i] = np.frombuffer(f.read(4*dim), dtype=np.float32, count=dim)
t += time.perf_counter()
print(f"[time={t:.4f}] :: (read_points) [filename='{fname}']")
return data
def build_graph(data, radius):
t = -time.perf_counter()
snn_model = build_snn_model(data)
t += time.perf_counter()
print(f"[time={t:.4f}] :: (build_snn_model) [n={data.shape[0]}]")
snn_time = t
n_edges = 0
edges = []
t = 0
for i in range(len(data)):
t1 = -time.perf_counter()
ind = snn_model.query_radius(data[i], radius, return_distance=False)
t1 += time.perf_counter()
t += t1
n_edges += len(ind)
edges.append(ind)
print(f"[time={t:.4f}] :: (queries)")
query_time = t
print(f"[time={query_time + snn_time:.4f}] :: (build_graph) [n_verts={data.shape[0]},n_edges={n_edges},avg_deg={n_edges/data.shape[0]:.4f},radius={radius:.3f}]")
return edges, n_edges
def write_graph(edges, n_edges, ofname):
t = -time.perf_counter()
with open(ofname, "w") as f:
f.write(f"{len(edges)} {n_edges}\n")
for u in range(len(edges)):
for v in sorted(edges[u]):
f.write(f"{u+1} {v+1}\n")
t += time.perf_counter()
print(f"[time={t:.4f}] :: (write_graph) [filename='{ofname}']")
def main(ifname, radius, ofname=None):
data = read_points(ifname)
edges, n_edges = build_graph(data, radius)
if ofname: write_graph(edges, n_edges, ofname)
return 0
if __name__ == "__main__":
if len(sys.argv) < 3:
sys.stderr.write(f"Usage: {sys.argv[0]} <points> <radius> [graph]\n")
sys.stderr.flush()
sys.exit(1)
ofname = None
if len(sys.argv) >= 4: ofname = sys.argv[3]
sys.exit(main(sys.argv[1], float(sys.argv[2]), ofname))