Skip to content

Commit 3b0d715

Browse files
authored
feat(medcat-den): CU-869an5f00 Add remote api (#163)
* CU-869an5f00: Update backend registration tests - use mocking * CU-869an5f00: Add new optional API for den to fine tune models * CU-869an5f00: Add eval API * CU-869an5f00: Add option to disallow local fine-tune and/or push of fine-tuned models * CU-869an5f00: Allow a DenWrapper to have a den config alongside it * CU-869an5f00: Disallow local fine tune with remote dens if/when applicable * CU-869an5f00: Separate local and remote exceptions * CU-869an5f00: Add a few tests for disallowing push * CU-869an5f00: Add another simple test * CU-869an5f00: Add wrapper for training to control disallowing supervised training * CU-869an5f00: Add a few simple tests for disallowing local training with a remote den * CU-869an5f00: Add new environmental variables to README * CU-869an5f00: Add a small note in code regarding defaults * CU-869an5f00: Propagate new options to get_default_den * CU-869an5f00: Add small comment regarding pushing base models * CU-869an5f00: Add more info to failure to load a registered den type * CU-869an5f00: Fix mocked den for python 3.12. That is, needed to make sure it is explicitly specced for the protocol for the instanceof check to pass
1 parent bd9fe38 commit 3b0d715

File tree

8 files changed

+343
-52
lines changed

8 files changed

+343
-52
lines changed

medcat-den/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,7 @@ However, there's a set of environmental variables that can be set in order to cu
133133
| MEDCAT_DEN_LOCAL_CACHE_EXPIRATION_TIME | int | The expriation time for local cache (in seconds) | The default is 10 days |
134134
| MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE | int | The maximum size of the cache in bytes | The default is 100 GB |
135135
| MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY | str | The eviction policy for the local cache | The default is LRU |
136+
| MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED | bool | Whether to allow locallly fine tuned model to be pushed to remote dens | Defaults to False |
137+
| MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE | bool | Whether to allow local fine tuning for remote dens | Defaults to False |
136138

137139
When creating a den, the resolver will use the explicitly passed values first, and if none are provided, it will default to the ones defined in the environmental variables.

medcat-den/src/medcat_den/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class LocalDenConfig(DenConfig):
1717
class RemoteDenConfig(DenConfig):
1818
host: str
1919
credentials: dict
20+
allow_local_fine_tune: bool
21+
allow_push_fine_tuned: bool
2022

2123

2224
class LocalCacheConfig(BaseModel):

medcat-den/src/medcat_den/den.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Protocol, Optional, runtime_checkable
1+
from typing import Protocol, Optional, runtime_checkable, Union
22

33
from medcat.cat import CAT
4+
from medcat.data.mctexport import MedCATTrainerExport
45

56
from medcat_den.base import ModelInfo
67
from medcat_den.wrappers import CATWrapper
@@ -102,6 +103,55 @@ def delete_model(self, model_info: ModelInfo,
102103
"""
103104
pass
104105

106+
def finetune_model(self, model_info: ModelInfo,
107+
data: Union[list[str], MedCATTrainerExport]
108+
) -> ModelInfo:
109+
"""Finetune the model on the remote den.
110+
111+
This is an optional API that is (generally) only available
112+
for remote dens. The idea is that the data is sent to the remote
113+
den and the finetuning is done on the remote.
114+
115+
If raw data is given, unless already present remotely, it will be
116+
uploaded to the remote den.
117+
118+
Args:
119+
model_info (ModelInfo): The model info
120+
data (Union[list[str], MedCATTrainerExport]): The list of project
121+
ids (already on remote) or the trainer export to train on.
122+
123+
Returns:
124+
ModelInfo: The resulting model.
125+
126+
Raises:
127+
UnsupportedAPIException: If the den does not support this API.
128+
"""
129+
130+
def evaluate_model(self, model_info: ModelInfo,
131+
data: Union[list[str], MedCATTrainerExport]) -> dict:
132+
"""Evaluate model on remote den.
133+
134+
This is an optional API that is (generally) only available
135+
for remote dens. The idea is that the data is sent to the remote
136+
den and the metrics are gathered on the remote.
137+
138+
If raw data is given, unless already present remotely, it will be
139+
uploaded to the remote den.
140+
141+
Args:
142+
model_info (ModelInfo): The model info.
143+
data (Union[list[str], MedCATTrainerExport]): The list of project
144+
ids (already on remote) or the trainer export to train on.
145+
146+
Returns:
147+
dict: The resulting metrics.
148+
"""
149+
pass
150+
151+
152+
class UnsupportedAPIException(ValueError):
153+
pass
154+
105155

106156
def get_default_den(
107157
type_: Optional[DenType] = None,
@@ -112,6 +162,8 @@ def get_default_den(
112162
expiration_time: Optional[int] = None,
113163
max_size: Optional[int] = None,
114164
eviction_policy: Optional[str] = None,
165+
remote_allow_local_fine_tune: Optional[str] = None,
166+
remote_allow_push_fine_tuned: Optional[str] = None,
115167
) -> Den:
116168
"""Get the default den.
117169
@@ -137,14 +189,19 @@ def get_default_den(
137189
Policies avialable: LRU (`least-recently-used`),
138190
LRS (`least-recently-stored`), LFU (`least-frequently-used`),
139191
and `none` (disables evictions).
192+
remote_allow_local_fine_tune (Optional[str]): Whether to allow local
193+
fine tuning of remote models.
194+
remote_allow_push_fine_tuned (Optional[str]): Whether to allow pushing
195+
of locally fine-tuned models to the remote
140196
141197
Returns:
142198
Den: The resolved den.
143199
"""
144200
# NOTE: doing dynamic import to avoid circular imports
145201
from medcat_den.resolver import resolve
146202
return resolve(type_, location, host, credentials, local_cache_path,
147-
expiration_time, max_size, eviction_policy)
203+
expiration_time, max_size, eviction_policy,
204+
remote_allow_local_fine_tune, remote_allow_push_fine_tuned)
148205

149206

150207
def get_default_user_local_den(

medcat-den/src/medcat_den/den_impl/file_den.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast, Any
1+
from typing import Optional, cast, Any, Union
22

33
import json
44
from datetime import datetime
@@ -7,8 +7,10 @@
77
import shutil
88

99
from medcat.cat import CAT
10+
from medcat.data.mctexport import MedCATTrainerExport
1011

11-
from medcat_den.den import Den, DuplicateModelException
12+
from medcat_den.den import (
13+
Den, DuplicateModelException, UnsupportedAPIException)
1214
from medcat_den.backend import DenType
1315
from medcat_den.base import ModelInfo
1416
from medcat_den.wrappers import CATWrapper
@@ -162,7 +164,8 @@ def fetch_model(self, model_info: ModelInfo) -> CATWrapper:
162164
model_path = self._get_model_zip_path(model_info)
163165
return cast(
164166
CATWrapper,
165-
CATWrapper.load_model_pack(model_path, model_info=model_info))
167+
CATWrapper.load_model_pack(model_path, model_info=model_info,
168+
den_cnf=self._cnf))
166169

167170
def push_model(self, cat: CAT, description: str) -> None:
168171
if isinstance(cat, CATWrapper):
@@ -220,3 +223,17 @@ def delete_model(self, model_info: ModelInfo,
220223
folder_path = zip_path.removesuffix(".zip")
221224
if os.path.exists(folder_path):
222225
shutil.rmtree(folder_path)
226+
227+
def finetune_model(self, model_info: ModelInfo,
228+
data: Union[list[str], MedCATTrainerExport]):
229+
raise UnsupportedAPIException(
230+
"Local den does not support finetuning on the den. "
231+
"Use a remote den instead or perform training locally."
232+
)
233+
234+
def evaluate_model(self, model_info: ModelInfo,
235+
data: Union[list[str], MedCATTrainerExport]) -> dict:
236+
raise UnsupportedAPIException(
237+
"Local den does not support evaluation on the den. "
238+
"Use a remote den instead or perform evaluation locally."
239+
)

medcat-den/src/medcat_den/resolver/resolver.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE = "MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE"
3939
MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY = (
4040
"MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY")
41+
MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE = (
42+
"MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE")
43+
MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED = (
44+
"MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED")
45+
46+
ALLOW_OPTION_LOWERCASE = ("true", "yes", "1", "y")
4147

4248

4349
def is_writable(path: str, propgate: bool = True) -> bool:
@@ -52,7 +58,10 @@ def _init_den_cnf(
5258
type_: Optional[DenType] = None,
5359
location: Optional[str] = None,
5460
host: Optional[str] = None,
55-
credentials: Optional[dict] = None,) -> DenConfig:
61+
credentials: Optional[dict] = None,
62+
remote_allow_local_fine_tune: Optional[str] = None,
63+
remote_allow_push_fine_tuned: Optional[str] = None,
64+
) -> DenConfig:
5665
# Priority: args > env > defaults
5766
type_in = (
5867
type_
@@ -82,13 +91,27 @@ def _init_den_cnf(
8291
den_cnf = LocalDenConfig(type=type_final,
8392
location=location_final)
8493
else:
94+
host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST)
8595
if not host:
8696
raise ValueError("Need to specify a host for remote den")
8797
if not credentials:
8898
raise ValueError("Need to specify credentials for remote den")
89-
den_cnf = RemoteDenConfig(type=type_final,
90-
host=host,
91-
credentials=credentials)
99+
# NOTE: these will default to False when nothing is specified
100+
# because "None" is not in ALLOW_OPTION_LOWERCASE
101+
allow_local_fine_tune = str(
102+
remote_allow_local_fine_tune or
103+
os.getenv(MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE)
104+
).lower() in ALLOW_OPTION_LOWERCASE
105+
allow_push_fine_tuned = str(
106+
remote_allow_push_fine_tuned or
107+
os.getenv(MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED)
108+
).lower() in ALLOW_OPTION_LOWERCASE
109+
den_cnf = RemoteDenConfig(
110+
type=type_final,
111+
host=host,
112+
credentials=credentials,
113+
allow_local_fine_tune=allow_local_fine_tune,
114+
allow_push_fine_tuned=allow_push_fine_tuned)
92115
return den_cnf
93116

94117

@@ -101,8 +124,12 @@ def resolve(
101124
expiration_time: Optional[int] = None,
102125
max_size: Optional[int] = None,
103126
eviction_policy: Optional[str] = None,
127+
remote_allow_local_fine_tune: Optional[str] = None,
128+
remote_allow_push_fine_tuned: Optional[str] = None,
104129
) -> Den:
105-
den_cnf = _init_den_cnf(type_, location, host, credentials)
130+
den_cnf = _init_den_cnf(type_, location, host, credentials,
131+
remote_allow_local_fine_tune,
132+
remote_allow_push_fine_tuned)
106133
den = resolve_from_config(den_cnf)
107134
lc_cnf = _init_lc_cnf(
108135
local_cache_path, expiration_time, max_size, eviction_policy)
@@ -126,19 +153,13 @@ def _resolve_local(config: LocalDenConfig) -> LocalFileDen:
126153
def resolve_from_config(config: DenConfig) -> Den:
127154
if isinstance(config, LocalDenConfig):
128155
return _resolve_local(config)
129-
# TODO: support remote (e)
130-
# elif type_final == DenType.MEDCATTERY:
131-
# host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST)
132-
# if host is None:
133-
# raise ValueError("Remote DEN requires a host address")
134-
# # later you’d plug in MedcatteryRemoteDen, MLFlowDen, etc.
135-
# return MedCATteryDen(host=host, credentials=credentials)
136156
elif has_registered_remote_den(config.type):
137157
den_cls = get_registered_remote_den(config.type)
138158
den = den_cls(cnf=config)
139159
if not isinstance(den, Den):
140160
raise ValueError(
141-
f"Registered den class for {config.type} is not a Den")
161+
f"Registered den class for {config.type} is not a Den. "
162+
f"Got {type(den)}: {den}")
142163
return den
143164
else:
144165
raise ValueError(

medcat-den/src/medcat_den/wrappers.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from medcat.cat import CAT
44
from medcat.utils.defaults import DEFAULT_PACK_NAME
55
from medcat.storage.serialisers import AvailableSerialisers
6+
from medcat.trainer import Trainer
7+
from medcat.data.mctexport import MedCATTrainerExport
68

79
from medcat_den.base import ModelInfo
10+
from medcat_den.config import DenConfig, RemoteDenConfig
811

912

1013
class CATWrapper(CAT):
@@ -20,6 +23,7 @@ class CATWrapper(CAT):
2023
"""
2124

2225
_model_info: ModelInfo
26+
_den_cnf: DenConfig
2327

2428
def save_model_pack(
2529
self, target_folder: str, pack_name: str = DEFAULT_PACK_NAME,
@@ -54,19 +58,36 @@ def save_model_pack(
5458
if not force_save_local and not is_injected_for_save():
5559
raise CannotSaveOnDiskException(
5660
f"Cannot save model on disk: {CATWrapper.__doc__}")
61+
if (is_injected_for_save() and isinstance(
62+
self._den_cnf, RemoteDenConfig) and
63+
not self._den_cnf.allow_push_fine_tuned):
64+
# NOTE: should there be a check whether this is a base model?
65+
raise CannotSendToRemoteException(
66+
"Cannot save fine-tuned model onto a remote den."
67+
"In order to make full use of the remote den capabilities, "
68+
"use the den API to fine tune a model directly on the den. "
69+
"See `Den.finetune_model` for details or set the config "
70+
"option of `allow_push_fine_tuned` to True"
71+
)
5772
return super().save_model_pack(
5873
target_folder, pack_name, serialiser_type, make_archive,
5974
only_archive, add_hash_to_pack_name, change_description)
6075

76+
@property
77+
def trainer(self) -> Trainer:
78+
tr = super().trainer
79+
return WrappedTrainer(self._den_cnf, tr)
80+
6181
@classmethod
6282
def load_model_pack(cls, model_pack_path: str,
6383
config_dict: Optional[dict] = None,
6484
addon_config_dict: Optional[dict[str, dict]] = None,
6585
model_info: Optional[ModelInfo] = None,
86+
den_cnf: Optional[DenConfig] = None,
6687
) -> 'CAT':
6788
"""Load the model pack from file.
6889
69-
This also
90+
This may also disallow model load from disk in certain secnarios.
7091
7192
Args:
7293
model_pack_path (str): The model pack path.
@@ -80,6 +101,9 @@ def load_model_pack(cls, model_pack_path: str,
80101
model_inof (Optional[ModelInfo]): The base model info based on
81102
which the model was originally fetched. Should not be
82103
left None.
104+
den_cnf: (Optional[DenConfig]): The config for the den being
105+
used. Should not be left None.
106+
83107
84108
Raises:
85109
ValueError: If the saved data does not represent a model pack.
@@ -95,10 +119,45 @@ def load_model_pack(cls, model_pack_path: str,
95119
cat.__class__ = CATWrapper
96120
if model_info is None:
97121
raise CannotWrapModel("Model info must be provided")
122+
if den_cnf is None:
123+
raise CannotWrapModel("den_cnf must be provided")
98124
cat._model_info = model_info
125+
cat._den_cnf = den_cnf
99126
return cat
100127

101128

129+
class WrappedTrainer(Trainer):
130+
131+
def __init__(self, den_cnf: DenConfig, delegate: Trainer):
132+
super().__init__(delegate.cdb, delegate.caller, delegate._pipeline)
133+
self._den_cnf = den_cnf
134+
135+
def train_supervised_raw(
136+
self, data: MedCATTrainerExport, reset_cui_count: bool = False,
137+
nepochs: int = 1, print_stats: int = 0, use_filters: bool = False,
138+
terminate_last: bool = False, use_overlaps: bool = False,
139+
use_cui_doc_limit: bool = False, test_size: float = 0,
140+
devalue_others: bool = False, use_groups: bool = False,
141+
never_terminate: bool = False,
142+
train_from_false_positives: bool = False,
143+
extra_cui_filter: Optional[set[str]] = None,
144+
disable_progress: bool = False, train_addons: bool = False):
145+
if (isinstance(self._den_cnf, RemoteDenConfig) and
146+
not self._den_cnf.allow_local_fine_tune):
147+
raise NotAllowedToFineTuneLocallyException(
148+
"You are not allowed to fine-tune remote models locally. "
149+
"Please use the `Den.finetune_model` method directly to "
150+
"fine tune on the remote den, or if required, set the "
151+
"`allow_local_fine_tune` config value to `True`."
152+
)
153+
return super().train_supervised_raw(
154+
data, reset_cui_count, nepochs, print_stats, use_filters,
155+
terminate_last, use_overlaps, use_cui_doc_limit, test_size,
156+
devalue_others, use_groups, never_terminate,
157+
train_from_false_positives, extra_cui_filter, disable_progress,
158+
train_addons)
159+
160+
102161
class CannotWrapModel(ValueError):
103162

104163
def __init__(self, *args):
@@ -109,3 +168,15 @@ class CannotSaveOnDiskException(ValueError):
109168

110169
def __init__(self, *args):
111170
super().__init__(*args)
171+
172+
173+
class CannotSendToRemoteException(ValueError):
174+
175+
def __call__(self, *args):
176+
return super().__call__(*args)
177+
178+
179+
class NotAllowedToFineTuneLocallyException(ValueError):
180+
181+
def __call__(self, *args):
182+
return super().__call__(*args)

0 commit comments

Comments
 (0)