Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/major refactor #12

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nada_ai/__init__.py

This file was deleted.

2 changes: 2 additions & 0 deletions nada_ai/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model_client import ModelClient
from .clients import *
4 changes: 4 additions & 0 deletions nada_ai/client/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sklearn import SklearnClient
from .torch import TorchClient

__all__ = ["SklearnClient", "TorchClient"]
29 changes: 29 additions & 0 deletions nada_ai/client/clients/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Scikit-learn client implementation"""

import sklearn
from nada_ai.client.model_client import ModelClient
from nada_ai.typing import LinearModel

__all__ = ["SklearnClient"]


class SklearnClient(ModelClient):
"""ModelClient for Scikit-learn models"""

def __init__(self, model: sklearn.base.BaseEstimator) -> None:
"""
Client initialization.

Args:
model (sklearn.base.BaseEstimator): Sklearn model object to wrap around.
"""
if isinstance(model, LinearModel):
state_dict = {"coef": model.coef_}
if model.fit_intercept is True:
state_dict.update({"intercept": model.intercept_})
else:
raise NotImplementedError(
f"Instantiating ModelClient from Sklearn model type `{type(model).__name__}` is not yet implemented."
)

self.state_dict = state_dict
19 changes: 19 additions & 0 deletions nada_ai/client/clients/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""PyTorch client implementation"""

from torch import nn
from nada_ai.client.model_client import ModelClient

__all__ = ["TorchClient"]


class TorchClient(ModelClient):
"""ModelClient for PyTorch models"""

def __init__(self, model: nn.Module) -> None:
"""
Client initialization.

Args:
model (nn.Module): PyTorch model object to wrap around.
"""
self.state_dict = model.state_dict()
86 changes: 9 additions & 77 deletions nada_ai/client.py → nada_ai/client/model_client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
"""
This module provides functions to work with the Python Nillion Client
"""
"""Model client implementation"""

from abc import ABC, ABCMeta
import nada_algebra as na
import nada_algebra.client as na_client
from typing import Any, Dict, Sequence, Union
from typing import Any, Dict, Sequence
from nada_ai.typing import NillionType

from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
LogisticRegressionCV,
)
import torch
from torch import nn
import sklearn
import numpy as np
import py_nillion_client as nillion

_NillionType = Union[
na.Rational,
na.SecretRational,
nillion.SecretInteger,
nillion.SecretUnsignedInteger,
nillion.PublicVariableInteger,
nillion.PublicVariableUnsignedInteger,
]
_LinearModel = Union[LinearRegression, LogisticRegression, LogisticRegressionCV]


class ModelClientMeta(ABCMeta):
Expand All @@ -42,9 +23,9 @@ def __call__(self, *args, **kwargs) -> object:
Returns:
object: Result object.
"""
obj = super(ModelClientMeta, self).__call__(*args, **kwargs)
obj = super().__call__(*args, **kwargs)
if not getattr(obj, "state_dict"):
raise AttributeError("required attribute `state_dict` not set")
raise AttributeError("Required attribute `state_dict` not set")
return obj


Expand All @@ -54,21 +35,21 @@ class ModelClient(ABC, metaclass=ModelClientMeta):
def export_state_as_secrets(
self,
name: str,
nada_type: _NillionType,
) -> Dict[str, _NillionType]:
nada_type: NillionType,
) -> Dict[str, NillionType]:
"""
Exports model state as a Dict of Nillion secret types.

Args:
name (str): Name to be used to store state secrets in the network.
nada_type (_NillionType): Data type to convert weights to.
nada_type (NillionType): Data type to convert weights to.

Raises:
NotImplementedError: Raised when unsupported model state type is passed.
TypeError: Raised when model state has incompatible values.

Returns:
Dict[str, _NillionType]: Dict of Nillion secret types that represents model state.
Dict[str, NillionType]: Dict of Nillion secret types that represents model state.
"""
if nada_type not in (na.Rational, na.SecretRational):
raise NotImplementedError("Exporting non-rational state is not supported")
Expand Down Expand Up @@ -104,52 +85,3 @@ def __ensure_numpy(self, array_like: Any) -> np.ndarray:
raise TypeError(
"Could not convert type `%s` to NumPy array" % type(array_like).__name__
)


class StateClient(ModelClient):
"""ModelClient for generic model states"""

def __init__(self, state_dict: Dict[str, Any]) -> None:
"""
Client initialization.
This client accepts an arbitrary model state as input.

Args:
state_dict (Dict[str, Any]): State dict.
"""
self.state_dict = state_dict


class TorchClient(ModelClient):
"""ModelClient for PyTorch models"""

def __init__(self, model: nn.Module) -> None:
"""
Client initialization.

Args:
model (nn.Module): PyTorch model object to wrap around.
"""
self.state_dict = model.state_dict()


class SklearnClient(ModelClient):
"""ModelClient for Scikit-learn models"""

def __init__(self, model: sklearn.base.BaseEstimator) -> None:
"""
Client initialization.

Args:
model (sklearn.base.BaseEstimator): Sklearn model object to wrap around.
"""
if isinstance(model, _LinearModel):
state_dict = {"coef": model.coef_}
if model.fit_intercept is True:
state_dict.update({"intercept": model.intercept_})
else:
raise NotImplementedError(
f"Instantiating ModelClient from Sklearn model type `{type(model).__name__}` is not yet implemented."
)

self.state_dict = state_dict
2 changes: 2 additions & 0 deletions nada_ai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Custom exceptions"""

__all__ = ["MismatchedShapesException"]


class MismatchedShapesException(Exception):
"""Raised when NadaArray shapes are incompatible"""
Expand Down
2 changes: 2 additions & 0 deletions nada_ai/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from nada_ai.nn.module import Module
from nada_ai.nn.parameter import Parameter

__all__ = ["LinearRegression"]


class LinearRegression(Module):
"""Linear regression implementation"""
Expand Down
3 changes: 1 addition & 2 deletions nada_ai/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .module import Module
from .parameter import Parameter
from .layers import *
from .activations import *
from .modules import *
Loading
Loading