From 12e678a9ce547c77f4a0efaa184e13f17da41cdc Mon Sep 17 00:00:00 2001 From: Berry Schoenmakers Date: Sun, 25 Feb 2024 12:53:59 +0100 Subject: [PATCH] PRSS details from asyncoro to runtime. --- mpyc/__init__.py | 2 +- mpyc/asyncoro.py | 24 +++++------------------- mpyc/mpctools.py | 7 ++++--- mpyc/runtime.py | 25 +++++++++++++++++++++++++ 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/mpyc/__init__.py b/mpyc/__init__.py index f8f421c..06ce6c1 100644 --- a/mpyc/__init__.py +++ b/mpyc/__init__.py @@ -29,7 +29,7 @@ and statistics (securely mimicking Python’s statistics module). """ -__version__ = '0.9.8' +__version__ = '0.9.9' __license__ = 'MIT License' import os diff --git a/mpyc/asyncoro.py b/mpyc/asyncoro.py index 9bfee6e..59b0f56 100644 --- a/mpyc/asyncoro.py +++ b/mpyc/asyncoro.py @@ -5,7 +5,6 @@ import sys import traceback from struct import pack, unpack_from -import itertools import functools import typing from asyncio import Protocol, Future, Task @@ -54,11 +53,7 @@ def connection_made(self, transport): rt = self.runtime pid_keys = [rt.pid.to_bytes(2, 'little')] # send pid if not rt.options.no_prss: - m = len(rt.parties) - t = rt.threshold - for subset in itertools.combinations(range(m), m - t): - if subset[0] == rt.pid and self.peer_pid in subset: - pid_keys.append(rt._prss_keys[subset]) # send PRSS keys + pid_keys.extend(rt._prss_keys_to_peer(self.peer_pid)) # send PRSS keys transport.writelines(pid_keys) self._key_transport_done() @@ -89,27 +84,18 @@ def data_received(self, data): return peer_pid = int.from_bytes(data[:2], 'little') - len_packet = 2 + del data[:2] rt = self.runtime if not rt.options.no_prss: - m = len(rt.parties) - t = rt.threshold - for subset in itertools.combinations(range(m), m - t): - if subset[0] == peer_pid and rt.pid in subset: - len_packet += 16 + len_packet = rt._prss_keys_from_peer(peer_pid) if len(data) < len_packet: return # record new protocol peer self.peer_pid = peer_pid if not rt.options.no_prss: - # store keys received from peer - len_packet = 2 - for subset in itertools.combinations(range(m), m - t): - if subset[0] == peer_pid and rt.pid in subset: - rt._prss_keys[subset] = data[len_packet:len_packet + 16] - len_packet += 16 - del data[:len_packet] + rt._prss_keys_from_peer(peer_pid, data) # store PRSS keys from peer + del data[:len_packet] self._key_transport_done() while len(data) >= 12: diff --git a/mpyc/mpctools.py b/mpyc/mpctools.py index 870e561..e8b4f93 100644 --- a/mpyc/mpctools.py +++ b/mpyc/mpctools.py @@ -12,7 +12,8 @@ runtime = None -_no_value = type('', (object,), {'__repr__': lambda self: ''}) +_no_value = type('mpyc.mpctools.NoValueType', (object,), {'__repr__': lambda self: ''})() +_no_value.__doc__ = 'Represents "empty" value, different from any other object including None.' def reduce(f, x, initial=_no_value): @@ -27,8 +28,8 @@ def reduce(f, x, initial=_no_value): may even be of different types. If initial is provided (possibly equal to None), it is placed before the - items of x (hence effectively serves as a default when x is empty). If - initial is not given and x contains only one item, that item is returned. + items of x (hence effectively serves as a default when x is empty). If no + initial value is given and x contains only one item, that item is returned. """ x = list(x) if initial is not _no_value: diff --git a/mpyc/runtime.py b/mpyc/runtime.py index 12568e8..1169795 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -108,6 +108,31 @@ def threshold(self, t): keys[subset] = secrets.token_bytes(16) # 128-bit key self._prss_keys = keys + def _prss_keys_to_peer(self, peer_pid): + """Return PRSS keys to be sent to peer.""" + m = len(self.parties) + t = self.threshold + keys = [] + for subset in itertools.combinations(range(m), m - t): + if subset[0] == self.pid and peer_pid in subset: + keys.append(self._prss_keys[subset]) + return keys + + def _prss_keys_from_peer(self, peer_pid, data=None): + """Store PRSS keys received from peer. + + If data is not given, return the size of PRSS keys to be stored. + """ + m = len(self.parties) + t = self.threshold + len_packet = 0 + for subset in itertools.combinations(range(m), m - t): + if subset[0] == peer_pid and self.pid in subset: + if data is not None: + self._prss_keys[subset] = data[len_packet:len_packet + 16] + len_packet += 16 + return len_packet + @functools.cache def prfs(self, bound): """PRFs with codomain range(bound) for pseudorandom secret sharing.