-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
144 lines (110 loc) · 3.9 KB
/
utils.py
File metadata and controls
144 lines (110 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from functools import partial
from typing import Iterable, Tuple, TypeVar, Union
import jax
# import jax.numpy as np
import jax.numpy as jnp
import numpy as np
import numpy as nnp
__all__ = ["argtopk", "arg_approx", "arg_approx_signed", "repeat", "concatenate"]
jax.config.update("jax_platform_name", "cpu")
def get_int_k(array: np.ndarray, k: Union[int, float]) -> int:
if type(k) is float:
if 0.0 < k < 1.0:
int_k = round(array.size * k)
if int_k == array.size:
return array.size - 1
elif int_k == 0:
return 1
return int_k
else:
raise ValueError()
else:
return int(k)
def argtopk(array: np.ndarray, k: Union[int, float]) -> np.ndarray:
k = get_int_k(array, k)
if k == 1:
return np.array([np.argmax(array)])
else:
if isinstance(array, jax.numpy.DeviceArray):
array = nnp.array(array)
return nnp.argpartition(array, -k, axis=None)[-k:]
def arg_sorted_topk(array: np.ndarray, k: Union[int, float]) -> np.ndarray:
k = get_int_k(array, k)
return np.argsort(array)[::-1][:k]
def arg_approx(array: np.ndarray, precision: float) -> np.ndarray:
if (1 / array.size) >= precision:
return np.array([np.argmax(array)])
input_sum = array.sum()
if input_sum <= 0:
return np.array([np.argmax(array)])
input = array.flatten()
threshold = input_sum * precision
sorted_input = np.sort(input[::-1])
topk = sorted_input.cumsum().searchsorted(threshold)
if topk == len(input):
return np.where(input > 0)[0]
else:
return argtopk(input, topk + 1)
@partial(jax.jit, backend="gpu")
def argmax_batch(array: np.ndarray):
return (jnp.arange(array.shape[0]), jnp.argmax(array, axis=-1))
vsearchsorted = jax.vmap(jnp.searchsorted, (0, 0), 0)
@partial(jax.jit, static_argnums=1, backend="gpu")
# @jax.jit
def arg_approx_batch_mask(
input,
precision: float,
) -> Tuple[np.ndarray, np.ndarray]:
input_sum = input.sum(axis=-1)
threshold = input_sum * precision
sorted_input = jnp.sort(input[:, ::-1], axis=-1)
input_sum = sorted_input.cumsum(axis=-1)
topk = jnp.minimum(vsearchsorted(input_sum, threshold), input.shape[-1] - 1)
input_threshold = sorted_input[jnp.arange(input.shape[0]), topk]
return input > input_threshold[:, np.newaxis]
def to_np_array(array):
if np == jnp:
return array
if isinstance(array, jax.numpy.DeviceArray):
return np.array(array)
elif isinstance(array, tuple):
return tuple([to_np_array(element) for element in array])
else:
return array
def arg_approx_batch(
array: np.ndarray,
precision: float,
) -> Tuple[np.ndarray, np.ndarray]:
if (1 / array.shape[-1]) >= precision:
return to_np_array(argmax_batch(array))
else:
return np.nonzero(to_np_array(arg_approx_batch_mask(array, precision)))
def index_update(x, idx, y, inplace=False):
if isinstance(x, jax.numpy.DeviceArray):
return jax.ops.index_update(x, idx, y)
else:
if not inplace:
x = np.copy(x)
x[idx] = y
return x
def arg_approx_signed(array: np.ndarray, precision: float) -> np.ndarray:
result = []
for input in [array.copy(), -array]:
input[input < 0] = 0
result.append(arg_approx(input, precision))
return np.concatenate(result)
def repeat(a: int, repeats: int) -> np.ndarray:
if repeats > 1:
return np.repeat(a, repeats)
elif repeats == 1:
return np.array([a])
else:
return np.array([])
def concatenate(a_tuple, axis=0, dtype=np.int64) -> np.ndarray:
if len(a_tuple) == 0:
return np.array([], dtype=dtype)
else:
return np.concatenate(a_tuple, axis)
T = TypeVar("T")
def filter_not_null(iterable: Iterable[T]) -> Iterable[T]:
return (element for element in iterable if element is not None)