|
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
| 9 | + "import sys\n", |
| 10 | + "sys.path.append(\"/Users/zizhouhuang/Desktop/polyfem-python/build/\")\n", |
9 | 11 | "import polyfempy as pf\n", |
10 | 12 | "import json\n", |
11 | | - "import numpy as np\n", |
12 | 13 | "import torch\n", |
13 | 14 | "\n", |
14 | 15 | "torch.set_default_dtype(torch.float64)" |
|
20 | 21 | "metadata": {}, |
21 | 22 | "outputs": [], |
22 | 23 | "source": [ |
23 | | - "# Differentiable simulator that computes shape derivatives\n", |
| 24 | + "# Differentiable simulator that computes initial derivatives\n", |
24 | 25 | "class Simulate(torch.autograd.Function):\n", |
25 | 26 | "\n", |
26 | 27 | " @staticmethod\n", |
27 | | - " def forward(ctx, solver, vertices):\n", |
| 28 | + " def forward(ctx, solver, body_ids, initial_velocities):\n", |
28 | 29 | " # Update solver setup\n", |
29 | | - " solver.mesh().set_vertices(vertices)\n", |
| 30 | + " for bid, vel in zip(body_ids, initial_velocities):\n", |
| 31 | + " print(bid, vel)\n", |
| 32 | + " solver.set_initial_velocity(bid, vel)\n", |
| 33 | + " sys.stdout.flush()\n", |
30 | 34 | " # Enable caching intermediate variables in the simulation, which will be used for solve_adjoint\n", |
31 | 35 | " solver.set_cache_level(pf.CacheLevel.Derivatives)\n", |
32 | 36 | " # Run simulation\n", |
33 | 37 | " solver.solve()\n", |
34 | | - " # Collect transient simulation solutions\n", |
35 | | - " cache = solver.get_solution_cache()\n", |
36 | | - " sol = torch.zeros((solver.ndof(), cache.size()))\n", |
37 | | - " for t in range(cache.size()):\n", |
38 | | - " sol[:, t] = torch.tensor(cache.solution(t))\n", |
39 | 38 | " # Cache solver for backward gradient propagation\n", |
40 | 39 | " ctx.solver = solver\n", |
41 | | - " return sol\n", |
| 40 | + " ctx.bids = body_ids\n", |
| 41 | + " return torch.tensor(solver.get_solutions())\n", |
42 | 42 | "\n", |
43 | 43 | " @staticmethod\n", |
44 | 44 | " @torch.autograd.function.once_differentiable\n", |
45 | 45 | " def backward(ctx, grad_output):\n", |
46 | 46 | " # solve_adjoint only needs to be called once per solver, independent of number of types of optimization variables\n", |
47 | | - " ctx.solver.solve_adjoint(grad_output)\n", |
48 | | - " # Compute shape derivatives\n", |
49 | | - " return None, torch.tensor(pf.shape_derivative(ctx.solver))" |
| 47 | + " ctx.solver.solve_adjoint(grad_output.detach().numpy())\n", |
| 48 | + " # Compute initial derivatives\n", |
| 49 | + " grads = pf.initial_velocity_derivative(ctx.solver)\n", |
| 50 | + " flat_grad = torch.zeros((len(ctx.bids), len(grads[ctx.bids[0]])), dtype=float)\n", |
| 51 | + " for id, g in grads.items():\n", |
| 52 | + " flat_grad[ctx.bids.index(id), :] = torch.tensor(g)\n", |
| 53 | + " return None, None, flat_grad" |
50 | 54 | ] |
51 | 55 | }, |
52 | 56 | { |
|
68 | 72 | "solver1.set_log_level(2)\n", |
69 | 73 | "solver1.load_mesh_from_settings()\n", |
70 | 74 | "\n", |
71 | | - "mesh = solver1.mesh()\n", |
72 | | - "v = mesh.vertices()\n", |
73 | | - "vertices = torch.tensor(solver1.mesh().vertices(), requires_grad=True)\n", |
74 | | - "\n", |
75 | 75 | "# Simulation 2\n", |
76 | 76 | "\n", |
77 | 77 | "config[\"initial_conditions\"][\"velocity\"][0][\"value\"] = [3, 0]\n", |
78 | 78 | "solver2 = pf.Solver()\n", |
79 | 79 | "solver2.set_settings(json.dumps(config), False)\n", |
80 | | - "solver2.set_log_level(2)\n", |
| 80 | + "solver2.set_log_level(1)\n", |
81 | 81 | "solver2.load_mesh_from_settings()" |
82 | 82 | ] |
83 | 83 | }, |
|
89 | 89 | "source": [ |
90 | 90 | "\n", |
91 | 91 | "# Verify gradient\n", |
92 | | - "\n", |
93 | | - "def loss(vertices):\n", |
94 | | - " solutions1 = Simulate.apply(solver1, vertices)\n", |
95 | | - " solutions2 = Simulate.apply(solver2, vertices)\n", |
96 | | - " print(obj.value(vertices.reshape(-1,1).detach().numpy()))\n", |
| 92 | + "def loss(param):\n", |
| 93 | + " solutions1 = Simulate.apply(solver1, body_ids, param)\n", |
| 94 | + " solutions2 = Simulate.apply(solver2, body_ids, param)\n", |
97 | 95 | " return torch.linalg.norm(solutions1[:, -1]) * torch.linalg.norm(solutions2[:, -1])\n", |
98 | 96 | "\n", |
99 | 97 | "torch.set_printoptions(12)\n", |
100 | 98 | "\n", |
101 | | - "param = vertices.clone().detach().requires_grad_(True)\n", |
| 99 | + "dt = 0.04\n", |
| 100 | + "solver1.set_cache_level(pf.CacheLevel.Derivatives)\n", |
| 101 | + "solver1.build_basis()\n", |
| 102 | + "solver1.assemble()\n", |
| 103 | + "solver1.init_timestepping(0, dt)\n", |
| 104 | + "solver2.set_cache_level(pf.CacheLevel.Derivatives)\n", |
| 105 | + "solver2.build_basis()\n", |
| 106 | + "solver2.assemble()\n", |
| 107 | + "solver2.init_timestepping(0, dt)\n", |
| 108 | + "param = torch.tensor([[5., 0], [0, 0]], requires_grad=True)\n", |
| 109 | + "body_ids = [1, 3]\n", |
| 110 | + "\n", |
102 | 111 | "theta = torch.randn_like(param)\n", |
103 | 112 | "l = loss(param)\n", |
104 | 113 | "l.backward()\n", |
|
109 | 118 | " f1 = loss(param + theta * t)\n", |
110 | 119 | " f2 = loss(param - theta * t)\n", |
111 | 120 | " fd = (f1 - f2) / (2 * t)\n", |
112 | | - " print(f'grad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')\n", |
| 121 | + " print(f'\\ngrad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')\n", |
113 | 122 | " print(f'f(x+dx)={f1}, f(x)={l.detach()}, f(x-dx)={f2}')\n", |
114 | 123 | " assert(abs(analytic - fd) <= 1e-4 * abs(analytic))" |
115 | 124 | ] |
|
0 commit comments