-
Notifications
You must be signed in to change notification settings - Fork 705
[ENH] check_estimator
utility for checking new estimators against unified API contract
#1954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@phoeenniixx, the templated It would be appreciated if you could review, perhaps you overrode something? |
if obj was not a descendant of BaseObject or BaseEstimator, returns empty list | ||
""" | ||
if hasattr(obj, "_pkg"): | ||
obj = obj._pkg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obj._pkg
is a method, so it should be obj._pkg()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you change it to be a classmethod
?
I originally designed it as classproperty
, so this call would be valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, this is actually another mistake. It is a classattribute
, but MyClass.pkg
should be called, not _pkg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No i didnt, i just looked at the DeepAR
implementation, there it was already classmethod
?
pytorch-forecasting/pytorch_forecasting/models/deepar/_deepar.py
Lines 40 to 45 in 3821c0b
@classmethod | |
def _pkg(cls): | |
"""Package containing the model.""" | |
from pytorch_forecasting.models.deepar._deepar_pkg import DeepAR_pkg | |
return DeepAR_pkg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You made it classmethod
in #1888, ig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know, I understood what is going on - pkg
is a classattribute
that is the public interface, and it calls the private extension locus _pkg
which is a classmethod
, like fit
and _fit
.
testclass_dict = get_test_class_registry() | ||
|
||
try: | ||
obj_scitypes = obj.get_tag("object_type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its obj.get_class_tag()
in ptf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_tag
also collects class tags, but good spot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably it!
I keep saying, never use try/except for condition checking because it is bad style, instead check for the condition explicitly. Otherwise it is hard to diagnose if it fails to behave as expected - e.g., if it fails due to a genuine exception rather than the condition being false.
(try/except is fine for error handling, this is about abusing it for condition checking)
And now I make the same mistake. I should have listened to myself :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually did the same to find this out :)
just removed the try-except
block... Thanks to your teachings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, all I am saying, it is worth listening to them oneself too.
test_cls_results = test_cls().run_tests( | ||
estimator=estimator, | ||
raise_exceptions=raise_exceptions, | ||
tests_to_run=tests_to_run, | ||
fixtures_to_run=fixtures_to_run, | ||
tests_to_exclude=tests_to_exclude, | ||
fixtures_to_exclude=fixtures_to_exclude, | ||
verbose=verbose and raise_exceptions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also looking at QuickTester. run_tests()
, it has following __init__
:
obj,
raise_exceptions=False,
tests_to_run=None,
fixtures_to_run=None,
tests_to_exclude=None,
fixtures_to_exclude=None,
are we overriding it somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is also a bug!
scikit-base
renames every "estimator" thing to "object" or "obj".
This is clearer for users of scikit-base
imo, but causes a discrepancy that is prone to bugs like this.
(so I was originally against, even if "estimator" is a less clear term)
We had long discussions with other core devs on whether we should rename this way, in the end it was decided that we do.
I made the above changes to a branch, after pulling this PR locally, I see the |
ok, I am now at the point that you had also arrived at - all tests are collected, instead of only those for the estimator |
@phoeenniixx, I think it works now |
I just pulled the branch locally, when I try : check_estimator(my_model, raise_exceptions=True) it prints output of all the tests, but I think in Because if there are like 50 test-fixtures to be tested, then it would be hard for the user to actually find which one is failing as the console also shows the passing ones and he'd have to move across the whole console to actually find the failing param... What do you think? |
Anything the tests print gets printed by default, if Or would you prefer that only the |
I have now changed the |
that is strange, I am suppressing stdout - line 246, via |
I have now added code that suppresses stderr as well. Do you still see it? |
I just tried, it is still printing - this is pretty aggressive. Where is it printing to, and why does it get displayed? |
I think I found the issue - its trainer = pl.Trainer(
...,
enable_progress_bar=False, # Disables the progress bar
enable_model_summary=False, # Disables the model summary printout and it would still print some logs like:
And to supress this, you'll have to use this statement before import logging
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING) |
I think we'll have to edit |
it would be better if we could catch the printout directly at where it is printed to - otherwise the machinery is too coupled to specifics of |
So, maybe we should do something like this:
|
That will prevent the tests from printing, yes - I think that is a good solution for now. The bigger problem is not resolved though, how to programmatically suppress lightning output. |
I asked AI the same question and I think it solved it for us :) import logging
from contextlib import ContextDecorator
class LoggerMute(ContextDecorator):
"""Context manager to temporarily mute specified loggers."""
def __init__(self, logger_names, active=True):
self.logger_names = logger_names
self.original_levels = {}
self.active = active
def __enter__(self):
if self.active:
for name in self.logger_names:
logger = logging.getLogger(name)
self.original_levels[name] = logger.level
logger.setLevel(logging.CRITICAL + 1) # Mute all levels
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.active:
for name, level in self.original_levels.items():
logger = logging.getLogger(name)
logger.setLevel(level) and use it in # inside loop B
loggers_to_mute = ["lightning.pytorch", "pytorch_lightning"]
try:
# Add the LoggerMute to your with statement!
with StderrMute(active=verbose < 2), \
StdoutMute(active=verbose < 2), \
LoggerMute(loggers_to_mute, active=verbose < 2):
test_fun(**deepcopy(args))
results[key] = "PASSED"
print_if_verbose("PASSED") |
This is because |
but logging is not the same as printing. Printing happens on the console - so where does the printout come from? The python |
I tried reading the And I tried jsut shutting the logging, and then this is not printed on the console. So, my suspicion is if you just suppress the logging, the info is never collected and it is not printed. From what I have been able to understand, sometimes people write to different streams then what import logging
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING) |
how exactly would a context manager look like that shuts down the |
…fixture generation handling (#446) This PR improves the functionality of `QuickTester`: * improved verbosity levels, `verbose` can now be an integer that allows granular control on whether printout from the tests are displayed or not * customizable fixture generation from the input object `obj` - in case this is different from the class/instance mechanism, e.g., using a package mechanism, see here for an example: sktime/pytorch-forecasting#1954
This adds a
check_estimator
utility for checking new estimators against unified API contract, and test registry utilities.Mostly based on the
sktime
utilities of the same names, changes are predominantly import path changes.