-
Notifications
You must be signed in to change notification settings - Fork 17
Futrell2018 SPRT benchmark using GAMs + control predictors #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
eae79e8
219173d
6fba657
9e8cc36
caadb00
e420983
2291eb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from brainscore_language import benchmark_registry | ||
| from .benchmark import Futrell2018GAMPearsonr | ||
|
|
||
| benchmark_registry['Futrell2018-GAM-pearsonr'] = Futrell2018GAMPearsonr |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,132 @@ | ||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||
| from numpy.random import RandomState | ||||||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| import rpy2.robjects as ro | ||||||||||||||||||||||||||||||||||||||
| from rpy2.robjects.packages import importr | ||||||||||||||||||||||||||||||||||||||
| from rpy2.robjects import numpy2ri | ||||||||||||||||||||||||||||||||||||||
| from rpy2.robjects import pandas2ri | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from brainio.assemblies import DataAssembly | ||||||||||||||||||||||||||||||||||||||
| from brainscore_core.benchmarks import BenchmarkBase | ||||||||||||||||||||||||||||||||||||||
| from brainscore_core.metrics import Score, Metric | ||||||||||||||||||||||||||||||||||||||
| from brainscore_language import load_dataset, load_metric | ||||||||||||||||||||||||||||||||||||||
| from brainscore_language.artificial_subject import ArtificialSubject | ||||||||||||||||||||||||||||||||||||||
| from brainscore_language.utils import attach_presentation_meta | ||||||||||||||||||||||||||||||||||||||
| from brainscore_language.utils.ceiling import ceiling_normalize | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class Futrell2018GAMPearsonr(BenchmarkBase): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Evaluate model ability to predict reading times on the natural stories corpus introduced in Futrell et al. 2018. | ||||||||||||||||||||||||||||||||||||||
| Alignment of reading times between model and human subjects is evaluated via a generalized additive | ||||||||||||||||||||||||||||||||||||||
| linear model, incorporating current- and previous-word surprisals along with control properties | ||||||||||||||||||||||||||||||||||||||
| of word length and word frequency. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| FORMULA = "reading_time ~ s(surprisal, bs='cr', k=20) + s(prev_surp, bs='cr', k=20) + " + \ | ||||||||||||||||||||||||||||||||||||||
| "te(freq, len, bs='cr') + te(prev_freq, prev_len, bs='cr')" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||
| self.data = load_dataset('Futrell2018') | ||||||||||||||||||||||||||||||||||||||
| self.metric = load_metric('pearsonr') | ||||||||||||||||||||||||||||||||||||||
| ceiler = SplitHalvesConsistency(num_splits=10, split_coordinate='subject_id', consistency_metric=self.metric) | ||||||||||||||||||||||||||||||||||||||
| ceiling = ceiler(self.data) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Load R dependencies. | ||||||||||||||||||||||||||||||||||||||
| numpy2ri.activate() | ||||||||||||||||||||||||||||||||||||||
| pandas2ri.activate() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| super(Futrell2018GAMPearsonr, self).__init__( | ||||||||||||||||||||||||||||||||||||||
| identifier='Futrell2018-GAM-pearsonr', | ||||||||||||||||||||||||||||||||||||||
| version=1, | ||||||||||||||||||||||||||||||||||||||
| parent='behavior', | ||||||||||||||||||||||||||||||||||||||
| ceiling=ceiling, | ||||||||||||||||||||||||||||||||||||||
| bibtex=self.data.bibtex) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def fit(self, surprisals, data_mask): | ||||||||||||||||||||||||||||||||||||||
| formula = ro.Formula(self.FORMULA) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| data = pd.DataFrame({ | ||||||||||||||||||||||||||||||||||||||
| "surprisal": surprisals, | ||||||||||||||||||||||||||||||||||||||
| "reading_time": self.data[data_mask].mean("subject"), | ||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||
| data["prev_surp"] = data["surprisal"].shift(1) | ||||||||||||||||||||||||||||||||||||||
| data["len"] = self.data[data_mask].word_core.str.len() | ||||||||||||||||||||||||||||||||||||||
| data["prev_len"] = data["len"].shift(1) | ||||||||||||||||||||||||||||||||||||||
| data["freq"] = surprisals # HACK need to look this up. | ||||||||||||||||||||||||||||||||||||||
| data["prev_freq"] = data["freq"].shift(1) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Second round of masking, excluding for which there are nan values (e.g. first word has no defined prev features) | ||||||||||||||||||||||||||||||||||||||
| data_mask = ~data.isna().any(axis=1) | ||||||||||||||||||||||||||||||||||||||
| data = data[data_mask] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # TODO check that columns match formula variable names | ||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| r_mgcv = importr("mgcv") | ||||||||||||||||||||||||||||||||||||||
| model = r_mgcv.gam(formula, data=data) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # TODO held out data | ||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. todo |
||||||||||||||||||||||||||||||||||||||
| predictions = r_mgcv.predict_gam(model, newdata=data, type="response") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return model, predictions, data.reading_time | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __call__(self, candidate: ArtificialSubject) -> Score: | ||||||||||||||||||||||||||||||||||||||
| # run experiment | ||||||||||||||||||||||||||||||||||||||
| candidate.start_behavioral_task(ArtificialSubject.Task.reading_times) | ||||||||||||||||||||||||||||||||||||||
| stimuli = self.data['word'].values | ||||||||||||||||||||||||||||||||||||||
| surprisals = candidate.digest_text(stimuli)['behavior'] | ||||||||||||||||||||||||||||||||||||||
| attach_presentation_meta(surprisals, self.data['presentation']) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # exclude first words | ||||||||||||||||||||||||||||||||||||||
| surprisals = surprisals[surprisals['word_within_sentence_id'] != 1] | ||||||||||||||||||||||||||||||||||||||
| data_mask = self.data['word_within_sentence_id'] != 1 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Fit and evaluate GAM model | ||||||||||||||||||||||||||||||||||||||
| model, predictions, targets = self.fit(surprisals, data_mask) | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+81
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # score | ||||||||||||||||||||||||||||||||||||||
| raw_score = self.metric(predictions, targets) | ||||||||||||||||||||||||||||||||||||||
| score = ceiling_normalize(raw_score, self.ceiling) | ||||||||||||||||||||||||||||||||||||||
| return score | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class SplitHalvesConsistency: | ||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could
benchmarks/futrell2018 plugin? I'm fine with either, slightly leaning towards adding this to the futrell2018 plugin
|
||||||||||||||||||||||||||||||||||||||
| # following | ||||||||||||||||||||||||||||||||||||||
| # https://github.com/brain-score/brain-score/blob/c51b8aa2c94212a9ac56c06c556afad0bb0a3521/brainscore/metrics/ceiling.py#L25-L96 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __init__(self, num_splits: int, split_coordinate: str, consistency_metric: Metric): | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| :param num_splits: how many times to create two halves | ||||||||||||||||||||||||||||||||||||||
| :param split_coordinate: over which coordinate to split the assembly into halves | ||||||||||||||||||||||||||||||||||||||
| :param consistency_metric: which metric to use to compute the consistency of two halves | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| self.num_splits = num_splits | ||||||||||||||||||||||||||||||||||||||
| self.split_coordinate = split_coordinate | ||||||||||||||||||||||||||||||||||||||
| self.consistency_metric = consistency_metric | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def __call__(self, assembly: DataAssembly) -> Score: | ||||||||||||||||||||||||||||||||||||||
| split_dim = np.array(assembly[self.split_coordinate].dims).item() | ||||||||||||||||||||||||||||||||||||||
| split_values = assembly[self.split_coordinate].values | ||||||||||||||||||||||||||||||||||||||
| random_state = RandomState(0) | ||||||||||||||||||||||||||||||||||||||
| consistencies, uncorrected_consistencies = [], [] | ||||||||||||||||||||||||||||||||||||||
| splits = range(self.num_splits) | ||||||||||||||||||||||||||||||||||||||
| for _ in splits: | ||||||||||||||||||||||||||||||||||||||
| half1_values = random_state.choice(split_values, size=len(split_values) // 2, replace=False) | ||||||||||||||||||||||||||||||||||||||
| half2_values = set(split_values) - set(half1_values) # this only works because of `replace=False` above | ||||||||||||||||||||||||||||||||||||||
| half1 = assembly[{split_dim: [value in half1_values for value in split_values]}].mean(split_dim) | ||||||||||||||||||||||||||||||||||||||
| half2 = assembly[{split_dim: [value in half2_values for value in split_values]}].mean(split_dim) | ||||||||||||||||||||||||||||||||||||||
| consistency = self.consistency_metric(half1, half2) | ||||||||||||||||||||||||||||||||||||||
| uncorrected_consistencies.append(consistency) | ||||||||||||||||||||||||||||||||||||||
| # Spearman-Brown correction for sub-sampling | ||||||||||||||||||||||||||||||||||||||
| corrected_consistency = 2 * consistency / (1 + (2 - 1) * consistency) | ||||||||||||||||||||||||||||||||||||||
| consistencies.append(corrected_consistency) | ||||||||||||||||||||||||||||||||||||||
| consistencies = Score(consistencies, coords={'split': splits}, dims=['split']) | ||||||||||||||||||||||||||||||||||||||
| uncorrected_consistencies = Score(uncorrected_consistencies, coords={'split': splits}, dims=['split']) | ||||||||||||||||||||||||||||||||||||||
| average_consistency = consistencies.median('split') | ||||||||||||||||||||||||||||||||||||||
| average_consistency.attrs['raw'] = consistencies | ||||||||||||||||||||||||||||||||||||||
| average_consistency.attrs['uncorrected_consistencies'] = uncorrected_consistencies | ||||||||||||||||||||||||||||||||||||||
| return average_consistency | ||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| channels: | ||
| - r | ||
| dependencies: | ||
| - pip | ||
| - r | ||
| - r-mgcv | ||
| - pip: | ||
| - rpy2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import numpy as np | ||
| from numpy.random import RandomState | ||
| from pytest import approx | ||
|
|
||
| from brainio.assemblies import BehavioralAssembly | ||
| from brainscore_language import load_benchmark | ||
| from brainscore_language.artificial_subject import ArtificialSubject | ||
|
|
||
|
|
||
| class TestBenchmark: | ||
| class DummyModel(ArtificialSubject): | ||
| def __init__(self, reading_times): | ||
| self.reading_times = reading_times | ||
|
|
||
| def digest_text(self, stimuli): | ||
| return {'behavior': BehavioralAssembly( | ||
| self.reading_times, | ||
| coords={'stimulus': ('presentation', stimuli), 'stimulus_id': ('presentation', np.arange(len(stimuli)))}, | ||
| dims=['presentation'])} | ||
|
|
||
| def start_behavioral_task(self, task: ArtificialSubject.Task): | ||
| if task != ArtificialSubject.Task.reading_times: | ||
| raise NotImplementedError() | ||
|
|
||
| def test_dummy_bad(self): | ||
| benchmark = load_benchmark('Futrell2018-GAM-pearsonr') | ||
| reading_times = RandomState(0).random(10256) | ||
| dummy_model = TestBenchmark.DummyModel(reading_times=reading_times) | ||
| score = benchmark(dummy_model) | ||
| assert score == approx(0.00853059 / .858, abs=0.001) | ||
|
|
||
| def test_exact(self): | ||
| benchmark = load_benchmark('Futrell2018-pearsonr') | ||
| dummy_model = TestBenchmark.DummyModel(reading_times=benchmark.data.mean('subject').values) | ||
| score = benchmark(dummy_model) | ||
| assert score == approx(1) | ||
|
|
||
| def test_ceiling(self): | ||
| benchmark = load_benchmark('Futrell2018-pearsonr') | ||
| ceiling = benchmark.ceiling | ||
| assert ceiling == approx(.858, abs=.0005) | ||
| assert ceiling.raw.median('split') == ceiling | ||
| assert ceiling.uncorrected_consistencies.median('split') < ceiling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo?