From 1f35d80353a80d66d522b4a65c518670c056dfb3 Mon Sep 17 00:00:00 2001 From: TobeTek Date: Fri, 20 Jan 2023 23:07:19 +0100 Subject: [PATCH 1/7] Update gitignore --- .gitignore | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 5e68aad..6e380c1 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ __pycache__/ # Distribution / packaging .Python -env/ bin/ build/ develop-eggs/ @@ -23,9 +22,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -40,13 +42,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo @@ -54,13 +60,104 @@ coverage.xml # Django stuff: *.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy # Sphinx documentation docs/_build/ # PyBuilder +.pybuilder/ target/ -#Ipython Notebook +# Jupyter Notebook .ipynb_checkpoints -*.swp + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file From 12ce06f5072ea95cc76a692a0dbc8f5a22d98284 Mon Sep 17 00:00:00 2001 From: TobeTek Date: Fri, 20 Jan 2023 23:09:08 +0100 Subject: [PATCH 2/7] Run black --- pywebpush/__init__.py | 213 +++++++++++++----------- pywebpush/__main__.py | 24 ++- pywebpush/tests/test_webpush.py | 283 ++++++++++++++++---------------- setup.py | 26 +-- 4 files changed, 285 insertions(+), 261 deletions(-) diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index e3c3137..c60beff 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -110,16 +110,16 @@ class WebPusher: WebPusher(subscription_info).send(data, headers) """ + subscription_info = {} valid_encodings = [ # "aesgcm128", # this is draft-0, but DO NOT USE. "aesgcm", # draft-httpbis-encryption-encoding-01 - "aes128gcm" # RFC8188 Standard encoding + "aes128gcm", # RFC8188 Standard encoding ] verbose = False - def __init__(self, subscription_info, requests_session=None, - verbose=False): + def __init__(self, subscription_info, requests_session=None, verbose=False): """Initialize using the info provided by the client PushSubscription object (See https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe) @@ -143,24 +143,22 @@ def __init__(self, subscription_info, requests_session=None, else: self.requests_method = requests_session - if 'endpoint' not in subscription_info: + if "endpoint" not in subscription_info: raise WebPushException("subscription_info missing endpoint URL") self.subscription_info = deepcopy(subscription_info) self.auth_key = self.receiver_key = None - if 'keys' in subscription_info: - keys = self.subscription_info['keys'] - for k in ['p256dh', 'auth']: + if "keys" in subscription_info: + keys = self.subscription_info["keys"] + for k in ["p256dh", "auth"]: if keys.get(k) is None: raise WebPushException("Missing keys value: {}".format(k)) if isinstance(keys[k], six.text_type): - keys[k] = bytes(keys[k].encode('utf8')) - receiver_raw = base64.urlsafe_b64decode( - self._repad(keys['p256dh'])) + keys[k] = bytes(keys[k].encode("utf8")) + receiver_raw = base64.urlsafe_b64decode(self._repad(keys["p256dh"])) if len(receiver_raw) != 65 and receiver_raw[0] != "\x04": raise WebPushException("Invalid p256dh key specified") self.receiver_key = receiver_raw - self.auth_key = base64.urlsafe_b64decode( - self._repad(keys['auth'])) + self.auth_key = base64.urlsafe_b64decode(self._repad(keys["auth"])) def verb(self, msg, *args, **kwargs): if self.verbose: @@ -168,7 +166,7 @@ def verb(self, msg, *args, **kwargs): def _repad(self, data): """Add base64 padding to the end of a string, if required""" - return data + b"===="[:len(data) % 4] + return data + b"===="[: len(data) % 4] def encode(self, data, content_encoding="aes128gcm"): """Encrypt the data. @@ -192,9 +190,10 @@ def encode(self, data, content_encoding="aes128gcm"): self.verb("Encoding data...") salt = None if content_encoding not in self.valid_encodings: - raise WebPushException("Invalid content encoding specified. " - "Select from " + - json.dumps(self.valid_encodings)) + raise WebPushException( + "Invalid content encoding specified. " + "Select from " + json.dumps(self.valid_encodings) + ) if content_encoding == "aesgcm": self.verb("Generating salt for aesgcm...") salt = os.urandom(16) @@ -203,11 +202,11 @@ def encode(self, data, content_encoding="aes128gcm"): server_key = ec.generate_private_key(ec.SECP256R1, default_backend()) crypto_key = server_key.public_key().public_bytes( encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint + format=serialization.PublicFormat.UncompressedPoint, ) if isinstance(data, six.text_type): - data = bytes(data.encode('utf8')) + data = bytes(data.encode("utf8")) if content_encoding == "aes128gcm": self.verb("Encrypting to aes128gcm...") encrypted = http_ece.encrypt( @@ -216,13 +215,12 @@ def encode(self, data, content_encoding="aes128gcm"): private_key=server_key, dh=self.receiver_key, auth_secret=self.auth_key, - version=content_encoding) - reply = CaseInsensitiveDict({ - 'body': encrypted - }) + version=content_encoding, + ) + reply = CaseInsensitiveDict({"body": encrypted}) else: self.verb("Encrypting to aesgcm...") - crypto_key = base64.urlsafe_b64encode(crypto_key).strip(b'=') + crypto_key = base64.urlsafe_b64encode(crypto_key).strip(b"=") encrypted = http_ece.encrypt( data, salt=salt, @@ -230,13 +228,16 @@ def encode(self, data, content_encoding="aes128gcm"): keyid=crypto_key.decode(), dh=self.receiver_key, auth_secret=self.auth_key, - version=content_encoding) - reply = CaseInsensitiveDict({ - 'crypto_key': crypto_key, - 'body': encrypted, - }) + version=content_encoding, + ) + reply = CaseInsensitiveDict( + { + "crypto_key": crypto_key, + "body": encrypted, + } + ) if salt: - reply['salt'] = base64.urlsafe_b64encode(salt).strip(b'=') + reply["salt"] = base64.urlsafe_b64encode(salt).strip(b"=") return reply def as_curl(self, endpoint, encoded_data, headers): @@ -255,23 +256,33 @@ def as_curl(self, endpoint, encoded_data, headers): """ header_list = [ - '-H "{}: {}" \\ \n'.format( - key.lower(), val) for key, val in headers.items() + '-H "{}: {}" \\ \n'.format(key.lower(), val) for key, val in headers.items() ] data = "" if encoded_data: with open("encrypted.data", "wb") as f: f.write(encoded_data) data = "--data-binary @encrypted.data" - if 'content-length' not in headers: + if "content-length" not in headers: self.verb("Generating content-length header...") header_list.append( - '-H "content-length: {}" \\ \n'.format(len(encoded_data))) - return ("""curl -vX POST {url} \\\n{headers}{data}""".format( - url=endpoint, headers="".join(header_list), data=data)) + '-H "content-length: {}" \\ \n'.format(len(encoded_data)) + ) + return """curl -vX POST {url} \\\n{headers}{data}""".format( + url=endpoint, headers="".join(header_list), data=data + ) - def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None, - content_encoding="aes128gcm", curl=False, timeout=None): + def send( + self, + data=None, + headers=None, + ttl=0, + gcm_key=None, + reg_id=None, + content_encoding="aes128gcm", + curl=False, + timeout=None, + ): """Encode and send the data to the Push Service. :param data: A serialized block of data (see encode() ). @@ -311,80 +322,86 @@ def send(self, data=None, headers=None, ttl=0, gcm_key=None, reg_id=None, # should use ';' instead of ',' to append the headers. # see # https://github.com/webpush-wg/webpush-encryption/issues/6 - crypto_key += ';' - crypto_key += ( - "dh=" + encoded["crypto_key"].decode('utf8')) - headers.update({ - 'crypto-key': crypto_key - }) + crypto_key += ";" + crypto_key += "dh=" + encoded["crypto_key"].decode("utf8") + headers.update({"crypto-key": crypto_key}) if "salt" in encoded: - headers.update({ - 'encryption': "salt=" + encoded['salt'].decode('utf8') - }) - headers.update({ - 'content-encoding': content_encoding, - }) + headers.update({"encryption": "salt=" + encoded["salt"].decode("utf8")}) + headers.update( + { + "content-encoding": content_encoding, + } + ) if gcm_key: # guess if it is a legacy GCM project key or actual FCM key # gcm keys are all about 40 chars (use 100 for confidence), # fcm keys are 153-175 chars if len(gcm_key) < 100: self.verb("Guessing this is legacy GCM...") - endpoint = 'https://android.googleapis.com/gcm/send' + endpoint = "https://android.googleapis.com/gcm/send" else: self.verb("Guessing this is FCM...") - endpoint = 'https://fcm.googleapis.com/fcm/send' + endpoint = "https://fcm.googleapis.com/fcm/send" reg_ids = [] if not reg_id: - reg_id = self.subscription_info['endpoint'].rsplit('/', 1)[-1] + reg_id = self.subscription_info["endpoint"].rsplit("/", 1)[-1] self.verb("Fetching out registration id: {}", reg_id) reg_ids.append(reg_id) gcm_data = dict() - gcm_data['registration_ids'] = reg_ids + gcm_data["registration_ids"] = reg_ids if data: - gcm_data['raw_data'] = base64.b64encode( - encoded.get('body')).decode('utf8') - gcm_data['time_to_live'] = int( - headers['ttl'] if 'ttl' in headers else ttl) + gcm_data["raw_data"] = base64.b64encode(encoded.get("body")).decode( + "utf8" + ) + gcm_data["time_to_live"] = int(headers["ttl"] if "ttl" in headers else ttl) encoded_data = json.dumps(gcm_data) - headers.update({ - 'Authorization': 'key='+gcm_key, - 'Content-Type': 'application/json', - }) + headers.update( + { + "Authorization": "key=" + gcm_key, + "Content-Type": "application/json", + } + ) else: - encoded_data = encoded.get('body') - endpoint = self.subscription_info['endpoint'] + encoded_data = encoded.get("body") + endpoint = self.subscription_info["endpoint"] - if 'ttl' not in headers or ttl: + if "ttl" not in headers or ttl: self.verb("Generating TTL of 0...") - headers['ttl'] = str(ttl or 0) + headers["ttl"] = str(ttl or 0) # Additionally useful headers: # Authorization / Crypto-Key (VAPID headers) if curl: return self.as_curl(endpoint, encoded_data, headers) - self.verb("\nSending request to" - "\n\thost: {}\n\theaders: {}\n\tdata: {}", - endpoint, headers, encoded_data) - resp = self.requests_method.post(endpoint, - data=encoded_data, - headers=headers, - timeout=timeout) - self.verb("\nResponse:\n\tcode: {}\n\tbody: {}\n", - resp.status_code, resp.text or "Empty") + self.verb( + "\nSending request to" "\n\thost: {}\n\theaders: {}\n\tdata: {}", + endpoint, + headers, + encoded_data, + ) + resp = self.requests_method.post( + endpoint, data=encoded_data, headers=headers, timeout=timeout + ) + self.verb( + "\nResponse:\n\tcode: {}\n\tbody: {}\n", + resp.status_code, + resp.text or "Empty", + ) return resp -def webpush(subscription_info, - data=None, - vapid_private_key=None, - vapid_claims=None, - content_encoding="aes128gcm", - curl=False, - timeout=None, - ttl=0, - verbose=False, - headers=None, - requests_session=None): +def webpush( + subscription_info, + data=None, + vapid_private_key=None, + vapid_claims=None, + content_encoding="aes128gcm", + curl=False, + timeout=None, + ttl=0, + verbose=False, + headers=None, + requests_session=None, +): """ One call solution to endcode and send `data` to the endpoint contained in `subscription_info` using optional VAPID auth headers. @@ -443,19 +460,17 @@ def webpush(subscription_info, if vapid_claims: if verbose: print("Generating VAPID headers...") - if not vapid_claims.get('aud'): - url = urlparse(subscription_info.get('endpoint')) + if not vapid_claims.get("aud"): + url = urlparse(subscription_info.get("endpoint")) aud = "{}://{}".format(url.scheme, url.netloc) - vapid_claims['aud'] = aud + vapid_claims["aud"] = aud # Remember, passed structures are mutable in python. # It's possible that a previously set `exp` field is no longer valid. - if (not vapid_claims.get('exp') - or vapid_claims.get('exp') < int(time.time())): + if not vapid_claims.get("exp") or vapid_claims.get("exp") < int(time.time()): # encryption lives for 12 hours - vapid_claims['exp'] = int(time.time()) + (12 * 60 * 60) + vapid_claims["exp"] = int(time.time()) + (12 * 60 * 60) if verbose: - print("Setting VAPID expry to {}...".format( - vapid_claims['exp'])) + print("Setting VAPID expry to {}...".format(vapid_claims["exp"])) if not vapid_private_key: raise WebPushException("VAPID dict missing 'private_key'") if isinstance(vapid_private_key, Vapid01): @@ -463,8 +478,7 @@ def webpush(subscription_info, elif os.path.isfile(vapid_private_key): # Presume that key from file is handled correctly by # py_vapid. - vv = Vapid.from_file( - private_key_file=vapid_private_key) # pragma no cover + vv = Vapid.from_file(private_key_file=vapid_private_key) # pragma no cover else: vv = Vapid.from_string(private_key=vapid_private_key) if verbose: @@ -485,7 +499,10 @@ def webpush(subscription_info, timeout=timeout, ) if not curl and response.status_code > 202: - raise WebPushException("Push failed: {} {}\nResponse body:{}".format( - response.status_code, response.reason, response.text), - response=response) + raise WebPushException( + "Push failed: {} {}\nResponse body:{}".format( + response.status_code, response.reason, response.text + ), + response=response, + ) return response diff --git a/pywebpush/__main__.py b/pywebpush/__main__.py index 26b713b..3c45d12 100644 --- a/pywebpush/__main__.py +++ b/pywebpush/__main__.py @@ -7,16 +7,25 @@ def get_config(): parser = argparse.ArgumentParser(description="WebPush tool") - parser.add_argument("--data", '-d', help="Data file") + parser.add_argument("--data", "-d", help="Data file") parser.add_argument("--info", "-i", help="Subscription Info JSON file") parser.add_argument("--head", help="Header Info JSON file") parser.add_argument("--claims", help="Vapid claim file") parser.add_argument("--key", help="Vapid private key file path") - parser.add_argument("--curl", help="Don't send, display as curl command", - default=False, action="store_true") + parser.add_argument( + "--curl", + help="Don't send, display as curl command", + default=False, + action="store_true", + ) parser.add_argument("--encoding", default="aes128gcm") - parser.add_argument("--verbose", "-v", help="Provide verbose feedback", - default=False, action="store_true") + parser.add_argument( + "--verbose", + "-v", + help="Provide verbose feedback", + default=False, + action="store_true", + ) args = parser.parse_args() @@ -45,7 +54,7 @@ def get_config(): def main(): - """ Send data """ + """Send data""" try: args = get_config() @@ -57,7 +66,8 @@ def main(): curl=args.curl, content_encoding=args.encoding, verbose=args.verbose, - headers=args.head) + headers=args.head, + ) print(result) except Exception as ex: print("ERROR: {}".format(ex)) diff --git a/pywebpush/tests/test_webpush.py b/pywebpush/tests/test_webpush.py index 4c313c8..a367ecf 100644 --- a/pywebpush/tests/test_webpush.py +++ b/pywebpush/tests/test_webpush.py @@ -24,70 +24,76 @@ class WebpushTestCase(unittest.TestCase): "M5xqEwuPM7VuQcyiLDhvovthPIXx+gsQRQ==" ) - def _gen_subscription_info(self, - recv_key=None, - endpoint="https://example.com/"): + def _gen_subscription_info(self, recv_key=None, endpoint="https://example.com/"): if not recv_key: recv_key = ec.generate_private_key(ec.SECP256R1, default_backend()) return { "endpoint": endpoint, "keys": { - 'auth': base64.urlsafe_b64encode(os.urandom(16)).strip(b'='), - 'p256dh': self._get_pubkey_str(recv_key), - } + "auth": base64.urlsafe_b64encode(os.urandom(16)).strip(b"="), + "p256dh": self._get_pubkey_str(recv_key), + }, } def _get_pubkey_str(self, priv_key): return base64.urlsafe_b64encode( priv_key.public_key().public_bytes( encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint - )).strip(b'=') + format=serialization.PublicFormat.UncompressedPoint, + ) + ).strip(b"=") def test_init(self): # use static values so we know what to look for in the reply subscription_info = { - u"endpoint": u"https://example.com/", - u"keys": { - u"p256dh": (u"BOrnIslXrUow2VAzKCUAE4sIbK00daEZCswOcf8m3T" - "F8V82B-OpOg5JbmYLg44kRcvQC1E2gMJshsUYA-_zMPR8"), - u"auth": u"k8JV6sjdbhAi1n3_LDBLvA" - } + "endpoint": "https://example.com/", + "keys": { + "p256dh": ( + "BOrnIslXrUow2VAzKCUAE4sIbK00daEZCswOcf8m3T" + "F8V82B-OpOg5JbmYLg44kRcvQC1E2gMJshsUYA-_zMPR8" + ), + "auth": "k8JV6sjdbhAi1n3_LDBLvA", + }, } - rk_decode = (b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b' - b'\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1' - b'|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3' - b'\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00' - b'\xfb\xfc\xcc=\x1f') + rk_decode = ( + b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b' + b"\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1" + b"|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3" + b"\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00" + b"\xfb\xfc\xcc=\x1f" + ) self.assertRaises( - WebPushException, - WebPusher, - {"keys": {'p256dh': 'AAA=', 'auth': 'AAA='}}) + WebPushException, WebPusher, {"keys": {"p256dh": "AAA=", "auth": "AAA="}} + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", "keys": {'p256dh': 'AAA='}}) + {"endpoint": "https://example.com", "keys": {"p256dh": "AAA="}}, + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", "keys": {'auth': 'AAA='}}) + {"endpoint": "https://example.com", "keys": {"auth": "AAA="}}, + ) self.assertRaises( WebPushException, WebPusher, - {"endpoint": "https://example.com", - "keys": {'p256dh': 'AAA=', 'auth': 'AAA='}}) + { + "endpoint": "https://example.com", + "keys": {"p256dh": "AAA=", "auth": "AAA="}, + }, + ) push = WebPusher(subscription_info) assert push.subscription_info != subscription_info - assert push.subscription_info['keys'] != subscription_info['keys'] - assert push.subscription_info['endpoint'] == subscription_info['endpoint'] + assert push.subscription_info["keys"] != subscription_info["keys"] + assert push.subscription_info["endpoint"] == subscription_info["endpoint"] assert push.receiver_key == rk_decode assert push.auth_key == b'\x93\xc2U\xea\xc8\xddn\x10"\xd6}\xff,0K\xbc' def test_encode(self): for content_encoding in ["aesgcm", "aes128gcm"]: - recv_key = ec.generate_private_key( - ec.SECP256R1, default_backend()) + recv_key = ec.generate_private_key(ec.SECP256R1, default_backend()) subscription_info = self._gen_subscription_info(recv_key) data = "Mary had a little lamb, with some nice mint jelly" push = WebPusher(subscription_info) @@ -99,48 +105,45 @@ def test_encode(self): """ # Convert these b64 strings into their raw, binary form. raw_salt = None - if 'salt' in encoded: - raw_salt = base64.urlsafe_b64decode( - push._repad(encoded['salt'])) + if "salt" in encoded: + raw_salt = base64.urlsafe_b64decode(push._repad(encoded["salt"])) raw_dh = None if content_encoding != "aes128gcm": - raw_dh = base64.urlsafe_b64decode( - push._repad(encoded['crypto_key'])) + raw_dh = base64.urlsafe_b64decode(push._repad(encoded["crypto_key"])) raw_auth = base64.urlsafe_b64decode( - push._repad(subscription_info['keys']['auth'])) + push._repad(subscription_info["keys"]["auth"]) + ) decoded = http_ece.decrypt( - encoded['body'], + encoded["body"], salt=raw_salt, dh=raw_dh, private_key=recv_key, auth_secret=raw_auth, - version=content_encoding - ) - assert decoded.decode('utf8') == data + version=content_encoding, + ) + assert decoded.decode("utf8") == data def test_bad_content_encoding(self): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb, with some nice mint jelly" push = WebPusher(subscription_info) - self.assertRaises(WebPushException, - push.encode, - data, - content_encoding="aesgcm128") + self.assertRaises( + WebPushException, push.encode, data, content_encoding="aesgcm128" + ) @patch("requests.post") def test_send(self, mock_post): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" WebPusher(subscription_info).send(data, headers) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey - assert pheaders.get('content-encoding') == 'aes128gcm' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" @patch("requests.post") def test_send_vapid(self, mock_post): @@ -153,26 +156,26 @@ def test_send_vapid(self, mock_post): vapid_private_key=self.vapid_key, vapid_claims={"sub": "mailto:ops@example.com"}, content_encoding="aesgcm", - headers={"Test-Header": "test-value"} + headers={"Test-Header": "test-value"}, ) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" def repad(str): - return str + "===="[:len(str) % 4] + return str + "===="[: len(str) % 4] auth = json.loads( base64.urlsafe_b64decode( - repad(pheaders['authorization'].split('.')[1]) - ).decode('utf8') + repad(pheaders["authorization"].split(".")[1]) + ).decode("utf8") ) - assert subscription_info.get('endpoint').startswith(auth['aud']) - assert 'vapid' in pheaders.get('authorization') - ckey = pheaders.get('crypto-key') - assert 'dh=' in ckey - assert pheaders.get('content-encoding') == 'aesgcm' - assert pheaders.get('test-header') == 'test-value' + assert subscription_info.get("endpoint").startswith(auth["aud"]) + assert "vapid" in pheaders.get("authorization") + ckey = pheaders.get("crypto-key") + assert "dh=" in ckey + assert pheaders.get("content-encoding") == "aesgcm" + assert pheaders.get("test-header") == "test-value" @patch.object(WebPusher, "send") @patch.object(py_vapid.Vapid, "sign") @@ -198,9 +201,11 @@ def test_webpush_vapid_exp(self, vapid_sign, pusher_send): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb" vapid_key = py_vapid.Vapid.from_string(self.vapid_key) - claims = dict(sub="mailto:ops@example.com", - aud="https://example.com", - exp=int(time.time() - 48600)) + claims = dict( + sub="mailto:ops@example.com", + aud="https://example.com", + exp=int(time.time() - 48600), + ) webpush( subscription_info=subscription_info, data=data, @@ -209,7 +214,7 @@ def test_webpush_vapid_exp(self, vapid_sign, pusher_send): ) vapid_sign.assert_called_once_with(claims) pusher_send.assert_called_once() - assert claims['exp'] > int(time.time()) + assert claims["exp"] > int(time.time()) @patch("requests.post") def test_send_bad_vapid_no_key(self, mock_post): @@ -224,8 +229,9 @@ def test_send_bad_vapid_no_key(self, mock_post): data=data, vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" - }) + "sub": "mailto:ops@example.com", + }, + ) @patch("requests.post") def test_send_bad_vapid_bad_return(self, mock_post): @@ -240,53 +246,47 @@ def test_send_bad_vapid_bad_return(self, mock_post): data=data, vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" + "sub": "mailto:ops@example.com", }, - vapid_private_key=self.vapid_key) + vapid_private_key=self.vapid_key, + ) @patch("requests.post") def test_send_empty(self, mock_post): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - WebPusher(subscription_info).send('', headers) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert 'encryption' not in pheaders - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + WebPusher(subscription_info).send("", headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert "encryption" not in pheaders + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey def test_encode_empty(self): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - encoded = WebPusher(subscription_info).encode('', headers) + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + encoded = WebPusher(subscription_info).encode("", headers) assert encoded is None def test_encode_no_crypto(self): subscription_info = self._gen_subscription_info() - del(subscription_info['keys']) - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} - data = 'Something' + del subscription_info["keys"] + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Something" pusher = WebPusher(subscription_info) - self.assertRaises( - WebPushException, - pusher.encode, - data, - headers) + self.assertRaises(WebPushException, pusher.encode, data, headers) @patch("requests.post") def test_send_no_headers(self, mock_post): subscription_info = self._gen_subscription_info() data = "Mary had a little lamb" WebPusher(subscription_info).send(data) - assert subscription_info.get('endpoint') == mock_post.call_args[0][0] - pheaders = mock_post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('content-encoding') == 'aes128gcm' + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("content-encoding") == "aes128gcm" @patch("pywebpush.open") def test_as_curl(self, opener): @@ -296,40 +296,39 @@ def test_as_curl(self, opener): data="Mary had a little lamb", vapid_claims={ "aud": "https://example.com", - "sub": "mailto:ops@example.com" + "sub": "mailto:ops@example.com", }, vapid_private_key=self.vapid_key, - curl=True + curl=True, ) for s in [ "curl -vX POST https://example.com", - "-H \"content-encoding: aes128gcm\"", - "-H \"authorization: vapid ", - "-H \"ttl: 0\"", - "-H \"content-length:" + '-H "content-encoding: aes128gcm"', + '-H "authorization: vapid ', + '-H "ttl: 0"', + '-H "content-length:', ]: assert s in result, "missing: {}".format(s) def test_ci_dict(self): ci = CaseInsensitiveDict({"Foo": "apple", "bar": "banana"}) - assert 'apple' == ci["foo"] - assert 'apple' == ci.get("FOO") - assert 'apple' == ci.get("Foo") - del (ci['FOO']) - assert ci.get('Foo') is None + assert "apple" == ci["foo"] + assert "apple" == ci.get("FOO") + assert "apple" == ci.get("Foo") + del ci["FOO"] + assert ci.get("Foo") is None @patch("requests.post") def test_gcm(self, mock_post): subscription_info = self._gen_subscription_info( - None, - endpoint="https://android.googleapis.com/gcm/send/regid123") - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + None, endpoint="https://android.googleapis.com/gcm/send/regid123" + ) + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" wp = WebPusher(subscription_info) wp.send(data, headers, gcm_key="gcm_key_value") - pdata = json.loads(mock_post.call_args[1].get('data')) - pheaders = mock_post.call_args[1].get('headers') + pdata = json.loads(mock_post.call_args[1].get("data")) + pheaders = mock_post.call_args[1].get("headers") assert pdata["registration_ids"][0] == "regid123" assert pheaders.get("authorization") == "key=gcm_key_value" assert pheaders.get("content-type") == "application/json" @@ -339,52 +338,50 @@ def test_timeout(self, mock_post): mock_post.return_value.status_code = 200 subscription_info = self._gen_subscription_info() WebPusher(subscription_info).send(timeout=5.2) - assert mock_post.call_args[1].get('timeout') == 5.2 + assert mock_post.call_args[1].get("timeout") == 5.2 webpush(subscription_info, timeout=10.001) - assert mock_post.call_args[1].get('timeout') == 10.001 + assert mock_post.call_args[1].get("timeout") == 10.001 @patch("requests.Session") def test_send_using_requests_session(self, mock_session): subscription_info = self._gen_subscription_info() - headers = {"Crypto-Key": "pre-existing", - "Authentication": "bearer vapid"} + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} data = "Mary had a little lamb" - WebPusher(subscription_info, - requests_session=mock_session).send(data, headers) - assert subscription_info.get( - 'endpoint') == mock_session.post.call_args[0][0] - pheaders = mock_session.post.call_args[1].get('headers') - assert pheaders.get('ttl') == '0' - assert pheaders.get('AUTHENTICATION') == headers.get('Authentication') - ckey = pheaders.get('crypto-key') - assert 'pre-existing' in ckey - assert pheaders.get('content-encoding') == 'aes128gcm' + WebPusher(subscription_info, requests_session=mock_session).send(data, headers) + assert subscription_info.get("endpoint") == mock_session.post.call_args[0][0] + pheaders = mock_session.post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" class WebpushExceptionTestCase(unittest.TestCase): - def test_exception(self): from requests import Response exp = WebPushException("foo") - assert ("{}".format(exp) == "WebPushException: foo") + assert "{}".format(exp) == "WebPushException: foo" # Really should try to load the response to verify, but this mock # covers what we need. response = Mock(spec=Response) response.text = ( - '{"code": 401, "errno": 109, "error": ' - '"Unauthorized", "more_info": "http://' - 'autopush.readthedocs.io/en/latest/htt' - 'p.html#error-codes", "message": "Requ' - 'est did not validate missing authoriz' - 'ation header"}') + '{"code": 401, "errno": 109, "error": ' + '"Unauthorized", "more_info": "http://' + "autopush.readthedocs.io/en/latest/htt" + 'p.html#error-codes", "message": "Requ' + "est did not validate missing authoriz" + 'ation header"}' + ) response.json.return_value = json.loads(response.text) response.status_code = 401 response.reason = "Unauthorized" exp = WebPushException("foo", response) assert "{}".format(exp) == "WebPushException: foo, Response {}".format( - response.text) - assert '{}'.format(exp.response), '' - assert exp.response.json().get('errno') == 109 + response.text + ) + assert "{}".format(exp.response), "" + assert exp.response.json().get("errno") == 109 exp = WebPushException("foo", [1, 2, 3]) - assert '{}'.format(exp) == "WebPushException: foo, Response [1, 2, 3]" + assert "{}".format(exp) == "WebPushException: foo, Response [1, 2, 3]" diff --git a/setup.py b/setup.py index aef04f8..7364a04 100644 --- a/setup.py +++ b/setup.py @@ -9,46 +9,46 @@ def read_from(file): reply = [] - with io.open(os.path.join(here, file), encoding='utf8') as f: + with io.open(os.path.join(here, file), encoding="utf8") as f: for line in f: line = line.strip() if not line: break - if line[:2] == '-r': - reply += read_from(line.split(' ')[1]) + if line[:2] == "-r": + reply += read_from(line.split(" ")[1]) continue - if line[0] != '#' or line[:2] != '//': + if line[0] != "#" or line[:2] != "//": reply.append(line) return reply here = os.path.abspath(os.path.dirname(__file__)) -with io.open(os.path.join(here, 'README.rst'), encoding='utf8') as f: +with io.open(os.path.join(here, "README.rst"), encoding="utf8") as f: README = f.read() -with io.open(os.path.join(here, 'CHANGELOG.md'), encoding='utf8') as f: +with io.open(os.path.join(here, "CHANGELOG.md"), encoding="utf8") as f: CHANGES = f.read() setup( name="pywebpush", version=__version__, packages=find_packages(), - description='WebPush publication library', - long_description=README + '\n\n' + CHANGES, + description="WebPush publication library", + long_description=README + "\n\n" + CHANGES, classifiers=[ "Topic :: Internet :: WWW/HTTP", "Programming Language :: Python :: Implementation :: PyPy", - 'Programming Language :: Python', + "Programming Language :: Python", "Programming Language :: Python :: 3", ], - keywords='push webpush publication', + keywords="push webpush publication", author="JR Conlin", author_email="src+webpusher@jrconlin.com", - url='https://github.com/web-push-libs/pywebpush', + url="https://github.com/web-push-libs/pywebpush", license="MPL2", include_package_data=True, zip_safe=False, - install_requires=read_from('requirements.txt'), - tests_require=read_from('test-requirements.txt'), + install_requires=read_from("requirements.txt"), + tests_require=read_from("test-requirements.txt"), entry_points=""" [console_scripts] pywebpush = pywebpush.__main__:main From 1bf1adc85314d88484cb45891504b65e6b82d7cd Mon Sep 17 00:00:00 2001 From: TobeTek Date: Fri, 20 Jan 2023 23:10:47 +0100 Subject: [PATCH 3/7] Install aiohttp --- pywebpush/__init__.py | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index c60beff..ff90aef 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -13,6 +13,7 @@ except ImportError: # pragma nocover from urlparse import urlparse +import aiohttp import six import http_ece import requests diff --git a/requirements.txt b/requirements.txt index 74596b3..16f1e7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiohttp cryptography>=2.6.1 http-ece>=1.1.0 requests>=2.21.0 From 6a90afc1d610b700d730034c33b44da7ccf72df9 Mon Sep 17 00:00:00 2001 From: TobeTek Date: Mon, 23 Jan 2023 09:20:25 +0100 Subject: [PATCH 4/7] Add .vscode to gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6e380c1..d7a9866 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,6 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +.vscode/ \ No newline at end of file From 278c11d8b4028165c48d7a2983a944732015ad60 Mon Sep 17 00:00:00 2001 From: TobeTek Date: Mon, 23 Jan 2023 09:21:12 +0100 Subject: [PATCH 5/7] Create WebPusher.send_async --- pywebpush/__init__.py | 61 +++++++++++++++++++++++++++++++------ pywebpush/tests/__init__.py | 0 2 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 pywebpush/tests/__init__.py diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index ff90aef..afbc58e 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -14,13 +14,18 @@ from urlparse import urlparse import aiohttp + import six import http_ece import requests + +from aiohttp import ClientResponse as AioHttpResponse from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization +from functools import partial from py_vapid import Vapid, Vapid01 +from requests import Response class WebPushException(Exception): @@ -120,7 +125,13 @@ class WebPusher: ] verbose = False - def __init__(self, subscription_info, requests_session=None, verbose=False): + def __init__( + self, + subscription_info, + requests_session=None, + aiohttp_session=None, + verbose=False, + ): """Initialize using the info provided by the client PushSubscription object (See https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe) @@ -144,6 +155,11 @@ def __init__(self, subscription_info, requests_session=None, verbose=False): else: self.requests_method = requests_session + if aiohttp_session is None: + self.aiohttp_method = partial(aiohttp.request, method="POST") + else: + self.aiohttp_method = aiohttp_session.post + if "endpoint" not in subscription_info: raise WebPushException("subscription_info missing endpoint URL") self.subscription_info = deepcopy(subscription_info) @@ -273,7 +289,7 @@ def as_curl(self, endpoint, encoded_data, headers): url=endpoint, headers="".join(header_list), data=data ) - def send( + def _prepare_send_data( self, data=None, headers=None, @@ -282,8 +298,7 @@ def send( reg_id=None, content_encoding="aes128gcm", curl=False, - timeout=None, - ): + ) -> dict: """Encode and send the data to the Push Service. :param data: A serialized block of data (see encode() ). @@ -304,9 +319,6 @@ def send( :type content_encoding: str :param curl: Display output as `curl` command instead of sending :type curl: bool - :param timeout: POST requests timeout - :type timeout: float or tuple - """ # Encode the data. if headers is None: @@ -371,16 +383,33 @@ def send( headers["ttl"] = str(ttl or 0) # Additionally useful headers: # Authorization / Crypto-Key (VAPID headers) - if curl: - return self.as_curl(endpoint, encoded_data, headers) + self.verb( "\nSending request to" "\n\thost: {}\n\theaders: {}\n\tdata: {}", endpoint, headers, encoded_data, ) + + return {"endpoint": endpoint, "data": encoded_data, "headers": headers} + + def send(self, *args, **kwargs) -> Response: + """Encode and send the data to the Push Service""" + timeout = kwargs.pop("timeout", 10000) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") + + if curl: + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + resp = self.requests_method.post( - endpoint, data=encoded_data, headers=headers, timeout=timeout + endpoint, + timeout=timeout, + **params, ) self.verb( "\nResponse:\n\tcode: {}\n\tbody: {}\n", @@ -389,6 +418,18 @@ def send( ) return resp + async def send_async(self, *args, **kwargs) -> AioHttpResponse: + timeout = kwargs.pop("timeout", 10000) + endpoint, params = self._prepare_send_data(*args, **kwargs) + resp = await self.aiohttp_method(endpoint, timeout=timeout, **params) + resp_text = await resp.text() + self.verb( + "\nResponse:\n\tcode: {}\n\tbody: {}\n", + resp.status, + resp_text or "Empty", + ) + return resp + def webpush( subscription_info, diff --git a/pywebpush/tests/__init__.py b/pywebpush/tests/__init__.py new file mode 100644 index 0000000..e69de29 From 0536c86d3fed9cdc003dc471b45d15ae298dacda Mon Sep 17 00:00:00 2001 From: TobeTek Date: Mon, 23 Jan 2023 10:05:22 +0100 Subject: [PATCH 6/7] Bug fix WebPusher.send_async --- pywebpush/__init__.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index afbc58e..1c388d3 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -7,6 +7,7 @@ import json import os import time +from typing import Union try: from urllib.parse import urlparse @@ -393,7 +394,7 @@ def _prepare_send_data( return {"endpoint": endpoint, "data": encoded_data, "headers": headers} - def send(self, *args, **kwargs) -> Response: + def send(self, *args, **kwargs) -> Union[Response, str]: """Encode and send the data to the Push Service""" timeout = kwargs.pop("timeout", 10000) curl = kwargs.pop("curl", False) @@ -418,9 +419,18 @@ def send(self, *args, **kwargs) -> Response: ) return resp - async def send_async(self, *args, **kwargs) -> AioHttpResponse: + async def send_async(self, *args, **kwargs) -> Union[AioHttpResponse, str]: timeout = kwargs.pop("timeout", 10000) - endpoint, params = self._prepare_send_data(*args, **kwargs) + curl = kwargs.pop("curl", False) + + params = self._prepare_send_data(*args, **kwargs) + endpoint = params.pop("endpoint") + + if curl: + encoded_data = params["data"] + headers = params["headers"] + return self.as_curl(endpoint, encoded_data=encoded_data, headers=headers) + resp = await self.aiohttp_method(endpoint, timeout=timeout, **params) resp_text = await resp.text() self.verb( From 2b3168f54689310f7cad4a9e767b89d0fee3adb8 Mon Sep 17 00:00:00 2001 From: TobeTek Date: Mon, 23 Jan 2023 10:05:41 +0100 Subject: [PATCH 7/7] Write tests for WebPusher.send_async --- pywebpush/tests/test_webpush.py | 84 +++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/pywebpush/tests/test_webpush.py b/pywebpush/tests/test_webpush.py index a367ecf..2e7bf6e 100644 --- a/pywebpush/tests/test_webpush.py +++ b/pywebpush/tests/test_webpush.py @@ -4,7 +4,7 @@ import unittest import time -from mock import patch, Mock +from mock import patch, Mock, AsyncMock import http_ece from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization @@ -14,8 +14,7 @@ from pywebpush import WebPusher, WebPushException, CaseInsensitiveDict, webpush -class WebpushTestCase(unittest.TestCase): - +class WebpushTestUtils: # This is a exported DER formatted string of an ECDH public key # This was lifted from the py_vapid tests. vapid_key = ( @@ -43,6 +42,8 @@ def _get_pubkey_str(self, priv_key): ) ).strip(b"=") + +class WebpushTestCase(WebpushTestUtils, unittest.TestCase): def test_init(self): # use static values so we know what to look for in the reply subscription_info = { @@ -357,6 +358,83 @@ def test_send_using_requests_session(self, mock_session): assert pheaders.get("content-encoding") == "aes128gcm" +class WebPusherAsyncTestCase(WebpushTestUtils, unittest.IsolatedAsyncioTestCase): + @patch("aiohttp.request", new_callable=AsyncMock) + async def test_send(self, mock_post): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + await WebPusher(subscription_info).send_async(data, headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" + + @patch("aiohttp.request", new_callable=AsyncMock) + async def test_send_empty(self, mock_post): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + await WebPusher(subscription_info).send_async("", headers) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert "encryption" not in pheaders + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + + @patch("aiohttp.request", new_callable=AsyncMock) + async def test_send_no_headers(self, mock_post): + subscription_info = self._gen_subscription_info() + data = "Mary had a little lamb" + await WebPusher(subscription_info).send_async(data) + assert subscription_info.get("endpoint") == mock_post.call_args[0][0] + pheaders = mock_post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("content-encoding") == "aes128gcm" + + @patch("aiohttp.request", new_callable=AsyncMock) + async def test_gcm(self, mock_post): + subscription_info = self._gen_subscription_info( + None, endpoint="https://android.googleapis.com/gcm/send/regid123" + ) + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + wp = WebPusher(subscription_info) + await wp.send_async(data, headers, gcm_key="gcm_key_value") + pdata = json.loads(mock_post.call_args[1].get("data")) + pheaders = mock_post.call_args[1].get("headers") + assert pdata["registration_ids"][0] == "regid123" + assert pheaders.get("authorization") == "key=gcm_key_value" + assert pheaders.get("content-type") == "application/json" + + @patch("aiohttp.request", new_callable=AsyncMock) + async def test_timeout(self, mock_post): + mock_post.return_value.status_code = 200 + subscription_info = self._gen_subscription_info() + await WebPusher(subscription_info).send_async(timeout=5.2) + assert mock_post.call_args[1].get("timeout") == 5.2 + + @patch("aiohttp.ClientSession", new_callable=AsyncMock) + async def test_send_using_requests_session(self, mock_session): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + await WebPusher(subscription_info, aiohttp_session=mock_session).send_async( + data, headers + ) + assert subscription_info.get("endpoint") == mock_session.post.call_args[0][0] + pheaders = mock_session.post.call_args[1].get("headers") + assert pheaders.get("ttl") == "0" + assert pheaders.get("AUTHENTICATION") == headers.get("Authentication") + ckey = pheaders.get("crypto-key") + assert "pre-existing" in ckey + assert pheaders.get("content-encoding") == "aes128gcm" + + class WebpushExceptionTestCase(unittest.TestCase): def test_exception(self): from requests import Response