From 4fd3c8567726f43612c5e2edf50bbc4b4123b84e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 13:10:11 +0100 Subject: [PATCH 01/20] WIP --- src/torchcodec/_core/AVIOFileLikeContext.cpp | 43 ++++++++++++++++++-- src/torchcodec/_core/AVIOFileLikeContext.h | 14 +++++-- src/torchcodec/_core/Encoder.cpp | 38 ++++++++++++++--- src/torchcodec/_core/Encoder.h | 12 +++++- src/torchcodec/_core/__init__.py | 1 + src/torchcodec/_core/ops.py | 22 ++++++++++ src/torchcodec/_core/pybind_ops.cpp | 25 ++++++++++++ src/torchcodec/encoders/_audio_encoder.py | 32 +++++++++++++++ 8 files changed, 174 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 5497f89bb..bbba5301a 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -10,20 +10,38 @@ namespace facebook::torchcodec { AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) + : AVIOFileLikeContext(fileLike, false) {} + +std::unique_ptr AVIOFileLikeContext::createForWriting(py::object fileLike) { + return std::unique_ptr(new AVIOFileLikeContext(fileLike, true)); +} + +AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isWriteMode) : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - TORCH_CHECK( - py::hasattr(fileLike, "read"), - "File like object must implement a read method."); + if (isWriteMode) { + TORCH_CHECK( + py::hasattr(fileLike, "write"), + "File like object must implement a write method."); + } else { + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method."); + } TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, nullptr, &seek, &fileLike_); + + if (isWriteMode) { + createAVIOContext(nullptr, &write, &seek, &fileLike_); + } else { + createAVIOContext(&read, nullptr, &seek, &fileLike_); + } } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { @@ -66,6 +84,23 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { return totalNumRead == 0 ? AVERROR_EOF : totalNumRead; } +int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + + // Note that we acquire the GIL outside of the loop. This is likely more + // efficient than releasing and acquiring it each loop iteration. + py::gil_scoped_acquire gil; + + // Create a bytes object from the buffer + py::bytes data_bytes(reinterpret_cast(buf), buf_size); + + // Call the Python write method + auto bytes_written = (*fileLike)->attr("write")(data_bytes); + + // Python write() should return the number of bytes written + return py::cast(bytes_written); +} + int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { // We do not know the file size. if (whence == AVSEEK_SIZE) { diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 3e80f1c6f..19cd33107 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -15,14 +15,22 @@ namespace py = pybind11; namespace facebook::torchcodec { -// Enables uers to pass in a Python file-like object. We then forward all read -// and seek calls back up to the methods on the Python object. -class AVIOFileLikeContext : public AVIOContextHolder { +// Enables users to pass in a Python file-like object. We then forward all read, +// write and seek calls back up to the methods on the Python object. +class __attribute__((visibility("default"))) AVIOFileLikeContext : public AVIOContextHolder { public: + // Constructor for reading from a file-like object explicit AVIOFileLikeContext(py::object fileLike); + + // Constructor for writing to a file-like object + static std::unique_ptr createForWriting(py::object fileLike); private: + // Private constructor for write mode + AVIOFileLikeContext(py::object fileLike, bool isWriteMode); + static int read(void* opaque, uint8_t* buf, int buf_size); + static int write(void* opaque, const uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); // Note that we dynamically allocate the Python object because we need to diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 905794328..d106d2c35 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -138,7 +138,7 @@ AudioEncoder::AudioEncoder( std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioContextHolder_(std::move(avioContextHolder)) { + avioTensorContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -153,7 +153,35 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - avFormatContext_->pb = avioContextHolder_->getAVIOContext(); + avFormatContext_->pb = avioTensorContextHolder_->getAVIOContext(); + + initializeEncoder(sampleRate, audioStreamOptions); +} + +AudioEncoder::AudioEncoder( + const torch::Tensor& samples, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const AudioStreamOptions& audioStreamOptions) + : samples_(validateSamples(samples)), + avioFileLikeContextHolder_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + AVFormatContext* avFormatContext = nullptr; + + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext for file-like object. ", + "Check the desired format? Got format=", + formatName, + ". ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioFileLikeContextHolder_->getAVIOContext(); initializeEncoder(sampleRate, audioStreamOptions); } @@ -217,10 +245,10 @@ void AudioEncoder::initializeEncoder( torch::Tensor AudioEncoder::encodeToTensor() { TORCH_CHECK( - avioContextHolder_ != nullptr, - "Cannot encode to tensor, avio context doesn't exist."); + avioTensorContextHolder_ != nullptr, + "Cannot encode to tensor, avio tensor context doesn't exist."); encode(); - return avioContextHolder_->getOutputTensor(); + return avioTensorContextHolder_->getOutputTensor(); } void AudioEncoder::encode() { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index ea5009018..0e5d7c8b0 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,6 +1,7 @@ #pragma once #include #include "src/torchcodec/_core/AVIOTensorContext.h" +#include "src/torchcodec/_core/AVIOFileLikeContext.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/StreamOptions.h" @@ -26,6 +27,12 @@ class AudioEncoder { std::string_view formatName, std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions); + AudioEncoder( + const torch::Tensor& samples, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const AudioStreamOptions& audioStreamOptions); void encode(); torch::Tensor encodeToTensor(); @@ -50,7 +57,10 @@ class AudioEncoder { const torch::Tensor samples_; // Stores the AVIOContext for the output tensor buffer. - std::unique_ptr avioContextHolder_; + std::unique_ptr avioTensorContextHolder_; + + // Stores the AVIOContext for file-like object output. + std::unique_ptr avioFileLikeContextHolder_; bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 77fc7b857..3d340bff8 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -23,6 +23,7 @@ create_from_file_like, create_from_tensor, encode_audio_to_file, + encode_audio_to_file_like, encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a68b51e22..b4392c93d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -153,6 +153,28 @@ def create_from_file_like( return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) +def encode_audio_to_file_like( + samples: torch.Tensor, + sample_rate: int, + format: str, + file_like: Union[io.RawIOBase, io.BufferedIOBase], + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, +) -> None: + """Encode audio samples to a file-like object. + + Args: + samples: Audio samples tensor + sample_rate: Sample rate in Hz + format: Audio format (e.g., "wav", "mp3", "flac") + file_like: File-like object that supports write() and seek() methods + bit_rate: Optional bit rate for encoding + num_channels: Optional number of output channels + """ + assert _pybind_ops is not None + _pybind_ops.encode_audio_to_file_like(samples, sample_rate, format, file_like, bit_rate, num_channels) + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6f873f5af..616643f3e 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -11,6 +11,8 @@ #include "src/torchcodec/_core/AVIOFileLikeContext.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" +#include "src/torchcodec/_core/Encoder.h" +#include "src/torchcodec/_core/StreamOptions.h" namespace py = pybind11; @@ -38,8 +40,31 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } +void encode_audio_to_file_like( + const torch::Tensor& samples, + int64_t sample_rate, + const std::string& format, + py::object file_like, + std::optional bit_rate = std::nullopt, + std::optional num_channels = std::nullopt) { + AudioStreamOptions audioStreamOptions; + audioStreamOptions.bitRate = bit_rate; + audioStreamOptions.numChannels = num_channels; + + auto avioContextHolder = AVIOFileLikeContext::createForWriting(file_like); + + AudioEncoder encoder( + samples, + static_cast(sample_rate), + format, + std::move(avioContextHolder), + audioStreamOptions); + encoder.encode(); +} + PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); + m.def("encode_audio_to_file_like", &encode_audio_to_file_like); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index 742ea9080..edaa9c5f0 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -98,3 +98,35 @@ def to_tensor( bit_rate=bit_rate, num_channels=num_channels, ) + + def encode_to_file_like( + self, + file_like, + format: str, + *, + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, + ) -> None: + """Encode samples into a file-like object. + + Args: + file_like: A file-like object that supports write() and seek() methods, + such as io.BytesIO(), an open file in binary write mode, etc. + format (str): The format of the encoded samples, e.g. "mp3", "wav" + or "flac". + bit_rate (int, optional): The output bit rate. Encoders typically + support a finite set of bit rate values, so ``bit_rate`` will be + matched to one of those supported values. The default is chosen + by FFmpeg. + num_channels (int, optional): The number of channels of the encoded + output samples. By default, the number of channels of the input + ``samples`` is used. + """ + _core.encode_audio_to_file_like( + samples=self._samples, + sample_rate=self._sample_rate, + format=format, + file_like=file_like, + bit_rate=bit_rate, + num_channels=num_channels, + ) From 8318fbe0f9bece89105c3ee7c704b7059d77ad13 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 13:39:57 +0100 Subject: [PATCH 02/20] Add tests --- src/torchcodec/_core/ops.py | 11 +- src/torchcodec/_core/pybind_ops.cpp | 26 ++- test/test_encoders.py | 265 ++++++++++++++++++++++++++++ 3 files changed, 298 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index b4392c93d..c37a964d6 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -172,7 +172,16 @@ def encode_audio_to_file_like( num_channels: Optional number of output channels """ assert _pybind_ops is not None - _pybind_ops.encode_audio_to_file_like(samples, sample_rate, format, file_like, bit_rate, num_channels) + + # Convert tensor to raw bytes and shape info for pybind + samples_contiguous = samples.contiguous() + samples_numpy = samples_contiguous.detach().cpu().numpy() + samples_bytes = samples_numpy.tobytes() + samples_shape = tuple(samples_contiguous.shape) + + _pybind_ops.encode_audio_to_file_like( + samples_bytes, samples_shape, sample_rate, format, file_like, bit_rate, num_channels + ) # ============================== diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 616643f3e..6c2556336 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -15,6 +16,7 @@ #include "src/torchcodec/_core/StreamOptions.h" namespace py = pybind11; +using namespace py::literals; namespace facebook::torchcodec { @@ -40,13 +42,26 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } -void encode_audio_to_file_like( - const torch::Tensor& samples, +int64_t encode_audio_to_file_like( + py::bytes samples_data, + py::tuple samples_shape, int64_t sample_rate, const std::string& format, py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { + + // Convert Python data back to tensor + auto shape_vec = samples_shape.cast>(); + std::string samples_str = samples_data; + + // Create tensor from raw data + auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); + auto samples = torch::from_blob( + const_cast(static_cast(samples_str.data())), + shape_vec, + tensor_options).clone(); // Clone to ensure memory ownership + AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; @@ -60,11 +75,16 @@ void encode_audio_to_file_like( std::move(avioContextHolder), audioStreamOptions); encoder.encode(); + + // Return 0 to indicate success + return 0; } PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); - m.def("encode_audio_to_file_like", &encode_audio_to_file_like); + m.def("encode_audio_to_file_like", &encode_audio_to_file_like, + "samples_data"_a, "samples_shape"_a, "sample_rate"_a, "format"_a, + "file_like"_a, "bit_rate"_a = py::none(), "num_channels"_a = py::none()); } } // namespace facebook::torchcodec diff --git a/test/test_encoders.py b/test/test_encoders.py index 284d053e3..ba4e4ac96 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,4 +1,5 @@ import json +import io import os import re import subprocess @@ -384,3 +385,267 @@ def test_1d_samples(self): AudioEncoder(samples_1d, sample_rate=sample_rate).to_tensor("wav"), AudioEncoder(samples_2d, sample_rate=sample_rate).to_tensor("wav"), ) + + # Test cases for encode_to_file_like method + def test_encode_to_file_like_basic(self, tmp_path): + """Test basic functionality of encode_to_file_like with BytesIO.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # Test with BytesIO + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="wav") + + # Verify data was written + assert buffer.tell() > 0 + + # Verify we can decode the result + buffer.seek(0) + decoded_samples = self.decode(buffer.getvalue()) + + # For lossless format like WAV, should be very close + torch.testing.assert_close( + decoded_samples.data, source_samples, rtol=0, atol=1e-4 + ) + + def test_encode_to_file_like_different_formats(self, tmp_path): + """Test encode_to_file_like with different audio formats.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + formats_to_test = ["wav", "flac", "mp3"] + + for format_name in formats_to_test: + if get_ffmpeg_major_version() == 4 and format_name == "wav": + continue # Skip WAV on FFmpeg 4 due to swresample issues + + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format=format_name) + + # Verify data was written + assert buffer.tell() > 0, f"No data written for format {format_name}" + + # Verify we can decode the result (for lossless formats) + buffer.seek(0) + decoded_samples = self.decode(buffer.getvalue()) + assert decoded_samples.data.shape[0] == source_samples.shape[0] # Same number of channels + + def test_encode_to_file_like_with_parameters(self, tmp_path): + """Test encode_to_file_like with different encoding parameters.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # Test with different bit rates + for bit_rate in [128_000, 256_000]: + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="mp3", bit_rate=bit_rate) + assert buffer.tell() > 0 + + # Test with different channel counts + for num_channels in [1, 2]: + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="wav", num_channels=num_channels) + assert buffer.tell() > 0 + + # Verify channel count + buffer.seek(0) + decoded_samples = self.decode(buffer.getvalue()) + assert decoded_samples.data.shape[0] == num_channels + + def test_encode_to_file_like_vs_to_tensor(self, tmp_path): + """Test that encode_to_file_like produces the same output as to_tensor.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # Get tensor output + tensor_output = encoder.to_tensor(format="wav") + + # Get file-like output + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="wav") + buffer.seek(0) + file_like_output = torch.frombuffer(buffer.getvalue(), dtype=torch.uint8) + + # They should be identical + torch.testing.assert_close(tensor_output, file_like_output) + + def test_encode_to_file_like_vs_to_file(self, tmp_path): + """Test that encode_to_file_like produces the same output as to_file.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # Get file output + file_path = tmp_path / "test.wav" + encoder.to_file(dest=str(file_path)) # Convert to string + + with open(file_path, "rb") as f: + file_output = f.read() + + # Get file-like output + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="wav") + buffer.seek(0) + file_like_output = buffer.getvalue() + + # They should be identical + assert file_output == file_like_output + + def test_encode_to_file_like_custom_file_object(self, tmp_path): + """Test encode_to_file_like with a custom file-like object.""" + + class CustomFileObject: + def __init__(self): + self.data = b"" + self.position = 0 + + def write(self, data): + if isinstance(data, (bytes, bytearray)): + self.data += data + self.position += len(data) + return len(data) + else: + raise TypeError("Expected bytes-like object") + + def seek(self, offset, whence=0): + if whence == 0: # SEEK_SET + self.position = offset + elif whence == 1: # SEEK_CUR + self.position += offset + elif whence == 2: # SEEK_END + self.position = len(self.data) + offset + return self.position + + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + custom_file = CustomFileObject() + encoder.encode_to_file_like(custom_file, format="wav") + + # Verify data was written + assert len(custom_file.data) > 0 + + # Verify we can decode the result + decoded_samples = self.decode(custom_file.data) + + # Allow for small differences in sample count due to encoding/padding + # Check that the shapes are approximately the same (within a few samples) + assert decoded_samples.data.shape[0] == source_samples.shape[0] # Same number of channels + sample_diff = abs(decoded_samples.data.shape[1] - source_samples.shape[1]) + assert sample_diff <= 10, f"Sample count difference too large: {sample_diff}" + + # Compare the overlapping portion + min_samples = min(decoded_samples.data.shape[1], source_samples.shape[1]) + torch.testing.assert_close( + decoded_samples.data[:, :min_samples], + source_samples[:, :min_samples], + rtol=0, atol=1e-4 + ) + + def test_encode_to_file_like_real_file(self, tmp_path): + """Test encode_to_file_like with a real file opened in binary write mode.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + file_path = tmp_path / "test_file_like.wav" + + # Use encode_to_file_like with a real file + with open(file_path, "wb") as f: + encoder.encode_to_file_like(f, format="wav") + + # Verify the file was created and has content + assert file_path.exists() + assert file_path.stat().st_size > 0 + + # Verify we can decode the result + decoded_samples = self.decode(str(file_path)) + torch.testing.assert_close( + decoded_samples.data, source_samples, rtol=0, atol=1e-4 + ) + + def test_encode_to_file_like_bad_input(self): + """Test encode_to_file_like with invalid inputs.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # Test with object missing write method + class NoWriteMethod: + def seek(self, offset, whence=0): + return 0 + + with pytest.raises(RuntimeError, match="File like object must implement a write method"): + encoder.encode_to_file_like(NoWriteMethod(), format="wav") + + # Test with object missing seek method + class NoSeekMethod: + def write(self, data): + return len(data) + + with pytest.raises(RuntimeError, match="File like object must implement a seek method"): + encoder.encode_to_file_like(NoSeekMethod(), format="wav") + + # Test with invalid format + buffer = io.BytesIO() + with pytest.raises(RuntimeError, match="Check the desired format"): + encoder.encode_to_file_like(buffer, format="invalid_format") + + # Test with invalid bit rate + buffer = io.BytesIO() + with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): + encoder.encode_to_file_like(buffer, format="wav", bit_rate=-1) + + def test_encode_to_file_like_multiple_calls(self): + """Test that encode_to_file_like can be called multiple times on the same encoder.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + # First call + buffer1 = io.BytesIO() + encoder.encode_to_file_like(buffer1, format="wav") + + # Second call with different format + buffer2 = io.BytesIO() + encoder.encode_to_file_like(buffer2, format="flac") + + # Both should have data + assert buffer1.tell() > 0 + assert buffer2.tell() > 0 + + # Verify both can be decoded + buffer1.seek(0) + buffer2.seek(0) + decoded1 = self.decode(buffer1.getvalue()) + decoded2 = self.decode(buffer2.getvalue()) + + # Both should be close to the original (within format limitations) + torch.testing.assert_close( + decoded1.data, source_samples, rtol=0, atol=1e-4 + ) + torch.testing.assert_close( + decoded2.data, source_samples, rtol=0, atol=1e-4 + ) + + def test_encode_to_file_like_empty_samples(self): + """Test encode_to_file_like with very short audio samples.""" + # Create very short audio sample + short_samples = torch.rand(2, 100) # 100 samples per channel + encoder = AudioEncoder(short_samples, sample_rate=16_000) + + buffer = io.BytesIO() + encoder.encode_to_file_like(buffer, format="wav") + + # Should still produce valid output + assert buffer.tell() > 0 + + # Verify it can be decoded + buffer.seek(0) + decoded_samples = self.decode(buffer.getvalue()) + assert decoded_samples.data.shape[0] == 2 # Same number of channels From 10cdd5b3ba9b56c4ee527d045a8433d9428fa7fe Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 13:40:07 +0100 Subject: [PATCH 03/20] Linter --- src/torchcodec/_core/AVIOFileLikeContext.cpp | 12 +++--- src/torchcodec/_core/AVIOFileLikeContext.h | 10 +++-- src/torchcodec/_core/Encoder.cpp | 2 +- src/torchcodec/_core/Encoder.h | 2 +- src/torchcodec/_core/ops.py | 14 +++++-- src/torchcodec/_core/pybind_ops.cpp | 42 ++++++++++++-------- test/test_encoders.py | 37 +++++++++-------- 7 files changed, 71 insertions(+), 48 deletions(-) diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index bbba5301a..e79e47707 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -12,8 +12,10 @@ namespace facebook::torchcodec { AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) : AVIOFileLikeContext(fileLike, false) {} -std::unique_ptr AVIOFileLikeContext::createForWriting(py::object fileLike) { - return std::unique_ptr(new AVIOFileLikeContext(fileLike, true)); +std::unique_ptr AVIOFileLikeContext::createForWriting( + py::object fileLike) { + return std::unique_ptr( + new AVIOFileLikeContext(fileLike, true)); } AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isWriteMode) @@ -36,7 +38,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isWriteMode) py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - + if (isWriteMode) { createAVIOContext(nullptr, &write, &seek, &fileLike_); } else { @@ -93,10 +95,10 @@ int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { // Create a bytes object from the buffer py::bytes data_bytes(reinterpret_cast(buf), buf_size); - + // Call the Python write method auto bytes_written = (*fileLike)->attr("write")(data_bytes); - + // Python write() should return the number of bytes written return py::cast(bytes_written); } diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 19cd33107..9a890e783 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -17,18 +17,20 @@ namespace facebook::torchcodec { // Enables users to pass in a Python file-like object. We then forward all read, // write and seek calls back up to the methods on the Python object. -class __attribute__((visibility("default"))) AVIOFileLikeContext : public AVIOContextHolder { +class __attribute__((visibility("default"))) AVIOFileLikeContext + : public AVIOContextHolder { public: // Constructor for reading from a file-like object explicit AVIOFileLikeContext(py::object fileLike); - + // Constructor for writing to a file-like object - static std::unique_ptr createForWriting(py::object fileLike); + static std::unique_ptr createForWriting( + py::object fileLike); private: // Private constructor for write mode AVIOFileLikeContext(py::object fileLike, bool isWriteMode); - + static int read(void* opaque, uint8_t* buf, int buf_size); static int write(void* opaque, const uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index d106d2c35..45ac67ed1 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -168,7 +168,7 @@ AudioEncoder::AudioEncoder( avioFileLikeContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; - + int status = avformat_alloc_output_context2( &avFormatContext, nullptr, formatName.data(), nullptr); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 0e5d7c8b0..cf205daa3 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,7 +1,7 @@ #pragma once #include -#include "src/torchcodec/_core/AVIOTensorContext.h" #include "src/torchcodec/_core/AVIOFileLikeContext.h" +#include "src/torchcodec/_core/AVIOTensorContext.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/StreamOptions.h" diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index c37a964d6..2f1f38a66 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -162,7 +162,7 @@ def encode_audio_to_file_like( num_channels: Optional[int] = None, ) -> None: """Encode audio samples to a file-like object. - + Args: samples: Audio samples tensor sample_rate: Sample rate in Hz @@ -172,15 +172,21 @@ def encode_audio_to_file_like( num_channels: Optional number of output channels """ assert _pybind_ops is not None - + # Convert tensor to raw bytes and shape info for pybind samples_contiguous = samples.contiguous() samples_numpy = samples_contiguous.detach().cpu().numpy() samples_bytes = samples_numpy.tobytes() samples_shape = tuple(samples_contiguous.shape) - + _pybind_ops.encode_audio_to_file_like( - samples_bytes, samples_shape, sample_rate, format, file_like, bit_rate, num_channels + samples_bytes, + samples_shape, + sample_rate, + format, + file_like, + bit_rate, + num_channels, ) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6c2556336..52925e161 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -11,8 +11,8 @@ #include #include "src/torchcodec/_core/AVIOFileLikeContext.h" -#include "src/torchcodec/_core/SingleStreamDecoder.h" #include "src/torchcodec/_core/Encoder.h" +#include "src/torchcodec/_core/SingleStreamDecoder.h" #include "src/torchcodec/_core/StreamOptions.h" namespace py = pybind11; @@ -50,41 +50,49 @@ int64_t encode_audio_to_file_like( py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { - // Convert Python data back to tensor auto shape_vec = samples_shape.cast>(); std::string samples_str = samples_data; - + // Create tensor from raw data auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); - auto samples = torch::from_blob( - const_cast(static_cast(samples_str.data())), - shape_vec, - tensor_options).clone(); // Clone to ensure memory ownership - + auto samples = + torch::from_blob( + const_cast(static_cast(samples_str.data())), + shape_vec, + tensor_options) + .clone(); // Clone to ensure memory ownership + AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; - + auto avioContextHolder = AVIOFileLikeContext::createForWriting(file_like); - + AudioEncoder encoder( - samples, - static_cast(sample_rate), + samples, + static_cast(sample_rate), format, - std::move(avioContextHolder), + std::move(avioContextHolder), audioStreamOptions); encoder.encode(); - + // Return 0 to indicate success return 0; } PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); - m.def("encode_audio_to_file_like", &encode_audio_to_file_like, - "samples_data"_a, "samples_shape"_a, "sample_rate"_a, "format"_a, - "file_like"_a, "bit_rate"_a = py::none(), "num_channels"_a = py::none()); + m.def( + "encode_audio_to_file_like", + &encode_audio_to_file_like, + "samples_data"_a, + "samples_shape"_a, + "sample_rate"_a, + "format"_a, + "file_like"_a, + "bit_rate"_a = py::none(), + "num_channels"_a = py::none()); } } // namespace facebook::torchcodec diff --git a/test/test_encoders.py b/test/test_encoders.py index ba4e4ac96..77d807187 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,5 +1,5 @@ -import json import io +import json import os import re import subprocess @@ -430,7 +430,9 @@ def test_encode_to_file_like_different_formats(self, tmp_path): # Verify we can decode the result (for lossless formats) buffer.seek(0) decoded_samples = self.decode(buffer.getvalue()) - assert decoded_samples.data.shape[0] == source_samples.shape[0] # Same number of channels + assert ( + decoded_samples.data.shape[0] == source_samples.shape[0] + ) # Same number of channels def test_encode_to_file_like_with_parameters(self, tmp_path): """Test encode_to_file_like with different encoding parameters.""" @@ -532,19 +534,22 @@ def seek(self, offset, whence=0): # Verify we can decode the result decoded_samples = self.decode(custom_file.data) - + # Allow for small differences in sample count due to encoding/padding # Check that the shapes are approximately the same (within a few samples) - assert decoded_samples.data.shape[0] == source_samples.shape[0] # Same number of channels + assert ( + decoded_samples.data.shape[0] == source_samples.shape[0] + ) # Same number of channels sample_diff = abs(decoded_samples.data.shape[1] - source_samples.shape[1]) assert sample_diff <= 10, f"Sample count difference too large: {sample_diff}" - + # Compare the overlapping portion min_samples = min(decoded_samples.data.shape[1], source_samples.shape[1]) torch.testing.assert_close( - decoded_samples.data[:, :min_samples], - source_samples[:, :min_samples], - rtol=0, atol=1e-4 + decoded_samples.data[:, :min_samples], + source_samples[:, :min_samples], + rtol=0, + atol=1e-4, ) def test_encode_to_file_like_real_file(self, tmp_path): @@ -580,7 +585,9 @@ class NoWriteMethod: def seek(self, offset, whence=0): return 0 - with pytest.raises(RuntimeError, match="File like object must implement a write method"): + with pytest.raises( + RuntimeError, match="File like object must implement a write method" + ): encoder.encode_to_file_like(NoWriteMethod(), format="wav") # Test with object missing seek method @@ -588,7 +595,9 @@ class NoSeekMethod: def write(self, data): return len(data) - with pytest.raises(RuntimeError, match="File like object must implement a seek method"): + with pytest.raises( + RuntimeError, match="File like object must implement a seek method" + ): encoder.encode_to_file_like(NoSeekMethod(), format="wav") # Test with invalid format @@ -626,12 +635,8 @@ def test_encode_to_file_like_multiple_calls(self): decoded2 = self.decode(buffer2.getvalue()) # Both should be close to the original (within format limitations) - torch.testing.assert_close( - decoded1.data, source_samples, rtol=0, atol=1e-4 - ) - torch.testing.assert_close( - decoded2.data, source_samples, rtol=0, atol=1e-4 - ) + torch.testing.assert_close(decoded1.data, source_samples, rtol=0, atol=1e-4) + torch.testing.assert_close(decoded2.data, source_samples, rtol=0, atol=1e-4) def test_encode_to_file_like_empty_samples(self): """Test encode_to_file_like with very short audio samples.""" From 67962d8bc83ea98898fe5f213fdfc7a134309e12 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 13:55:47 +0100 Subject: [PATCH 04/20] Renaming --- src/torchcodec/encoders/_audio_encoder.py | 2 +- test/test_encoders.py | 89 ++++++++++++----------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index edaa9c5f0..75f35cda8 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -99,7 +99,7 @@ def to_tensor( num_channels=num_channels, ) - def encode_to_file_like( + def to_file_like( self, file_like, format: str, diff --git a/test/test_encoders.py b/test/test_encoders.py index 77d807187..f1746e329 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -136,13 +136,16 @@ def test_bad_input(self): ): encoder.to_tensor(format=bad_format) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parametrized(self, method, tmp_path): - valid_params = ( - dict(dest=str(tmp_path / "output.mp3")) - if method == "to_file" - else dict(format="mp3") - ) + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp3")) + elif method == "to_tensor": + valid_params = dict(format="mp3") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp3") + else: + raise ValueError(f"Unknown method: {method}") decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -386,16 +389,16 @@ def test_1d_samples(self): AudioEncoder(samples_2d, sample_rate=sample_rate).to_tensor("wav"), ) - # Test cases for encode_to_file_like method - def test_encode_to_file_like_basic(self, tmp_path): - """Test basic functionality of encode_to_file_like with BytesIO.""" + # Test cases for to_file_like method + def test_to_file_like_basic(self, tmp_path): + """Test basic functionality of to_file_like with BytesIO.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) # Test with BytesIO buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="wav") + encoder.to_file_like(buffer, format="wav") # Verify data was written assert buffer.tell() > 0 @@ -409,8 +412,8 @@ def test_encode_to_file_like_basic(self, tmp_path): decoded_samples.data, source_samples, rtol=0, atol=1e-4 ) - def test_encode_to_file_like_different_formats(self, tmp_path): - """Test encode_to_file_like with different audio formats.""" + def test_to_file_like_different_formats(self, tmp_path): + """Test to_file_like with different audio formats.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -422,7 +425,7 @@ def test_encode_to_file_like_different_formats(self, tmp_path): continue # Skip WAV on FFmpeg 4 due to swresample issues buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format=format_name) + encoder.to_file_like(buffer, format=format_name) # Verify data was written assert buffer.tell() > 0, f"No data written for format {format_name}" @@ -434,8 +437,8 @@ def test_encode_to_file_like_different_formats(self, tmp_path): decoded_samples.data.shape[0] == source_samples.shape[0] ) # Same number of channels - def test_encode_to_file_like_with_parameters(self, tmp_path): - """Test encode_to_file_like with different encoding parameters.""" + def test_to_file_like_with_parameters(self, tmp_path): + """Test to_file_like with different encoding parameters.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -443,13 +446,13 @@ def test_encode_to_file_like_with_parameters(self, tmp_path): # Test with different bit rates for bit_rate in [128_000, 256_000]: buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="mp3", bit_rate=bit_rate) + encoder.to_file_like(buffer, format="mp3", bit_rate=bit_rate) assert buffer.tell() > 0 # Test with different channel counts for num_channels in [1, 2]: buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="wav", num_channels=num_channels) + encoder.to_file_like(buffer, format="wav", num_channels=num_channels) assert buffer.tell() > 0 # Verify channel count @@ -457,8 +460,8 @@ def test_encode_to_file_like_with_parameters(self, tmp_path): decoded_samples = self.decode(buffer.getvalue()) assert decoded_samples.data.shape[0] == num_channels - def test_encode_to_file_like_vs_to_tensor(self, tmp_path): - """Test that encode_to_file_like produces the same output as to_tensor.""" + def test_to_file_like_vs_to_tensor(self, tmp_path): + """Test that to_file_like produces the same output as to_tensor.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -468,15 +471,15 @@ def test_encode_to_file_like_vs_to_tensor(self, tmp_path): # Get file-like output buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="wav") + encoder.to_file_like(buffer, format="wav") buffer.seek(0) file_like_output = torch.frombuffer(buffer.getvalue(), dtype=torch.uint8) # They should be identical torch.testing.assert_close(tensor_output, file_like_output) - def test_encode_to_file_like_vs_to_file(self, tmp_path): - """Test that encode_to_file_like produces the same output as to_file.""" + def test_to_file_like_vs_to_file(self, tmp_path): + """Test that to_file_like produces the same output as to_file.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -490,15 +493,15 @@ def test_encode_to_file_like_vs_to_file(self, tmp_path): # Get file-like output buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="wav") + encoder.to_file_like(buffer, format="wav") buffer.seek(0) file_like_output = buffer.getvalue() # They should be identical assert file_output == file_like_output - def test_encode_to_file_like_custom_file_object(self, tmp_path): - """Test encode_to_file_like with a custom file-like object.""" + def test_to_file_like_custom_file_object(self, tmp_path): + """Test to_file_like with a custom file-like object.""" class CustomFileObject: def __init__(self): @@ -527,7 +530,7 @@ def seek(self, offset, whence=0): encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) custom_file = CustomFileObject() - encoder.encode_to_file_like(custom_file, format="wav") + encoder.to_file_like(custom_file, format="wav") # Verify data was written assert len(custom_file.data) > 0 @@ -552,17 +555,17 @@ def seek(self, offset, whence=0): atol=1e-4, ) - def test_encode_to_file_like_real_file(self, tmp_path): - """Test encode_to_file_like with a real file opened in binary write mode.""" + def test_to_file_like_real_file(self, tmp_path): + """Test to_file_like with a real file opened in binary write mode.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) file_path = tmp_path / "test_file_like.wav" - # Use encode_to_file_like with a real file + # Use to_file_like with a real file with open(file_path, "wb") as f: - encoder.encode_to_file_like(f, format="wav") + encoder.to_file_like(f, format="wav") # Verify the file was created and has content assert file_path.exists() @@ -574,8 +577,8 @@ def test_encode_to_file_like_real_file(self, tmp_path): decoded_samples.data, source_samples, rtol=0, atol=1e-4 ) - def test_encode_to_file_like_bad_input(self): - """Test encode_to_file_like with invalid inputs.""" + def test_to_file_like_bad_input(self): + """Test to_file_like with invalid inputs.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -588,7 +591,7 @@ def seek(self, offset, whence=0): with pytest.raises( RuntimeError, match="File like object must implement a write method" ): - encoder.encode_to_file_like(NoWriteMethod(), format="wav") + encoder.to_file_like(NoWriteMethod(), format="wav") # Test with object missing seek method class NoSeekMethod: @@ -598,31 +601,31 @@ def write(self, data): with pytest.raises( RuntimeError, match="File like object must implement a seek method" ): - encoder.encode_to_file_like(NoSeekMethod(), format="wav") + encoder.to_file_like(NoSeekMethod(), format="wav") # Test with invalid format buffer = io.BytesIO() with pytest.raises(RuntimeError, match="Check the desired format"): - encoder.encode_to_file_like(buffer, format="invalid_format") + encoder.to_file_like(buffer, format="invalid_format") # Test with invalid bit rate buffer = io.BytesIO() with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): - encoder.encode_to_file_like(buffer, format="wav", bit_rate=-1) + encoder.to_file_like(buffer, format="wav", bit_rate=-1) - def test_encode_to_file_like_multiple_calls(self): - """Test that encode_to_file_like can be called multiple times on the same encoder.""" + def test_to_file_like_multiple_calls(self): + """Test that to_file_like can be called multiple times on the same encoder.""" asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) # First call buffer1 = io.BytesIO() - encoder.encode_to_file_like(buffer1, format="wav") + encoder.to_file_like(buffer1, format="wav") # Second call with different format buffer2 = io.BytesIO() - encoder.encode_to_file_like(buffer2, format="flac") + encoder.to_file_like(buffer2, format="flac") # Both should have data assert buffer1.tell() > 0 @@ -638,14 +641,14 @@ def test_encode_to_file_like_multiple_calls(self): torch.testing.assert_close(decoded1.data, source_samples, rtol=0, atol=1e-4) torch.testing.assert_close(decoded2.data, source_samples, rtol=0, atol=1e-4) - def test_encode_to_file_like_empty_samples(self): - """Test encode_to_file_like with very short audio samples.""" + def test_to_file_like_empty_samples(self): + """Test to_file_like with very short audio samples.""" # Create very short audio sample short_samples = torch.rand(2, 100) # 100 samples per channel encoder = AudioEncoder(short_samples, sample_rate=16_000) buffer = io.BytesIO() - encoder.encode_to_file_like(buffer, format="wav") + encoder.to_file_like(buffer, format="wav") # Should still produce valid output assert buffer.tell() > 0 From 78276a2af2506cdb4cc4d3ec74f4b38d63b36e8b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 14:02:46 +0100 Subject: [PATCH 05/20] Add tests --- test/test_encoders.py | 283 ++++++++---------------------------------- 1 file changed, 55 insertions(+), 228 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index f1746e329..6e8d6c24a 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -176,7 +176,7 @@ def test_bad_input_parametrized(self, method, tmp_path): ): getattr(decoder, method)(**valid_params, num_channels=num_channels) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) @pytest.mark.parametrize("format", ("wav", "flac")) def test_round_trip(self, method, format, tmp_path): # Check that decode(encode(samples)) == samples on lossless formats @@ -193,10 +193,16 @@ def test_round_trip(self, method, format, tmp_path): encoded_path = str(tmp_path / f"output.{format}") encoded_source = encoded_path encoder.to_file(dest=encoded_path) - else: + elif method == "to_tensor": encoded_source = encoder.to_tensor(format=format) assert encoded_source.dtype == torch.uint8 assert encoded_source.ndim == 1 + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format) + encoded_source = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") rtol, atol = (0, 1e-4) if format == "wav" else (None, None) torch.testing.assert_close( @@ -208,7 +214,7 @@ def test_round_trip(self, method, format, tmp_path): @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_against_cli( self, asset, @@ -244,8 +250,14 @@ def test_against_cli( if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" encoder.to_file(dest=str(encoded_by_us), **params) - else: + elif method == "to_tensor": encoded_by_us = encoder.to_tensor(format=format, **params) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format, **params) + encoded_by_us = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") captured = capfd.readouterr() if format == "wav": @@ -281,15 +293,14 @@ def test_against_cli( if method == "to_file": validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) - else: - assert method == "to_tensor", "wrong test parametrization!" @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) - def test_to_tensor_against_to_file( - self, asset, bit_rate, num_channels, format, tmp_path + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) + def test_against_to_file( + self, asset, bit_rate, num_channels, format, tmp_path, method ): if get_ffmpeg_major_version() == 4 and format == "wav": pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") @@ -299,12 +310,22 @@ def test_to_tensor_against_to_file( params = dict(bit_rate=bit_rate, num_channels=num_channels) encoded_file = tmp_path / f"output.{format}" encoder.to_file(dest=str(encoded_file), **params) - encoded_tensor = encoder.to_tensor( - format=format, bit_rate=bit_rate, num_channels=num_channels - ) + + if method == "to_tensor": + encoded_output = encoder.to_tensor( + format=format, bit_rate=bit_rate, num_channels=num_channels + ) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like, format=format, bit_rate=bit_rate, num_channels=num_channels + ) + encoded_output = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") torch.testing.assert_close( - self.decode(encoded_file).data, self.decode(encoded_tensor).data + self.decode(encoded_file).data, self.decode(encoded_output).data ) def test_encode_to_tensor_long_output(self): @@ -354,7 +375,7 @@ def test_contiguity(self): @pytest.mark.parametrize("num_channels_input", (1, 2)) @pytest.mark.parametrize("num_channels_output", (1, 2, None)) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_num_channels( self, num_channels_input, num_channels_output, method, tmp_path ): @@ -372,8 +393,14 @@ def test_num_channels( encoded_path = str(tmp_path / f"output.{format}") encoded_source = encoded_path encoder.to_file(dest=encoded_path, **params) - else: + elif method == "to_tensor": encoded_source = encoder.to_tensor(format=format, **params) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format, **params) + encoded_source = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") if num_channels_output is None: num_channels_output = num_channels_input @@ -389,168 +416,32 @@ def test_1d_samples(self): AudioEncoder(samples_2d, sample_rate=sample_rate).to_tensor("wav"), ) - # Test cases for to_file_like method - def test_to_file_like_basic(self, tmp_path): - """Test basic functionality of to_file_like with BytesIO.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - # Test with BytesIO - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="wav") - - # Verify data was written - assert buffer.tell() > 0 - - # Verify we can decode the result - buffer.seek(0) - decoded_samples = self.decode(buffer.getvalue()) - - # For lossless format like WAV, should be very close - torch.testing.assert_close( - decoded_samples.data, source_samples, rtol=0, atol=1e-4 - ) - - def test_to_file_like_different_formats(self, tmp_path): - """Test to_file_like with different audio formats.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - formats_to_test = ["wav", "flac", "mp3"] - - for format_name in formats_to_test: - if get_ffmpeg_major_version() == 4 and format_name == "wav": - continue # Skip WAV on FFmpeg 4 due to swresample issues - - buffer = io.BytesIO() - encoder.to_file_like(buffer, format=format_name) - - # Verify data was written - assert buffer.tell() > 0, f"No data written for format {format_name}" - - # Verify we can decode the result (for lossless formats) - buffer.seek(0) - decoded_samples = self.decode(buffer.getvalue()) - assert ( - decoded_samples.data.shape[0] == source_samples.shape[0] - ) # Same number of channels - - def test_to_file_like_with_parameters(self, tmp_path): - """Test to_file_like with different encoding parameters.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - # Test with different bit rates - for bit_rate in [128_000, 256_000]: - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="mp3", bit_rate=bit_rate) - assert buffer.tell() > 0 - - # Test with different channel counts - for num_channels in [1, 2]: - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="wav", num_channels=num_channels) - assert buffer.tell() > 0 - - # Verify channel count - buffer.seek(0) - decoded_samples = self.decode(buffer.getvalue()) - assert decoded_samples.data.shape[0] == num_channels - - def test_to_file_like_vs_to_tensor(self, tmp_path): - """Test that to_file_like produces the same output as to_tensor.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - # Get tensor output - tensor_output = encoder.to_tensor(format="wav") - - # Get file-like output - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="wav") - buffer.seek(0) - file_like_output = torch.frombuffer(buffer.getvalue(), dtype=torch.uint8) - - # They should be identical - torch.testing.assert_close(tensor_output, file_like_output) - - def test_to_file_like_vs_to_file(self, tmp_path): - """Test that to_file_like produces the same output as to_file.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - # Get file output - file_path = tmp_path / "test.wav" - encoder.to_file(dest=str(file_path)) # Convert to string - - with open(file_path, "rb") as f: - file_output = f.read() - - # Get file-like output - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="wav") - buffer.seek(0) - file_like_output = buffer.getvalue() - - # They should be identical - assert file_output == file_like_output - def test_to_file_like_custom_file_object(self, tmp_path): - """Test to_file_like with a custom file-like object.""" - class CustomFileObject: def __init__(self): - self.data = b"" - self.position = 0 + self._file = io.BytesIO() def write(self, data): - if isinstance(data, (bytes, bytearray)): - self.data += data - self.position += len(data) - return len(data) - else: - raise TypeError("Expected bytes-like object") + return self._file.write(data) def seek(self, offset, whence=0): - if whence == 0: # SEEK_SET - self.position = offset - elif whence == 1: # SEEK_CUR - self.position += offset - elif whence == 2: # SEEK_END - self.position = len(self.data) + offset - return self.position + return self._file.seek(offset, whence) + + def get_encoded_data(self): + return self._file.getvalue() asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - custom_file = CustomFileObject() - encoder.to_file_like(custom_file, format="wav") - - # Verify data was written - assert len(custom_file.data) > 0 + file_like = CustomFileObject() + encoder.to_file_like(file_like, format="wav") - # Verify we can decode the result - decoded_samples = self.decode(custom_file.data) + decoded_samples = self.decode(file_like.get_encoded_data()) - # Allow for small differences in sample count due to encoding/padding - # Check that the shapes are approximately the same (within a few samples) - assert ( - decoded_samples.data.shape[0] == source_samples.shape[0] - ) # Same number of channels - sample_diff = abs(decoded_samples.data.shape[1] - source_samples.shape[1]) - assert sample_diff <= 10, f"Sample count difference too large: {sample_diff}" - - # Compare the overlapping portion - min_samples = min(decoded_samples.data.shape[1], source_samples.shape[1]) torch.testing.assert_close( - decoded_samples.data[:, :min_samples], - source_samples[:, :min_samples], + decoded_samples.data, + source_samples, rtol=0, atol=1e-4, ) @@ -563,27 +454,19 @@ def test_to_file_like_real_file(self, tmp_path): file_path = tmp_path / "test_file_like.wav" - # Use to_file_like with a real file - with open(file_path, "wb") as f: - encoder.to_file_like(f, format="wav") - - # Verify the file was created and has content - assert file_path.exists() - assert file_path.stat().st_size > 0 + with open(file_path, "wb") as file_like: + encoder.to_file_like(file_like, format="wav") - # Verify we can decode the result decoded_samples = self.decode(str(file_path)) torch.testing.assert_close( decoded_samples.data, source_samples, rtol=0, atol=1e-4 ) - def test_to_file_like_bad_input(self): - """Test to_file_like with invalid inputs.""" + def test_to_file_like_bad_methods(self): asset = NASA_AUDIO_MP3 source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - # Test with object missing write method class NoWriteMethod: def seek(self, offset, whence=0): return 0 @@ -593,7 +476,6 @@ def seek(self, offset, whence=0): ): encoder.to_file_like(NoWriteMethod(), format="wav") - # Test with object missing seek method class NoSeekMethod: def write(self, data): return len(data) @@ -602,58 +484,3 @@ def write(self, data): RuntimeError, match="File like object must implement a seek method" ): encoder.to_file_like(NoSeekMethod(), format="wav") - - # Test with invalid format - buffer = io.BytesIO() - with pytest.raises(RuntimeError, match="Check the desired format"): - encoder.to_file_like(buffer, format="invalid_format") - - # Test with invalid bit rate - buffer = io.BytesIO() - with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): - encoder.to_file_like(buffer, format="wav", bit_rate=-1) - - def test_to_file_like_multiple_calls(self): - """Test that to_file_like can be called multiple times on the same encoder.""" - asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset).data - encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) - - # First call - buffer1 = io.BytesIO() - encoder.to_file_like(buffer1, format="wav") - - # Second call with different format - buffer2 = io.BytesIO() - encoder.to_file_like(buffer2, format="flac") - - # Both should have data - assert buffer1.tell() > 0 - assert buffer2.tell() > 0 - - # Verify both can be decoded - buffer1.seek(0) - buffer2.seek(0) - decoded1 = self.decode(buffer1.getvalue()) - decoded2 = self.decode(buffer2.getvalue()) - - # Both should be close to the original (within format limitations) - torch.testing.assert_close(decoded1.data, source_samples, rtol=0, atol=1e-4) - torch.testing.assert_close(decoded2.data, source_samples, rtol=0, atol=1e-4) - - def test_to_file_like_empty_samples(self): - """Test to_file_like with very short audio samples.""" - # Create very short audio sample - short_samples = torch.rand(2, 100) # 100 samples per channel - encoder = AudioEncoder(short_samples, sample_rate=16_000) - - buffer = io.BytesIO() - encoder.to_file_like(buffer, format="wav") - - # Should still produce valid output - assert buffer.tell() > 0 - - # Verify it can be decoded - buffer.seek(0) - decoded_samples = self.decode(buffer.getvalue()) - assert decoded_samples.data.shape[0] == 2 # Same number of channels From aa10ed1d7a9aca9c7158caca1edaa17e228047c8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 14:46:09 +0100 Subject: [PATCH 06/20] Avoid depending on numpy for bytes conversion --- src/torchcodec/_core/ops.py | 16 ++++++++++------ src/torchcodec/_core/pybind_ops.cpp | 23 +++++++++-------------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 2f1f38a66..960e92876 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import ctypes import io import json import warnings @@ -173,15 +174,18 @@ def encode_audio_to_file_like( """ assert _pybind_ops is not None - # Convert tensor to raw bytes and shape info for pybind + # Enforce float32 dtype requirement + if samples.dtype != torch.float32: + raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}") + samples_contiguous = samples.contiguous() - samples_numpy = samples_contiguous.detach().cpu().numpy() - samples_bytes = samples_numpy.tobytes() - samples_shape = tuple(samples_contiguous.shape) + + data_ptr = samples_contiguous.data_ptr() + shape = list(samples_contiguous.shape) _pybind_ops.encode_audio_to_file_like( - samples_bytes, - samples_shape, + data_ptr, + shape, sample_rate, format, file_like, diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 52925e161..380afd727 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -43,25 +43,20 @@ int64_t create_from_file_like( } int64_t encode_audio_to_file_like( - py::bytes samples_data, - py::tuple samples_shape, + uintptr_t data_ptr, + py::list shape, int64_t sample_rate, const std::string& format, py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { - // Convert Python data back to tensor - auto shape_vec = samples_shape.cast>(); - std::string samples_str = samples_data; + // Convert Python list to vector + auto shape_vec = shape.cast>(); - // Create tensor from raw data + // Create tensor from existing data pointer (enforcing float32) auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); - auto samples = - torch::from_blob( - const_cast(static_cast(samples_str.data())), - shape_vec, - tensor_options) - .clone(); // Clone to ensure memory ownership + auto samples = torch::from_blob( + reinterpret_cast(data_ptr), shape_vec, tensor_options); AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; @@ -86,8 +81,8 @@ PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def( "encode_audio_to_file_like", &encode_audio_to_file_like, - "samples_data"_a, - "samples_shape"_a, + "data_ptr"_a, + "shape"_a, "sample_rate"_a, "format"_a, "file_like"_a, From dfa5bcbbdc97b156c3e0e4445f2788233bcfbebf Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 14:58:44 +0100 Subject: [PATCH 07/20] Use string_view --- src/torchcodec/_core/pybind_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 380afd727..533628d7b 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -46,7 +46,7 @@ int64_t encode_audio_to_file_like( uintptr_t data_ptr, py::list shape, int64_t sample_rate, - const std::string& format, + std::string_view format, py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { From 6adb7dc62fbcbd8264b23f016c800ffc7b1dc40f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 15:00:44 +0100 Subject: [PATCH 08/20] make shape a vec --- src/torchcodec/_core/pybind_ops.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 533628d7b..34938d28b 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -44,19 +44,16 @@ int64_t create_from_file_like( int64_t encode_audio_to_file_like( uintptr_t data_ptr, - py::list shape, + const std::vector& shape, int64_t sample_rate, std::string_view format, py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { - // Convert Python list to vector - auto shape_vec = shape.cast>(); - // Create tensor from existing data pointer (enforcing float32) auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); auto samples = torch::from_blob( - reinterpret_cast(data_ptr), shape_vec, tensor_options); + reinterpret_cast(data_ptr), shape, tensor_options); AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; From 8951870e47b5f1471f513d9512b0b992039e7632 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 15:03:08 +0100 Subject: [PATCH 09/20] dataptr is int64_t --- src/torchcodec/_core/pybind_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 34938d28b..fb6cdf323 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -43,7 +43,7 @@ int64_t create_from_file_like( } int64_t encode_audio_to_file_like( - uintptr_t data_ptr, + int64_t data_ptr, const std::vector& shape, int64_t sample_rate, std::string_view format, From 9b6d9eee26a6e948cb2caf1c2ce773b0b369b8ce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 15:31:25 +0100 Subject: [PATCH 10/20] lifetime management --- src/torchcodec/_core/ops.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 960e92876..ecca7376d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -178,14 +178,11 @@ def encode_audio_to_file_like( if samples.dtype != torch.float32: raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}") - samples_contiguous = samples.contiguous() - - data_ptr = samples_contiguous.data_ptr() - shape = list(samples_contiguous.shape) + samples = samples.contiguous() _pybind_ops.encode_audio_to_file_like( - data_ptr, - shape, + samples.data_ptr(), + list(samples.shape), sample_rate, format, file_like, @@ -193,6 +190,10 @@ def encode_audio_to_file_like( num_channels, ) + # This check is useless but it's critical to keep it to ensures that samples + # is still alive during the call to encode_audio_to_file_like. + assert samples.is_contiguous() + # ============================== # Abstract impl for the operators. Needed by torch.compile. From b3fc71435adec2de1a83a96e7197b7b74fda374f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 5 Jul 2025 15:45:57 +0100 Subject: [PATCH 11/20] Add contiguity check --- src/torchcodec/_core/ops.py | 2 -- test/test_encoders.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index ecca7376d..336068411 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import ctypes import io import json import warnings @@ -174,7 +173,6 @@ def encode_audio_to_file_like( """ assert _pybind_ops is not None - # Enforce float32 dtype requirement if samples.dtype != torch.float32: raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}") diff --git a/test/test_encoders.py b/test/test_encoders.py index 6e8d6c24a..375a9efd0 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -343,21 +343,16 @@ def test_encode_to_tensor_long_output(self): torch.testing.assert_close(self.decode(encoded_tensor).data, samples) - def test_contiguity(self): + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_contiguity(self, method, tmp_path): # Ensure that 2 waveforms with the same values are encoded in the same # way, regardless of their memory layout. Here we encode 2 equal # waveforms, one is row-aligned while the other is column-aligned. - # TODO: Ideally we'd be testing all encoding methods here num_samples = 10_000 # per channel contiguous_samples = torch.rand(2, num_samples).contiguous() assert contiguous_samples.stride() == (num_samples, 1) - params = dict(format="flac", bit_rate=44_000) - encoded_from_contiguous = AudioEncoder( - contiguous_samples, sample_rate=16_000 - ).to_tensor(**params) - non_contiguous_samples = contiguous_samples.T.contiguous().T assert non_contiguous_samples.stride() == (1, 2) @@ -365,9 +360,28 @@ def test_contiguity(self): contiguous_samples, non_contiguous_samples, rtol=0, atol=0 ) - encoded_from_non_contiguous = AudioEncoder( - non_contiguous_samples, sample_rate=16_000 - ).to_tensor(**params) + def encode_to_tensor(samples): + params = dict(bit_rate=44_000) + if method == "to_file": + dest = str(tmp_path / "output.flac") + AudioEncoder(samples, sample_rate=16_000).to_file(dest=dest, **params) + with open(dest, "rb") as f: + return torch.frombuffer(f.read(), dtype=torch.uint8) + elif method == "to_tensor": + return AudioEncoder(samples, sample_rate=16_000).to_tensor( + format="flac", **params + ) + elif method == "to_file_like": + file_like = io.BytesIO() + AudioEncoder(samples, sample_rate=16_000).to_file_like( + file_like, format="flac", **params + ) + return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) + else: + raise ValueError(f"Unknown method: {method}") + + encoded_from_contiguous = encode_to_tensor(contiguous_samples) + encoded_from_non_contiguous = encode_to_tensor(non_contiguous_samples) torch.testing.assert_close( encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 From a78ef8bd076c87a48c728e24baa6d44f7a9a25be Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 08:46:16 +0100 Subject: [PATCH 12/20] refac --- src/torchcodec/_core/pybind_ops.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index fb6cdf323..1dfd6eb3e 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include @@ -16,7 +15,6 @@ #include "src/torchcodec/_core/StreamOptions.h" namespace py = pybind11; -using namespace py::literals; namespace facebook::torchcodec { @@ -55,6 +53,8 @@ int64_t encode_audio_to_file_like( auto samples = torch::from_blob( reinterpret_cast(data_ptr), shape, tensor_options); + // TODO Fix implicit int conversion: + // https://github.com/pytorch/torchcodec/issues/679 AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; @@ -78,13 +78,13 @@ PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def( "encode_audio_to_file_like", &encode_audio_to_file_like, - "data_ptr"_a, - "shape"_a, - "sample_rate"_a, - "format"_a, - "file_like"_a, - "bit_rate"_a = py::none(), - "num_channels"_a = py::none()); + "data_ptr", + "shape", + "sample_rate", + "format", + "file_like", + "bit_rate", + "num_channels"); } } // namespace facebook::torchcodec From fb6e463471b4c1a64eb99175df2a1c502199f996 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 08:48:42 +0100 Subject: [PATCH 13/20] refac --- src/torchcodec/_core/Encoder.cpp | 12 ++++++------ src/torchcodec/_core/Encoder.h | 11 ++++------- src/torchcodec/_core/pybind_ops.cpp | 8 +++----- src/torchcodec/encoders/_audio_encoder.py | 7 +++++-- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 45ac67ed1..fc9765cf1 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -138,7 +138,7 @@ AudioEncoder::AudioEncoder( std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioTensorContextHolder_(std::move(avioContextHolder)) { + avioToTensorContext_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -153,7 +153,7 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - avFormatContext_->pb = avioTensorContextHolder_->getAVIOContext(); + avFormatContext_->pb = avioToTensorContext_->getAVIOContext(); initializeEncoder(sampleRate, audioStreamOptions); } @@ -165,7 +165,7 @@ AudioEncoder::AudioEncoder( std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioFileLikeContextHolder_(std::move(avioContextHolder)) { + avioFileLikeContext_(std::move(avioContextHolder)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; @@ -181,7 +181,7 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - avFormatContext_->pb = avioFileLikeContextHolder_->getAVIOContext(); + avFormatContext_->pb = avioFileLikeContext_->getAVIOContext(); initializeEncoder(sampleRate, audioStreamOptions); } @@ -245,10 +245,10 @@ void AudioEncoder::initializeEncoder( torch::Tensor AudioEncoder::encodeToTensor() { TORCH_CHECK( - avioTensorContextHolder_ != nullptr, + avioToTensorContext_ != nullptr, "Cannot encode to tensor, avio tensor context doesn't exist."); encode(); - return avioTensorContextHolder_->getOutputTensor(); + return avioToTensorContext_->getOutputTensor(); } void AudioEncoder::encode() { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index cf205daa3..0d091e5bd 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -25,13 +25,13 @@ class AudioEncoder { const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr AVIOToTensorContext, const AudioStreamOptions& audioStreamOptions); AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr AVIOFileLikeContext, const AudioStreamOptions& audioStreamOptions); void encode(); torch::Tensor encodeToTensor(); @@ -56,11 +56,8 @@ class AudioEncoder { const torch::Tensor samples_; - // Stores the AVIOContext for the output tensor buffer. - std::unique_ptr avioTensorContextHolder_; - - // Stores the AVIOContext for file-like object output. - std::unique_ptr avioFileLikeContextHolder_; + std::unique_ptr avioToTensorContext_; + std::unique_ptr avioFileLikeContext_; bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 1dfd6eb3e..74de4f9f5 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -40,7 +40,7 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } -int64_t encode_audio_to_file_like( +void encode_audio_to_file_like( int64_t data_ptr, const std::vector& shape, int64_t sample_rate, @@ -48,13 +48,14 @@ int64_t encode_audio_to_file_like( py::object file_like, std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt) { - // Create tensor from existing data pointer (enforcing float32) + // We assume float32 *and* contiguity, this must be enforced by the caller. auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); auto samples = torch::from_blob( reinterpret_cast(data_ptr), shape, tensor_options); // TODO Fix implicit int conversion: // https://github.com/pytorch/torchcodec/issues/679 + // same for sample_rate parameter below AudioStreamOptions audioStreamOptions; audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; @@ -68,9 +69,6 @@ int64_t encode_audio_to_file_like( std::move(avioContextHolder), audioStreamOptions); encoder.encode(); - - // Return 0 to indicate success - return 0; } PYBIND11_MODULE(decoder_core_pybind_ops, m) { diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index 75f35cda8..27df6dcf2 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -110,8 +110,11 @@ def to_file_like( """Encode samples into a file-like object. Args: - file_like: A file-like object that supports write() and seek() methods, - such as io.BytesIO(), an open file in binary write mode, etc. + file_like: A file-like object that supports ``write()`` and + ``seek()`` methods, such as io.BytesIO(), an open file in binary + write mode, etc. Methods must have the following signature: + ``write(data: bytes) -> int`` and ``seek(offset: int, whence: + int = 0) -> int``. format (str): The format of the encoded samples, e.g. "mp3", "wav" or "flac". bit_rate (int, optional): The output bit rate. Encoders typically From 6e88ee676c511ae9a6746e7b8a6cfa8e181e9729 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 09:12:34 +0100 Subject: [PATCH 14/20] mend --- src/torchcodec/_core/Encoder.cpp | 9 ++++----- src/torchcodec/_core/Encoder.h | 8 ++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index fc9765cf1..b9332f723 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -135,10 +135,10 @@ AudioEncoder::AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr avioToTensorContext, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioToTensorContext_(std::move(avioContextHolder)) { + avioToTensorContext_(std::move(avioToTensorContext)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -162,13 +162,12 @@ AudioEncoder::AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr avioFileLikeContext, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioFileLikeContext_(std::move(avioContextHolder)) { + avioFileLikeContext_(std::move(avioFileLikeContext)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; - int status = avformat_alloc_output_context2( &avFormatContext, nullptr, formatName.data(), nullptr); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 0d091e5bd..1919950fc 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -21,19 +21,27 @@ class AudioEncoder { int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions); + + // We need one constructor for each type of AVIOContextHolder. We can't have a + // single constructor that accepts the base AVIOContextHolder class and hold + // that as attribute, because we are calling the getOutputTensor() method on + // the AVIOToTensorContext, which is not available in the base class. AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, std::unique_ptr AVIOToTensorContext, const AudioStreamOptions& audioStreamOptions); + AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, std::unique_ptr AVIOFileLikeContext, const AudioStreamOptions& audioStreamOptions); + void encode(); + torch::Tensor encodeToTensor(); private: From 1c6fad8c492e4413a8429137e78c7c7b69940bcd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 09:24:55 +0100 Subject: [PATCH 15/20] WIP --- src/torchcodec/_core/AVIOContextHolder.cpp | 8 +-- src/torchcodec/_core/AVIOFileLikeContext.cpp | 53 +++++--------------- src/torchcodec/_core/AVIOFileLikeContext.h | 17 ++----- src/torchcodec/_core/pybind_ops.cpp | 2 +- 4 files changed, 21 insertions(+), 59 deletions(-) diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index e0462c28d..99db89880 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -23,10 +23,10 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK( - (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), - "seek method must be defined, and either write or read must be defined. " - "But not both!") + // TORCH_CHECK( + // (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), + // "seek method must be defined, and either write or read must be + // defined. " "But not both!") avioContext_.reset(avioAllocContext( buffer, bufferSize, diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index e79e47707..3870e5a15 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -10,40 +10,20 @@ namespace facebook::torchcodec { AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) - : AVIOFileLikeContext(fileLike, false) {} - -std::unique_ptr AVIOFileLikeContext::createForWriting( - py::object fileLike) { - return std::unique_ptr( - new AVIOFileLikeContext(fileLike, true)); -} - -AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isWriteMode) : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - if (isWriteMode) { - TORCH_CHECK( - py::hasattr(fileLike, "write"), - "File like object must implement a write method."); - } else { - TORCH_CHECK( - py::hasattr(fileLike, "read"), - "File like object must implement a read method."); - } + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method."); TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - - if (isWriteMode) { - createAVIOContext(nullptr, &write, &seek, &fileLike_); - } else { - createAVIOContext(&read, nullptr, &seek, &fileLike_); - } + createAVIOContext(&read, &write, &seek, &fileLike_); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { @@ -86,23 +66,6 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { return totalNumRead == 0 ? AVERROR_EOF : totalNumRead; } -int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { - auto fileLike = static_cast(opaque); - - // Note that we acquire the GIL outside of the loop. This is likely more - // efficient than releasing and acquiring it each loop iteration. - py::gil_scoped_acquire gil; - - // Create a bytes object from the buffer - py::bytes data_bytes(reinterpret_cast(buf), buf_size); - - // Call the Python write method - auto bytes_written = (*fileLike)->attr("write")(data_bytes); - - // Python write() should return the number of bytes written - return py::cast(bytes_written); -} - int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { // We do not know the file size. if (whence == AVSEEK_SIZE) { @@ -114,4 +77,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { return py::cast((*fileLike)->attr("seek")(offset, whence)); } +int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + py::gil_scoped_acquire gil; + py::bytes bytes_obj(reinterpret_cast(buf), buf_size); + + return py::cast((*fileLike)->attr("write")(bytes_obj)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 9a890e783..009485152 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -15,25 +15,16 @@ namespace py = pybind11; namespace facebook::torchcodec { -// Enables users to pass in a Python file-like object. We then forward all read, -// write and seek calls back up to the methods on the Python object. -class __attribute__((visibility("default"))) AVIOFileLikeContext - : public AVIOContextHolder { +// Enables uers to pass in a Python file-like object. We then forward all read +// and seek calls back up to the methods on the Python object. +class AVIOFileLikeContext : public AVIOContextHolder { public: - // Constructor for reading from a file-like object explicit AVIOFileLikeContext(py::object fileLike); - // Constructor for writing to a file-like object - static std::unique_ptr createForWriting( - py::object fileLike); - private: - // Private constructor for write mode - AVIOFileLikeContext(py::object fileLike, bool isWriteMode); - static int read(void* opaque, uint8_t* buf, int buf_size); - static int write(void* opaque, const uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); + static int write(void* opaque, const uint8_t* buf, int buf_size); // Note that we dynamically allocate the Python object because we need to // strictly control when its destructor is called. We must hold the GIL diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 74de4f9f5..82b277be0 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -60,7 +60,7 @@ void encode_audio_to_file_like( audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; - auto avioContextHolder = AVIOFileLikeContext::createForWriting(file_like); + auto avioContextHolder = std::make_unique(file_like); AudioEncoder encoder( samples, From eb1b51d44960face8caf24d057703b2bc1aa7dce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 11:11:38 +0100 Subject: [PATCH 16/20] WIP --- src/torchcodec/_core/AVIOFileLikeContext.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 3870e5a15..884687720 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -16,9 +16,9 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - TORCH_CHECK( - py::hasattr(fileLike, "read"), - "File like object must implement a read method."); + // TORCH_CHECK( + // py::hasattr(fileLike, "read"), + // "File like object must implement a read method."); TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); From 4d82cbb62ecd7eb44501f27718cec26ca63dd460 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 12:11:06 +0100 Subject: [PATCH 17/20] bypass pybind warning --- src/torchcodec/_core/AVIOFileLikeContext.h | 3 ++- src/torchcodec/_core/CMakeLists.txt | 10 +--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 009485152..ba9ca8c79 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -17,7 +17,8 @@ namespace facebook::torchcodec { // Enables uers to pass in a Python file-like object. We then forward all read // and seek calls back up to the methods on the Python object. -class AVIOFileLikeContext : public AVIOContextHolder { +class __attribute__((visibility("default"))) AVIOFileLikeContext + : public AVIOContextHolder { public: explicit AVIOFileLikeContext(py::object fileLike); diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 7196d0487..1dbfa52f4 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -66,6 +66,7 @@ function(make_torchcodec_libraries set(decoder_sources AVIOContextHolder.cpp AVIOTensorContext.cpp + AVIOFileLikeContext.cpp FFMPEGCommon.cpp Frame.cpp DeviceInterface.cpp @@ -142,15 +143,6 @@ function(make_torchcodec_libraries "${pybind_ops_sources}" "${pybind_ops_dependencies}" ) - # pybind11 limits the visibility of symbols in the shared library to prevent - # stray initialization of py::objects. The rest of the object code must - # match. See: - # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes - target_compile_options( - ${pybind_ops_library_name} - PUBLIC - "-fvisibility=hidden" - ) # If we don't make sure this flag is set, we run into segfauls at import # time on Mac. See: # https://github.com/pybind/pybind11/issues/3907#issuecomment-1170412764 From 88b335afd7e45c1653bedfcf97b02bf48e6579e9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 12:23:56 +0100 Subject: [PATCH 18/20] Simplify some stuff --- src/torchcodec/_core/AVIOContextHolder.cpp | 8 ++++---- src/torchcodec/_core/AVIOFileLikeContext.cpp | 12 ++++++++---- src/torchcodec/_core/AVIOFileLikeContext.h | 5 ++++- src/torchcodec/_core/pybind_ops.cpp | 6 ++++-- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index 99db89880..1b070e4f5 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -23,10 +23,10 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - // TORCH_CHECK( - // (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), - // "seek method must be defined, and either write or read must be - // defined. " "But not both!") + TORCH_CHECK( + (seek != nullptr) && ((write != nullptr) || (read != nullptr)), + "seek method must be defined, and at least one of write or read must be " + "defined too"); avioContext_.reset(avioAllocContext( buffer, bufferSize, diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 884687720..b897a93dc 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -9,16 +9,20 @@ namespace facebook::torchcodec { -AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) +AVIOFileLikeContext::AVIOFileLikeContext( + py::object fileLike, + std::string_view neededMethod) : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - // TORCH_CHECK( - // py::hasattr(fileLike, "read"), - // "File like object must implement a read method."); + TORCH_CHECK( + py::hasattr(fileLike, neededMethod.data()), + "File like object must implement a ", + neededMethod, + " method."); TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index ba9ca8c79..544a776d1 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -17,10 +17,13 @@ namespace facebook::torchcodec { // Enables uers to pass in a Python file-like object. We then forward all read // and seek calls back up to the methods on the Python object. +// TODO: explain this. We probably don't want it. class __attribute__((visibility("default"))) AVIOFileLikeContext : public AVIOContextHolder { public: - explicit AVIOFileLikeContext(py::object fileLike); + explicit AVIOFileLikeContext( + py::object fileLike, + std::string_view neededMethod); private: static int read(void* opaque, uint8_t* buf, int buf_size); diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 82b277be0..7c5f54386 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -33,7 +33,8 @@ int64_t create_from_file_like( realSeek = seekModeFromString(seek_mode.value()); } - auto avioContextHolder = std::make_unique(file_like); + auto avioContextHolder = + std::make_unique(file_like, /*neededMethod=*/"read"); SingleStreamDecoder* decoder = new SingleStreamDecoder(std::move(avioContextHolder), realSeek); @@ -60,7 +61,8 @@ void encode_audio_to_file_like( audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; - auto avioContextHolder = std::make_unique(file_like); + auto avioContextHolder = std::make_unique( + file_like, /*neededMethod=*/"write"); AudioEncoder encoder( samples, From 843ff795c96589fcd411b58a31ad55d9e78eb225 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 13:41:45 +0100 Subject: [PATCH 19/20] Add comment --- src/torchcodec/_core/ops.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 336068411..6d532a242 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -176,8 +176,21 @@ def encode_audio_to_file_like( if samples.dtype != torch.float32: raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}") - samples = samples.contiguous() + # We're having the same problem as with the decoder's create_from_file_like: + # We should be able to pass a tensor directly, but this leads to a pybind + # error. In order to work around this, we pass the pointer to the tensor's + # data, and its shape, in order to re-construct it in C++. For this to work: + # - the tensor must be float32 + # - the tensor must be contiguous, which is why we call contiguous(). + # In theory we could avoid this restriction by also passing the strides? + # - IMPORTANT: the input samples tensor and its underlying data must be + # alive during the call. + # + # A more elegant solution would be to cast the tensor into a py::object, but + # casting the py::object backk to a tensor in C++ seems to lead to the same + # pybing error. + samples = samples.contiguous() _pybind_ops.encode_audio_to_file_like( samples.data_ptr(), list(samples.shape), From 558c8f734c109714621b6f7c37c75feef291a3d6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 6 Jul 2025 13:58:18 +0100 Subject: [PATCH 20/20] Fix timeout --- src/torchcodec/_core/AVIOContextHolder.cpp | 15 ++++++++----- src/torchcodec/_core/AVIOContextHolder.h | 1 + src/torchcodec/_core/AVIOFileLikeContext.cpp | 22 ++++++++++++-------- src/torchcodec/_core/AVIOFileLikeContext.h | 4 +--- src/torchcodec/_core/AVIOTensorContext.cpp | 6 ++++-- src/torchcodec/_core/pybind_ops.cpp | 6 +++--- 6 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index 1b070e4f5..c1188e684 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -14,6 +14,7 @@ void AVIOContextHolder::createAVIOContext( AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, + bool isForWriting, int bufferSize) { TORCH_CHECK( bufferSize > 0, @@ -23,14 +24,18 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK( - (seek != nullptr) && ((write != nullptr) || (read != nullptr)), - "seek method must be defined, and at least one of write or read must be " - "defined too"); + TORCH_CHECK(seek != nullptr, "seek method must be defined"); + + if (isForWriting) { + TORCH_CHECK(write != nullptr, "write method must be defined for writing"); + } else { + TORCH_CHECK(read != nullptr, "read method must be defined for reading"); + } + avioContext_.reset(avioAllocContext( buffer, bufferSize, - /*write_flag=*/write != nullptr, + /*write_flag=*/isForWriting, heldData, read, write, diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 54d239cd3..16d70beaf 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -51,6 +51,7 @@ class AVIOContextHolder { AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, + bool isForWriting, int bufferSize = defaultBufferSize); private: diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index b897a93dc..800edb4e7 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -9,25 +9,29 @@ namespace facebook::torchcodec { -AVIOFileLikeContext::AVIOFileLikeContext( - py::object fileLike, - std::string_view neededMethod) +AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isForWriting) : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - TORCH_CHECK( - py::hasattr(fileLike, neededMethod.data()), - "File like object must implement a ", - neededMethod, - " method."); + + if (isForWriting) { + TORCH_CHECK( + py::hasattr(fileLike, "write"), + "File like object must implement a write method for writing."); + } else { + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method for reading."); + } + TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, &write, &seek, &fileLike_); + createAVIOContext(&read, &write, &seek, &fileLike_, isForWriting); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 544a776d1..9f8258008 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -21,9 +21,7 @@ namespace facebook::torchcodec { class __attribute__((visibility("default"))) AVIOFileLikeContext : public AVIOContextHolder { public: - explicit AVIOFileLikeContext( - py::object fileLike, - std::string_view neededMethod); + explicit AVIOFileLikeContext(py::object fileLike, bool isForWriting); private: static int read(void* opaque, uint8_t* buf, int buf_size); diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index df97e0218..3f45f5be5 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -105,12 +105,14 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) TORCH_CHECK(data.numel() > 0, "data must not be empty"); TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8"); - createAVIOContext(&read, nullptr, &seek, &tensorContext_); + createAVIOContext( + &read, nullptr, &seek, &tensorContext_, /*isForWriting=*/false); } AVIOToTensorContext::AVIOToTensorContext() : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} { - createAVIOContext(nullptr, &write, &seek, &tensorContext_); + createAVIOContext( + nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true); } torch::Tensor AVIOToTensorContext::getOutputTensor() { diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 7c5f54386..e4e5369d2 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -34,7 +34,7 @@ int64_t create_from_file_like( } auto avioContextHolder = - std::make_unique(file_like, /*neededMethod=*/"read"); + std::make_unique(file_like, /*isForWriting=*/false); SingleStreamDecoder* decoder = new SingleStreamDecoder(std::move(avioContextHolder), realSeek); @@ -61,8 +61,8 @@ void encode_audio_to_file_like( audioStreamOptions.bitRate = bit_rate; audioStreamOptions.numChannels = num_channels; - auto avioContextHolder = std::make_unique( - file_like, /*neededMethod=*/"write"); + auto avioContextHolder = + std::make_unique(file_like, /*isForWriting=*/true); AudioEncoder encoder( samples,