Skip to content

Commit

Permalink
more flexible tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
benlipkin committed Feb 3, 2023
1 parent c5b77ba commit 2ce5e8b
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 12 deletions.
5 changes: 1 addition & 4 deletions probsem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

from probsem.abstract import Object, IModel
from probsem.utils import tokenize

openai.api_key_path = str(pathlib.Path.home() / ".openai_api_key")

Expand Down Expand Up @@ -115,9 +114,7 @@ def _set_torch_device(self) -> None:

@functools.lru_cache(maxsize=128)
def _encode_text(self, text: str) -> typing.Dict[str, torch.Tensor]:
return self._tokenizer(
tokenize(text), is_split_into_words=True, return_tensors="pt"
).to(self._device)
return self._tokenizer(text, return_tensors="pt").to(self._device)

def _decode_text(self, tokens: torch.Tensor) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=True)
Expand Down
8 changes: 0 additions & 8 deletions probsem/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
import typing

import nltk
import numpy as np
import numpy.typing as npt

Expand All @@ -10,13 +9,6 @@ def sanitize_filename(text: str) -> str:
return re.sub(r"^[ .]|[/<>:\"\\|?*]+|[ .]$", "-", text)


def tokenize(text: str) -> typing.List[str]:
text = text.replace("\n", " NEWLINE ").replace("'", " ` ")
tokens = nltk.tokenize.treebank.TreebankWordTokenizer().tokenize(text)
tokens = [t.replace("NEWLINE", "\n").replace("`", "'") for t in tokens]
return tokens


def normalize(weights: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
return np.exp(weights) / np.sum(np.exp(weights))

Expand Down

0 comments on commit 2ce5e8b

Please sign in to comment.