Skip to content

Conversation

fkiraly
Copy link
Collaborator

@fkiraly fkiraly commented Aug 27, 2025

This adds a check_estimator utility for checking new estimators against unified API contract, and test registry utilities.

Mostly based on the sktime utilities of the same names, changes are predominantly import path changes.

@fkiraly fkiraly requested a review from benHeid as a code owner August 27, 2025 06:48
@fkiraly fkiraly added the enhancement New feature or request label Aug 27, 2025
@fkiraly
Copy link
Collaborator Author

fkiraly commented Aug 27, 2025

@phoeenniixx, the templated check_estimator is unable to retrieve any tests, and I do not understand why.

It would be appreciated if you could review, perhaps you overrode something?

if obj was not a descendant of BaseObject or BaseEstimator, returns empty list
"""
if hasattr(obj, "_pkg"):
obj = obj._pkg
Copy link
Member

@phoeenniixx phoeenniixx Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obj._pkg is a method, so it should be obj._pkg()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you change it to be a classmethod?

I originally designed it as classproperty, so this call would be valid.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this is actually another mistake. It is a classattribute, but MyClass.pkg should be called, not _pkg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No i didnt, i just looked at the DeepAR implementation, there it was already classmethod?

@classmethod
def _pkg(cls):
"""Package containing the model."""
from pytorch_forecasting.models.deepar._deepar_pkg import DeepAR_pkg
return DeepAR_pkg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You made it classmethod in #1888, ig

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know, I understood what is going on - pkg is a classattribute that is the public interface, and it calls the private extension locus _pkg which is a classmethod, like fit and _fit.

testclass_dict = get_test_class_registry()

try:
obj_scitypes = obj.get_tag("object_type")
Copy link
Member

@phoeenniixx phoeenniixx Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its obj.get_class_tag() in ptf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_tag also collects class tags, but good spot

Copy link
Collaborator Author

@fkiraly fkiraly Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably it!

I keep saying, never use try/except for condition checking because it is bad style, instead check for the condition explicitly. Otherwise it is hard to diagnose if it fails to behave as expected - e.g., if it fails due to a genuine exception rather than the condition being false.

(try/except is fine for error handling, this is about abusing it for condition checking)

And now I make the same mistake. I should have listened to myself :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually did the same to find this out :)
just removed the try-except block... Thanks to your teachings

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, all I am saying, it is worth listening to them oneself too.

Comment on lines 116 to 123
test_cls_results = test_cls().run_tests(
estimator=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
verbose=verbose and raise_exceptions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also looking at QuickTester. run_tests(), it has following __init__:

obj,
raise_exceptions=False,
tests_to_run=None,
fixtures_to_run=None,
tests_to_exclude=None,
fixtures_to_exclude=None,

are we overriding it somehow?

Copy link
Collaborator Author

@fkiraly fkiraly Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is also a bug!

scikit-base renames every "estimator" thing to "object" or "obj".

This is clearer for users of scikit-base imo, but causes a discrepancy that is prone to bugs like this.
(so I was originally against, even if "estimator" is a less clear term)

We had long discussions with other core devs on whether we should rename this way, in the end it was decided that we do.

@phoeenniixx
Copy link
Member

I made the above changes to a branch, after pulling this PR locally, I see the check_estimator is able to collect tests, but it is collecting ALL the tests, for all the models, maybe because test_clss_for_est in check_estimator gets [<class 'pytorch_forecasting.tests.test_all_estimators.TestAllPtForecasters'>] and not the actual model class?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Aug 28, 2025

ok, I am now at the point that you had also arrived at - all tests are collected, instead of only those for the estimator

@fkiraly fkiraly requested a review from phoeenniixx September 7, 2025 12:23
@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

@phoeenniixx, I think it works now

@phoeenniixx
Copy link
Member

phoeenniixx commented Sep 7, 2025

I just pulled the branch locally, when I try :

check_estimator(my_model, raise_exceptions=True)

it prints output of all the tests, but I think in sktime, we only get the raised exception? Can we do something similar here? I mean check_estimator just returns the raise exception and "ALL TESTS PASSED!", if none is raised? This will keep the console clean and the user will get only the problematic fixtures, rather than all the fixtures?

Because if there are like 50 test-fixtures to be tested, then it would be hard for the user to actually find which one is failing as the console also shows the passing ones and he'd have to move across the whole console to actually find the failing param...

What do you think?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

Anything the tests print gets printed by default, if verbose=True. If you put verbose=False, then nothing gets printed.

Or would you prefer that only the check_estimator printout gets printed, and nothing from within the tests?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

I have now changed the verbose argument to be an integer - at default, nothing from within the tests is printed, but it can be configured to print all by changing it to the integer 2.

@phoeenniixx
Copy link
Member

phoeenniixx commented Sep 7, 2025

I have now changed the verbose argument to be an integer - at default, nothing from within the tests is printed, but it can be configured to print all by changing it to the integer 2.

Yes that makes more sense! Thanks!
Though I have few questions/suggestions:

  • I have a doubt about verbose - what is the difference between different values of verbose? When I try verbose=0 I see something like this:

    image

    But I think verbose=0 should not print all this? It should be at verbose=1 or verbose=2. I think the "amount" of info printed should increase as value increase

    • Only error (if any) for verbose=0 and the dict that is printed at last.
    • The info printed in above image (for each fixture) for verbose=1.
    • The whole output of the tests (along with the training verbose output) for each fixture in verbose=2.
  • Should we add one more param supress_warnings which is False by default. Eg, right now, I was also testing it for v2 and as we know we having warnings at every layer - D1, D2, model etc, and this would be noise for someone who is debugging, so it would be good if we could supress them. Or does it already exist in skbase ?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

that is strange, I am suppressing stdout - line 246, via StdoutMute - why is it still being printed? Is it not coming from stdout?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

I have now added code that suppresses stderr as well. Do you still see it?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

I just tried, it is still printing - this is pretty aggressive. Where is it printing to, and why does it get displayed?

@phoeenniixx
Copy link
Member

I think I found the issue - its lightning logging. Actually Trainer's logging is not supressed by StdoutMute, you'll have to specifically add these agrs to the Trainer:

trainer = pl.Trainer(
        ...,
        enable_progress_bar=False,  # Disables the progress bar
        enable_model_summary=False,  # Disables the model summary printout

and it would still print some logs like:

Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

And to supress this, you'll have to use this statement before trainer:

    import logging
    logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)

@phoeenniixx
Copy link
Member

I think we'll have to edit _integration here for this to actually work :)
This I think makes the things a bit more complex then they are meant to be?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

it would be better if we could catch the printout directly at where it is printed to - otherwise the machinery is too coupled to specifics of lightning

@phoeenniixx
Copy link
Member

So, maybe we should do something like this:

  • We add the params enable_progress_bar and enable_model_summary to false by default in _integration, to be made True only if the user says so.
  • We add the logging.getLogger("lightning.pytorch").setLevel(logging.WARNING) at all times as idts we need in tests the info what we are using - its always cpu..?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

That will prevent the tests from printing, yes - I think that is a good solution for now.

The bigger problem is not resolved though, how to programmatically suppress lightning output.

@phoeenniixx
Copy link
Member

phoeenniixx commented Sep 7, 2025

I asked AI the same question and I think it solved it for us :)
Create a special class:

import logging
from contextlib import ContextDecorator

class LoggerMute(ContextDecorator):
    """Context manager to temporarily mute specified loggers."""
    def __init__(self, logger_names, active=True):
        self.logger_names = logger_names
        self.original_levels = {}
        self.active = active

    def __enter__(self):
        if self.active:
            for name in self.logger_names:
                logger = logging.getLogger(name)
                self.original_levels[name] = logger.level
                logger.setLevel(logging.CRITICAL + 1) # Mute all levels
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.active:
            for name, level in self.original_levels.items():
                logger = logging.getLogger(name)
                logger.setLevel(level)

and use it in run_tests

                # inside loop B
                loggers_to_mute = ["lightning.pytorch", "pytorch_lightning"]
                try:
                    # Add the LoggerMute to your with statement!
                    with StderrMute(active=verbose < 2), \
                            StdoutMute(active=verbose < 2), \
                            LoggerMute(loggers_to_mute, active=verbose < 2):

                        test_fun(**deepcopy(args))

                    results[key] = "PASSED"
                    print_if_verbose("PASSED")

@phoeenniixx
Copy link
Member

This is because trainer doesnot print, it "logs". I think that is why stdoutmute was not working..?

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 7, 2025

but logging is not the same as printing. Printing happens on the console - so where does the printout come from? The python logging logger, for instance, does not automatically print.

@phoeenniixx
Copy link
Member

I tried reading the Trainer from lightning, the code flow is quite confusing but from what I am able to understand is this:
There is a class call ModelSummary, which gets the model summary and logs the model summary. See here.

And I tried jsut shutting the logging, and then this is not printed on the console. So, my suspicion is if you just suppress the logging, the info is never collected and it is not printed. From what I have been able to understand, sometimes people write to different streams then what stdoutmute can supress and that is why it is printed always. And here, the log is used to save the info (see here) and then this info is somehow used to get the final model summary and is written to some different stream than what std can handle. So, I think we could try just supressing the logging here?

    import logging
    logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)

@fkiraly
Copy link
Collaborator Author

fkiraly commented Sep 11, 2025

how exactly would a context manager look like that shuts down the lightning printouts then? Is it simply setting the logger to warning?

fkiraly added a commit to sktime/skbase that referenced this pull request Sep 18, 2025
…fixture generation handling (#446)

This PR improves the functionality of `QuickTester`:

* improved verbosity levels, `verbose` can now be an integer that allows
granular control on whether printout from the tests are displayed or not
* customizable fixture generation from the input object `obj` - in case
this is different from the class/instance mechanism, e.g., using a
package mechanism, see here for an example:
sktime/pytorch-forecasting#1954
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants