Conversation
F3z11
left a comment
There was a problem hiding this comment.
Hi Matteo! Nice to hear from you :))
First of all, I am glad you appreciated our implementation, it is really important for us. And secondly, but for sure not less important, thank you for your contribution from a software engineering point of view, we really appreciate that.
I reviewed your PR, and you made a lot of improvements from the stability and compliance point of view that make our RDA implementation stronger.
I noticed just a couple of minor issues that you find commented in the related file, but nothing substantial.
Regarding the variable names in the fit method, I proposed you a schema just at the beginning of the function. I’m happy to discuss this further if you have questions.
In case you have any other question feel free to ask.
Thank you again for your contribution, we really appreciate it! 😊
| """ | ||
| return hasattr(self, "_is_fitted") and self._is_fitted | ||
|
|
||
| def fit(self, X: np.ndarray, y: np.ndarray) -> RegularizedDiscriminantAnalysis: |
There was a problem hiding this comment.
About the variable naming in fit, I propose this framework:
- Public Attributes: I think we can start from the public attributes provided by LDA and QDA implementations in sklearn. classes, n_features_in, means, priors, pooled_covariances (we can think about renaming it to just covariance to match sklearn style).
- Local variables: covariances (QDA uses this as raw, but in RDA we use the pooled version, which we make public), class_counts , n_samples (these last two are not used after fit and are not very informative since the priors are rarely known in advance and the number of samples can be obtained by inspecting the data).
- Private attributes: I do not see any attribute needed after fitting the model that should be kept private.
There was a problem hiding this comment.
I'm afraid self._class_counts must remain a private field as it's defined within the fit() method and later on used in _apply_regularization(). We could also opt to pass it as a parameter
There was a problem hiding this comment.
Alright, I reviewed it, and I agree with you. Let's keep it private.
|
I've addressed your reviews. Thanks for the feedback. About the tests which are not currently being run, Also, I suggest adding buttons to the README to highlight that Ruff, MyPy and Pytests actions are passing. Something like P.S. If you're looking for an alternative to mypy for type-checking, check out ty |
|
Yes! They would be very useful, especially because the RDA is a distance-based method, so standardization is recommended as preprocessing, and its integration into a pipeline could be useful. Also, the GridSearch test can be helpful since the RDA has 2 hyperparameters to tune (sorry, I forgot to mention this part of your commit last time, my bad). For the Ruff, MyPy, and Pytests buttons, feel free to add them in the next (and I believe final) commit :)) Thank you for the suggestion about ty, I will definitely look into it; it seems interesting! For now, let's keep mypy for type checking :) Thanks again for your work, I appreciate it! |
|
That should do it :) |
Hey guys,
what’s up?
First of all, great work — I really like your implementation! 🙂
I took the liberty of slightly refactoring the code. In particular, I focused on the following points:
Type hinting and formatting
Using
ruffandmypy, I added some basic type-hinting and formatting rules directly in thepyproject.toml. I feel these make the code more readable and maintainable. I also added GitHub Actions workflows to run these checks automatically on every push/PR.Unit testing
I refactored the unit test you provided to use
pytest.Scikit-learn integration
I noticed that you inherit from
BaseEstimatorandClassifierMixin. Scikit-learn provides a useful guide on how to correctly implement estimators that remain coherent with the rest of the library. You can find it here.Unfortunately, the guide is somewhat outdated and misses a few changes introduced in recent versions (>= 1.6), such as the recommended use of the
validate_dataAPI — which I talk about more here.I’ve implemented a basic testing pipeline to verify the estimator’s consistency using scikit-learn’s
check_estimator. You can find these tests intest_sklearn_compatibility.py.The key test is:
https://github.com/Dr4k3z/RegularizedDiscriminantAnalysis/blob/5010f0122a87cfd2c1a5e306b2781268347afe41/tests/test_sklearn_compatibility.py#L41-L43
which passes in the latest commit of this PR.
In the same file, you’ll also find other tests (
_test_pipeline_usageand_test_grid_search_cv) that are currently not being executed. They come from another project of mine. I think they could be useful because they show how the estimator behaves in a real pipeline, but they would need to be adapted to your workflow.All modifications to
fit()andpredict()were made solely to comply with scikit-learn’s API standards; they do not alter the core logic of your implementation.One point we may want to discuss further is variable naming. Scikit-learn requires that public attributes not be created inside
fit(), so I’ve changed most of them to private. In my opinion, some of these do not need to be class members at all and could simply be local variables. Let me know what you think.I hope you find the changes useful — I really enjoyed working on this!
Have a great day ❤️