Skip to content

Commit

Permalink
Merge pull request #20 from quantopian/encryption
Browse files Browse the repository at this point in the history
ENH: Add encryption support.
  • Loading branch information
Scott Sanderson authored Jun 23, 2016
2 parents 01056b4 + 7b8bad8 commit 8e5f713
Show file tree
Hide file tree
Showing 18 changed files with 866 additions and 330 deletions.
36 changes: 0 additions & 36 deletions bin/pgcontents
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
from getpass import getuser
from os import getcwd
from os.path import join
import subprocess
from textwrap import dedent

Expand All @@ -16,10 +15,6 @@ from pgcontents.utils.migrate import (
temp_alembic_ini,
upgrade,
)
from pgcontents.utils.sync import (
checkpoint_all,
download_checkpoints,
)


@click.group(context_settings=dict(help_option_names=['-h', '--help']))
Expand Down Expand Up @@ -114,36 +109,5 @@ def gen_migration(db_url):
)


@main.command('download_checkpoints')
@_db_url
@_directory
@_users
def _download_checkpoints(db_url, directory, users):
"""
Download checkpoints to a directory.
"""
users = users.split(',')
if len(users) == 1:
download_checkpoints(db_url, directory, users[0])
else:
for user in users:
download_checkpoints(db_url, join(directory, user), user)


@main.command('checkpoint_all')
@_db_url
@_directory
@_users
def _checkpoint_all(db_url, directory, users):
"""
Upload a checkpoint for every file in a directory.
"""
users = users.split(',')
if len(users) == 1:
checkpoint_all(db_url, directory, users[0])
else:
for user in users:
checkpoint_all(db_url, join(directory, user), user)

if __name__ == "__main__":
main()
19 changes: 16 additions & 3 deletions pgcontents/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import posixpath

from tornado.web import HTTPError
from .error import PathOutsideRoot
from .error import CorruptedFile, PathOutsideRoot
from .utils.ipycompat import reads, writes

NBFORMAT_VERSION = 4
Expand Down Expand Up @@ -117,7 +117,10 @@ def reads_base64(nb, as_version=NBFORMAT_VERSION):
"""
Read a notebook from base64.
"""
return reads(b64decode(nb).decode('utf-8'), as_version=as_version)
try:
return reads(b64decode(nb).decode('utf-8'), as_version=as_version)
except Exception as e:
raise CorruptedFile(e)


def _decode_text_from_base64(path, bcontent):
Expand Down Expand Up @@ -161,7 +164,17 @@ def from_b64(path, bcontent, format):
'text': _decode_text_from_base64,
None: _decode_unknown_from_base64,
}
content, real_format = decoders[format](path, bcontent)

try:
content, real_format = decoders[format](path, bcontent)
except HTTPError:
# Pass through HTTPErrors, since we intend for them to bubble all the
# way back to the API layer.
raise
except Exception as e:
# Anything else should be wrapped in a CorruptedFile, since it likely
# indicates misconfiguration of encryption.
raise CorruptedFile(e)

default_mimes = {
'text': 'text/plain',
Expand Down
54 changes: 20 additions & 34 deletions pgcontents/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .api_utils import (
_decode_unknown_from_base64,
outside_root_to_404,
prefix_dirs,
reads_base64,
to_b64,
writes_base64,
Expand All @@ -16,7 +15,6 @@
delete_remote_checkpoints,
delete_single_remote_checkpoint,
get_remote_checkpoint,
latest_remote_checkpoints,
list_remote_checkpoints,
move_remote_checkpoints,
purge_remote_checkpoints,
Expand All @@ -40,7 +38,14 @@ def create_notebook_checkpoint(self, nb, path):
"""
b64_content = writes_base64(nb)
with self.engine.begin() as db:
return save_remote_checkpoint(db, self.user_id, path, b64_content)
return save_remote_checkpoint(
db,
self.user_id,
path,
b64_content,
self.crypto.encrypt,
self.max_file_size_bytes,
)

@outside_root_to_404
def create_file_checkpoint(self, content, format, path):
Expand All @@ -53,7 +58,14 @@ def create_file_checkpoint(self, content, format, path):
except ValueError as e:
self.do_400(str(e))
with self.engine.begin() as db:
return save_remote_checkpoint(db, self.user_id, path, b64_content)
return save_remote_checkpoint(
db,
self.user_id,
path,
b64_content,
self.crypto.encrypt,
self.max_file_size_bytes,
)

@outside_root_to_404
def delete_checkpoint(self, checkpoint_id, path):
Expand All @@ -63,27 +75,28 @@ def delete_checkpoint(self, checkpoint_id, path):
db, self.user_id, path, checkpoint_id,
)

def _get_checkpoint(self, checkpoint_id, path):
def get_checkpoint_content(self, checkpoint_id, path):
"""Get the content of a checkpoint."""
with self.engine.begin() as db:
return get_remote_checkpoint(
db,
self.user_id,
path,
checkpoint_id,
self.crypto.decrypt,
)['content']

@outside_root_to_404
def get_notebook_checkpoint(self, checkpoint_id, path):
b64_content = self._get_checkpoint(checkpoint_id, path)
b64_content = self.get_checkpoint_content(checkpoint_id, path)
return {
'type': 'notebook',
'content': reads_base64(b64_content),
}

@outside_root_to_404
def get_file_checkpoint(self, checkpoint_id, path):
b64_content = self._get_checkpoint(checkpoint_id, path)
b64_content = self.get_checkpoint_content(checkpoint_id, path)
content, format = _decode_unknown_from_base64(path, b64_content)
return {
'type': 'file',
Expand Down Expand Up @@ -120,30 +133,3 @@ def purge_db(self):
"""
with self.engine.begin() as db:
purge_remote_checkpoints(db, self.user_id)

def dump(self, contents_mgr):
"""
Synchronize the state of our database with the specified
ContentsManager.
Gets the most recent checkpoint for each file and passes it to the
supplied ContentsManager to be saved.
"""
with self.engine.begin() as db:
records = latest_remote_checkpoints(db, self.user_id)
for record in records:
path = record['path']
if not path.endswith('.ipynb'):
self.log.warn('Ignoring non-notebook file: {}', path)
continue
for dirname in prefix_dirs(path):
self.log.info("Ensuring directory [%s]" % dirname)
contents_mgr.save(
model={'type': 'directory'},
path=dirname,
)
self.log.info("Writing notebook [%s]" % path)
contents_mgr.save(
self.get_notebook_checkpoint(record['id'], path),
path,
)
129 changes: 129 additions & 0 deletions pgcontents/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Interface definition for encryption/decryption plugins for
PostgresContentsManager.
Encryption backends should raise pgcontents.error.CorruptedFile if they
encounter an input that they cannot decrypt.
"""
from .error import CorruptedFile


class NoEncryption(object):
"""
No-op encryption backend.
encrypt() and decrypt() simply return their inputs.
Methods
-------
encrypt : callable[bytes -> bytes]
decrypt : callable[bytes -> bytes]
"""
def encrypt(self, b):
return b

def decrypt(self, b):
return b


class FernetEncryption(object):
"""
Notebook encryption using cryptography.fernet for symmetric-key encryption.
Parameters
----------
fernet : cryptography.fernet.Fernet
The Fernet object to use for encryption.
Methods
-------
encrypt : callable[bytes -> bytes]
decrypt : callable[bytes -> bytes]
Notes
-----
``cryptography.fernet.MultiFernet`` can be used instead of a vanilla
``Fernet`` to allow zero-downtime key rotation.
See Also
--------
:func:`pgcontents.utils.sync.reencrypt_user`
"""
__slots__ = ('_fernet',)

def __init__(self, fernet):
self._fernet = fernet

def encrypt(self, s):
return self._fernet.encrypt(s)

def decrypt(self, s):
try:
return self._fernet.decrypt(s)
except Exception as e:
raise CorruptedFile(e)

def __copy__(self, memo):
# Any value that appears in an IPython/Jupyter Config object needs to
# be deepcopy-able. Cryptography's Fernet objects aren't deepcopy-able,
# so we copy our underlying state to a new FernetEncryption object.
return FernetEncryption(self._fernet)

def __deepcopy__(self, memo):
# Any value that appears in an IPython/Jupyter Config object needs to
# be deepcopy-able. Cryptography's Fernet objects aren't deepcopy-able,
# so we copy our underlying state to a new FernetEncryption object.
return FernetEncryption(self._fernet)


class FallbackCrypto(object):
"""
Notebook encryption that accepts a list of crypto instances and decrypts by
trying them in order.
Sub-cryptos should raise ``CorruptedFile`` if they're unable to decrypt an
input.
This is conceptually similar to the technique used by
``cryptography.fernet.MultiFernet`` for implementing key rotation.
Parameters
----------
cryptos : list[object]
A sequence of cryptos to use for decryption. cryptos[0] will always be
used for encryption.
Methods
-------
encrypt : callable[bytes -> bytes]
decrypt : callable[bytes -> bytes]
Notes
-----
Since NoEncryption will always succeed, it is only supported as the last
entry in ``cryptos``. Passing a list with a NoEncryption not in the last
location will raise a ValueError.
"""
__slots__ = ('_cryptos',)

def __init__(self, cryptos):
# Only the last crypto can be a ``NoEncryption``.
for c in cryptos[:-1]:
if isinstance(c, NoEncryption):
raise ValueError(
"NoEncryption is only supported as the last fallback."
)

self._cryptos = cryptos

def encrypt(self, s):
return self._cryptos[0].encrypt(s)

def decrypt(self, s):
errors = []
for c in self._cryptos:
try:
return c.decrypt(s)
except CorruptedFile as e:
errors.append(e)
raise CorruptedFile(errors)
37 changes: 30 additions & 7 deletions pgcontents/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from contextlib import contextmanager
from six.moves import zip
from six.moves import map, zip

from psycopg2.errorcodes import (
FOREIGN_KEY_VIOLATION,
Expand Down Expand Up @@ -65,14 +65,37 @@ def _get_name(column_like):
return column_like.clause.name


def to_dict(fields, row):
def to_dict_no_content(fields, row):
"""
Convert a SQLAlchemy row to a dict.
Convert a SQLAlchemy row that does not contain a 'content' field to a dict.
If row is None, return None.
Raises AssertionError if there is a field named 'content' in ``fields``.
"""
assert(len(fields) == len(row))

field_names = list(map(_get_name, fields))
assert 'content' not in field_names, "Unexpected content field."

return dict(zip(field_names, row))


def to_dict_with_content(fields, row, decrypt_func):
"""
Convert a SQLAlchemy row that contains a 'content' field to a dict.
``decrypt_func`` will be applied to the ``content`` field of the row.
If row is None, return None.
Raises AssertionError if there is no field named 'content' in ``fields``.
"""
assert(len(fields) == len(row))
return {
_get_name(field): value
for field, value in zip(fields, row)
}

field_names = list(map(_get_name, fields))
assert 'content' in field_names, "Missing content field."

result = dict(zip(field_names, row))
result['content'] = decrypt_func(result['content'])
return result
4 changes: 4 additions & 0 deletions pgcontents/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ class FileTooLarge(Exception):

class RenameRoot(Exception):
pass


class CorruptedFile(Exception):
pass
Loading

0 comments on commit 8e5f713

Please sign in to comment.