diff --git a/nada-project.toml b/nada-project.toml index 62069d2..cc3ae99 100644 --- a/nada-project.toml +++ b/nada-project.toml @@ -237,4 +237,9 @@ prime_size = 128 [[programs]] path = "src/shuffle.py" name = "shuffle" -prime_size = 128 \ No newline at end of file +prime_size = 128 + +[[programs]] +path = "src/shuffle_simple.py" +name = "shuffle_simple" +prime_size = 128 diff --git a/src/shuffle_simple.py b/src/shuffle_simple.py new file mode 100644 index 0000000..2c786d0 --- /dev/null +++ b/src/shuffle_simple.py @@ -0,0 +1,21 @@ +from nada_dsl import SecretInteger + +import nada_numpy as na +from nada_numpy import shuffle + + +def nada_main(): + + # Note: + # The current shuffle operation only supports vectors with + # a power-of-two size, e.g., 2, 4, 8, 16, 32, ... + size=4 + + parties = na.parties(2) + nums = na.array([size], parties[0], "num", SecretInteger) + + shuffled_nums = shuffle(nums) + + return ( + na.output(shuffled_nums, parties[1], "shuffled_num") + ) diff --git a/tests/shuffle_simple_test.py b/tests/shuffle_simple_test.py new file mode 100644 index 0000000..ec9713e --- /dev/null +++ b/tests/shuffle_simple_test.py @@ -0,0 +1,45 @@ +from nada_test import nada_test + +inputs = {"num_0": 10, "num_1": 20, "num_2": 30, "num_3": 40} + +# Test that the shuffled array contains the same values as the input, regardless of order +@nada_test(program="shuffle_simple") +def shuffle_simple_test_same_values(): + outputs = yield inputs + + shuffled_nums = [ + outputs["shuffled_num_0"], + outputs["shuffled_num_1"], + outputs["shuffled_num_2"], + outputs["shuffled_num_3"] + ] + + # Assert that the sorted output contains the same values as the sorted input + assert sorted(shuffled_nums) == sorted([ + inputs["num_0"], + inputs["num_1"], + inputs["num_2"], + inputs["num_3"] + ]), "Test failed: the shuffled array contains different values." + +# Test that the resulting shuffled array is not in the same order as the input +@nada_test(program="shuffle_simple") +def shuffle_simple_test_not_same_order(): + outputs = yield inputs + + original_nums = [ + inputs["num_0"], + inputs["num_1"], + inputs["num_2"], + inputs["num_3"] + ] + + shuffled_nums = [ + outputs["shuffled_num_0"], + outputs["shuffled_num_1"], + outputs["shuffled_num_2"], + outputs["shuffled_num_3"] + ] + + # Assert that the shuffled numbers are NOT in the same order as the input + assert shuffled_nums != original_nums, "Test failed: the order did not change" diff --git a/tests/shuffle_simple_test.yaml b/tests/shuffle_simple_test.yaml new file mode 100644 index 0000000..214f43b --- /dev/null +++ b/tests/shuffle_simple_test.yaml @@ -0,0 +1,12 @@ +--- +program: shuffle_simple +inputs: + num_0: 3 + num_1: 3 + num_2: 3 + num_3: 3 +expected_outputs: + shuffled_num_0: 3 + shuffled_num_1: 3 + shuffled_num_2: 3 + shuffled_num_3: 3