Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
anomaly_match/pretrained_cache/models--timm--*/blobs/* filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
id-token: write
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
Expand Down
9 changes: 8 additions & 1 deletion anomaly_match/models/FixMatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ def __init__(
self.ema_m = ema_m

# Create two versions of the model: one for training and one for evaluation with EMA
# eval_model skips pretrained weight download because its weights are immediately
# overwritten by copying from train_model below.
self.train_model = net_builder(num_classes=num_classes, in_channels=in_channels)
self.eval_model = net_builder(num_classes=num_classes, in_channels=in_channels)
try:
self.eval_model = net_builder(
num_classes=num_classes, in_channels=in_channels, pretrained=False
)
except TypeError:
self.eval_model = net_builder(num_classes=num_classes, in_channels=in_channels)
self.T = T
self.p_cutoff = p_cutoff
self.lambda_u = lambda_u
Expand Down
Git LFS file not shown
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
e074755b39c3cd8ca77d917a7cfb62e70da6bf88
35 changes: 32 additions & 3 deletions anomaly_match/utils/get_net_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
from pathlib import Path

import timm
import torch.nn as nn
from loguru import logger

# Local cache directory for pretrained weights, avoids flock issues on NFS filesystems.
_PACKAGE_DIR = Path(__file__).resolve().parent.parent
_PRETRAINED_CACHE_DIR = _PACKAGE_DIR / "pretrained_cache"

# Mapping from AnomalyMatch net names to timm model identifiers.
# efficientnet-lite variants use the tf_ prefix for TF-style same-padding,
# which is backward compatible with the previous efficientnet_lite_pytorch package.
Expand Down Expand Up @@ -113,7 +119,7 @@ def get_net_builder(net_name, pretrained=False, in_channels=3):
if net_name == "test-cnn":
logger.debug("Using test-cnn model (for testing only)")

def build_test_cnn(num_classes, in_channels):
def build_test_cnn(num_classes, in_channels, pretrained=None):
return TestCNN(num_classes=num_classes, in_channels=in_channels)

return build_test_cnn
Expand All @@ -124,10 +130,33 @@ def build_test_cnn(num_classes, in_channels):
f"(timm: {timm_name})"
)

def build_model(num_classes, in_channels, _timm_name=timm_name, _pretrained=use_pretrained):
def build_model(
num_classes, in_channels, _timm_name=timm_name, _pretrained=use_pretrained, pretrained=None
):
effective_pretrained = pretrained if pretrained is not None else _pretrained
if effective_pretrained:
try:
return timm.create_model(
_timm_name,
pretrained=True,
num_classes=num_classes,
in_chans=in_channels,
cache_dir=str(_PRETRAINED_CACHE_DIR),
)
except Exception:
logger.warning(
f"Bundled pretrained weights not available (clone with git-lfs to avoid "
f"re-downloading). Downloading {_timm_name} from HuggingFace."
)
return timm.create_model(
_timm_name,
pretrained=True,
num_classes=num_classes,
in_chans=in_channels,
)
return timm.create_model(
_timm_name,
pretrained=_pretrained,
pretrained=False,
num_classes=num_classes,
in_chans=in_channels,
)
Expand Down