Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mr_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install ".[dev]"
python -m pip install ".[dev, lightning]"
python -m pip install pytest-cov

- name: Quality Assurance
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/mr_ci_text_spotting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
run: |
python -m pip install --no-cache-dir wheel
python -m pip install --no-cache-dir numpy==1.26.4 torch==2.2.2 torchvision==0.17.2 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --no-cache-dir ".[dev]"
python -m pip install --no-cache-dir ".[dev, lightning]"
python -m pip install --no-cache-dir pytest-cov
python -m pip install --no-cache-dir --no-build-isolation 'git+https://github.com/facebookresearch/detectron2.git'
python -m pip install --no-cache-dir --no-build-isolation 'git+https://github.com/maps-as-data/DeepSolo.git'
Expand Down
4 changes: 4 additions & 0 deletions mapreader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from mapreader.classify.datasets import PatchDataset
from mapreader.classify.datasets import PatchContextDataset
from mapreader.classify.classifier import ClassifierContainer
try:
from mapreader.classify.lightning_classifier import LightningClassifierContainer
except ImportError:
pass
from mapreader.classify import custom_models

# spot_text
Expand Down
30 changes: 17 additions & 13 deletions mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ class ClassifierContainer:
A dictionary to store dataloaders for the model.
labels_map : dict
A dictionary mapping label indices to their labels.
dataset_sizes : dict
A dictionary to store sizes of datasets for the model.
model : torch.nn.Module
The model.
input_size : Tuple of int
Expand Down Expand Up @@ -152,17 +150,18 @@ def __init__(
elif isinstance(model, str):
if huggingface:
try:
from transformers import AutoModelForImageClassification, AutoImageProcessor
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
except ImportError:
raise ImportError(
"Hugging Face models require the 'transformers' library: 'pip install transformers'."
)
print(f"[INFO] Initializing Hugging Face model: {model}")
num_labels = len(self.labels_map)
self.model = AutoModelForImageClassification.from_pretrained(
model,
num_labels=num_labels,
ignore_mismatched_sizes=True
model, num_labels=num_labels, ignore_mismatched_sizes=True
).to(self.device)
hf_processor = AutoImageProcessor.from_pretrained(model)
size = getattr(hf_processor, "size", {})
Expand Down Expand Up @@ -289,7 +288,7 @@ def _initialize_model(
num_ftrs = model_dw.fc.in_features
model_dw.fc = nn.Linear(num_ftrs, last_layer_num_classes)
is_inception = True
input_size = 299
input_size = (299, 299)

else:
raise NotImplementedError(
Expand Down Expand Up @@ -340,7 +339,7 @@ def generate_layerwise_lrs(
elif spacing.lower() in ["log", "geomspace"]:
lrs = np.geomspace(min_lr, max_lr, len(list(self.model.named_parameters())))
params2optimize = [
{"params": params, "learning rate": lr}
{"params": params, "lr": lr}
for (_, params), lr in zip(self.model.named_parameters(), lrs)
]

Expand Down Expand Up @@ -587,10 +586,14 @@ def model_summary(
if trainable_col:
col_names = ["num_params", "output_size", "trainable"]
else:
col_names = ["output_size", "output_size", "num_params"]
col_names = ["input_size", "output_size", "num_params"]

model_summary = summary(
self.model, input_size=input_size, col_names=col_names, device=self.device, **kwargs
self.model,
input_size=input_size,
col_names=col_names,
device=self.device,
**kwargs,
)
print(model_summary)

Expand Down Expand Up @@ -1109,7 +1112,7 @@ def train_core(
best_model_wts = copy.deepcopy(self.model.state_dict())

if phase.lower() in valid_phase_names:
if epoch % tmp_file_save_freq == 0:
if tmp_file_save_freq and epoch % tmp_file_save_freq == 0:
tmp_str = f'[INFO] Checkpoint file saved to "{self.tmp_save_filename}".' # noqa
print(
self._print_colors["lgrey"]
Expand Down Expand Up @@ -1149,7 +1152,7 @@ def _get_logits(out):
try:
out = out.logits
except AttributeError as err:
raise AttributeError(err.message)
raise AttributeError(str(err))
return out

def _gen_epoch_msg(self, phase: str, epoch_msg: str) -> str:
Expand Down Expand Up @@ -1236,8 +1239,9 @@ def calculate_add_metrics(
y_score = np.array(y_score)

for average in [None, "micro", "macro", "weighted"]:
labels = list(range(y_score.shape[1])) if average is None else None
precision, recall, fscore, support = precision_recall_fscore_support(
y_true, y_pred, average=average
y_true, y_pred, average=average, labels=labels
)

if average is None:
Expand Down
Loading
Loading