Skip to content

CUDA driver sensitivity of vmap (?) #3

@stevet40

Description

@stevet40

This is pretty obscure, and possibly un-fixable. I encountered a weird torch error when benchmarking the population model log prob.

RuntimeError: r.nvmlDeviceGetNvLinkRemoteDeviceType_ INTERNAL ASSERT FAILED at "../c10/cuda/driver_api.cpp":27, please report a bug to PyTorch. Can't find nvmlDeviceGetNvLinkRemoteDeviceType: /lib64/libnvidia-ml.so.1: undefined symbol: nvmlDeviceGetNvLinkRemoteDeviceType

The full traceback had some mentions of vmap-ish stuff. I'll paste that at the bottom. Most interestingly, this error only triggered when the batch size I was passing to the log prob hit ~1,000,000. So there's perhaps something that changes under the hood for large batch sizes (cool/interesting).

A quick Google took me to this SO question:
https://stackoverflow.com/questions/79185168/runtimeerror-r-nvmldevicegetnvlinkremotedevicetype-internal-assert-failed-at

I got the error when running on Sunrise with torch==2.3.0, cuda==11.8, and CUDA Driver Version: 465.19.01. Downgrading to torch==2.2.2 with the cu118 wheel seemed to fix the problem. IDK what, if anything, we should do about this. Maybe relaxing the minimum torch version further to torch>=2.2.0 might save others from this situation? Having said that, such a large batch size is right on the edge of what is doable (memory-wise) on a single GPU anyway, so maybe doesn't matter too much. 🤷‍♂️

Full traceback...

batch size: 1048576 repeats: 2
Traceback (most recent call last):
  File "/cfs/home/stth6288/pop-cosmos-public/bench.py", line 31, in <module>
    p = model.log_prob(phi[:N])
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 632, in log_prob
    xT, lp = self.score_model.solve_odes_forward(
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 292, in solve_odes_forward
    state = odeint_adjoint(
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/adjoint.py", line 206, in odeint_adjoint
    ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, o
ptions, event_fn, adjoint_rtol, adjoint_atol,
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/adjoint.py", line 24, in forward
    ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=opt
ions, event_fn=event_fn)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/odeint.py", line 80, in odeint
    solution = solver.integrate(t)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/solvers.py", line 32, in integrate
    self._before_integrate(t)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/rk_common.py", line 213, in _before_integrate
    f0 = self.func(t[0], self.y0)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 197, in forward
    return self.base_func(t, y)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 197, in forward
    return self.base_func(t, y)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torchdiffeq/_impl/misc.py", line 144, in forward
    f = self.base_func(t, _flat_to_shape(y, (), self.shapes))
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 150, in forward
    divergence = torch.vmap(get_trace_of_jacobian, in_dims=in_dims)(x, self.co
nditional)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, *
*kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/flowfusion/diffusion.py", line 146, in get_trace_of_jacobian
    return torch.trace(torch.func.jacrev(f)(x_sample))
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 633, in wrapper_fn
    flat_jacobians_per_input = compute_jacobian_stacked()
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 564, in compute_jacobian_stacked
    chunked_result = vmap(vjp_fn)(basis)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, *
*kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 363, in wrapper
    result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangen
ts,
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/_functorch/eager_transforms.py", line 145, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/__init__.py", line 412, in grad
    result = _engine_run_backward(
  File "/cfs/home/stth6288/.local-co/envs/st_torch/lib/python3.10/site-package
s/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engi
ne to run the backward pass
RuntimeError: r.nvmlDeviceGetNvLinkRemoteDeviceType_ INTERNAL ASSERT FAILED at
 "../c10/cuda/driver_api.cpp":27, please report a bug to PyTorch. Can't find n
vmlDeviceGetNvLinkRemoteDeviceType: /lib64/libnvidia-ml.so.1: undefined symbol
: nvmlDeviceGetNvLinkRemoteDeviceType

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingquestionFurther information is requestedwontfixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions