This is an interactive visual analytic framework designed to assist in the analysis and interpretation of seq2seq and generative Transformer models. Main components include Hidden State Projection (including decoder state projection over each timestep), Attention Head View for attention pattern and head importance, as well as Instance View for input attribution.
We suggest installing within a virtual environment such as virtualenv
or conda
git clone https://github.com/raymondzmc/visual-analytics-for-generative-transformers.git ~/visual-analytics-for-generative-transformers
# Set up Python environment
cd ~/visual-analytics-for-generative-transformers
pip3 install -r requirements.txt
To visualize a specific dataset, create a new file in application/dataset/
containing a function that returns an iterable object without taking any arguments (see sst2.py
for examples), then import the dataset function in application/models/__init__.py
.
We suggest following the format for the Dataset
object in the datasets
library
The dataset
object needs to have the attribute visualize_columns
corresponding to a list of columns to be visualized in the Data Table View, while all items in the iterable object needs have the following keys:
id
: Index of the exampletokens
: String representation of input tokens for visualizing the saliency maps in the Instance Investigation Viewmax_length (m)
: Maximum input length for the model
For a customized transformer-based model, create a new file in application/models/
containing a function that returns the initialized model and tokenizer without taking any arguments (see bert_classification.py
for examples). Finally, import the model function in application/models/__init__.py
.
We suggest the model should be implemented in such a way where an item from the dataset
could be directly used as input (to avoid modifying the functions in the backend), such that:
example = dataset[0]
output = model(example)
Additionally, the model needs to have the following attributes/methods for our backend algorithms to work:
hidden_size (h)
: The hidden state dimensionsnum_hidden_layers (l)
: Number of hidden layersnum_attention_heads (h)
: Number of attention heads in each layerprune_heads(heads_to_prune)
(optional): Method for pruning attention heads (see the implementation from Hugging Face'stransformers
library)
The directory for each visualization task should contain subdirectories representing the different checkpoints, while each checkpoint directory contains the data files for the T3-Vis to visualize. We require the following required data files be stored in .pt
format for consistency (see torch.save and torch.load).
head_importance.pt
:l × h
array encoding the relative task importance of each attention headaggregate_attn.pt
: list of length(l × h)
where each item is a dictionary containing the aggregated attention matrix of an attention head with keyslayer
: Integer representing the layer indexhead
: Integer representing the layer indexattn
: Am × m
list where each value is the average attention value at the corresponding index
projection_data.pt
: Dict object where each value is a list of lengthn
(number of examples in dataset), with the following keysid
: Index of data exampleprojection_<l>_<0, 1>
: The hidden state projection value of axis <0, 1> for layer (used in Hidden States View)avg_variability
: Average variability of the data example across previous epochs (used in Cartography View)avg_confidence
: Average true class confidence of the data example across previous epochs (used in Cartography View)others
(optional): Discrete or continuous attributes of the data examples used for color encodings (e.g. label, prediction, loss, length), need to be defined inprojection_keys
inget_data()
model.pt
(optional): Model parameters for the current checkpoint (see loading and saving documenation for Pytorch)
For example in the SST-2 Demo, the following file structure is required under the resource directory sst_bert
.
└── sst_bert
├── epoch_1
│ ├── aggregate_attn.pt
│ ├── head_importance.pt
│ ├── model.pt
│ └── projection_data.pt
├── epoch_2
│ ├── aggregate_attn.pt
│ ├── head_importance.pt
│ ├── model.pt
│ └── projection_data.pt
├── epoch_3
│ ├── aggregate_attn.pt
│ ├── head_importance.pt
│ ├── model.pt
│ └── projection_data.pt
└── pretrained
├── aggregate_attn.pt
├── head_importance.pt
└── projection_data.pt
Please see our script run_sst2_classification.py
for example on how to processing data for visualization during fine-tuning. We provide some helpful functions in the application/utils
directory.
cd application
python app.py --model <model_function> --dataset <dataset_function> --resource_dir <resource_directory>
For example, for the Pegasus_XSum
Demo, run the following command:
python app.py --dataset xsum_test_set --model pegasus --resource_dir resources/pegasus_xsum/