Skip to content

Commit

Permalink
Add T5 models
Browse files Browse the repository at this point in the history
  • Loading branch information
awwkl committed Jul 11, 2023
1 parent 44d030c commit 8ace8f3
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
2 changes: 2 additions & 0 deletions brainscore_language/model_helpers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def _tokenize_overflow_aware(self, context, num_previous_context_tokens: int) ->
context_tokens.pop('num_truncated_tokens')
if 'overflow_to_sample_mapping' in context_tokens:
context_tokens.pop('overflow_to_sample_mapping')
if self.basemodel.config.is_encoder_decoder:
context_tokens['decoder_input_ids'] = context_tokens['input_ids']
context_tokens.to(self.device)
return context_tokens, num_new_context_tokens

Expand Down
76 changes: 76 additions & 0 deletions brainscore_language/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from brainscore_language import model_registry
from brainscore_language import ArtificialSubject
from brainscore_language.model_helpers.huggingface import HuggingfaceSubject
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# layer assignment based on choosing the maximally scoring layer on Pereira2018-encoding

model_registry['t5-small'] = lambda: HuggingfaceSubject(
model_id='t5-small',
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-small', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-small'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.6'}
)

model_registry['t5-base'] = lambda: HuggingfaceSubject(
model_id='t5-base',
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-base', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-base'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.9'}
)

model_registry['t5-large'] = lambda: HuggingfaceSubject(
model_id='t5-large',
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-large', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-large'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.17'}
)

model_registry['t5-xl'] = lambda: HuggingfaceSubject(
model_id='t5-xl',
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-xl', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-xl'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.2'}
)

model_registry['t5-xxl'] = lambda: HuggingfaceSubject(
model_id='t5-xxl',
model=AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-xxl', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/t5-v1_1-xxl'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.0'}
)

model_registry['flan-t5-small'] = lambda: HuggingfaceSubject(
model_id='flan-t5-small',
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-small'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.7'}
)

model_registry['flan-t5-base'] = lambda: HuggingfaceSubject(
model_id='flan-t5-base',
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-base'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.7'}
)

model_registry['flan-t5-large'] = lambda: HuggingfaceSubject(
model_id='flan-t5-large',
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-large', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-large'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'encoder.block.18'}
)

model_registry['flan-t5-xl'] = lambda: HuggingfaceSubject(
model_id='flan-t5-xl',
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xl', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-xl'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.2'}
)

model_registry['flan-t5-xxl'] = lambda: HuggingfaceSubject(
model_id='flan-t5-xxl',
model=AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xxl', device_map="auto"),
tokenizer=AutoTokenizer.from_pretrained('google/flan-t5-xxl'),
region_layer_mapping={ArtificialSubject.RecordingTarget.language_system: 'decoder.block.0'}
)
71 changes: 71 additions & 0 deletions brainscore_language/models/t5/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest

from brainscore_language import load_model
from brainscore_language.artificial_subject import ArtificialSubject


@pytest.mark.memory_intense
@pytest.mark.parametrize('model_identifier, expected_reading_times', [
('t5-small', [25.646585, 23.780153, 23.018826, 22.344381, 11.96658, 27.054287, 10.594951, 13.187043]),
('t5-base', [7.7039944e-03, 6.8635613e-02, 3.1093130e+01, 1.2913298e+02, 8.5430244e+01, 1.6261120e+01, 8.2980719e+00, 2.9535002e+01]),
('t5-large', [31.604916, 18.852331, 30.816673, 48.99762 , 49.006733, 36.088543, 14.189968, 37.781395]),
('t5-xl', [ 5.2831264, 18.823713, 19.249414, 35.212494, 24.10475, 19.929758 , 11.064505 , 16.397375 ]),
('t5-xxl', [26.934216, 30.064108, 18.61358, 71.8481, 20.456089, 18.108957, 25.52297, 20.845043]),
('flan-t5-small', [4.626572, 5.4074254, 2.9690156, 5.98445, 12.027061, 11.096782, 16.912296, 14.794151]),
('flan-t5-base', [1.8610231, 1.5091983, 2.3265584, 2.5798035, 0.9352376, 2.594869 , 3.4819074, 2.7790558]),
('flan-t5-large', [2.2994747, 4.1134634, 1.6111257, 10.103671, 11.365605, 3.37785, 1.4599704, 2.9243639]),
('flan-t5-xl', [2.5323708, 2.9281907, 3.2239344, 10.614168, 7.162341, 3.0385818, 2.9526176, 2.7103176]),
('flan-t5-xxl', [2.3222983, 2.3133714, 2.8529167, 11.162584, 6.798625, 4.742971, 2.9756427, 2.9877827]),
])
def test_reading_times(model_identifier, expected_reading_times):
model = load_model(model_identifier)
text = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy']
model.start_behavioral_task(task=ArtificialSubject.Task.reading_times)
reading_times = model.digest_text(text)['behavior']
np.testing.assert_allclose(reading_times, expected_reading_times, atol=0.001)


@pytest.mark.memory_intense
@pytest.mark.parametrize('model_identifier, expected_next_words', [
('t5-small', ['in', 'in', 'in']),
('t5-base', ['<extra_id_27>', '</s>', '<extra_id_27>']),
('t5-large', ['<extra_id_11>', '<extra_id_11>', '<extra_id_11>']),
('t5-xl', ['', '', '']),
('t5-xxl', ['', 'ES', ',']),
('flan-t5-small', ['...', '...', '...']),
('flan-t5-base', ['</s>', '...', '</s>']),
('flan-t5-large', ['', '', '']),
('flan-t5-xl', ['', '...', '</s>']),
('flan-t5-xxl', ['</s>', '.', '.']),
])
def test_next_word(model_identifier, expected_next_words):
model = load_model(model_identifier)
text = ['the quick brown fox', 'jumps over', 'the lazy']
model.start_behavioral_task(task=ArtificialSubject.Task.next_word)
next_word_predictions = model.digest_text(text)['behavior']
np.testing.assert_array_equal(next_word_predictions, expected_next_words)


@pytest.mark.memory_intense
@pytest.mark.parametrize('model_identifier, feature_size', [
('t5-small', 512),
('t5-base', 768),
('t5-large', 1024),
('t5-xl', 2048),
('t5-xxl', 4096),
('flan-t5-small', 512),
('flan-t5-base', 768),
('flan-t5-large', 1024),
('flan-t5-xl', 2048),
('flan-t5-xxl', 4096),
])
def test_neural(model_identifier, feature_size):
model = load_model(model_identifier)
text = ['the quick brown fox', 'jumps over', 'the lazy dog']
model.start_neural_recording(recording_target=ArtificialSubject.RecordingTarget.language_system,
recording_type=ArtificialSubject.RecordingType.fMRI)
representations = model.digest_text(text)['neural']
assert len(representations['presentation']) == 3
np.testing.assert_array_equal(representations['stimulus'], text)
assert len(representations['neuroid']) == feature_size
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ dependencies = [
"gensim",
"joblib",
"accelerate",
"sentencepiece",
"protobuf==3.19.4",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 8ace8f3

Please sign in to comment.