Skip to content

Commit

Permalink
feat: ✨ llama による実装を追加
Browse files Browse the repository at this point in the history
  • Loading branch information
dino3616 committed Jan 28, 2025
1 parent 8a02a3c commit c640a93
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def forward( # noqa: C901, PLR0912, PLR0913
if scalar is not None:
if scalar.dim() == 1:
scalar = scalar.unsqueeze(-1)
pooled_output = torch.cat([pooled_output, scalar], dim=1)
pooled_output_with_scalar = torch.cat([pooled_output, scalar], dim=1)

logits = self.classifier(pooled_output)
logits = self.classifier(pooled_output_with_scalar)

loss = None
if labels is not None:
Expand Down
3 changes: 1 addition & 2 deletions src/bert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
TrainingArguments,
)

from bert.dataset import TextPairWithScalarDataset

from .dataset import TextPairWithScalarDataset
from .modeling import ModernBertForSequenceClassificationWithScalar

load_dotenv()
Expand Down
Empty file added src/llama/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions src/llama/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
project_name: lockerai-reranking
project_dir: ${oc.env:PROJECT_DIR}

model:
name: ashitano-dcon/lockerai-reranking-llama
base_name: princeton-nlp/Sheared-LLaMA-2.7B
save_dir: ${project_dir}/checkpoints/${model.name}
classifier_pooling: mean

train:
split: train
num_epochs: 3
learning_rate: 5e-5
batch_size_per_device: 16
gradient_accumulation_steps: 1
gradient_checkpointing: True
max_grad_norm: 1.0
optim: adamw_torch
weight_decay: 0.0
scheduler: linear
warmup_steps: 0
warmup_ratio: 0.0

eval:
split: test
batch_size_per_device: 16
gradient_accumulation_steps: 1

huggingface:
token: ${oc.env:HF_TOKEN}

wandb:
project: ${oc.env:WANDB_PROJECT}
key: ${oc.env:WANDB_API_KEY}

hydra:
run:
dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
sweep:
dir: ${hydra.run.dir}
job:
chdir: False
verbose: INFO
46 changes: 46 additions & 0 deletions src/llama/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any

import torch
from datasets import Dataset
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)


class TextPairWithScalarDataset(Dataset):
def __init__(
self,
text1s: Any,
text2s: Any,
labels: Any,
scalars: Any,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
) -> None:
self.text1s = text1s
self.text2s = text2s
self.labels = labels
self.scalars = scalars
self.tokenizer = tokenizer

def __len__(self) -> int:
return len(self.text1s)

def __getitem__(self, idx: Any) -> dict[str, Any]:
text1 = self.text1s[idx]
text2 = self.text2s[idx]
scalar = self.scalars[idx]
label = self.labels[idx]

encoding = self.tokenizer(
text1,
text2,
return_tensors="pt",
)

return {
"input_ids": encoding["input_ids"].squeeze(0), # type: ignore # noqa: PGH003
"attention_mask": encoding["attention_mask"].squeeze(0), # type: ignore # noqa: PGH003
"scalar": torch.tensor(scalar, dtype=torch.float),
"labels": torch.tensor(label, dtype=torch.int),
}
228 changes: 228 additions & 0 deletions src/llama/modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)

LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


@add_start_docstrings(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassificationWithScalar(LlamaPreTrainedModel):
def __init__(self, config: LlamaConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.classifier = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value: nn.Embedding):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward( # noqa: C901, PLR0912, PLR0913
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | list[torch.FloatTensor] | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
scalar: torch.Tensor | None = None,
) -> tuple | SequenceClassifierOutputWithPast:
r"""Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" # noqa: D205
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]

batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # type: ignore # noqa: PGH003

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
elif input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(last_hidden_state.device)
else:
sequence_lengths = -1

batch_indices = torch.arange(batch_size, device=last_hidden_state.device)
last_hidden_state = last_hidden_state[batch_indices, sequence_lengths]

if scalar is not None:
if scalar.dim() == 1:
scalar = scalar.unsqueeze(-1)
last_hidden_state_with_scalar = torch.cat([last_hidden_state, scalar], dim=1)

logits = self.classifier(last_hidden_state_with_scalar)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[1:]
return ((loss, *output)) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading

0 comments on commit c640a93

Please sign in to comment.