This is pretty obscure, and possibly un-fixable. I encountered a weird torch error when benchmarking the population model log prob.
RuntimeError: r.nvmlDeviceGetNvLinkRemoteDeviceType_ INTERNAL ASSERT FAILED at "../c10/cuda/driver_api.cpp":27, please report a bug to PyTorch. Can't find nvmlDeviceGetNvLinkRemoteDeviceType: /lib64/libnvidia-ml.so.1: undefined symbol: nvmlDeviceGetNvLinkRemoteDeviceType
The full traceback had some mentions of vmap-ish stuff. I'll paste that at the bottom. Most interestingly, this error only triggered when the batch size I was passing to the log prob hit ~1,000,000. So there's perhaps something that changes under the hood for large batch sizes (cool/interesting).
Full traceback...
batch size: 1048576 repeats: 2
Traceback (most recent call last):
File "/cfs/home/stth6288/pop-cosmos-public/bench.py", line 31, in <module>
p = model.log_prob(phi[:N])
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 632, in log_prob
xT, lp = self.score_model.solve_odes_forward(
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 292, in solve_odes_forward
state = odeint_adjoint(
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/adjoint.py", line 206, in odeint_adjoint
ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, o
ptions, event_fn, adjoint_rtol, adjoint_atol,
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/adjoint.py", line 24, in forward
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=opt
ions, event_fn=event_fn)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/odeint.py", line 80, in odeint
solution = solver.integrate(t)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/solvers.py", line 32, in integrate
self._before_integrate(t)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/rk_common.py", line 213, in _before_integrate
f0 = self.func(t[0], self.y0)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 197, in forward
return self.base_func(t, y)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 197, in forward
return self.base_func(t, y)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 144, in forward
f = self.base_func(t, _flat_to_shape(y, (), self.shapes))
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 150, in forward
divergence = torch.vmap(get_trace_of_jacobian, in_dims=in_dims)(x, self.co
nditional)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, *
*kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 281, in vmap_impl
return _flat_vmap(
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 403, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 146, in get_trace_of_jacobian
return torch.trace(torch.func.jacrev(f)(x_sample))
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 633, in wrapper_fn
flat_jacobians_per_input = compute_jacobian_stacked()
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 564, in compute_jacobian_stacked
chunked_result = vmap(vjp_fn)(basis)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, *
*kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 281, in vmap_impl
return _flat_vmap(
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 403, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 363, in wrapper
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangen
ts,
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 145, in _autograd_grad
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/__init__.py", line 412, in grad
result = _engine_run_backward(
File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engi
ne to run the backward pass
RuntimeError: r.nvmlDeviceGetNvLinkRemoteDeviceType_ INTERNAL ASSERT FAILED at
"../c10/cuda/driver_api.cpp":27, please report a bug to PyTorch. Can't find n
vmlDeviceGetNvLinkRemoteDeviceType: /lib64/libnvidia-ml.so.1: undefined symbol
: nvmlDeviceGetNvLinkRemoteDeviceType
This is pretty obscure, and possibly un-fixable. I encountered a weird torch error when benchmarking the population model log prob.
The full traceback had some mentions of vmap-ish stuff. I'll paste that at the bottom. Most interestingly, this error only triggered when the batch size I was passing to the log prob hit ~1,000,000. So there's perhaps something that changes under the hood for large batch sizes (cool/interesting).
A quick Google took me to this SO question:
https://stackoverflow.com/questions/79185168/runtimeerror-r-nvmldevicegetnvlinkremotedevicetype-internal-assert-failed-at
I got the error when running on Sunrise with
torch==2.3.0,cuda==11.8, and CUDADriver Version: 465.19.01. Downgrading totorch==2.2.2with thecu118wheel seemed to fix the problem. IDK what, if anything, we should do about this. Maybe relaxing the minimumtorchversion further totorch>=2.2.0might save others from this situation? Having said that, such a large batch size is right on the edge of what is doable (memory-wise) on a single GPU anyway, so maybe doesn't matter too much. 🤷♂️Full traceback...