diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index e0462c28..c1188e68 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -14,6 +14,7 @@ void AVIOContextHolder::createAVIOContext( AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, + bool isForWriting, int bufferSize) { TORCH_CHECK( bufferSize > 0, @@ -23,14 +24,18 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK( - (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), - "seek method must be defined, and either write or read must be defined. " - "But not both!") + TORCH_CHECK(seek != nullptr, "seek method must be defined"); + + if (isForWriting) { + TORCH_CHECK(write != nullptr, "write method must be defined for writing"); + } else { + TORCH_CHECK(read != nullptr, "read method must be defined for reading"); + } + avioContext_.reset(avioAllocContext( buffer, bufferSize, - /*write_flag=*/write != nullptr, + /*write_flag=*/isForWriting, heldData, read, write, diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 54d239cd..16d70bea 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -51,6 +51,7 @@ class AVIOContextHolder { AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, + bool isForWriting, int bufferSize = defaultBufferSize); private: diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 5497f89b..800edb4e 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -9,21 +9,29 @@ namespace facebook::torchcodec { -AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) +AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isForWriting) : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; - TORCH_CHECK( - py::hasattr(fileLike, "read"), - "File like object must implement a read method."); + + if (isForWriting) { + TORCH_CHECK( + py::hasattr(fileLike, "write"), + "File like object must implement a write method for writing."); + } else { + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method for reading."); + } + TORCH_CHECK( py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, nullptr, &seek, &fileLike_); + createAVIOContext(&read, &write, &seek, &fileLike_, isForWriting); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { @@ -77,4 +85,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { return py::cast((*fileLike)->attr("seek")(offset, whence)); } +int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + py::gil_scoped_acquire gil; + py::bytes bytes_obj(reinterpret_cast(buf), buf_size); + + return py::cast((*fileLike)->attr("write")(bytes_obj)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 3e80f1c6..9f825800 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -17,13 +17,16 @@ namespace facebook::torchcodec { // Enables uers to pass in a Python file-like object. We then forward all read // and seek calls back up to the methods on the Python object. -class AVIOFileLikeContext : public AVIOContextHolder { +// TODO: explain this. We probably don't want it. +class __attribute__((visibility("default"))) AVIOFileLikeContext + : public AVIOContextHolder { public: - explicit AVIOFileLikeContext(py::object fileLike); + explicit AVIOFileLikeContext(py::object fileLike, bool isForWriting); private: static int read(void* opaque, uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); + static int write(void* opaque, const uint8_t* buf, int buf_size); // Note that we dynamically allocate the Python object because we need to // strictly control when its destructor is called. We must hold the GIL diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index df97e021..3f45f5be 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -105,12 +105,14 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) TORCH_CHECK(data.numel() > 0, "data must not be empty"); TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8"); - createAVIOContext(&read, nullptr, &seek, &tensorContext_); + createAVIOContext( + &read, nullptr, &seek, &tensorContext_, /*isForWriting=*/false); } AVIOToTensorContext::AVIOToTensorContext() : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} { - createAVIOContext(nullptr, &write, &seek, &tensorContext_); + createAVIOContext( + nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true); } torch::Tensor AVIOToTensorContext::getOutputTensor() { diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 7196d048..1dbfa52f 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -66,6 +66,7 @@ function(make_torchcodec_libraries set(decoder_sources AVIOContextHolder.cpp AVIOTensorContext.cpp + AVIOFileLikeContext.cpp FFMPEGCommon.cpp Frame.cpp DeviceInterface.cpp @@ -142,15 +143,6 @@ function(make_torchcodec_libraries "${pybind_ops_sources}" "${pybind_ops_dependencies}" ) - # pybind11 limits the visibility of symbols in the shared library to prevent - # stray initialization of py::objects. The rest of the object code must - # match. See: - # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes - target_compile_options( - ${pybind_ops_library_name} - PUBLIC - "-fvisibility=hidden" - ) # If we don't make sure this flag is set, we run into segfauls at import # time on Mac. See: # https://github.com/pybind/pybind11/issues/3907#issuecomment-1170412764 diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 90579432..b9332f72 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -135,10 +135,10 @@ AudioEncoder::AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr avioToTensorContext, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), - avioContextHolder_(std::move(avioContextHolder)) { + avioToTensorContext_(std::move(avioToTensorContext)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( @@ -153,7 +153,34 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - avFormatContext_->pb = avioContextHolder_->getAVIOContext(); + avFormatContext_->pb = avioToTensorContext_->getAVIOContext(); + + initializeEncoder(sampleRate, audioStreamOptions); +} + +AudioEncoder::AudioEncoder( + const torch::Tensor& samples, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioFileLikeContext, + const AudioStreamOptions& audioStreamOptions) + : samples_(validateSamples(samples)), + avioFileLikeContext_(std::move(avioFileLikeContext)) { + setFFmpegLogLevel(); + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext for file-like object. ", + "Check the desired format? Got format=", + formatName, + ". ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioFileLikeContext_->getAVIOContext(); initializeEncoder(sampleRate, audioStreamOptions); } @@ -217,10 +244,10 @@ void AudioEncoder::initializeEncoder( torch::Tensor AudioEncoder::encodeToTensor() { TORCH_CHECK( - avioContextHolder_ != nullptr, - "Cannot encode to tensor, avio context doesn't exist."); + avioToTensorContext_ != nullptr, + "Cannot encode to tensor, avio tensor context doesn't exist."); encode(); - return avioContextHolder_->getOutputTensor(); + return avioToTensorContext_->getOutputTensor(); } void AudioEncoder::encode() { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index ea500901..1919950f 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,5 +1,6 @@ #pragma once #include +#include "src/torchcodec/_core/AVIOFileLikeContext.h" #include "src/torchcodec/_core/AVIOTensorContext.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/StreamOptions.h" @@ -20,13 +21,27 @@ class AudioEncoder { int sampleRate, std::string_view fileName, const AudioStreamOptions& audioStreamOptions); + + // We need one constructor for each type of AVIOContextHolder. We can't have a + // single constructor that accepts the base AVIOContextHolder class and hold + // that as attribute, because we are calling the getOutputTensor() method on + // the AVIOToTensorContext, which is not available in the base class. AudioEncoder( const torch::Tensor& samples, int sampleRate, std::string_view formatName, - std::unique_ptr avioContextHolder, + std::unique_ptr AVIOToTensorContext, const AudioStreamOptions& audioStreamOptions); + + AudioEncoder( + const torch::Tensor& samples, + int sampleRate, + std::string_view formatName, + std::unique_ptr AVIOFileLikeContext, + const AudioStreamOptions& audioStreamOptions); + void encode(); + torch::Tensor encodeToTensor(); private: @@ -49,8 +64,8 @@ class AudioEncoder { const torch::Tensor samples_; - // Stores the AVIOContext for the output tensor buffer. - std::unique_ptr avioContextHolder_; + std::unique_ptr avioToTensorContext_; + std::unique_ptr avioFileLikeContext_; bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 77fc7b85..3d340bff 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -23,6 +23,7 @@ create_from_file_like, create_from_tensor, encode_audio_to_file, + encode_audio_to_file_like, encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a68b51e2..6d532a24 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -153,6 +153,59 @@ def create_from_file_like( return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) +def encode_audio_to_file_like( + samples: torch.Tensor, + sample_rate: int, + format: str, + file_like: Union[io.RawIOBase, io.BufferedIOBase], + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, +) -> None: + """Encode audio samples to a file-like object. + + Args: + samples: Audio samples tensor + sample_rate: Sample rate in Hz + format: Audio format (e.g., "wav", "mp3", "flac") + file_like: File-like object that supports write() and seek() methods + bit_rate: Optional bit rate for encoding + num_channels: Optional number of output channels + """ + assert _pybind_ops is not None + + if samples.dtype != torch.float32: + raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}") + + # We're having the same problem as with the decoder's create_from_file_like: + # We should be able to pass a tensor directly, but this leads to a pybind + # error. In order to work around this, we pass the pointer to the tensor's + # data, and its shape, in order to re-construct it in C++. For this to work: + # - the tensor must be float32 + # - the tensor must be contiguous, which is why we call contiguous(). + # In theory we could avoid this restriction by also passing the strides? + # - IMPORTANT: the input samples tensor and its underlying data must be + # alive during the call. + # + # A more elegant solution would be to cast the tensor into a py::object, but + # casting the py::object backk to a tensor in C++ seems to lead to the same + # pybing error. + + samples = samples.contiguous() + _pybind_ops.encode_audio_to_file_like( + samples.data_ptr(), + list(samples.shape), + sample_rate, + format, + file_like, + bit_rate, + num_channels, + ) + + # This check is useless but it's critical to keep it to ensures that samples + # is still alive during the call to encode_audio_to_file_like. + assert samples.is_contiguous() + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6f873f5a..e4e5369d 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -10,7 +10,9 @@ #include #include "src/torchcodec/_core/AVIOFileLikeContext.h" +#include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" +#include "src/torchcodec/_core/StreamOptions.h" namespace py = pybind11; @@ -31,15 +33,58 @@ int64_t create_from_file_like( realSeek = seekModeFromString(seek_mode.value()); } - auto avioContextHolder = std::make_unique(file_like); + auto avioContextHolder = + std::make_unique(file_like, /*isForWriting=*/false); SingleStreamDecoder* decoder = new SingleStreamDecoder(std::move(avioContextHolder), realSeek); return reinterpret_cast(decoder); } +void encode_audio_to_file_like( + int64_t data_ptr, + const std::vector& shape, + int64_t sample_rate, + std::string_view format, + py::object file_like, + std::optional bit_rate = std::nullopt, + std::optional num_channels = std::nullopt) { + // We assume float32 *and* contiguity, this must be enforced by the caller. + auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32); + auto samples = torch::from_blob( + reinterpret_cast(data_ptr), shape, tensor_options); + + // TODO Fix implicit int conversion: + // https://github.com/pytorch/torchcodec/issues/679 + // same for sample_rate parameter below + AudioStreamOptions audioStreamOptions; + audioStreamOptions.bitRate = bit_rate; + audioStreamOptions.numChannels = num_channels; + + auto avioContextHolder = + std::make_unique(file_like, /*isForWriting=*/true); + + AudioEncoder encoder( + samples, + static_cast(sample_rate), + format, + std::move(avioContextHolder), + audioStreamOptions); + encoder.encode(); +} + PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); + m.def( + "encode_audio_to_file_like", + &encode_audio_to_file_like, + "data_ptr", + "shape", + "sample_rate", + "format", + "file_like", + "bit_rate", + "num_channels"); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/encoders/_audio_encoder.py b/src/torchcodec/encoders/_audio_encoder.py index 742ea908..27df6dcf 100644 --- a/src/torchcodec/encoders/_audio_encoder.py +++ b/src/torchcodec/encoders/_audio_encoder.py @@ -98,3 +98,38 @@ def to_tensor( bit_rate=bit_rate, num_channels=num_channels, ) + + def to_file_like( + self, + file_like, + format: str, + *, + bit_rate: Optional[int] = None, + num_channels: Optional[int] = None, + ) -> None: + """Encode samples into a file-like object. + + Args: + file_like: A file-like object that supports ``write()`` and + ``seek()`` methods, such as io.BytesIO(), an open file in binary + write mode, etc. Methods must have the following signature: + ``write(data: bytes) -> int`` and ``seek(offset: int, whence: + int = 0) -> int``. + format (str): The format of the encoded samples, e.g. "mp3", "wav" + or "flac". + bit_rate (int, optional): The output bit rate. Encoders typically + support a finite set of bit rate values, so ``bit_rate`` will be + matched to one of those supported values. The default is chosen + by FFmpeg. + num_channels (int, optional): The number of channels of the encoded + output samples. By default, the number of channels of the input + ``samples`` is used. + """ + _core.encode_audio_to_file_like( + samples=self._samples, + sample_rate=self._sample_rate, + format=format, + file_like=file_like, + bit_rate=bit_rate, + num_channels=num_channels, + ) diff --git a/test/test_encoders.py b/test/test_encoders.py index 284d053e..375a9efd 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,3 +1,4 @@ +import io import json import os import re @@ -135,13 +136,16 @@ def test_bad_input(self): ): encoder.to_tensor(format=bad_format) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parametrized(self, method, tmp_path): - valid_params = ( - dict(dest=str(tmp_path / "output.mp3")) - if method == "to_file" - else dict(format="mp3") - ) + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp3")) + elif method == "to_tensor": + valid_params = dict(format="mp3") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp3") + else: + raise ValueError(f"Unknown method: {method}") decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -172,7 +176,7 @@ def test_bad_input_parametrized(self, method, tmp_path): ): getattr(decoder, method)(**valid_params, num_channels=num_channels) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) @pytest.mark.parametrize("format", ("wav", "flac")) def test_round_trip(self, method, format, tmp_path): # Check that decode(encode(samples)) == samples on lossless formats @@ -189,10 +193,16 @@ def test_round_trip(self, method, format, tmp_path): encoded_path = str(tmp_path / f"output.{format}") encoded_source = encoded_path encoder.to_file(dest=encoded_path) - else: + elif method == "to_tensor": encoded_source = encoder.to_tensor(format=format) assert encoded_source.dtype == torch.uint8 assert encoded_source.ndim == 1 + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format) + encoded_source = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") rtol, atol = (0, 1e-4) if format == "wav" else (None, None) torch.testing.assert_close( @@ -204,7 +214,7 @@ def test_round_trip(self, method, format, tmp_path): @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_against_cli( self, asset, @@ -240,8 +250,14 @@ def test_against_cli( if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" encoder.to_file(dest=str(encoded_by_us), **params) - else: + elif method == "to_tensor": encoded_by_us = encoder.to_tensor(format=format, **params) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format, **params) + encoded_by_us = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") captured = capfd.readouterr() if format == "wav": @@ -277,15 +293,14 @@ def test_against_cli( if method == "to_file": validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) - else: - assert method == "to_tensor", "wrong test parametrization!" @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) - def test_to_tensor_against_to_file( - self, asset, bit_rate, num_channels, format, tmp_path + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) + def test_against_to_file( + self, asset, bit_rate, num_channels, format, tmp_path, method ): if get_ffmpeg_major_version() == 4 and format == "wav": pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") @@ -295,12 +310,22 @@ def test_to_tensor_against_to_file( params = dict(bit_rate=bit_rate, num_channels=num_channels) encoded_file = tmp_path / f"output.{format}" encoder.to_file(dest=str(encoded_file), **params) - encoded_tensor = encoder.to_tensor( - format=format, bit_rate=bit_rate, num_channels=num_channels - ) + + if method == "to_tensor": + encoded_output = encoder.to_tensor( + format=format, bit_rate=bit_rate, num_channels=num_channels + ) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like, format=format, bit_rate=bit_rate, num_channels=num_channels + ) + encoded_output = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") torch.testing.assert_close( - self.decode(encoded_file).data, self.decode(encoded_tensor).data + self.decode(encoded_file).data, self.decode(encoded_output).data ) def test_encode_to_tensor_long_output(self): @@ -318,21 +343,16 @@ def test_encode_to_tensor_long_output(self): torch.testing.assert_close(self.decode(encoded_tensor).data, samples) - def test_contiguity(self): + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_contiguity(self, method, tmp_path): # Ensure that 2 waveforms with the same values are encoded in the same # way, regardless of their memory layout. Here we encode 2 equal # waveforms, one is row-aligned while the other is column-aligned. - # TODO: Ideally we'd be testing all encoding methods here num_samples = 10_000 # per channel contiguous_samples = torch.rand(2, num_samples).contiguous() assert contiguous_samples.stride() == (num_samples, 1) - params = dict(format="flac", bit_rate=44_000) - encoded_from_contiguous = AudioEncoder( - contiguous_samples, sample_rate=16_000 - ).to_tensor(**params) - non_contiguous_samples = contiguous_samples.T.contiguous().T assert non_contiguous_samples.stride() == (1, 2) @@ -340,9 +360,28 @@ def test_contiguity(self): contiguous_samples, non_contiguous_samples, rtol=0, atol=0 ) - encoded_from_non_contiguous = AudioEncoder( - non_contiguous_samples, sample_rate=16_000 - ).to_tensor(**params) + def encode_to_tensor(samples): + params = dict(bit_rate=44_000) + if method == "to_file": + dest = str(tmp_path / "output.flac") + AudioEncoder(samples, sample_rate=16_000).to_file(dest=dest, **params) + with open(dest, "rb") as f: + return torch.frombuffer(f.read(), dtype=torch.uint8) + elif method == "to_tensor": + return AudioEncoder(samples, sample_rate=16_000).to_tensor( + format="flac", **params + ) + elif method == "to_file_like": + file_like = io.BytesIO() + AudioEncoder(samples, sample_rate=16_000).to_file_like( + file_like, format="flac", **params + ) + return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) + else: + raise ValueError(f"Unknown method: {method}") + + encoded_from_contiguous = encode_to_tensor(contiguous_samples) + encoded_from_non_contiguous = encode_to_tensor(non_contiguous_samples) torch.testing.assert_close( encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 @@ -350,7 +389,7 @@ def test_contiguity(self): @pytest.mark.parametrize("num_channels_input", (1, 2)) @pytest.mark.parametrize("num_channels_output", (1, 2, None)) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_num_channels( self, num_channels_input, num_channels_output, method, tmp_path ): @@ -368,8 +407,14 @@ def test_num_channels( encoded_path = str(tmp_path / f"output.{format}") encoded_source = encoded_path encoder.to_file(dest=encoded_path, **params) - else: + elif method == "to_tensor": encoded_source = encoder.to_tensor(format=format, **params) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like(file_like, format=format, **params) + encoded_source = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") if num_channels_output is None: num_channels_output = num_channels_input @@ -384,3 +429,72 @@ def test_1d_samples(self): AudioEncoder(samples_1d, sample_rate=sample_rate).to_tensor("wav"), AudioEncoder(samples_2d, sample_rate=sample_rate).to_tensor("wav"), ) + + def test_to_file_like_custom_file_object(self, tmp_path): + class CustomFileObject: + def __init__(self): + self._file = io.BytesIO() + + def write(self, data): + return self._file.write(data) + + def seek(self, offset, whence=0): + return self._file.seek(offset, whence) + + def get_encoded_data(self): + return self._file.getvalue() + + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + file_like = CustomFileObject() + encoder.to_file_like(file_like, format="wav") + + decoded_samples = self.decode(file_like.get_encoded_data()) + + torch.testing.assert_close( + decoded_samples.data, + source_samples, + rtol=0, + atol=1e-4, + ) + + def test_to_file_like_real_file(self, tmp_path): + """Test to_file_like with a real file opened in binary write mode.""" + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + file_path = tmp_path / "test_file_like.wav" + + with open(file_path, "wb") as file_like: + encoder.to_file_like(file_like, format="wav") + + decoded_samples = self.decode(str(file_path)) + torch.testing.assert_close( + decoded_samples.data, source_samples, rtol=0, atol=1e-4 + ) + + def test_to_file_like_bad_methods(self): + asset = NASA_AUDIO_MP3 + source_samples = self.decode(asset).data + encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) + + class NoWriteMethod: + def seek(self, offset, whence=0): + return 0 + + with pytest.raises( + RuntimeError, match="File like object must implement a write method" + ): + encoder.to_file_like(NoWriteMethod(), format="wav") + + class NoSeekMethod: + def write(self, data): + return len(data) + + with pytest.raises( + RuntimeError, match="File like object must implement a seek method" + ): + encoder.to_file_like(NoSeekMethod(), format="wav")