diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index bd57cac5..7329e460 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -137,7 +137,7 @@ jobs: ls - name: Run Python tests run: | - ${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest test -v --tb=short + ${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest test -v --tb=short -k test_num_channels - name: Run Python benchmark run: | ${CONDA_RUN} time python benchmarks/decoders/gpu_benchmark.py --devices=cuda:0,cpu --resize_devices=none diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml index 1855e904..eaaa4235 100644 --- a/.github/workflows/linux_wheel.yaml +++ b/.github/workflows/linux_wheel.yaml @@ -123,4 +123,4 @@ jobs: ls - name: Run Python tests run: | - pytest test -vvv + pytest test -vvv -k test_num_channels diff --git a/.github/workflows/macos_wheel.yaml b/.github/workflows/macos_wheel.yaml index ee436b7a..2c39e9dc 100644 --- a/.github/workflows/macos_wheel.yaml +++ b/.github/workflows/macos_wheel.yaml @@ -122,4 +122,4 @@ jobs: - name: Run Python tests run: | - pytest test -vvv + pytest test -vvv -k test_num_channels diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 5a6a0d7e..ce33613c 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -99,8 +99,18 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { } // namespace AudioEncoder::~AudioEncoder() { - if (avFormatContext_ && avFormatContext_->pb && !avioContextHolder_) { - avio_close(avFormatContext_->pb); + close_avio(); +} + +void AudioEncoder::close_avio() { + if (avFormatContext_ && avFormatContext_->pb) { + avio_flush(avFormatContext_->pb); + + if (!avioContextHolder_) { + avio_close(avFormatContext_->pb); + // avoids closing again in destructor, which would segfault. + avFormatContext_->pb = nullptr; + } } } @@ -308,6 +318,8 @@ void AudioEncoder::encode() { status == AVSUCCESS, "Error in: av_write_trailer", getFFMPEGErrorStringFromErrorCode(status)); + + close_avio(); } UniqueAVFrame AudioEncoder::maybeConvertAVFrame(const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 04ed6a13..723849f0 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -33,6 +33,7 @@ class AudioEncoder { void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket); void flushBuffers(); + void close_avio(); UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; diff --git a/test/test_encoders.py b/test/test_encoders.py index d432263a..73410831 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -367,8 +367,9 @@ def test_contiguity(self): @pytest.mark.parametrize("num_channels_input", (1, 2)) @pytest.mark.parametrize("num_channels_output", (1, 2, None)) @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("i", range(10_000)) def test_num_channels( - self, num_channels_input, num_channels_output, method, tmp_path + self, num_channels_input, num_channels_output, method, tmp_path, i ): # We just check that the num_channels parameter is respected. # Correctness is checked in other tests (like test_against_cli())