Skip to content

Conversation

@carterprince
Copy link

@carterprince carterprince commented Dec 19, 2025

Issue

N/A - Parity with Classifier (see merged PR #550).

Motivation and Context

Adds predict_raw_logits to TabPFNRegressor to allow analysis of individual estimator predictions (e.g., for uncertainty estimation).

Implementation:

  • Refactored predict to use a shared _raw_predict helper.
  • Implemented predict_raw_logits returning stacked log-probabilities.
  • Handles Bar Distribution translation internally to ensure logits from different estimators (which may have shifted borders) are aligned to the global borders before stacking.

Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

Added TabPFNRegressor.predict_raw_logits(X) -> np.ndarray returning shape (n_estimators, n_samples, n_bins).


How Has This Been Tested?

Added test_predict_raw_logits in tests/test_regressor_interface.py.

  • Verifies output shape and float32 dtype.
  • Verifies handling of -inf (log-zero probability) while forbidding +inf/NaN.
  • Verifies ensemble diversity (estimators return different values).

Ran locally: pytest tests/test_regressor_interface.py::test_predict_raw_logits


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A entry has been added to CHANGELOG.md (if relevant for users).
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

@carterprince carterprince requested a review from a team as a code owner December 19, 2025 04:14
@carterprince carterprince requested review from klemens-floege and removed request for a team December 19, 2025 04:14
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant
Copy link

CLAassistant commented Dec 19, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new predict_raw_logits method to the TabPFNRegressor, achieving parity with the classifier. This is accomplished by refactoring the existing predict method to use a new internal _raw_predict helper, which is a clean and well-structured implementation. The changes are well-tested with a new test case that verifies the output shape, data type, and diversity of predictions from different estimators. My review includes a couple of minor suggestions to improve documentation clarity and ensure correct telemetry tracking. Overall, this is a solid contribution.

Handles input validation, preprocessing, and the forward pass.
Returns the stack of aligned probabilities from all estimators
(shape: [n_estimators, n_samples, n_borders]) mapped to the global
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for _raw_predict states that the returned tensor has a shape of [n_estimators, n_samples, n_borders]. However, the last dimension corresponds to the number of bins/buckets, which is len(borders) - 1. To avoid confusion, it would be clearer to use n_bins instead of n_borders.

Suggested change
(shape: [n_estimators, n_samples, n_borders]) mapped to the global
(shape: [n_estimators, n_samples, n_bins]) mapped to the global

return logit_to_output(output_type=output_type)

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @track_model_call decorator for predict_raw_logits appears to have a copy-pasted model_method from the predict method. For accurate telemetry, this should be updated to "predict_raw_logits".

Suggested change
@track_model_call(model_method="predict", param_names=["X"])
@track_model_call(model_method="predict_raw_logits", param_names=["X"])

Copy link
Contributor

@klemens-floege klemens-floege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @carterprince thank you for opening a PR on our repo; I think it would be cool to standardize the Classifier/Regressor. You highlighted uncertainty estimates, please note that you can use the Regressor's binning distribution to get some uncertainty measures.

Please click the re-review button once all tests are passing and then I will give it a thorough review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants