Skip to content

Performance Degradation in implicit_acoustic_marmousi_jax Example After Long Training Runs #5

@Parallelopiped

Description

@Parallelopiped

Describe the bug
In the example located at seistorch/examples/nn_embedded_fwi/model_representation/implicit_acoustic_marmousi_jax/, I have observed a significant slowdown in iteration speed during the implicit FWI calculation when the number of training iterations is large (e.g., greater than 10,000). Initially, the training speed on my setup is around 12.6 iterations per second (it/s), but it gradually decreases to around 7 it/s or even less, effectively halving the iteration speed. This performance degradation is accompanied by a noticeable drop in both CPU and GPU utilization, as well as power consumption.

I also tested this behavior using a single GPU, but the issue persists.

To Reproduce
Steps to reproduce the behavior:

  1. Go to 'seistorch/examples/nn_embedded_fwi/model_representation/implicit_acoustic_marmousi_jax/'
  2. After forward calc completed, run python step2_ifwi.py configure4x2.py
  3. Monitor the training speed, CPU/GPU utilization, and power consumption over time.

Expected behavior
The iteration speed should remain consistent throughout the training process without significant drops in performance.

System Information

  • OS: Ubuntu 20.04.6 LTS x86_64
  • Kernel: 5.15.0-105-generic
  • Shell: zsh 5.8
  • CPU: AMD Ryzen 9 7900X (24 cores) @ 4.700GHz
  • GPU: NVIDIA RTX 4090 (Dual setup)
  • Memory: 128GiB

Environment:

  • jax 0.4.35 pypi_0 pypi
  • jaxlib 0.4.34 pypi_0 pypi
  • pytorch 2.5.1 py3.10_cuda12.4_cudnn9.1.0_0
  • cudatoolkit / nvcc: 12.1

Would appreciate any insights into what might be causing this slowdown and potential solutions to address it.👍

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions