|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +from brainscore_language import load_model |
| 5 | +from brainscore_language.artificial_subject import ArtificialSubject |
| 6 | + |
| 7 | + |
| 8 | +@pytest.mark.memory_intense |
| 9 | +@pytest.mark.parametrize('model_identifier, expected_reading_times', [ |
| 10 | + ('t5-small', [25.646585, 23.780153, 23.018826, 22.344381, 11.96658, 27.054287, 10.594951, 13.187043]), |
| 11 | + ('t5-base', [7.7039944e-03, 6.8635613e-02, 3.1093130e+01, 1.2913298e+02, 8.5430244e+01, 1.6261120e+01, 8.2980719e+00, 2.9535002e+01]), |
| 12 | + ('t5-large', [31.604916, 18.852331, 30.816673, 48.99762 , 49.006733, 36.088543, 14.189968, 37.781395]), |
| 13 | + ('t5-xl', [ 5.2831264, 18.823713, 19.249414, 35.212494, 24.10475, 19.929758 , 11.064505 , 16.397375 ]), |
| 14 | + ('t5-xxl', [26.934216, 30.064108, 18.61358, 71.8481, 20.456089, 18.108957, 25.52297, 20.845043]), |
| 15 | + ('flan-t5-small', [4.626572, 5.4074254, 2.9690156, 5.98445, 12.027061, 11.096782, 16.912296, 14.794151]), |
| 16 | + ('flan-t5-base', [1.8610231, 1.5091983, 2.3265584, 2.5798035, 0.9352376, 2.594869 , 3.4819074, 2.7790558]), |
| 17 | + ('flan-t5-large', [2.2994747, 4.1134634, 1.6111257, 10.103671, 11.365605, 3.37785, 1.4599704, 2.9243639]), |
| 18 | + ('flan-t5-xl', [2.5323708, 2.9281907, 3.2239344, 10.614168, 7.162341, 3.0385818, 2.9526176, 2.7103176]), |
| 19 | + ('flan-t5-xxl', [2.3222983, 2.3133714, 2.8529167, 11.162584, 6.798625, 4.742971, 2.9756427, 2.9877827]), |
| 20 | +]) |
| 21 | +def test_reading_times(model_identifier, expected_reading_times): |
| 22 | + model = load_model(model_identifier) |
| 23 | + text = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy'] |
| 24 | + model.start_behavioral_task(task=ArtificialSubject.Task.reading_times) |
| 25 | + reading_times = model.digest_text(text)['behavior'] |
| 26 | + np.testing.assert_allclose(reading_times, expected_reading_times, atol=0.001) |
| 27 | + |
| 28 | + |
| 29 | +@pytest.mark.memory_intense |
| 30 | +@pytest.mark.parametrize('model_identifier, expected_next_words', [ |
| 31 | + ('t5-small', ['in', 'in', 'in']), |
| 32 | + ('t5-base', ['<extra_id_27>', '</s>', '<extra_id_27>']), |
| 33 | + ('t5-large', ['<extra_id_11>', '<extra_id_11>', '<extra_id_11>']), |
| 34 | + ('t5-xl', ['', '', '']), |
| 35 | + ('t5-xxl', ['', 'ES', ',']), |
| 36 | + ('flan-t5-small', ['...', '...', '...']), |
| 37 | + ('flan-t5-base', ['</s>', '...', '</s>']), |
| 38 | + ('flan-t5-large', ['', '', '']), |
| 39 | + ('flan-t5-xl', ['', '...', '</s>']), |
| 40 | + ('flan-t5-xxl', ['</s>', '.', '.']), |
| 41 | +]) |
| 42 | +def test_next_word(model_identifier, expected_next_words): |
| 43 | + model = load_model(model_identifier) |
| 44 | + text = ['the quick brown fox', 'jumps over', 'the lazy'] |
| 45 | + model.start_behavioral_task(task=ArtificialSubject.Task.next_word) |
| 46 | + next_word_predictions = model.digest_text(text)['behavior'] |
| 47 | + np.testing.assert_array_equal(next_word_predictions, expected_next_words) |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.memory_intense |
| 51 | +@pytest.mark.parametrize('model_identifier, feature_size', [ |
| 52 | + ('t5-small', 512), |
| 53 | + ('t5-base', 768), |
| 54 | + ('t5-large', 1024), |
| 55 | + ('t5-xl', 2048), |
| 56 | + ('t5-xxl', 4096), |
| 57 | + ('flan-t5-small', 512), |
| 58 | + ('flan-t5-base', 768), |
| 59 | + ('flan-t5-large', 1024), |
| 60 | + ('flan-t5-xl', 2048), |
| 61 | + ('flan-t5-xxl', 4096), |
| 62 | +]) |
| 63 | +def test_neural(model_identifier, feature_size): |
| 64 | + model = load_model(model_identifier) |
| 65 | + text = ['the quick brown fox', 'jumps over', 'the lazy dog'] |
| 66 | + model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system, |
| 67 | + recording_type=ArtificialSubject.RecordingType.fMRI) |
| 68 | + representations = model.digest_text(text)['neural'] |
| 69 | + assert len(representations['presentation']) == 3 |
| 70 | + np.testing.assert_array_equal(representations['stimulus'], text) |
| 71 | + assert len(representations['neuroid']) == feature_size |
0 commit comments