Skip to content

Commit ca3d42f

Browse files
authored
Merge pull request Huelse#20 from CarmenLee111/master
add pickle support in wrapper (#1)
2 parents e837b8f + a2ad817 commit ca3d42f

File tree

3 files changed

+181
-8
lines changed

3 files changed

+181
-8
lines changed

src/base64.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
base64.cpp and base64.h
3+
base64 encoding and decoding with C++.
4+
Version: 1.01.00
5+
Copyright (C) 2004-2017 René Nyffenegger
6+
This source code is provided 'as-is', without any express or implied
7+
warranty. In no event will the author be held liable for any damages
8+
arising from the use of this software.
9+
Permission is granted to anyone to use this software for any purpose,
10+
including commercial applications, and to alter it and redistribute it
11+
freely, subject to the following restrictions:
12+
1. The origin of this source code must not be misrepresented; you must not
13+
claim that you wrote the original source code. If you use this source code
14+
in a product, an acknowledgment in the product documentation would be
15+
appreciated but is not required.
16+
2. Altered source versions must be plainly marked as such, and must not be
17+
misrepresented as being the original source code.
18+
3. This notice may not be removed or altered from any source distribution.
19+
René Nyffenegger [email protected]
20+
*/
21+
22+
#include "base64.h"
23+
#include <iostream>
24+
25+
static const std::string base64_chars =
26+
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
27+
"abcdefghijklmnopqrstuvwxyz"
28+
"0123456789+/";
29+
30+
31+
static inline bool is_base64(unsigned char c) {
32+
return (isalnum(c) || (c == '+') || (c == '/'));
33+
}
34+
35+
std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) {
36+
std::string ret;
37+
int i = 0;
38+
int j = 0;
39+
unsigned char char_array_3[3];
40+
unsigned char char_array_4[4];
41+
42+
while (in_len--) {
43+
char_array_3[i++] = *(bytes_to_encode++);
44+
if (i == 3) {
45+
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
46+
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
47+
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
48+
char_array_4[3] = char_array_3[2] & 0x3f;
49+
50+
for(i = 0; (i <4) ; i++)
51+
ret += base64_chars[char_array_4[i]];
52+
i = 0;
53+
}
54+
}
55+
56+
if (i)
57+
{
58+
for(j = i; j < 3; j++)
59+
char_array_3[j] = '\0';
60+
61+
char_array_4[0] = ( char_array_3[0] & 0xfc) >> 2;
62+
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
63+
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
64+
65+
for (j = 0; (j < i + 1); j++)
66+
ret += base64_chars[char_array_4[j]];
67+
68+
while((i++ < 3))
69+
ret += '=';
70+
71+
}
72+
73+
return ret;
74+
75+
}
76+
77+
std::string base64_decode(std::string const& encoded_string) {
78+
int in_len = encoded_string.size();
79+
int i = 0;
80+
int j = 0;
81+
int in_ = 0;
82+
unsigned char char_array_4[4], char_array_3[3];
83+
std::string ret;
84+
85+
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
86+
char_array_4[i++] = encoded_string[in_]; in_++;
87+
if (i ==4) {
88+
for (i = 0; i <4; i++)
89+
char_array_4[i] = base64_chars.find(char_array_4[i]);
90+
91+
char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4);
92+
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
93+
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
94+
95+
for (i = 0; (i < 3); i++)
96+
ret += char_array_3[i];
97+
i = 0;
98+
}
99+
}
100+
101+
if (i) {
102+
for (j = 0; j < i; j++)
103+
char_array_4[j] = base64_chars.find(char_array_4[j]);
104+
105+
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
106+
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
107+
108+
for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
109+
}
110+
111+
return ret;
112+
}

src/base64.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//
2+
// base64 encoding and decoding with C++.
3+
// Version: 1.01.00
4+
//
5+
6+
#ifndef BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A
7+
#define BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A
8+
9+
#include <string>
10+
11+
std::string base64_encode(unsigned char const* , unsigned int len);
12+
std::string base64_decode(std::string const& s);
13+
14+
#endif /* BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A */

src/wrapper.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <pybind11/stl.h>
66
#include "seal/seal.h"
77
#include <fstream>
8+
#include "base64.h"
89

910
using namespace std;
1011
using namespace seal;
@@ -18,6 +19,44 @@ PYBIND11_MAKE_OPAQUE(std::vector<std::int64_t>);
1819

1920
using parms_id_type = std::array<std::uint64_t, 4>;
2021

22+
template <class T>
23+
py::tuple serialize(T &c)
24+
{
25+
std::stringstream output(std::ios::binary | std::ios::out);
26+
c.save(output);
27+
std::string cipherstr = output.str();
28+
std::string base64_encoded_cipher = base64_encode(reinterpret_cast<const unsigned char *>(cipherstr.c_str()), cipherstr.length());
29+
return py::make_tuple(base64_encoded_cipher);
30+
}
31+
32+
template <class T>
33+
T deserialize(py::tuple t)
34+
{
35+
if (t.size() != 1)
36+
throw std::runtime_error("(Pickle) Invalid input tuple!");
37+
T c = T();
38+
std::string cipherstr_encoded = t[0].cast<std::string>();
39+
std::string cipherstr_decoded = base64_decode(cipherstr_encoded);
40+
std::stringstream input(std::ios::binary | std::ios::in);
41+
input.str(cipherstr_decoded);
42+
c.load(input);
43+
return c;
44+
}
45+
46+
template <class T>
47+
T deserialize_context(py::tuple t)
48+
{
49+
if (t.size() != 2)
50+
throw std::runtime_error("(Pickle) Invalid input tuple!");
51+
T c = T();
52+
std::string cipherstr_encoded = t[1].cast<std::string>();
53+
std::string cipherstr_decoded = base64_decode(cipherstr_encoded);
54+
std::stringstream input(std::ios::binary | std::ios::in);
55+
input.str(cipherstr_decoded);
56+
c.load(t[0].cast<std::shared_ptr<SEALContext>>(), input);
57+
return c;
58+
}
59+
2160
PYBIND11_MODULE(seal, m)
2261
{
2362
m.doc() = "Microsoft SEAL (3.4.5) For Python. From https://github.com/Huelse/SEAL-Python";
@@ -61,7 +100,8 @@ PYBIND11_MODULE(seal, m)
61100
std::ifstream in(path, std::ifstream::binary);
62101
p.load(in);
63102
in.close();
64-
});
103+
})
104+
.def(py::pickle(&serialize<EncryptionParameters>, &deserialize<EncryptionParameters>));
65105

66106
// context.h
67107
py::class_<EncryptionParameterQualifiers, std::unique_ptr<EncryptionParameterQualifiers, py::nodelete>>(m, "EncryptionParameterQualifiers")
@@ -147,7 +187,8 @@ PYBIND11_MODULE(seal, m)
147187
std::ifstream in(path, std::ifstream::binary);
148188
c.load(context, in);
149189
in.close();
150-
});
190+
})
191+
.def(py::pickle(&serialize<Plaintext>, &deserialize_context<Plaintext>));
151192

152193
// ciphertext.h
153194
py::class_<Ciphertext>(m, "Ciphertext")
@@ -177,7 +218,8 @@ PYBIND11_MODULE(seal, m)
177218
std::ifstream in(path, std::ifstream::binary);
178219
c.load(context, in);
179220
in.close();
180-
});
221+
})
222+
.def(py::pickle(&serialize<Ciphertext>, &deserialize_context<Ciphertext>));
181223

182224
// secretkey.h
183225
py::class_<SecretKey>(m, "SecretKey")
@@ -192,7 +234,8 @@ PYBIND11_MODULE(seal, m)
192234
std::ifstream in(path, std::ifstream::binary);
193235
c.load(context, in);
194236
in.close();
195-
});
237+
})
238+
.def(py::pickle(&serialize<SecretKey>, &deserialize_context<SecretKey>));
196239

197240
// publickey.h
198241
py::class_<PublicKey>(m, "PublicKey")
@@ -207,7 +250,8 @@ PYBIND11_MODULE(seal, m)
207250
std::ifstream in(path, std::ifstream::binary);
208251
c.load(context, in);
209252
in.close();
210-
});
253+
})
254+
.def(py::pickle(&serialize<PublicKey>, &deserialize_context<PublicKey>));
211255

212256
// kswitchkeys.h
213257
py::class_<KSwitchKeys>(m, "KSwitchKeys")
@@ -222,7 +266,8 @@ PYBIND11_MODULE(seal, m)
222266
std::ifstream in(path, std::ifstream::binary);
223267
c.load(context, in);
224268
in.close();
225-
});
269+
})
270+
.def(py::pickle(&serialize<KSwitchKeys>, &deserialize_context<KSwitchKeys>));
226271

227272
// relinKeys.h
228273
py::class_<RelinKeys, KSwitchKeys>(m, "RelinKeys")
@@ -237,7 +282,8 @@ PYBIND11_MODULE(seal, m)
237282
std::ifstream in(path, std::ifstream::binary);
238283
c.load(context, in);
239284
in.close();
240-
});
285+
})
286+
.def(py::pickle(&serialize<RelinKeys>, &deserialize_context<RelinKeys>));
241287

242288
// galoisKeys.h
243289
py::class_<GaloisKeys, KSwitchKeys>(m, "GaloisKeys")
@@ -252,7 +298,8 @@ PYBIND11_MODULE(seal, m)
252298
std::ifstream in(path, std::ifstream::binary);
253299
c.load(context, in);
254300
in.close();
255-
});
301+
})
302+
.def(py::pickle(&serialize<GaloisKeys>, &deserialize_context<GaloisKeys>));
256303

257304
// keygenerator.h
258305
py::class_<KeyGenerator>(m, "KeyGenerator")

0 commit comments

Comments
 (0)