Skip to content

Commit ced1fdd

Browse files
committed
fix bug in test
1 parent 37cdee6 commit ced1fdd

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

test/test_differentiable.ipynb

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9+
"import sys\n",
10+
"sys.path.append(\"/Users/zizhouhuang/Desktop/polyfem-python/build/\")\n",
911
"import polyfempy as pf\n",
1012
"import json\n",
11-
"import numpy as np\n",
1213
"import torch\n",
1314
"\n",
1415
"torch.set_default_dtype(torch.float64)"
@@ -20,33 +21,36 @@
2021
"metadata": {},
2122
"outputs": [],
2223
"source": [
23-
"# Differentiable simulator that computes shape derivatives\n",
24+
"# Differentiable simulator that computes initial derivatives\n",
2425
"class Simulate(torch.autograd.Function):\n",
2526
"\n",
2627
" @staticmethod\n",
27-
" def forward(ctx, solver, vertices):\n",
28+
" def forward(ctx, solver, body_ids, initial_velocities):\n",
2829
" # Update solver setup\n",
29-
" solver.mesh().set_vertices(vertices)\n",
30+
" for bid, vel in zip(body_ids, initial_velocities):\n",
31+
" print(bid, vel)\n",
32+
" solver.set_initial_velocity(bid, vel)\n",
33+
" sys.stdout.flush()\n",
3034
" # Enable caching intermediate variables in the simulation, which will be used for solve_adjoint\n",
3135
" solver.set_cache_level(pf.CacheLevel.Derivatives)\n",
3236
" # Run simulation\n",
3337
" solver.solve()\n",
34-
" # Collect transient simulation solutions\n",
35-
" cache = solver.get_solution_cache()\n",
36-
" sol = torch.zeros((solver.ndof(), cache.size()))\n",
37-
" for t in range(cache.size()):\n",
38-
" sol[:, t] = torch.tensor(cache.solution(t))\n",
3938
" # Cache solver for backward gradient propagation\n",
4039
" ctx.solver = solver\n",
41-
" return sol\n",
40+
" ctx.bids = body_ids\n",
41+
" return torch.tensor(solver.get_solutions())\n",
4242
"\n",
4343
" @staticmethod\n",
4444
" @torch.autograd.function.once_differentiable\n",
4545
" def backward(ctx, grad_output):\n",
4646
" # solve_adjoint only needs to be called once per solver, independent of number of types of optimization variables\n",
47-
" ctx.solver.solve_adjoint(grad_output)\n",
48-
" # Compute shape derivatives\n",
49-
" return None, torch.tensor(pf.shape_derivative(ctx.solver))"
47+
" ctx.solver.solve_adjoint(grad_output.detach().numpy())\n",
48+
" # Compute initial derivatives\n",
49+
" grads = pf.initial_velocity_derivative(ctx.solver)\n",
50+
" flat_grad = torch.zeros((len(ctx.bids), len(grads[ctx.bids[0]])), dtype=float)\n",
51+
" for id, g in grads.items():\n",
52+
" flat_grad[ctx.bids.index(id), :] = torch.tensor(g)\n",
53+
" return None, None, flat_grad"
5054
]
5155
},
5256
{
@@ -68,16 +72,12 @@
6872
"solver1.set_log_level(2)\n",
6973
"solver1.load_mesh_from_settings()\n",
7074
"\n",
71-
"mesh = solver1.mesh()\n",
72-
"v = mesh.vertices()\n",
73-
"vertices = torch.tensor(solver1.mesh().vertices(), requires_grad=True)\n",
74-
"\n",
7575
"# Simulation 2\n",
7676
"\n",
7777
"config[\"initial_conditions\"][\"velocity\"][0][\"value\"] = [3, 0]\n",
7878
"solver2 = pf.Solver()\n",
7979
"solver2.set_settings(json.dumps(config), False)\n",
80-
"solver2.set_log_level(2)\n",
80+
"solver2.set_log_level(1)\n",
8181
"solver2.load_mesh_from_settings()"
8282
]
8383
},
@@ -89,16 +89,25 @@
8989
"source": [
9090
"\n",
9191
"# Verify gradient\n",
92-
"\n",
93-
"def loss(vertices):\n",
94-
" solutions1 = Simulate.apply(solver1, vertices)\n",
95-
" solutions2 = Simulate.apply(solver2, vertices)\n",
96-
" print(obj.value(vertices.reshape(-1,1).detach().numpy()))\n",
92+
"def loss(param):\n",
93+
" solutions1 = Simulate.apply(solver1, body_ids, param)\n",
94+
" solutions2 = Simulate.apply(solver2, body_ids, param)\n",
9795
" return torch.linalg.norm(solutions1[:, -1]) * torch.linalg.norm(solutions2[:, -1])\n",
9896
"\n",
9997
"torch.set_printoptions(12)\n",
10098
"\n",
101-
"param = vertices.clone().detach().requires_grad_(True)\n",
99+
"dt = 0.04\n",
100+
"solver1.set_cache_level(pf.CacheLevel.Derivatives)\n",
101+
"solver1.build_basis()\n",
102+
"solver1.assemble()\n",
103+
"solver1.init_timestepping(0, dt)\n",
104+
"solver2.set_cache_level(pf.CacheLevel.Derivatives)\n",
105+
"solver2.build_basis()\n",
106+
"solver2.assemble()\n",
107+
"solver2.init_timestepping(0, dt)\n",
108+
"param = torch.tensor([[5., 0], [0, 0]], requires_grad=True)\n",
109+
"body_ids = [1, 3]\n",
110+
"\n",
102111
"theta = torch.randn_like(param)\n",
103112
"l = loss(param)\n",
104113
"l.backward()\n",
@@ -109,7 +118,7 @@
109118
" f1 = loss(param + theta * t)\n",
110119
" f2 = loss(param - theta * t)\n",
111120
" fd = (f1 - f2) / (2 * t)\n",
112-
" print(f'grad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')\n",
121+
" print(f'\\ngrad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')\n",
113122
" print(f'f(x+dx)={f1}, f(x)={l.detach()}, f(x-dx)={f2}')\n",
114123
" assert(abs(analytic - fd) <= 1e-4 * abs(analytic))"
115124
]

0 commit comments

Comments
 (0)