Skip to content

Commit 8d67540

Browse files
authored
Add current version from GitLab (#1)
* add lib files * add notebooks * remove scripts dir * extend gitignore * add tests * Update pyproject.toml * update changelog * Update README.md * add short documentation to readme * ignore unused imports in init files * install all required dependencies for tests * update imports formatting * fix imports in notebooks * ignore E402 errors * fix dependencies * decrease required precision in test * fix doctest * update line separators to LF * update date in changelog
1 parent fa7d9c2 commit 8d67540

File tree

65 files changed

+12129
-13
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+12129
-13
lines changed

.ci/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env bash
22
set -ex
33
pip install poetry
4-
poetry install
4+
poetry install --all-extras --with test
55
poetry run pytest

.flake8

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ max-line-length = 120
33
# W503: we prefer line breaks _before_ operators (as changed in PEP8 in 2016).
44
# E203: whitespace before : , black is right here: https://github.com/psf/black/issues/315
55
ignore = W503,E203
6+
# Ignore `F401` (unused imports) in all `__init__.py` files.
7+
# Ignore `E402` (import not at top of file) in all notebooks. `# flake8-noqa-cell-E402` doesn't work.
8+
per-file-ignores =
9+
__init__.py: F401
10+
notebooks/*: E402
611
show-source = True
712
statistics = True
813
exclude =

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,10 @@
77
.vscode
88
venv
99
.venv
10+
tmp
11+
/models
12+
lightning_logs
13+
14+
# generated package files
15+
mim_nlp.egg-info
16+
/dist

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
11
# Changelog
22

3+
## 0.2.0 April 9, 2024
4+
* Moved files from GitLab project.
5+
* Classification
6+
* Neural Network
7+
* SVM
8+
* Regression
9+
* Neural Network
10+
* Seq2Seq
11+
* Summarization
12+
* Preprocessing
13+
* Text cleaning
14+
* Lemmatization
15+
* Deduplication
16+
317
## 0.1.0 April 2, 2024
418
* Project created.

README.md

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,73 @@
11
# MIM NLP
2+
With this package you can easily use pre-trained models and fine-tune them,
3+
as well as create and train your own neural networks.
24

3-
## Project goal
5+
Below, we list NLP tasks and models that are available:
6+
* Classification
7+
* Neural Network
8+
* SVM
9+
* Regression
10+
* Neural Network
11+
* Seq2Seq
12+
* Summarization (Neural Network)
13+
14+
It comes with utilities for text pre-processing such as:
15+
* Text cleaning
16+
* Lemmatization
17+
* Deduplication
418

519
## Installation
620

21+
### TODO PyPI package
22+
The package comes with the following extras (optional dependencies for given modules):
23+
- `svm` - simple svm model for classification
24+
- `classifier` - classification models: svm, neural networks
25+
- `regressor` - regression models
26+
- `preprocessing` - cleaning text, lemmatization and deduplication
27+
- `seq2seq` - `Seq2Seq` and `Summarizer` models
28+
729
## Usage
30+
Examples can be found in the [notebooks directory](/notebooks).
831

9-
## Development
32+
### Model classes
33+
* `classifier.nn.NNClassifier` - Neural Network Classifier
34+
* `classifier.svm.SVMClassifier` - Support Vector Machine Classifier
35+
* `classifier.svm.SVMClassifierWithFeatureSelection` - `SVMClassifier` with additional feature selection step
36+
* `regressor.AutoRegressor` - regressor based on transformers' Auto Classes
37+
* `regressor.NNRegressor` - Neural Network Regressor
38+
* `seq2seq.AutoSummarizer` - summarizer based on transformers' Auto Classes
39+
40+
### Interface
41+
All the model classes have common interface:
42+
* `fit`
43+
* `predict`
44+
* `save`
45+
* `load`
1046

47+
and specific additional methods.
48+
49+
### Text pre-processing
50+
* `preprocessing.TextCleaner` - define a pipeline for text cleaning, supports concurrent processesing
51+
* `preprocessing.lemmatize` - lemmatize text in Polish with Morfeusz
52+
* `preprocessing.Deduplicator` - find near-duplicate texts (depending on `threshold`) with Jaccard index for n-grams
53+
54+
## Development
1155
Remember to use a separate environment for each project.
1256
Run commands below inside the project's environment.
1357

1458
### Dependencies
15-
1659
We use `poetry` for dependency management.
1760
If you have never used it, consult
1861
[poetry documentation](https://python-poetry.org/docs/)
1962
for installation guidelines and basic usage instructions.
2063

2164
```sh
22-
poetry install
65+
poetry install --with dev
66+
```
67+
68+
To fix the `Failed to unlock the collection!` error or stuck packages installation, execute the below command:
69+
```sh
70+
export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
2371
```
2472

2573
### Git hooks
@@ -37,13 +85,11 @@ Fails if any changes are made, so you have to run `git add` and `git commit` onc
3785
* _Strip notebooks_ – produces _stripped_ versions of notebooks in `stripped` directory.
3886

3987
### Tests
40-
4188
```sh
4289
pytest
4390
```
4491

4592
### Linting
46-
4793
We use `isort` and `flake8` along with `nbqa` to ensure code quality.
4894
The appropriate options are set in configuration files.
4995
You can run them with:
@@ -62,3 +108,9 @@ You can run black to format code (including notebooks):
62108
```sh
63109
black .
64110
```
111+
112+
### New version release
113+
In order to add the next version of the package to PyPI, do the following steps:
114+
- First, increment the package version in `pyproject.toml`.
115+
- Then build the new version: run `poetry build` in the root directory.
116+
- Finally, upload to PyPI: `poetry publish` (two newly created files).

mim_nlp/classifier/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .nn_classifier import NNClassifier
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Callable, Optional, Union
4+
5+
import numpy as np
6+
import torch.nn as nn
7+
from numpy._typing import NDArray, _ArrayLikeInt_co, _ArrayLikeStr_co
8+
from pytorch_lightning.callbacks import Callback
9+
from sklearn.feature_extraction.text import TfidfVectorizer
10+
from sklearn.pipeline import Pipeline
11+
from torch import Tensor
12+
from torch.nn.modules.loss import _Loss
13+
from torch.optim import Adam, Optimizer
14+
from torchmetrics import Metric
15+
from transformers import PreTrainedTokenizerBase
16+
17+
from mim_nlp.models import Classifier
18+
from mim_nlp.neural_network import NNModelMixin
19+
20+
21+
class NNClassifier(NNModelMixin, Classifier):
22+
"""Neural Network Classifier
23+
24+
The `input_size` parameter denotes the length of a tokenized text.
25+
This should be equal to the size of the input layer in the neural network.
26+
In the case of using TF-IDF, the output size is constant and equal to the size of the vocabulary,
27+
so the `input_size` has to be set accordingly.
28+
When transformers' tokenizer is used,
29+
a tokenized text is padded or truncated to a constant size equal to the `input_size`.
30+
31+
Neural network should omit activation function and return logits.
32+
Take that into consideration when choosing the loss function!
33+
We use Sigmoid / Softmax internally to get predictions.
34+
35+
The `loss_function` is by default set to BCEWithLogitsLoss,
36+
which combines a Sigmoid layer and the BCELoss in one single class.
37+
For multiclass classification, use Cross Entropy Loss. Both losses take logits, as stated above.
38+
39+
Callables in `metrics_dict` take predictions (as probabilities) and targets, in that order! Callables can't be
40+
lambda functions because they are not pickleable and it would cause problems with saving the model.
41+
42+
Tips:
43+
- Change every lambda function to a function.
44+
- Set every argument in the function via `functools.partial`.
45+
46+
Example:
47+
>>> def accuracy_binary(y_pred, y_target):
48+
... y_pred = y_pred > 0.5
49+
... return torch.sum(y_target == y_pred) / len(y_target)
50+
51+
The `device` parameter can have the following values:
52+
- `"cpu"` - The model will be loaded on the CPU.
53+
- `"cuda"` - The model will be loaded on a single GPU.
54+
- `"cuda:i"` - The model will be loaded on the specific GPU with the index `i`.
55+
56+
It is also possible to use multiple GPUs. To do this:
57+
- Set `device` to `"cuda"`.
58+
- Set `many_gpus` to `True`.
59+
- As default, it will use all of them.
60+
61+
To use only selected GPUs - set the environmental variable `CUDA_VISIBLE_DEVICES`.
62+
"""
63+
64+
def __init__(
65+
self,
66+
batch_size: int,
67+
epochs: int,
68+
input_size: int,
69+
tokenizer: Optional[Union[PreTrainedTokenizerBase, Pipeline, TfidfVectorizer]],
70+
neural_network: nn.Module,
71+
loss_function: Union[_Loss, Callable[[Any, Any], Any]] = nn.BCEWithLogitsLoss(),
72+
optimizer: type[Optimizer] = Adam,
73+
optimizer_params: Optional[dict[str, Any]] = None,
74+
train_metrics_dict: Optional[dict[str, Union[Metric, Callable[[Tensor, Tensor], Any]]]] = None,
75+
eval_metrics_dict: Optional[dict[str, Union[Metric, Callable[[Tensor, Tensor], Any]]]] = None,
76+
callbacks: Optional[Union[Callback, list[Callback]]] = None,
77+
device: str = "cuda:0",
78+
many_gpus: bool = False,
79+
):
80+
super().__init__(
81+
batch_size=batch_size,
82+
epochs=epochs,
83+
input_size=input_size,
84+
tokenizer=tokenizer,
85+
neural_network=neural_network,
86+
loss_function=loss_function,
87+
optimizer=optimizer,
88+
optimizer_params=optimizer_params,
89+
train_metrics_dict=train_metrics_dict,
90+
eval_metrics_dict=eval_metrics_dict,
91+
callbacks=callbacks,
92+
device=device,
93+
many_gpus=many_gpus,
94+
)
95+
96+
def fit(self, x_train: _ArrayLikeStr_co, y_train: _ArrayLikeInt_co, fit_tokenizer: bool = False) -> None:
97+
"""For multiclass classifications `y_train` labels should be encoded as categorical, i.e. integers."""
98+
is_multiclass = False
99+
# check if multiclass
100+
if any(y >= 2 for y in y_train):
101+
y_train = Tensor(y_train).long()
102+
is_multiclass = True
103+
else:
104+
y_train = Tensor(y_train).float()
105+
super()._fit(
106+
x_train,
107+
y_train,
108+
x_eval=None,
109+
y_eval=None,
110+
fit_tokenizer=fit_tokenizer,
111+
is_classification=True,
112+
is_multiclass=is_multiclass,
113+
)
114+
115+
def fit_eval(
116+
self,
117+
x_train: _ArrayLikeStr_co,
118+
y_train: _ArrayLikeInt_co,
119+
x_eval: _ArrayLikeStr_co,
120+
y_eval: _ArrayLikeInt_co,
121+
fit_tokenizer: bool = False,
122+
) -> None:
123+
"""For multiclass classifications `y` labels should be encoded as categorical, i.e. integers."""
124+
is_multiclass = False
125+
# check if multiclass
126+
if any(y >= 2 for y in y_train):
127+
y_train = Tensor(y_train).long()
128+
y_eval = Tensor(y_eval).long()
129+
is_multiclass = True
130+
else:
131+
y_train = Tensor(y_train).float()
132+
y_eval = Tensor(y_eval).float()
133+
super()._fit(
134+
x_train,
135+
y_train,
136+
x_eval,
137+
y_eval,
138+
fit_tokenizer=fit_tokenizer,
139+
is_classification=True,
140+
is_multiclass=is_multiclass,
141+
)
142+
143+
def fit_tokenizer(self, x_train: _ArrayLikeStr_co, y_train: Optional[_ArrayLikeInt_co] = None) -> None:
144+
super().fit_tokenizer(x_train, y_train)
145+
146+
def predict(
147+
self, x: _ArrayLikeStr_co, batch_size: Optional[int] = None, score_threshold: float = 0.5
148+
) -> NDArray[np.int64]:
149+
predictions = self._get_predictions(x, batch_size)
150+
if predictions.shape[1] > 1:
151+
# multiclass classification
152+
return np.array(np.argmax(predictions, axis=1), dtype=np.int64)
153+
return np.array(predictions.flatten() > score_threshold, dtype=np.int64)
154+
155+
def predict_scores(self, x: _ArrayLikeStr_co, batch_size: Optional[int] = None) -> NDArray[np.float64]:
156+
predictions = self._get_predictions(x, batch_size)
157+
if predictions.shape[1] == 1:
158+
predictions = predictions.flatten()
159+
return np.array(predictions, dtype=np.float64)
160+
161+
def test(
162+
self,
163+
x: _ArrayLikeStr_co,
164+
y_test: _ArrayLikeInt_co,
165+
batch_size: Optional[int] = None,
166+
test_metrics_dict: Optional[dict[str, Union[Metric, Callable[[Tensor, Tensor], Any]]]] = None,
167+
) -> dict[str, Any]:
168+
if self.nn_module.is_multiclass:
169+
y_test = Tensor(y_test).long()
170+
else:
171+
y_test = Tensor(y_test).float()
172+
return super()._test(x, y_test, batch_size, test_metrics_dict)

mim_nlp/classifier/svm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .svm import SVMClassifier, SVMClassifierWithFeatureSelection

0 commit comments

Comments
 (0)