Skip to content

Commit 8ace8f3

Browse files
committed
Add T5 models
1 parent 44d030c commit 8ace8f3

File tree

4 files changed

+151
-0
lines changed

4 files changed

+151
-0
lines changed

brainscore_language/model_helpers/huggingface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def _tokenize_overflow_aware(self, context, num_previous_context_tokens: int) ->
181181
context_tokens.pop('num_truncated_tokens')
182182
if 'overflow_to_sample_mapping' in context_tokens:
183183
context_tokens.pop('overflow_to_sample_mapping')
184+
if self.basemodel.config.is_encoder_decoder:
185+
context_tokens['decoder_input_ids'] = context_tokens['input_ids']
184186
context_tokens.to(self.device)
185187
return context_tokens, num_new_context_tokens
186188

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from brainscore_language import model_registry
2+
from brainscore_language import ArtificialSubject
3+
from brainscore_language.model_helpers.huggingface import HuggingfaceSubject
4+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5+
6+
# layer assignment based on choosing the maximally scoring layer on Pereira2018-encoding
7+
8+
model_registry['t5-small'] = lambda: HuggingfaceSubject(
9+
model_id='t5-small',
10+
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-small', device_map="auto"),
11+
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-small'),
12+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.6'}
13+
)
14+
15+
model_registry['t5-base'] = lambda: HuggingfaceSubject(
16+
model_id='t5-base',
17+
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-base', device_map="auto"),
18+
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-base'),
19+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.9'}
20+
)
21+
22+
model_registry['t5-large'] = lambda: HuggingfaceSubject(
23+
model_id='t5-large',
24+
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-large', device_map="auto"),
25+
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-large'),
26+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.17'}
27+
)
28+
29+
model_registry['t5-xl'] = lambda: HuggingfaceSubject(
30+
model_id='t5-xl',
31+
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-xl', device_map="auto"),
32+
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-xl'),
33+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.2'}
34+
)
35+
36+
model_registry['t5-xxl'] = lambda: HuggingfaceSubject(
37+
model_id='t5-xxl',
38+
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-xxl', device_map="auto"),
39+
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-xxl'),
40+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.0'}
41+
)
42+
43+
model_registry['flan-t5-small'] = lambda: HuggingfaceSubject(
44+
model_id='flan-t5-small',
45+
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small', device_map="auto"),
46+
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-small'),
47+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.7'}
48+
)
49+
50+
model_registry['flan-t5-base'] = lambda: HuggingfaceSubject(
51+
model_id='flan-t5-base',
52+
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base', device_map="auto"),
53+
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-base'),
54+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.7'}
55+
)
56+
57+
model_registry['flan-t5-large'] = lambda: HuggingfaceSubject(
58+
model_id='flan-t5-large',
59+
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-large', device_map="auto"),
60+
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-large'),
61+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.18'}
62+
)
63+
64+
model_registry['flan-t5-xl'] = lambda: HuggingfaceSubject(
65+
model_id='flan-t5-xl',
66+
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xl', device_map="auto"),
67+
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-xl'),
68+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.2'}
69+
)
70+
71+
model_registry['flan-t5-xxl'] = lambda: HuggingfaceSubject(
72+
model_id='flan-t5-xxl',
73+
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xxl', device_map="auto"),
74+
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-xxl'),
75+
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.0'}
76+
)

brainscore_language/models/t5/test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import pytest
3+
4+
from brainscore_language import load_model
5+
from brainscore_language.artificial_subject import ArtificialSubject
6+
7+
8+
@pytest.mark.memory_intense
9+
@pytest.mark.parametrize('model_identifier, expected_reading_times', [
10+
('t5-small', [25.646585, 23.780153, 23.018826, 22.344381, 11.96658, 27.054287, 10.594951, 13.187043]),
11+
('t5-base', [7.7039944e-03, 6.8635613e-02, 3.1093130e+01, 1.2913298e+02, 8.5430244e+01, 1.6261120e+01, 8.2980719e+00, 2.9535002e+01]),
12+
('t5-large', [31.604916, 18.852331, 30.816673, 48.99762 , 49.006733, 36.088543, 14.189968, 37.781395]),
13+
('t5-xl', [ 5.2831264, 18.823713, 19.249414, 35.212494, 24.10475, 19.929758 , 11.064505 , 16.397375 ]),
14+
('t5-xxl', [26.934216, 30.064108, 18.61358, 71.8481, 20.456089, 18.108957, 25.52297, 20.845043]),
15+
('flan-t5-small', [4.626572, 5.4074254, 2.9690156, 5.98445, 12.027061, 11.096782, 16.912296, 14.794151]),
16+
('flan-t5-base', [1.8610231, 1.5091983, 2.3265584, 2.5798035, 0.9352376, 2.594869 , 3.4819074, 2.7790558]),
17+
('flan-t5-large', [2.2994747, 4.1134634, 1.6111257, 10.103671, 11.365605, 3.37785, 1.4599704, 2.9243639]),
18+
('flan-t5-xl', [2.5323708, 2.9281907, 3.2239344, 10.614168, 7.162341, 3.0385818, 2.9526176, 2.7103176]),
19+
('flan-t5-xxl', [2.3222983, 2.3133714, 2.8529167, 11.162584, 6.798625, 4.742971, 2.9756427, 2.9877827]),
20+
])
21+
def test_reading_times(model_identifier, expected_reading_times):
22+
model = load_model(model_identifier)
23+
text = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy']
24+
model.start_behavioral_task(task=ArtificialSubject.Task.reading_times)
25+
reading_times = model.digest_text(text)['behavior']
26+
np.testing.assert_allclose(reading_times, expected_reading_times, atol=0.001)
27+
28+
29+
@pytest.mark.memory_intense
30+
@pytest.mark.parametrize('model_identifier, expected_next_words', [
31+
('t5-small', ['in', 'in', 'in']),
32+
('t5-base', ['<extra_id_27>', '</s>', '<extra_id_27>']),
33+
('t5-large', ['<extra_id_11>', '<extra_id_11>', '<extra_id_11>']),
34+
('t5-xl', ['', '', '']),
35+
('t5-xxl', ['', 'ES', ',']),
36+
('flan-t5-small', ['...', '...', '...']),
37+
('flan-t5-base', ['</s>', '...', '</s>']),
38+
('flan-t5-large', ['', '', '']),
39+
('flan-t5-xl', ['', '...', '</s>']),
40+
('flan-t5-xxl', ['</s>', '.', '.']),
41+
])
42+
def test_next_word(model_identifier, expected_next_words):
43+
model = load_model(model_identifier)
44+
text = ['the quick brown fox', 'jumps over', 'the lazy']
45+
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
46+
next_word_predictions = model.digest_text(text)['behavior']
47+
np.testing.assert_array_equal(next_word_predictions, expected_next_words)
48+
49+
50+
@pytest.mark.memory_intense
51+
@pytest.mark.parametrize('model_identifier, feature_size', [
52+
('t5-small', 512),
53+
('t5-base', 768),
54+
('t5-large', 1024),
55+
('t5-xl', 2048),
56+
('t5-xxl', 4096),
57+
('flan-t5-small', 512),
58+
('flan-t5-base', 768),
59+
('flan-t5-large', 1024),
60+
('flan-t5-xl', 2048),
61+
('flan-t5-xxl', 4096),
62+
])
63+
def test_neural(model_identifier, feature_size):
64+
model = load_model(model_identifier)
65+
text = ['the quick brown fox', 'jumps over', 'the lazy dog']
66+
model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system,
67+
recording_type=ArtificialSubject.RecordingType.fMRI)
68+
representations = model.digest_text(text)['neural']
69+
assert len(representations['presentation']) == 3
70+
np.testing.assert_array_equal(representations['stimulus'], text)
71+
assert len(representations['neuroid']) == feature_size

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ dependencies = [
2222
"gensim",
2323
"joblib",
2424
"accelerate",
25+
"sentencepiece",
26+
"protobuf==3.19.4",
2527
]
2628

2729
[project.optional-dependencies]

0 commit comments

Comments
 (0)