From 878a929bd1b8b887938179b0d88ee979056a65a1 Mon Sep 17 00:00:00 2001 From: exdysa <91800957+exdysa@users.noreply.github.com> Date: Fri, 2 Jan 2026 09:44:54 -0500 Subject: [PATCH 1/5] +tests --- tests/__init__.py | 0 tests/test_extract_calls.py | 36 ++++++++++++++++++++++ tvln/batch.py | 14 +++++++-- tvln/clip_features.py | 12 ++++++++ tvln/main.py | 15 ++------- tvln/tests/test_extract_calls.py | 53 -------------------------------- 6 files changed, 61 insertions(+), 69 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_extract_calls.py delete mode 100644 tvln/tests/test_extract_calls.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_extract_calls.py b/tests/test_extract_calls.py new file mode 100644 index 0000000..5f4a1ce --- /dev/null +++ b/tests/test_extract_calls.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 +# + +import pytest +from unittest.mock import patch, MagicMock + +from tvln.clip_features import CLIPFeatures + + +@pytest.fixture +def dummy_image_file(): + return {"text": ["a"], "image": ["b"]} + + +@patch("tvln.clip_features.cleanup") +@patch.object(CLIPFeatures, "extract", return_value=MagicMock(name="tensor")) +@patch.object(CLIPFeatures, "set_model_link") +@patch.object(CLIPFeatures, "set_model_type") +@patch.object(CLIPFeatures, "set_precision") +@patch.object(CLIPFeatures, "set_device") +def test_clip_features_flow(mock_dev, mock_prec, mock_type, mock_link, mock_extract, mock_cleanup, dummy_image_file): + # run the three blocks (copy‑paste the original snippet here) + # block 1 + from tvln.main import main + + tensor_stack = main() + # assertions + assert mock_dev.call_count == 3 + assert mock_prec.call_count == 3 + assert mock_link.call_count == 2 + assert mock_type.call_count == 1 + assert mock_extract.call_count == 3 + assert mock_cleanup.call_count == 3 + for name, tensors in tensor_stack.items(): + if name != "F1 VAE": + assert isinstance(tensors[1], MagicMock) # extraction was triggered diff --git a/tvln/batch.py b/tvln/batch.py index 66d52a0..4d8ca35 100644 --- a/tvln/batch.py +++ b/tvln/batch.py @@ -1,20 +1,28 @@ # SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 # -import torch from pathlib import Path +import torch + class ImageFile: _image_path: str def __init__(self) -> None: + """Initializes an ImageFile instance with a default""" self._default_path = Path(__file__).resolve().parent / "assets" / "DSC_0047.png" self._default_path.resolve() self._default_path.as_posix() def single_image(self) -> None: - image_path = input("Enter the path to an image file (e.g. /home/user/image.png, C:/Users/user/Pictures/...): ") + """Set absolute path to an image file, ensuring the file exists, falling back to a default image if none is provided.""" + from sys import modules as sys_modules + + if "pytest" not in sys_modules: + image_path = input("Enter the path to an image file (e.g. /home/user/image.png, C:/Users/user/Pictures/...): ") + else: + image_path = None if not image_path: image_path = self._default_path if not Path(image_path).resolve().is_file(): @@ -32,8 +40,8 @@ def as_tensor(self, dtype: torch.dtype, device: torch.device, normalize: bool = :param normalize: Normalize tensor to [-1, 1]: :return: Tensor of shape ``[1, 3, H, W]`` on ``device``.""" - from numpy._typing import NDArray from numpy import array as np_array + from numpy._typing import NDArray from PIL.Image import open as open_img with open_img(str(self._image_path)).convert("RGB") as pil_image: diff --git a/tvln/clip_features.py b/tvln/clip_features.py index dbf7043..722220c 100644 --- a/tvln/clip_features.py +++ b/tvln/clip_features.py @@ -194,3 +194,15 @@ def extract(self, image: ImageFile, last_layer=False) -> Tensor: else: self._images = [image._image_path] return self.ImageEncoder() + + +def cleanup(model: CLIPFeatures, device: str) -> None: # type:ignore + import torch + import gc + + if device != "cpu": + gpu = getattr(torch, device) + gpu.empty_cache() + model: None = None + del model + gc.collect() diff --git a/tvln/main.py b/tvln/main.py index 07a851a..67010df 100644 --- a/tvln/main.py +++ b/tvln/main.py @@ -4,22 +4,10 @@ import torch -def cleanup(model, device: str): - import torch - import gc - - if device != "cpu": - gpu = getattr(torch, device) - gpu.empty_cache() - model = None - del model - gc.collect() - - @torch.no_grad def main(): import os - from tvln.clip_features import CLIPFeatures, DeviceName, PrecisionType, ModelLink, ModelType + from tvln.clip_features import CLIPFeatures, DeviceName, PrecisionType, ModelLink, ModelType, cleanup from tvln.batch import ImageFile from huggingface_hub import snapshot_download from diffusers import AutoencoderKL @@ -76,6 +64,7 @@ def main(): "F1 VAE": ["black-forest-labs/FLUX.1-dev", vae_tensor[0].sample()], } print(aggregate_data) + return aggregate_data if __name__ == "__main__": diff --git a/tvln/tests/test_extract_calls.py b/tvln/tests/test_extract_calls.py deleted file mode 100644 index 142a209..0000000 --- a/tvln/tests/test_extract_calls.py +++ /dev/null @@ -1,53 +0,0 @@ -mport builtins -import pytest -from unittest.mock import patch, MagicMock - -from tefln.main import CLIPFeatures, cleanup, ModelLink, ModelType - -@pytest.fixture -def dummy_image_file(): - return {"text": ["a"], "image": ["b"]} - -@patch("tefln.main.cleanup") -@patch.object(CLIPFeatures, "extract", return_value=MagicMock(name="tensor")) -@patch.object(CLIPFeatures, "set_model_link") -@patch.object(CLIPFeatures, "set_model_type") -@patch.object(CLIPFeatures, "set_precision") -@patch.object(CLIPFeatures, "set_device") -def test_clip_features_flow( - mock_dev, mock_prec, mock_type, mock_link, mock_extract, mock_cleanup, dummy_image_file -): - # run the three blocks (copy‑paste the original snippet here) - # block 1 - f = CLIPFeatures() - f.set_device("cpu") - f.set_precision("fp16") - f.set_model_link(ModelLink.VIT_L_14_LAION2B_S32B_B82K) - t1 = f.extract(dummy_image_file) - cleanup(model=f, device="cpu") - - # block 2 - f = CLIPFeatures() - f.set_device("cpu") - f.set_precision("fp16") - f.set_model_type(ModelType.VIT_L_14_LAION400M_E32) - t2 = f.extract(dummy_image_file, last_layer=True) - cleanup(model=f, device="cpu") - - # block 3 - f = CLIPFeatures() - f.set_device("cpu") - f.set_precision("fp16") - f.set_model_link(ModelLink.VIT_BIGG_14_LAION2B_S39B_B160K) - t3 = f.extract(dummy_image_file) - cleanup(model=f, device="cpu") - - # assertions - assert mock_dev.call_count == 3 - assert mock_prec.call_count == 3 - assert mock_link.call_count == 2 - assert mock_type.call_count == 1 - assert mock_extract.call_count == 3 - assert mock_cleanup.call_count == 3 - for t in (t1, t2, t3): - assert isinstance(t, MagicMock) # extraction was triggered \ No newline at end of file From cbfb98a08ea6cead54214e1a1802187e9b5a0447 Mon Sep 17 00:00:00 2001 From: exdysa <91800957+exdysa@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:47:18 -0500 Subject: [PATCH 2/5] ~Improve test, encapsulate image batching, extraction, parameters --- tests/test_extract_calls.py | 4 +-- tvln/batch.py | 6 +++- tvln/clip_features.py | 60 +++--------------------------------- tvln/extract.py | 61 +++++++++++++++++++++++++++++-------- tvln/main.py | 48 ++++++++++------------------- tvln/options.py | 48 +++++++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 103 deletions(-) create mode 100644 tvln/options.py diff --git a/tests/test_extract_calls.py b/tests/test_extract_calls.py index 5f4a1ce..a3d84fd 100644 --- a/tests/test_extract_calls.py +++ b/tests/test_extract_calls.py @@ -3,8 +3,8 @@ import pytest from unittest.mock import patch, MagicMock - from tvln.clip_features import CLIPFeatures +from tvln.extract import FeatureExtractor @pytest.fixture @@ -12,7 +12,7 @@ def dummy_image_file(): return {"text": ["a"], "image": ["b"]} -@patch("tvln.clip_features.cleanup") +@patch.object(FeatureExtractor, "cleanup") @patch.object(CLIPFeatures, "extract", return_value=MagicMock(name="tensor")) @patch.object(CLIPFeatures, "set_model_link") @patch.object(CLIPFeatures, "set_model_type") diff --git a/tvln/batch.py b/tvln/batch.py index 4d8ca35..8a6c7b2 100644 --- a/tvln/batch.py +++ b/tvln/batch.py @@ -33,7 +33,7 @@ def single_image(self) -> None: if not isinstance(self._image_path, str): raise TypeError(f"Expected a string or list of strings for `image_paths` {self._image_path}, got {type(self._image_path)} ") - def as_tensor(self, dtype: torch.dtype, device: torch.device, normalize: bool = False) -> None: + def as_tensor(self, dtype: torch.dtype, device: str, normalize: bool = False) -> None: """Convert a Pillow `Image` to a batched `torch.Tensor`\n :param image: Pillow image (RGB) to encode. :param device: Target device for the tensor (default: ``gpu.device``). @@ -51,3 +51,7 @@ def as_tensor(self, dtype: torch.dtype, device: torch.device, normalize: bool = if normalize: tensor = tensor * 2.0 - 1.0 self.tensor = tensor + + @property + def image_path(self) -> str: + return self._image_path diff --git a/tvln/clip_features.py b/tvln/clip_features.py index 722220c..9ac478c 100644 --- a/tvln/clip_features.py +++ b/tvln/clip_features.py @@ -3,52 +3,10 @@ from enum import Enum -from open_clip import list_pretrained -from open_clip.pretrained import _PRETRAINED from torch import Tensor, nn from tvln.batch import ImageFile - - -class DeviceName(str, Enum): - """Graphics processors usable by the CLIP pipeline.""" - - CPU = "cpu" - CUDA = "cuda" - MPS = "mps" - - -class PrecisionType(str, Enum): - """Supported numeric float precision.""" - - FP64 = "fp64" - FP32 = "fp32" - BF16 = "bf16" - FP16 = "fp16" - - -ModelType = Enum( - "ModelData", - { - # member name → (model_type, pretrained) value - f"{model.replace('-', '_').upper()}_{pretrained.replace('-', '_').upper()}": ( - model, - pretrained, - ) - for model, pretrained in list_pretrained() - }, -) - - -ModelLink = Enum( - "ModelData", - { - f"{family.replace('-', '_').upper()}_{id.replace('-', '_').upper()}": (data.get("hf_hub", "").strip("/"), data.get("url")) - for family, name in _PRETRAINED.items() - for id, data in name.items() - if data.get("hf_hub") or data.get("url") - }, -) +from tvln.options import PrecisionType, ModelType, DeviceName def get_model_and_pretrained(member: Enum) -> tuple[str, str]: @@ -59,7 +17,9 @@ def get_model_and_pretrained(member: Enum) -> tuple[str, str]: class CLIPEncoder(nn.Module): - """CLIP wrapper courtesy ncclab-sustech/BrainFLORA""" + """CLIP wrapper\n\n + MIT licensed by ncclab-sustech/BrainFLORA + """ def __init__(self, device: str = "cpu", model: str = "openai/clip-vit-large-patch14") -> None: """Instantiate the encoder with a specific device and model\n @@ -194,15 +154,3 @@ def extract(self, image: ImageFile, last_layer=False) -> Tensor: else: self._images = [image._image_path] return self.ImageEncoder() - - -def cleanup(model: CLIPFeatures, device: str) -> None: # type:ignore - import torch - import gc - - if device != "cpu": - gpu = getattr(torch, device) - gpu.empty_cache() - model: None = None - del model - gc.collect() diff --git a/tvln/extract.py b/tvln/extract.py index 9a950a0..e360faf 100644 --- a/tvln/extract.py +++ b/tvln/extract.py @@ -2,24 +2,61 @@ # from enum import Enum + from tvln.batch import ImageFile -from tvln.clip_features import CLIPFeatures, ModelType, ModelLink, PrecisionType, DeviceName +from tvln.clip_features import CLIPFeatures +from tvln.options import DeviceName, ModelLink, ModelType, PrecisionType class FeatureExtractor: - def __init__(self, text_device: DeviceName, precision: PrecisionType): - self.text_device = text_device + def __init__(self, device: DeviceName, precision: PrecisionType, image_file: ImageFile): + self.device = device self.precision = precision + self.image_file = image_file + + def extract_features(self, model_info: Enum | str, last_layer: bool = False): + """Extract features from the image using the specified model. + :param model_info: The kind of model to use + :param last_layer: Whether the features are extracted from the last layer of the model or from an intermediate layer.""" + + if isinstance(model_info, str): + import os - def extract_features(self, model_info: Enum, image_file: ImageFile, last_layer: bool = False): - feature_extractor = CLIPFeatures() - feature_extractor.set_device(self.text_device) - feature_extractor.set_precision(self.precision) + from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + from huggingface_hub import snapshot_download + + vae_path = snapshot_download(model_info, allow_patterns=["vae/*"]) + vae_path = os.path.join(vae_path, "vae") + vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.precision.value).to(self.device.value) + vae_tensor = vae_model.tiled_encode(self.image_file.tensor, return_dict=False) + return vae_tensor, model_info + clip_extractor = CLIPFeatures() + clip_extractor.set_device(self.device) + clip_extractor.set_precision(self.precision) if isinstance(model_info, ModelLink): - feature_extractor.set_model_link(model_info) + clip_extractor.set_model_link(model_info) elif isinstance(model_info, ModelType): - feature_extractor.set_model_type(model_info) - tensor = feature_extractor.extract(image_file, last_layer) - data = vars(feature_extractor) - self.cleanup(model=feature_extractor, device=self.text_device) + clip_extractor.set_model_type(model_info) + print(clip_extractor._precision) + tensor = clip_extractor.extract(self.image_file, last_layer) + data = vars(clip_extractor) + self.cleanup(model=clip_extractor) return tensor, data + + def cleanup(self, model: CLIPFeatures) -> None: # type:ignore + """Cleans up the model and frees GPU memory + :param model: The model instance used for feature extraction""" + + import gc + + import torch + + if self.device != "cpu": + gpu = getattr(torch, self.device) + gpu.empty_cache() + model: None = None + del model + gc.collect() + + def set_device(self, device: DeviceName): + self.device = device diff --git a/tvln/main.py b/tvln/main.py index 67010df..011419b 100644 --- a/tvln/main.py +++ b/tvln/main.py @@ -7,12 +7,15 @@ @torch.no_grad def main(): import os - from tvln.clip_features import CLIPFeatures, DeviceName, PrecisionType, ModelLink, ModelType, cleanup - from tvln.batch import ImageFile + from huggingface_hub import snapshot_download - from diffusers import AutoencoderKL + + from tvln.batch import ImageFile + from tvln.options import DeviceName, ModelLink, ModelType, PrecisionType + from tvln.extract import FeatureExtractor text_device: str = DeviceName.CPU + device = text_device precision = PrecisionType.FP32 dtype = torch.float32 if torch.cuda.is_available(): @@ -24,38 +27,19 @@ def main(): image_file = ImageFile() image_file.single_image() - image_file.as_tensor(dtype=torch.float32, device=text_device) - - feature_extractor = CLIPFeatures() - feature_extractor.set_device(text_device) - feature_extractor.set_precision(precision) - feature_extractor.set_model_link(ModelLink.VIT_L_14_LAION2B_S32B_B82K) # type: ignore - clip_l_tensor = feature_extractor.extract(image_file) - clip_l_data = vars(feature_extractor) - cleanup(model=feature_extractor, device=text_device) - - feature_extractor = CLIPFeatures() - feature_extractor.set_device(text_device) - feature_extractor.set_precision(precision) - feature_extractor.set_model_type(ModelType.VIT_L_14_LAION400M_E32) - clip_l_e32_tensor = feature_extractor.extract(image_file, last_layer=True) - clip_l_e32_data = vars(feature_extractor) - cleanup(model=feature_extractor, device=text_device) - - feature_extractor = CLIPFeatures() - feature_extractor.set_device(text_device) - feature_extractor.set_precision(precision) - feature_extractor.set_model_link(ModelLink.VIT_BIGG_14_LAION2B_S39B_B160K) - clip_g_tensor = feature_extractor.extract(image_file) - clip_g_data = vars(feature_extractor) - cleanup(model=feature_extractor, device=text_device) - vae_path = snapshot_download("black-forest-labs/FLUX.1-dev", allow_patterns=["vae/*"]) - vae_path = os.path.join(vae_path, "vae") + feature_extractor = FeatureExtractor(text_device, precision=precision, image_file=image_file) + print(feature_extractor.precision) + clip_l_tensor, clip_l_data = feature_extractor.extract_features(model_info=ModelLink.VIT_L_14_LAION2B_S32B_B82K) + print(feature_extractor.precision) + clip_l_e32_tensor, clip_l_e32_data = feature_extractor.extract_features(ModelType.VIT_L_14_LAION400M_E32, last_layer=True) + print(feature_extractor.precision) + clip_g_tensor, clip_g_data = feature_extractor.extract_features(ModelLink.VIT_BIGG_14_LAION2B_S39B_B160K) image_file.as_tensor(dtype=dtype, device=device) - vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=dtype).to(device) - vae_tensor = vae_model.tiled_encode(image_file.tensor, return_dict=False) + print(dtype) + feature_extractor.set_device(device) + vae_tensor, vae_data = feature_extractor.extract_features("black-forest-labs/FLUX.1-dev") aggregate_data = { "CLIP_L_S32B82K": [clip_l_data, clip_l_tensor], diff --git a/tvln/options.py b/tvln/options.py new file mode 100644 index 0000000..a0ed9a5 --- /dev/null +++ b/tvln/options.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 +# + + +from enum import Enum +from open_clip import list_pretrained +from open_clip.pretrained import _PRETRAINED + + +class DeviceName(str, Enum): + """Graphics processors usable by the CLIP pipeline.""" + + CPU = "cpu" + CUDA = "cuda" + MPS = "mps" + + +class PrecisionType(str, Enum): + """Supported numeric float precision.""" + + FP64 = "fp64" + FP32 = "fp32" + BF16 = "bf16" + FP16 = "fp16" + + +ModelType = Enum( + "ModelData", + { + # member name → (model_type, pretrained) value + f"{model.replace('-', '_').upper()}_{pretrained.replace('-', '_').upper()}": ( + model, + pretrained, + ) + for model, pretrained in list_pretrained() + }, +) + + +ModelLink = Enum( + "ModelData", + { + f"{family.replace('-', '_').upper()}_{id.replace('-', '_').upper()}": (data.get("hf_hub", "").strip("/"), data.get("url")) + for family, name in _PRETRAINED.items() + for id, data in name.items() + if data.get("hf_hub") or data.get("url") + }, +) From b0598c2d32bbaab2a6f0440ae53a244f1504c9f0 Mon Sep 17 00:00:00 2001 From: exdysa <91800957+exdysa@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:49:21 -0500 Subject: [PATCH 3/5] ~add workflow --- .github/workflows/tvln.yml | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/tvln.yml diff --git a/.github/workflows/tvln.yml b/.github/workflows/tvln.yml new file mode 100644 index 0000000..96d051b --- /dev/null +++ b/.github/workflows/tvln.yml @@ -0,0 +1,39 @@ + +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: tvln pytest + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install system deps + run: | + sudo apt-get update + + + - name: Run Python Tests + uses: actions/setup-python@v5 + with: + python-version: 3.13 + - name: Install dependencies and run tests + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + uv python install 3.13 + uv python pin 3.13 + uv sync --group dev + source .venv/bin/activate + pytest -rPvv + From 73c89274b230a684a9e5555b076706d7b88d5b99 Mon Sep 17 00:00:00 2001 From: exdysa <91800957+exdysa@users.noreply.github.com> Date: Fri, 2 Jan 2026 11:12:36 -0500 Subject: [PATCH 4/5] ~patch patches --- tests/test_extract_calls.py | 52 ++++++++++++++++++++++++++++++++----- tvln/extract.py | 14 +++++++--- tvln/main.py | 6 +---- tvln/options.py | 4 +++ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/tests/test_extract_calls.py b/tests/test_extract_calls.py index a3d84fd..3b2d677 100644 --- a/tests/test_extract_calls.py +++ b/tests/test_extract_calls.py @@ -4,12 +4,37 @@ import pytest from unittest.mock import patch, MagicMock from tvln.clip_features import CLIPFeatures -from tvln.extract import FeatureExtractor +from tvln.extract import FeatureExtractor, AutoencoderKL, snapshot_download @pytest.fixture def dummy_image_file(): - return {"text": ["a"], "image": ["b"]} + """Return a minimal ImageFile‑like object.""" + return MagicMock() + + +@pytest.fixture +def dummy_download(): + """Patch `snapshot_download` so that no network traffic occurs.""" + with patch("tvln.extract.snapshot_download", autospec=True) as mock_dl: + mock_dl.return_value = "/tmp/vae" # dummy path + yield mock_dl + + +@pytest.fixture +def dummy_encoder(): + """ + Patch `AutoencoderKL` so that the model is never actually loaded. + The patched class returns a mock instance whose `tiled_encode` method + yields a dummy tensor. + """ + with patch("tvln.extract.AutoencoderKL", autospec=True) as MockKL: + # The classmethod `from_pretrained` should return a mock model + mock_model = MagicMock(name="vae_model") + MockKL.from_pretrained.return_value = mock_model + # The instance method `tiled_encode` returns a dummy tensor + mock_model.tiled_encode.return_value = MagicMock(name="vae_tensor") + yield MockKL @patch.object(FeatureExtractor, "cleanup") @@ -18,17 +43,30 @@ def dummy_image_file(): @patch.object(CLIPFeatures, "set_model_type") @patch.object(CLIPFeatures, "set_precision") @patch.object(CLIPFeatures, "set_device") -def test_clip_features_flow(mock_dev, mock_prec, mock_type, mock_link, mock_extract, mock_cleanup, dummy_image_file): +def test_clip_features_flow( + mock_set_device, + mock_set_precision, + mock_set_model_type, + mock_set_model_link, + mock_extract, + mock_cleanup, + dummy_encoder, + dummy_download, + dummy_image_file, +): + """ + Verify that the feature‑extraction pipeline calls the expected + """ # run the three blocks (copy‑paste the original snippet here) # block 1 from tvln.main import main tensor_stack = main() # assertions - assert mock_dev.call_count == 3 - assert mock_prec.call_count == 3 - assert mock_link.call_count == 2 - assert mock_type.call_count == 1 + assert mock_set_device.call_count == 3 + assert mock_set_precision.call_count == 3 + assert mock_set_model_link.call_count == 2 + assert mock_set_model_type.call_count == 1 assert mock_extract.call_count == 3 assert mock_cleanup.call_count == 3 for name, tensors in tensor_stack.items(): diff --git a/tvln/extract.py b/tvln/extract.py index e360faf..f2b5a97 100644 --- a/tvln/extract.py +++ b/tvln/extract.py @@ -3,9 +3,12 @@ from enum import Enum +import torch from tvln.batch import ImageFile from tvln.clip_features import CLIPFeatures from tvln.options import DeviceName, ModelLink, ModelType, PrecisionType +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from huggingface_hub import snapshot_download class FeatureExtractor: @@ -22,12 +25,9 @@ def extract_features(self, model_info: Enum | str, last_layer: bool = False): if isinstance(model_info, str): import os - from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL - from huggingface_hub import snapshot_download - vae_path = snapshot_download(model_info, allow_patterns=["vae/*"]) vae_path = os.path.join(vae_path, "vae") - vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.precision.value).to(self.device.value) + vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=self.precision).to(self.device.value) vae_tensor = vae_model.tiled_encode(self.image_file.tensor, return_dict=False) return vae_tensor, model_info clip_extractor = CLIPFeatures() @@ -60,3 +60,9 @@ def cleanup(self, model: CLIPFeatures) -> None: # type:ignore def set_device(self, device: DeviceName): self.device = device + + def set_precision(self, precision: PrecisionType | torch.dtype): + if isinstance(precision, PrecisionType): + self.precision = precision.value + else: + self.precision = precision diff --git a/tvln/main.py b/tvln/main.py index 011419b..f1d6bbf 100644 --- a/tvln/main.py +++ b/tvln/main.py @@ -6,10 +6,6 @@ @torch.no_grad def main(): - import os - - from huggingface_hub import snapshot_download - from tvln.batch import ImageFile from tvln.options import DeviceName, ModelLink, ModelType, PrecisionType from tvln.extract import FeatureExtractor @@ -37,8 +33,8 @@ def main(): clip_g_tensor, clip_g_data = feature_extractor.extract_features(ModelLink.VIT_BIGG_14_LAION2B_S39B_B160K) image_file.as_tensor(dtype=dtype, device=device) - print(dtype) feature_extractor.set_device(device) + feature_extractor.set_precision(torch.float32) vae_tensor, vae_data = feature_extractor.extract_features("black-forest-labs/FLUX.1-dev") aggregate_data = { diff --git a/tvln/options.py b/tvln/options.py index a0ed9a5..14b4578 100644 --- a/tvln/options.py +++ b/tvln/options.py @@ -22,6 +22,10 @@ class PrecisionType(str, Enum): FP32 = "fp32" BF16 = "bf16" FP16 = "fp16" + FLOAT16 = "torch.float16" + BFLOAT16 = "torch.bfloat16" + FLOAT32 = "torch.float32" + FLOAT64 = "torch.float64" ModelType = Enum( From 23f370d6d604c39c55859ca1155883623d2a8e36 Mon Sep 17 00:00:00 2001 From: exdysa <91800957+exdysa@users.noreply.github.com> Date: Fri, 2 Jan 2026 11:13:57 -0500 Subject: [PATCH 5/5] +add test status badge --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 4e912f1..2da9f44 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,5 @@ uv sync --group dev ``` tvln ``` + +[![tvln pytest](https://github.com/darkshapes/tvln/actions/workflows/tvln.yml/badge.svg)](https://github.com/darkshapes/tvln/actions/workflows/tvln.yml)