|
13 | 13 | "metadata": {}, |
14 | 14 | "outputs": [], |
15 | 15 | "source": [ |
16 | | - "#|default_exp utils" |
| 16 | + "# |default_exp utils" |
17 | 17 | ] |
18 | 18 | }, |
19 | 19 | { |
|
22 | 22 | "metadata": {}, |
23 | 23 | "outputs": [], |
24 | 24 | "source": [ |
25 | | - "#|export\n", |
| 25 | + "# |export\n", |
26 | 26 | "import matplotlib.pyplot as plt\n", |
27 | | - "from matplotlib.collections import LineCollection\n", |
| 27 | + "from matplotlib.collections import LineCollection\n", |
28 | 28 | "import numpy as np\n", |
29 | 29 | "import jax\n", |
30 | 30 | "import jax.numpy as jnp\n", |
|
44 | 44 | "metadata": {}, |
45 | 45 | "outputs": [], |
46 | 46 | "source": [ |
47 | | - "#|export\n", |
48 | | - "key = jax.random.PRNGKey(0)\n", |
| 47 | + "# |export\n", |
| 48 | + "key = jax.random.PRNGKey(0)\n", |
49 | 49 | "logsumexp = jax.scipy.special.logsumexp" |
50 | 50 | ] |
51 | 51 | }, |
|
55 | 55 | "metadata": {}, |
56 | 56 | "outputs": [], |
57 | 57 | "source": [ |
58 | | - "#|export\n", |
| 58 | + "# |export\n", |
59 | 59 | "def keysplit(key, *ns):\n", |
60 | | - " if len(ns) == 0: \n", |
| 60 | + " if len(ns) == 0:\n", |
61 | 61 | " return jax.random.split(key, 1)[0]\n", |
62 | 62 | " elif len(ns) == 1:\n", |
63 | | - " n, = ns\n", |
64 | | - " if n == 1: return keysplit(key)\n", |
65 | | - " else: return jax.random.split(key, ns[0])\n", |
| 63 | + " (n,) = ns\n", |
| 64 | + " if n == 1:\n", |
| 65 | + " return keysplit(key)\n", |
| 66 | + " else:\n", |
| 67 | + " return jax.random.split(key, ns[0])\n", |
66 | 68 | " else:\n", |
67 | 69 | " keys = []\n", |
68 | | - " for n in ns: keys.append(keysplit(key, n))\n", |
69 | | - " return keys\n" |
| 70 | + " for n in ns:\n", |
| 71 | + " keys.append(keysplit(key, n))\n", |
| 72 | + " return keys" |
70 | 73 | ] |
71 | 74 | }, |
72 | 75 | { |
|
122 | 125 | "metadata": {}, |
123 | 126 | "outputs": [], |
124 | 127 | "source": [ |
125 | | - "#|export\n", |
| 128 | + "# |export\n", |
126 | 129 | "def bounding_box(arr, pad=0):\n", |
127 | 130 | " \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n", |
128 | | - " return jnp.array([\n", |
129 | | - " [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n", |
130 | | - " [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n", |
131 | | - " ])" |
| 131 | + " return jnp.array(\n", |
| 132 | + " [\n", |
| 133 | + " [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n", |
| 134 | + " [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n", |
| 135 | + " ]\n", |
| 136 | + " )" |
132 | 137 | ] |
133 | 138 | }, |
134 | 139 | { |
|
137 | 142 | "metadata": {}, |
138 | 143 | "outputs": [], |
139 | 144 | "source": [ |
140 | | - "#|export\n", |
| 145 | + "# |export\n", |
141 | 146 | "def argmax_axes(a, axes=None):\n", |
142 | 147 | " \"\"\"Argmax along specified axes\"\"\"\n", |
143 | | - " if axes is None: return jnp.argmax(a)\n", |
144 | | - " \n", |
145 | | - " n = len(axes) \n", |
146 | | - " axes_ = set(range(a.ndim))\n", |
| 148 | + " if axes is None:\n", |
| 149 | + " return jnp.argmax(a)\n", |
| 150 | + "\n", |
| 151 | + " n = len(axes)\n", |
| 152 | + " axes_ = set(range(a.ndim))\n", |
147 | 153 | " axes_0 = axes\n", |
148 | | - " axes_1 = sorted(axes_ - set(axes_0)) \n", |
149 | | - " axes_ = axes_0 + axes_1\n", |
| 154 | + " axes_1 = sorted(axes_ - set(axes_0))\n", |
| 155 | + " axes_ = axes_0 + axes_1\n", |
150 | 156 | "\n", |
151 | 157 | " b = jnp.transpose(a, axes=axes_)\n", |
152 | 158 | " c = b.reshape(np.prod(b.shape[:n]), -1)\n", |
153 | 159 | "\n", |
154 | 160 | " I = jnp.argmax(c, axis=0)\n", |
155 | | - " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n", |
| 161 | + " I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n", |
| 162 | + " b.shape[n:] + (n,)\n", |
| 163 | + " )\n", |
156 | 164 | "\n", |
157 | | - " return I" |
| 165 | + " return I" |
158 | 166 | ] |
159 | 167 | }, |
160 | 168 | { |
|
177 | 185 | "test_shape = (3, 99, 5, 9)\n", |
178 | 186 | "a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n", |
179 | 187 | "\n", |
180 | | - "I = argmax_axes(a, axes=[0,1])\n", |
| 188 | + "I = argmax_axes(a, axes=[0, 1])\n", |
181 | 189 | "I.shape" |
182 | 190 | ] |
183 | 191 | }, |
|
194 | 202 | "metadata": {}, |
195 | 203 | "outputs": [], |
196 | 204 | "source": [ |
197 | | - "#|export\n", |
198 | | - "def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n", |
199 | | - "def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])" |
| 205 | + "# |export\n", |
| 206 | + "def cam_to_screen(x):\n", |
| 207 | + " return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n", |
| 208 | + "\n", |
| 209 | + "\n", |
| 210 | + "def screen_to_cam(y):\n", |
| 211 | + " return y[2] * jnp.array([y[0], y[1], 1.0])" |
200 | 212 | ] |
201 | 213 | }, |
202 | 214 | { |
|
205 | 217 | "metadata": {}, |
206 | 218 | "outputs": [], |
207 | 219 | "source": [ |
208 | | - "#|export\n", |
209 | | - "def rot2d(hd): return jnp.array([\n", |
210 | | - " [jnp.cos(hd), -jnp.sin(hd)], \n", |
211 | | - " [jnp.sin(hd), jnp.cos(hd)]\n", |
212 | | - " ]);\n", |
| 220 | + "# |export\n", |
| 221 | + "def rot2d(hd):\n", |
| 222 | + " return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n", |
| 223 | + "\n", |
213 | 224 | "\n", |
214 | | - "def pack_2dpose(x,hd): \n", |
215 | | - " return jnp.concatenate([x,jnp.array([hd])])\n", |
| 225 | + "def pack_2dpose(x, hd):\n", |
| 226 | + " return jnp.concatenate([x, jnp.array([hd])])\n", |
216 | 227 | "\n", |
217 | | - "def apply_2dpose(p, ys): \n", |
218 | | - " return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n", |
219 | 228 | "\n", |
220 | | - "def unit_vec(hd): \n", |
| 229 | + "def apply_2dpose(p, ys):\n", |
| 230 | + " return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n", |
| 231 | + "\n", |
| 232 | + "\n", |
| 233 | + "def unit_vec(hd):\n", |
221 | 234 | " return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n", |
222 | 235 | "\n", |
| 236 | + "\n", |
223 | 237 | "def adjust_angle(hd):\n", |
224 | 238 | " \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n", |
225 | | - " return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi" |
| 239 | + " return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi" |
226 | 240 | ] |
227 | 241 | }, |
228 | 242 | { |
|
238 | 252 | "metadata": {}, |
239 | 253 | "outputs": [], |
240 | 254 | "source": [ |
241 | | - "#|export\n", |
| 255 | + "# |export\n", |
242 | 256 | "from genjax.incremental import UnknownChange, NoChange, Diff\n", |
243 | 257 | "\n", |
244 | 258 | "\n", |
245 | 259 | "def argdiffs(args, other=None):\n", |
246 | | - " return tuple(map(lambda v: Diff(v, UnknownChange), args))\n" |
| 260 | + " return tuple(map(lambda v: Diff(v, UnknownChange), args))" |
247 | 261 | ] |
248 | 262 | }, |
249 | 263 | { |
|
252 | 266 | "metadata": {}, |
253 | 267 | "outputs": [], |
254 | 268 | "source": [ |
255 | | - "#|export\n", |
| 269 | + "# |export\n", |
256 | 270 | "from builtins import property as _property, tuple as _tuple\n", |
257 | 271 | "from typing import Any\n", |
258 | 272 | "\n", |
259 | 273 | "\n", |
260 | 274 | "class Args(tuple):\n", |
261 | 275 | " def __new__(cls, *args, **kwargs):\n", |
262 | 276 | " return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n", |
263 | | - " \n", |
| 277 | + "\n", |
264 | 278 | " def __init__(self, *args, **kwargs):\n", |
265 | 279 | " self._d = dict()\n", |
266 | | - " for k,v in kwargs.items():\n", |
| 280 | + " for k, v in kwargs.items():\n", |
267 | 281 | " self._d[k] = v\n", |
268 | 282 | " setattr(self, k, v)\n", |
269 | 283 | "\n", |
|
297 | 311 | "metadata": {}, |
298 | 312 | "outputs": [], |
299 | 313 | "source": [ |
300 | | - "#|export\n", |
301 | | - "# \n", |
| 314 | + "# |export\n", |
| 315 | + "#\n", |
302 | 316 | "# Monkey patching `sample` for `BuiltinGenerativeFunction`\n", |
303 | | - "# \n", |
| 317 | + "#\n", |
304 | 318 | "cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n", |
305 | 319 | "\n", |
| 320 | + "\n", |
306 | 321 | "def genjax_sample(self, key, *args, **kwargs):\n", |
307 | 322 | " tr = self.simulate(key, args)\n", |
308 | 323 | " return tr.get_retval()\n", |
309 | 324 | "\n", |
| 325 | + "\n", |
310 | 326 | "setattr(cls, \"sample\", genjax_sample)\n", |
311 | 327 | "\n", |
312 | 328 | "\n", |
313 | | - "# \n", |
| 329 | + "#\n", |
314 | 330 | "# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n", |
315 | | - "# \n", |
| 331 | + "#\n", |
316 | 332 | "cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n", |
317 | 333 | "\n", |
| 334 | + "\n", |
318 | 335 | "def deff_gen_func_call(self, key, **kwargs):\n", |
319 | 336 | " return self.gen_fn.sample(key, *self.args, **kwargs)\n", |
320 | 337 | "\n", |
| 338 | + "\n", |
321 | 339 | "def deff_gen_func_logpdf(self, x, **kwargs):\n", |
322 | 340 | " return self.gen_fn.logpdf(x, *self.args, **kwargs)\n", |
323 | 341 | "\n", |
| 342 | + "\n", |
324 | 343 | "setattr(cls, \"__call__\", deff_gen_func_call)\n", |
325 | 344 | "setattr(cls, \"sample\", deff_gen_func_call)\n", |
326 | 345 | "setattr(cls, \"logpdf\", deff_gen_func_logpdf)" |
|
0 commit comments