Skip to content

Latest commit

 

History

History
77 lines (57 loc) · 3.88 KB

README.md

File metadata and controls

77 lines (57 loc) · 3.88 KB

Welcome to probaforms

PyPI version Tests Docs Downloads License: MIT

Probaforms is a python library of conditional Generative Adversarial Networks, Normalizing Flows, Variational Autoencoders and other generative models for tabular data. All models have a sklearn-like interface to enable rapid use in a variety of science and engineering applications.

Implemented conditional models

Model Type Paper
ConditionalNormal MDN Bishop CM. Mixture density networks. 1994.
CVAE VAE Kingma DP, Welling M. Auto-encoding variational bayes. arXiv:1312.6114. ICLR 2014.
ConditionalWGAN GAN Arjovsky M, Chintala S, Bottou L. Wasserstein generative adversarial networks. arXiv:1701.07875. ICML 2017.
RealNVP Normalizing Flow Dinh L, Sohl-Dickstein J, Bengio S. Density estimation using real nvp. arXiv:1605.08803. ICLR 2017.

Installation

pip install probaforms

or

git clone https://github.com/hse-cs/probaforms
cd probaforms
pip install -e .

or

poetry install

Basic usage

(See more examples in the documentation.)

The following code snippet generates a noisy synthetic data, fits a conditional generative model, sample new objects, and displays the results.

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from probaforms.models import RealNVP

# generate sample X with conditions C
X, y = make_moons(n_samples=1000, noise=0.1)
C = y.reshape(-1, 1)

# fit nomalizing flow model
model = RealNVP(lr=0.01, n_epochs=100)
model.fit(X, C)

# sample new objects
X_gen = model.sample(C)

# display the results
plt.scatter(X_gen[y==0, 0], X_gen[y==0, 1])
plt.scatter(X_gen[y==1, 0], X_gen[y==1, 1])
plt.show()

Support

Thanks to all our contributors