-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/prophet #11
Merged
Merged
Feature/prophet #11
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
3f05e19
Restructure models & add prophet boilerplate
mathias-nillion 0217be8
Baseline prophet skeleton
mathias-nillion eb99afa
Merge branch 'main' into feature/prophet
mathias-nillion 65b433e
Implement prophet model
mathias-nillion e76f593
Update module & parameter logic
mathias-nillion bc5eee6
Update tests
mathias-nillion 688e0b0
Update poetry
mathias-nillion a771eed
Update examples
mathias-nillion eee449a
Update tests
mathias-nillion 50e0853
Fix prophet bug
mathias-nillion 8e43b30
Fix type checking
mathias-nillion 1ef4145
Refactoring
mathias-nillion 5b73d51
Fix multiplicative seasonality
mathias-nillion 56add95
fix flat trend
mathias-nillion 83d7206
Variable naming
mathias-nillion 6726a7a
Refactor fourier
mathias-nillion efab681
Fix formatting
mathias-nillion c60da03
Fix formatting
mathias-nillion 7610e74
Fix formatting
mathias-nillion 868c9d4
Fix formatting
mathias-nillion 8cc6489
Bump nada-algebra version
mathias-nillion 3dca259
Merge main
mathias-nillion d7c205b
Update poetry
mathias-nillion 4a3ce9e
Update examples
mathias-nillion b1b64f0
Update tests
mathias-nillion 56ca1b8
Fix compatibility with legacy python versions
mathias-nillion 1911508
Fix example tests
mathias-nillion 074cde1
Fix rescaling
mathias-nillion 3830420
Bump nada-algebra version
mathias-nillion 295a3ce
Update poetry
mathias-nillion File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,5 @@ version = "0.1.0" | |
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/main.py" | ||
path = "src/complex_model.py" | ||
prime_size = 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# This directory is kept purposely, so that no compilation errors arise. | ||
# Ignore everything in this directory | ||
* | ||
# Except this file | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# This directory is kept purposely, so that no compilation errors arise. | ||
# Ignore everything in this directory | ||
* | ||
# Except this file | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "complex_model" | ||
name = "neural_net" | ||
version = "0.1.0" | ||
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/main.py" | ||
path = "src/neural_net.py" | ||
prime_size = 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# This directory is kept purposely, so that no compilation errors arise. | ||
# Ignore everything in this directory | ||
* | ||
# Except this file | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
name = "time_series" | ||
version = "0.1.0" | ||
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/time_series.py" | ||
prime_size = 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import asyncio | ||
import py_nillion_client as nillion | ||
import os | ||
import sys | ||
import time | ||
import numpy as np | ||
import nada_algebra as na | ||
import pandas as pd | ||
from nada_ai.client import ProphetClient | ||
from prophet import Prophet | ||
from dotenv import load_dotenv | ||
|
||
# Add the parent directory to the system path to import modules from it | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
||
# Import helper functions for creating nillion client and getting keys | ||
from neural_net.network.helpers.nillion_client_helper import create_nillion_client | ||
from neural_net.network.helpers.nillion_keypath_helper import ( | ||
getUserKeyFromFile, | ||
getNodeKeyFromFile, | ||
) | ||
import nada_algebra.client as na_client | ||
|
||
# Load environment variables from a .env file | ||
load_dotenv() | ||
|
||
|
||
# Decorator function to measure and log the execution time of asynchronous functions | ||
def async_timer(file_path): | ||
def decorator(func): | ||
async def wrapper(*args, **kwargs): | ||
start_time = time.time() | ||
result = await func(*args, **kwargs) | ||
end_time = time.time() | ||
elapsed_time = end_time - start_time | ||
|
||
# Log the execution time to a file | ||
with open(file_path, "a") as file: | ||
file.write(f"{elapsed_time:.6f},\n") | ||
return result | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
# Asynchronous function to store a program on the nillion client | ||
@async_timer("bench/store_program.txt") | ||
async def store_program(client, user_id, cluster_id, program_name, program_mir_path): | ||
action_id = await client.store_program(cluster_id, program_name, program_mir_path) | ||
program_id = f"{user_id}/{program_name}" | ||
print("Stored program. action_id:", action_id) | ||
print("Stored program_id:", program_id) | ||
return program_id | ||
|
||
|
||
# Asynchronous function to store secrets on the nillion client | ||
@async_timer("bench/store_secrets.txt") | ||
async def store_secrets(client, cluster_id, program_id, party_id, party_name, secrets): | ||
secret_bindings = nillion.ProgramBindings(program_id) | ||
secret_bindings.add_input_party(party_name, party_id) | ||
|
||
# Store the secret for the specified party | ||
store_id = await client.store_secrets(cluster_id, secret_bindings, secrets, None) | ||
return store_id | ||
|
||
|
||
# Asynchronous function to perform computation on the nillion client | ||
@async_timer("bench/compute.txt") | ||
async def compute( | ||
client, cluster_id, compute_bindings, store_ids, computation_time_secrets | ||
): | ||
compute_id = await client.compute( | ||
cluster_id, | ||
compute_bindings, | ||
store_ids, | ||
computation_time_secrets, | ||
nillion.PublicVariables({}), | ||
) | ||
|
||
# Monitor and print the computation result | ||
print(f"The computation was sent to the network. compute_id: {compute_id}") | ||
while True: | ||
compute_event = await client.next_compute_event() | ||
if isinstance(compute_event, nillion.ComputeFinishedEvent): | ||
print(f"✅ Compute complete for compute_id {compute_event.uuid}") | ||
return compute_event.result.value | ||
|
||
|
||
# Main asynchronous function to coordinate the process | ||
async def main(): | ||
cluster_id = os.getenv("NILLION_CLUSTER_ID") | ||
userkey = getUserKeyFromFile(os.getenv("NILLION_USERKEY_PATH_PARTY_1")) | ||
nodekey = getNodeKeyFromFile(os.getenv("NILLION_NODEKEY_PATH_PARTY_1")) | ||
client = create_nillion_client(userkey, nodekey) | ||
party_id = client.party_id | ||
user_id = client.user_id | ||
party_names = na_client.parties(2) | ||
program_name = "main" | ||
program_mir_path = f"./target/{program_name}.nada.bin" | ||
|
||
if not os.path.exists("bench"): | ||
os.mkdir("bench") | ||
|
||
na.set_log_scale(50) | ||
|
||
# Store the program | ||
program_id = await store_program( | ||
client, user_id, cluster_id, program_name, program_mir_path | ||
) | ||
|
||
# Train prophet model | ||
model = Prophet() | ||
|
||
ds = pd.date_range("2024-05-01", "2024-05-17").tolist() | ||
y = np.arange(1, 18).tolist() | ||
|
||
fit_model = model.fit(df=pd.DataFrame({"ds": ds, "y": y})) | ||
|
||
print("Model params are:", fit_model.params) | ||
print("Number of detected changepoints:", fit_model.n_changepoints) | ||
|
||
# Create and store model secrets via ModelClient | ||
model_client = ProphetClient(fit_model) | ||
model_secrets = nillion.Secrets( | ||
model_client.export_state_as_secrets("my_prophet", na.SecretRational) | ||
) | ||
|
||
model_store_id = await store_secrets( | ||
client, cluster_id, program_id, party_id, party_names[0], model_secrets | ||
) | ||
|
||
# Store inputs to perform inference for | ||
future_df = fit_model.make_future_dataframe(periods=3) | ||
inference_ds = fit_model.setup_dataframe(future_df.copy()) | ||
|
||
my_input = {} | ||
my_input.update( | ||
na_client.array(inference_ds["floor"].to_numpy(), "floor", na.SecretRational) | ||
) | ||
my_input.update( | ||
na_client.array(inference_ds["t"].to_numpy(), "t", na.SecretRational) | ||
) | ||
|
||
input_secrets = nillion.Secrets(my_input) | ||
|
||
data_store_id = await store_secrets( | ||
client, cluster_id, program_id, party_id, party_names[1], input_secrets | ||
) | ||
|
||
# Set up the compute bindings for the parties | ||
compute_bindings = nillion.ProgramBindings(program_id) | ||
[ | ||
compute_bindings.add_input_party(party_name, party_id) | ||
for party_name in party_names | ||
] | ||
compute_bindings.add_output_party(party_names[1], party_id) | ||
|
||
print(f"Computing using program {program_id}") | ||
print(f"Use secret store_id: {model_store_id} {data_store_id}") | ||
|
||
# Perform the computation and return the result | ||
result = await compute( | ||
client, | ||
cluster_id, | ||
compute_bindings, | ||
[model_store_id, data_store_id], | ||
nillion.Secrets({}), | ||
) | ||
|
||
# Sort & rescale the obtained results by the quantization scale | ||
outputs = [ | ||
na_client.float_from_rational(result[1]) | ||
for result in sorted( | ||
result.items(), | ||
key=lambda x: int(x[0].replace("my_output", "").replace("_", "")), | ||
) | ||
] | ||
|
||
print(f"🖥️ The result is {outputs}") | ||
|
||
expected = fit_model.predict(inference_ds)["yhat"].to_numpy() | ||
print(f"🖥️ VS expected plain-text result {expected}") | ||
return result | ||
|
||
|
||
# Run the main function if the script is executed directly | ||
if __name__ == "__main__": | ||
asyncio.run(main()) |
12 changes: 12 additions & 0 deletions
12
examples/time_series/network/helpers/nillion_client_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
from helpers.nillion_payments_helper import create_payments_config | ||
|
||
|
||
def create_nillion_client(userkey, nodekey): | ||
bootnodes = [os.getenv("NILLION_BOOTNODE_MULTIADDRESS")] | ||
payments_config = create_payments_config() | ||
|
||
return nillion.NillionClient( | ||
nodekey, bootnodes, nillion.ConnectionMode.relay(), userkey, payments_config | ||
) |
10 changes: 10 additions & 0 deletions
10
examples/time_series/network/helpers/nillion_keypath_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
|
||
|
||
def getUserKeyFromFile(userkey_filepath): | ||
return nillion.UserKey.from_file(userkey_filepath) | ||
|
||
|
||
def getNodeKeyFromFile(nodekey_filepath): | ||
return nillion.NodeKey.from_file(nodekey_filepath) |
12 changes: 12 additions & 0 deletions
12
examples/time_series/network/helpers/nillion_payments_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
|
||
|
||
def create_payments_config(): | ||
return nillion.PaymentsConfig( | ||
os.getenv("NILLION_BLOCKCHAIN_RPC_ENDPOINT"), | ||
os.getenv("NILLION_WALLET_PRIVATE_KEY"), | ||
int(os.getenv("NILLION_CHAIN_ID")), | ||
os.getenv("NILLION_PAYMENTS_SC_ADDRESS"), | ||
os.getenv("NILLION_BLINDING_FACTORS_MANAGER_SC_ADDRESS"), | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcabrero this is indeed stupidly high - reason is that some prophet parameters are extremely small (ie 1e-13) which means that they get rounded to zero - which is not yet supported.
instead of going into
export_state
and patching this (which will be unnecessary once zero secrets are allowed), I opted to put this here temporarily