-
Notifications
You must be signed in to change notification settings - Fork 537
Add predict_raw_logits to TabPFNRegressor #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add predict_raw_logits to TabPFNRegressor #688
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new predict_raw_logits method to the TabPFNRegressor, achieving parity with the classifier. This is accomplished by refactoring the existing predict method to use a new internal _raw_predict helper, which is a clean and well-structured implementation. The changes are well-tested with a new test case that verifies the output shape, data type, and diversity of predictions from different estimators. My review includes a couple of minor suggestions to improve documentation clarity and ensure correct telemetry tracking. Overall, this is a solid contribution.
| Handles input validation, preprocessing, and the forward pass. | ||
| Returns the stack of aligned probabilities from all estimators | ||
| (shape: [n_estimators, n_samples, n_borders]) mapped to the global |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for _raw_predict states that the returned tensor has a shape of [n_estimators, n_samples, n_borders]. However, the last dimension corresponds to the number of bins/buckets, which is len(borders) - 1. To avoid confusion, it would be clearer to use n_bins instead of n_borders.
| (shape: [n_estimators, n_samples, n_borders]) mapped to the global | |
| (shape: [n_estimators, n_samples, n_bins]) mapped to the global |
| return logit_to_output(output_type=output_type) | ||
|
|
||
| @config_context(transform_output="default") | ||
| @track_model_call(model_method="predict", param_names=["X"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The @track_model_call decorator for predict_raw_logits appears to have a copy-pasted model_method from the predict method. For accurate telemetry, this should be updated to "predict_raw_logits".
| @track_model_call(model_method="predict", param_names=["X"]) | |
| @track_model_call(model_method="predict_raw_logits", param_names=["X"]) |
klemens-floege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @carterprince thank you for opening a PR on our repo; I think it would be cool to standardize the Classifier/Regressor. You highlighted uncertainty estimates, please note that you can use the Regressor's binning distribution to get some uncertainty measures.
Please click the re-review button once all tests are passing and then I will give it a thorough review
Issue
N/A - Parity with Classifier (see merged PR #550).
Motivation and Context
Adds
predict_raw_logitstoTabPFNRegressorto allow analysis of individual estimator predictions (e.g., for uncertainty estimation).Implementation:
predictto use a shared_raw_predicthelper.predict_raw_logitsreturning stacked log-probabilities.Public API Changes
Added
TabPFNRegressor.predict_raw_logits(X) -> np.ndarrayreturning shape(n_estimators, n_samples, n_bins).How Has This Been Tested?
Added
test_predict_raw_logitsintests/test_regressor_interface.py.float32dtype.-inf(log-zero probability) while forbidding+inf/NaN.Ran locally:
pytest tests/test_regressor_interface.py::test_predict_raw_logitsChecklist
CHANGELOG.md(if relevant for users).