diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..79fbefc --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit= + */site-packages/* + */distutils/* + */src/tests/* diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..2f68891 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{py,md}] +indent_size = 4 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..574c60c --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length=99 +inline-quotes=" +exclude= + src/tests/mockredis diff --git a/.gitignore b/.gitignore index bfcaa84..752e49c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,8 @@ bin/ include/ lib/ pip-selfcheck.json -local_settings.py \ No newline at end of file +local_settings.py +.DS_Store +.coverage +htmlcov +.env diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..dd6a21b --- /dev/null +++ b/Pipfile @@ -0,0 +1,29 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[scripts] +test = "python -m unittest discover -s src/tests" +test-coverage = "coverage run -m unittest discover -s src/tests" +test-report = "coverage report" +test-report-html = "coverage html" +lint = "flake8" +start = "gunicorn src.server:app --reload" +start-prod = "gunicorn src.server:app" + +[packages] +falcon = "*" +gunicorn = "*" +python-mimeparse = "*" +redis = "*" +requests = "*" +six = "*" + +[dev-packages] +flake8 = "*" +flake8-quotes = "*" +coverage = "*" + +[requires] +python_version = "3.6" diff --git a/Pipfile.lock b/Pipfile.lock new file mode 100644 index 0000000..502ac9c --- /dev/null +++ b/Pipfile.lock @@ -0,0 +1,173 @@ +{ + "_meta": { + "hash": { + "sha256": "005bb44c9d727ce258e82803f26bbf60b326494cf87e2ffc2f8f52ea5757a8b0" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.6" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "certifi": { + "hashes": [ + "sha256:339dc09518b07e2fa7eda5450740925974815557727d6bd35d319c1524a04a4c", + "sha256:6d58c986d22b038c8c0df30d639f23a3e6d172a05c3583e766f4c0b785c0986a" + ], + "version": "==2018.10.15" + }, + "chardet": { + "hashes": [ + "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", + "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" + ], + "version": "==3.0.4" + }, + "falcon": { + "hashes": [ + "sha256:0a66b33458fab9c1e400a9be1a68056abda178eb02a8cb4b8f795e9df20b053b", + "sha256:3981f609c0358a9fcdb25b0e7fab3d9e23019356fb429c635ce4133135ae1bc4" + ], + "index": "pypi", + "version": "==1.4.1" + }, + "gunicorn": { + "hashes": [ + "sha256:aa8e0b40b4157b36a5df5e599f45c9c76d6af43845ba3b3b0efe2c70473c2471", + "sha256:fa2662097c66f920f53f70621c6c58ca4a3c4d3434205e608e121b5b3b71f4f3" + ], + "index": "pypi", + "version": "==19.9.0" + }, + "idna": { + "hashes": [ + "sha256:156a6814fb5ac1fc6850fb002e0852d56c0c8d2531923a51032d1b70760e186e", + "sha256:684a38a6f903c1d71d6d5fac066b58d7768af4de2b832e426ec79c30daa94a16" + ], + "version": "==2.7" + }, + "python-mimeparse": { + "hashes": [ + "sha256:76e4b03d700a641fd7761d3cd4fdbbdcd787eade1ebfac43f877016328334f78", + "sha256:a295f03ff20341491bfe4717a39cd0a8cc9afad619ba44b77e86b0ab8a2b8282" + ], + "index": "pypi", + "version": "==1.6.0" + }, + "redis": { + "hashes": [ + "sha256:8a1900a9f2a0a44ecf6e8b5eb3e967a9909dfed219ad66df094f27f7d6f330fb", + "sha256:a22ca993cea2962dbb588f9f30d0015ac4afcc45bee27d3978c0dbe9e97c6c0f" + ], + "index": "pypi", + "version": "==2.10.6" + }, + "requests": { + "hashes": [ + "sha256:99dcfdaaeb17caf6e526f32b6a7b780461512ab3f1d992187801694cba42770c", + "sha256:a84b8c9ab6239b578f22d1c21d51b696dcfe004032bb80ea832398d6909d7279" + ], + "index": "pypi", + "version": "==2.20.0" + }, + "six": { + "hashes": [ + "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", + "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" + ], + "index": "pypi", + "version": "==1.11.0" + }, + "urllib3": { + "hashes": [ + "sha256:61bf29cada3fc2fbefad4fdf059ea4bd1b4a86d2b6d15e1c7c0b582b9752fe39", + "sha256:de9529817c93f27c8ccbfead6985011db27bd0ddfcdb2d86f3f663385c6a9c22" + ], + "version": "==1.24.1" + } + }, + "develop": { + "coverage": { + "hashes": [ + "sha256:03481e81d558d30d230bc12999e3edffe392d244349a90f4ef9b88425fac74ba", + "sha256:0b136648de27201056c1869a6c0d4e23f464750fd9a9ba9750b8336a244429ed", + "sha256:0bf8cbbd71adfff0ef1f3a1531e6402d13b7b01ac50a79c97ca15f030dba6306", + "sha256:10a46017fef60e16694a30627319f38a2b9b52e90182dddb6e37dcdab0f4bf95", + "sha256:198626739a79b09fa0a2f06e083ffd12eb55449b5f8bfdbeed1df4910b2ca640", + "sha256:23d341cdd4a0371820eb2b0bd6b88f5003a7438bbedb33688cd33b8eae59affd", + "sha256:28b2191e7283f4f3568962e373b47ef7f0392993bb6660d079c62bd50fe9d162", + "sha256:2a5b73210bad5279ddb558d9a2bfedc7f4bf6ad7f3c988641d83c40293deaec1", + "sha256:2eb564bbf7816a9d68dd3369a510be3327f1c618d2357fa6b1216994c2e3d508", + "sha256:337ded681dd2ef9ca04ef5d93cfc87e52e09db2594c296b4a0a3662cb1b41249", + "sha256:3a2184c6d797a125dca8367878d3b9a178b6fdd05fdc2d35d758c3006a1cd694", + "sha256:3c79a6f7b95751cdebcd9037e4d06f8d5a9b60e4ed0cd231342aa8ad7124882a", + "sha256:3d72c20bd105022d29b14a7d628462ebdc61de2f303322c0212a054352f3b287", + "sha256:3eb42bf89a6be7deb64116dd1cc4b08171734d721e7a7e57ad64cc4ef29ed2f1", + "sha256:4635a184d0bbe537aa185a34193898eee409332a8ccb27eea36f262566585000", + "sha256:56e448f051a201c5ebbaa86a5efd0ca90d327204d8b059ab25ad0f35fbfd79f1", + "sha256:5a13ea7911ff5e1796b6d5e4fbbf6952381a611209b736d48e675c2756f3f74e", + "sha256:69bf008a06b76619d3c3f3b1983f5145c75a305a0fea513aca094cae5c40a8f5", + "sha256:6bc583dc18d5979dc0f6cec26a8603129de0304d5ae1f17e57a12834e7235062", + "sha256:701cd6093d63e6b8ad7009d8a92425428bc4d6e7ab8d75efbb665c806c1d79ba", + "sha256:7608a3dd5d73cb06c531b8925e0ef8d3de31fed2544a7de6c63960a1e73ea4bc", + "sha256:76ecd006d1d8f739430ec50cc872889af1f9c1b6b8f48e29941814b09b0fd3cc", + "sha256:7aa36d2b844a3e4a4b356708d79fd2c260281a7390d678a10b91ca595ddc9e99", + "sha256:7d3f553904b0c5c016d1dad058a7554c7ac4c91a789fca496e7d8347ad040653", + "sha256:7e1fe19bd6dce69d9fd159d8e4a80a8f52101380d5d3a4d374b6d3eae0e5de9c", + "sha256:8c3cb8c35ec4d9506979b4cf90ee9918bc2e49f84189d9bf5c36c0c1119c6558", + "sha256:9d6dd10d49e01571bf6e147d3b505141ffc093a06756c60b053a859cb2128b1f", + "sha256:be6cfcd8053d13f5f5eeb284aa8a814220c3da1b0078fa859011c7fffd86dab9", + "sha256:c1bb572fab8208c400adaf06a8133ac0712179a334c09224fb11393e920abcdd", + "sha256:de4418dadaa1c01d497e539210cb6baa015965526ff5afc078c57ca69160108d", + "sha256:e05cb4d9aad6233d67e0541caa7e511fa4047ed7750ec2510d466e806e0255d6", + "sha256:f05a636b4564104120111800021a92e43397bc12a5c72fed7036be8556e0029e", + "sha256:f3f501f345f24383c0000395b26b726e46758b71393267aeae0bd36f8b3ade80" + ], + "index": "pypi", + "version": "==4.5.1" + }, + "flake8": { + "hashes": [ + "sha256:6a35f5b8761f45c5513e3405f110a86bea57982c3b75b766ce7b65217abe1670", + "sha256:c01f8a3963b3571a8e6bd7a4063359aff90749e160778e03817cd9b71c9e07d2" + ], + "index": "pypi", + "version": "==3.6.0" + }, + "flake8-quotes": { + "hashes": [ + "sha256:fd9127ad8bbcf3b546fa7871a5266fd8623ce765ebe3d5aa5eabb80c01212b26" + ], + "index": "pypi", + "version": "==1.0.0" + }, + "mccabe": { + "hashes": [ + "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", + "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" + ], + "version": "==0.6.1" + }, + "pycodestyle": { + "hashes": [ + "sha256:cbc619d09254895b0d12c2c691e237b2e91e9b2ecf5e84c26b35400f93dcfb83", + "sha256:cbfca99bd594a10f674d0cd97a3d802a1fdef635d4361e1a2658de47ed261e3a" + ], + "version": "==2.4.0" + }, + "pyflakes": { + "hashes": [ + "sha256:9a7662ec724d0120012f6e29d6248ae3727d821bba522a0e6b356eff19126a49", + "sha256:f661252913bc1dbe7fcfcbf0af0db3f42ab65aabd1a6ca68fe5d466bace94dae" + ], + "version": "==2.0.0" + } + } +} diff --git a/README.md b/README.md index 43df882..e8ae59a 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,37 @@ GoonAuth2 is a REST API service that can be used to authorize membership in the Something Is Awful internet forum. +## Requirements + +- Pipenv +- Python3 (v3.6+) +- Redis (v5.0.0+) + ## Installation -The service is powered by **Python 3** (v3.4.3) and **Redis** (v2.8.4). Python dependencies include **Falcon**, **redis-py**, **requests**, and **gunicorn**. You can install all of these dependencies via the included `requirements.txt`: +Install dependencies with **Pipenv** via the included **Pipfile**: + +```sh +$> pipenv install +``` + +A few environment variables can be set within a **.env** file (placed in the root of this project) to customize functionality: + +- `REDIS_URL` + - **String** in the following format: `redis://[username]:[password]@[hostname]:6379` + - **Default:** "" (will attempt to connect to localhost:6379 without a username or password) +- `HASH_LIFESPAN_MINS` + - **Number** of minutes a hash is good for + - **Default:** 5 - virtualenv . - pip install -r requirements.txt +The only things stored in the database are short-lived `key:value` pairs that automatically expire in `HASH_LIFESPAN_MINS * 60` seconds. -There are a couple of values you'll need to update before the server will work. First, update `REDIS_HOST`, `REDIS_PORT`, and `REDIS_DB_NUM` to point to whatever Redis server you want to use. The only things stored in the database are short-lived `key:value` pairs that automatically expire in `HASH_LIFESPAN_MINS * 60` seconds. +The following values will also need to be set so that the server can access SA profiles: -You'll also want to set values for the following strings: `COOKIE_SESSIONID`, `COOKIE_SESSIONHASH`, `COOKIE_BBUSERID`, and `COOKIE_BBPASSWORD`. I opted to create an accompanying `local_settings.py` file and define them within that, but feel free to specify them as you wish. +- `COOKIE_SESSIONID` +- `COOKIE_SESSIONHASH` +- `COOKIE_BBUSERID` +- `COOKIE_BBPASSWORD`. These four values need to be taken from an existing logged-in user's cookies: @@ -19,7 +40,9 @@ These four values need to be taken from an existing logged-in user's cookies: Once everything is in place, you can start the server using `gunicorn`: - gunicorn server:app +```sh +$> pipenv run start-prod +``` ## Usage @@ -29,9 +52,11 @@ POST to `/v1/generate_hash/` with a JSON-encoded payload containing a `username` The returned payload will contain a `hash` key with a random 32-character alphanumeric value: - { - "hash": "hMPAtkx6xIEtVfqqP0X9bvEG8lU4Yypb" - } +```json +{ + "hash": "hMPAtkx6xIEtVfqqP0X9bvEG8lU4Yypb" +} +``` The hash will expire after **5 minutes** but can easily be re-generated after expiration by re-submitting the above request. @@ -49,6 +74,8 @@ Once the hash is in-place, POST a request to `/v1/validate_user/` with a JSON-en The returned payload will contain a `validated` key with a `boolean` value of whether or not the hash was detected : - { - "validated": true - } +```json +{ + "validated": true +} +``` diff --git a/helpers.py b/helpers.py deleted file mode 100644 index 89da4f6..0000000 --- a/helpers.py +++ /dev/null @@ -1,45 +0,0 @@ -import falcon -import json -import uuid - - -def get_json(req): - """ - Turn a request stream into a JSON dictionary - """ - try: - body = req.stream.read() - raw_json = json.loads(body.decode('utf-8')) - except Exception as ex: - ex_type = type(ex) - str_error = str(ex) - - if ex_type is ValueError: - str_error = 'You must specify a JSON-encoded body' - - raise falcon.HTTPBadRequest('JSON Error', str_error) - - return raw_json - - -def get_username(body): - """ - Pass in the request body (the output from json.loads()) and check for a username - """ - if 'username' not in body: - raise falcon.HTTPMissingParam('username') - - if not body['username']: - raise falcon.HTTPInvalidParam( - 'Username cannot be blank', - 'username' - ) - - return body['username'].replace(' ', '%20') - - -def get_hash(): - """ - Return a 32-character long random string - """ - return str(uuid.uuid4()).replace('-', '')[:32] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 820d00c..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -falcon==0.3.0 -gunicorn==19.4.5 -python-mimeparse==1.5.1 -redis==2.10.5 -requests==2.9.1 -six==1.10.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/helpers.py b/src/helpers.py new file mode 100644 index 0000000..f23cb61 --- /dev/null +++ b/src/helpers.py @@ -0,0 +1,45 @@ +import falcon +import json +import uuid + + +def get_json(req: falcon.Request) -> dict: + """ + Turn a request stream into a JSON dictionary + """ + try: + body = req.stream.read() + raw_json = json.loads(body.decode("utf-8")) + except Exception as ex: + ex_type = type(ex) + str_error = str(ex) + + if ex_type is ValueError: + str_error = "You must specify a JSON-encoded body" + + raise falcon.HTTPBadRequest("JSON Error", str_error) + + return raw_json + + +def get_username(body: dict) -> str: + """ + Pass in the request body (the output from json.loads()) and check for a username + """ + if "username" not in body: + raise falcon.HTTPMissingParam("username") + + if not body["username"]: + raise falcon.HTTPInvalidParam( + "Username cannot be blank", + "username" + ) + + return body["username"].replace(" ", "%20") + + +def get_hash() -> str: + """ + Return a 32-character long random string + """ + return str(uuid.uuid4()).replace("-", "")[:32] diff --git a/server.py b/src/server.py similarity index 60% rename from server.py rename to src/server.py index 62fce56..575602f 100644 --- a/server.py +++ b/src/server.py @@ -1,59 +1,50 @@ import json import re +import os import falcon import redis import requests -import helpers -from local_settings import COOKIE_SESSIONID, COOKIE_SESSIONHASH, COOKIE_BBUSERID, COOKIE_BBPASSWORD +from . import helpers """ Settings """ # The number of minutes hashes are good for before they're deleted -HASH_LIFESPAN_MINS = 5 -# A URL to look up SA users by their username -SA_PROFILE_URL = 'http://forums.somethingawful.com/member.php?action=getinfo&username=' +HASH_LIFESPAN_MINS = os.getenv("HASH_LIFESPAN_MINS", 5) # Cookies we'll need to spoof before we can verify a user's profile SA_COOKIES = { - 'sessionid': COOKIE_SESSIONID, - 'sessionhash': COOKIE_SESSIONHASH, - 'bbuserid': COOKIE_BBUSERID, - 'bbpassword': COOKIE_BBPASSWORD + "sessionid": os.getenv("COOKIE_SESSIONID"), + "sessionhash": os.getenv("COOKIE_SESSIONHASH"), + "bbuserid": os.getenv("COOKIE_BBUSERID"), + "bbpassword": os.getenv("COOKIE_BBPASSWORD"), } -REDIS_HOST = 'localhost' -REDIS_PORT = 6379 -REDIS_DB_NUM = 1 +# URL in the following format: redis://[username:password]@localhost:6379. +# DB number can be specified by updating "0" below +REDIS_URL = os.getenv("REDIS_URL", "") + "/0" +# A URL to look up SA users by their username +SA_PROFILE_URL = "http://forums.somethingawful.com/member.php?action=getinfo&username=" """ Begin Server """ # Connect to the Redis DB (and automatically decode values because they're all going to be strings) -redis_db = redis.StrictRedis( - host=REDIS_HOST, - port=REDIS_PORT, - db=REDIS_DB_NUM, - decode_responses=True) +redis_db = redis.StrictRedis.from_url(REDIS_URL, decode_responses=True) class RequireJSON(object): """ - The API is only intended to handle application/json requests and responses + The API is only intended to handle application/json requests """ def process_request(self, req, resp): - if not req.client_accepts_json: - raise falcon.HTTPNotAcceptable( - 'This API only supports JSON-encoded responses' - ) - - if req.method in ['POST']: - if 'application/json' not in req.content_type: + if req.method in ["POST"]: + if "application/json" not in req.content_type: raise falcon.HTTPUnsupportedMediaType( - 'This API only supports JSON-encoded requests' + "This API only supports JSON-encoded requests" ) @@ -71,8 +62,8 @@ def on_post(self, req, resp): user_hash = helpers.get_hash() redis_db.setex(username, HASH_LIFESPAN_MINS * 60, user_hash) - resp.status = falcon.HTTP_OK - resp.body = json.dumps({'hash': user_hash}) + resp.status = falcon.HTTP_200 + resp.body = json.dumps({"hash": user_hash}) class ValidateUserResource: @@ -86,11 +77,11 @@ def on_post(self, req, resp): user_hash = redis_db.get(username) if not user_hash: raise falcon.HTTPBadRequest( - 'Hash Missing', - 'A hash does not exist for this username. Run /generate_hash/ first' + "Hash Missing", + "A hash does not exist for this username. Run /generate_hash/ first" ) - # The URL to the user's profile page + # The URL to the user"s profile page profile_url = SA_PROFILE_URL + username # We can't view user profiles unless we're logged in, so we'll need to use a @@ -101,8 +92,8 @@ def on_post(self, req, resp): # Do a regex search to find the user's hash in their profile page result = re.search(user_hash, raw_profile.text) - resp.status = falcon.HTTP_OK - resp.body = json.dumps({'validated': result is not None}) + resp.status = falcon.HTTP_200 + resp.body = json.dumps({"validated": result is not None}) app = falcon.API(middleware=[ @@ -110,5 +101,5 @@ def on_post(self, req, resp): ]) generate_hash = GenerateHashResource() validate_user = ValidateUserResource() -app.add_route('/v1/generate_hash', generate_hash) -app.add_route('/v1/validate_user', validate_user) +app.add_route("/v1/generate_hash", generate_hash) +app.add_route("/v1/validate_user", validate_user) diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/mockredis/__init__.py b/src/tests/mockredis/__init__.py new file mode 100755 index 0000000..668b617 --- /dev/null +++ b/src/tests/mockredis/__init__.py @@ -0,0 +1,10 @@ +############# +# mock-redis-py v2.9.3 +# +# Manually patched with https://github.com/locationlabs/mockredis/pull/124 because otherwise we +# have to fight with bytestrings during testing +############# + +from mockredis.client import MockRedis, mock_redis_client, mock_strict_redis_client + +__all__ = ["MockRedis", "mock_redis_client", "mock_strict_redis_client"] diff --git a/src/tests/mockredis/client.py b/src/tests/mockredis/client.py new file mode 100755 index 0000000..fb33d01 --- /dev/null +++ b/src/tests/mockredis/client.py @@ -0,0 +1,1601 @@ +from __future__ import division +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from datetime import datetime, timedelta +from hashlib import sha1 +from operator import add +from random import choice, sample +import re +import sys +import time +import fnmatch + +from mockredis.clock import SystemClock +from mockredis.lock import MockRedisLock +from mockredis.exceptions import RedisError, ResponseError, WatchError +from mockredis.pipeline import MockRedisPipeline +from mockredis.script import Script +from mockredis.sortedset import SortedSet + +if sys.version_info >= (3, 0): + long = int + xrange = range + basestring = str + from functools import reduce + + +class MockRedis(object): + """ + A Mock for a redis-py Redis object + + Expire functionality must be explicitly + invoked using do_expire(time). Automatic + expiry is NOT supported. + """ + + def __init__(self, + strict=False, + clock=None, + load_lua_dependencies=True, + blocking_timeout=1000, + blocking_sleep_interval=0.01, + decode_responses=False, + **kwargs): + """ + Initialize as either StrictRedis or Redis. + + Defaults to non-strict. + """ + self.strict = strict + self.clock = SystemClock() if clock is None else clock + self.load_lua_dependencies = load_lua_dependencies + self.blocking_timeout = blocking_timeout + self.blocking_sleep_interval = blocking_sleep_interval + # The 'Redis' store + self.redis = defaultdict(dict) + self.redis_config = defaultdict(dict) + self.timeouts = defaultdict(dict) + # The 'PubSub' store + self.pubsub = defaultdict(list) + # Dictionary from script to sha ''Script'' + self.shas = dict() + self.decode_responses = decode_responses + + @classmethod + def from_url(cls, url, db=None, **kwargs): + return cls(**kwargs) + + # Connection Functions # + + def echo(self, msg): + return self._encode(msg) + + def ping(self): + return 'PONG' if self.decode_responses else b'PONG' + + # Transactions Functions # + + def lock(self, key, timeout=0, sleep=0): + """Emulate lock.""" + return MockRedisLock(self, key, timeout, sleep) + + def pipeline(self, transaction=True, shard_hint=None): + """Emulate a redis-python pipeline.""" + return MockRedisPipeline(self, transaction, shard_hint) + + def transaction(self, func, *watches, **kwargs): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + + Copied directly from redis-py. + """ + shard_hint = kwargs.pop('shard_hint', None) + value_from_callable = kwargs.pop('value_from_callable', False) + watch_delay = kwargs.pop('watch_delay', None) + with self.pipeline(True, shard_hint) as pipe: + while 1: + try: + if watches: + pipe.watch(*watches) + func_value = func(pipe) + exec_value = pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + time.sleep(watch_delay) + continue + + def watch(self, *argv, **kwargs): + """ + Mock does not support command buffering so watch + is a no-op + """ + pass + + def unwatch(self): + """ + Mock does not support command buffering so unwatch + is a no-op + """ + pass + + def multi(self, *argv, **kwargs): + """ + Mock does not support command buffering so multi + is a no-op + """ + pass + + def execute(self): + """Emulate the execute method. All piped commands are executed immediately + in this mock, so this is a no-op.""" + pass + + # Keys Functions # + + def type(self, key): + key = self._encode(key) + if key not in self.redis: + res = b'none' + else: + type_ = type(self.redis[key]) + if type_ is dict: + res = b'hash' + elif type_ is str: + res = b'string' + elif type_ is set: + res = b'set' + elif type_ is list: + res = b'list' + elif type_ is SortedSet: + res = b'zset' + else: + raise TypeError('unhandled type {}'.format(type_)) + if self.decode_responses: + return res.decode('utf-8') + return res + + def keys(self, pattern='*'): + """Emulate keys.""" + # making sure the pattern is unicode/str. + try: + pattern = pattern.decode('utf-8') + # This throws an AttributeError in python 3, or an + # UnicodeEncodeError in python 2 + except (AttributeError, UnicodeEncodeError): + pass + + # Make regex out of glob styled pattern. + regex = fnmatch.translate(pattern) + regex = re.compile(re.sub(r'(^|[^\\])\.', r'\1[^/]', regex)) + + # Find every key that matches the pattern + return [key for key in self.redis.keys() + if regex.match(key if self.decode_responses + else key.decode('utf-8'))] + + def delete(self, *keys): + """Emulate delete.""" + key_counter = 0 + for key in map(self._encode, keys): + if key in self.redis: + del self.redis[key] + key_counter += 1 + if key in self.timeouts: + del self.timeouts[key] + return key_counter + + def __delitem__(self, name): + if self.delete(name) == 0: + # redispy doesn't correctly raise KeyError here, so we don't either + pass + + def exists(self, key): + """Emulate exists.""" + return self._encode(key) in self.redis + __contains__ = exists + + def _expire(self, key, delta): + if key not in self.redis: + return False + + self.timeouts[key] = self.clock.now() + delta + return True + + def expire(self, key, delta): + """Emulate expire""" + delta = delta if isinstance(delta, timedelta) else timedelta(seconds=delta) + return self._expire(self._encode(key), delta) + + def pexpire(self, key, milliseconds): + """Emulate pexpire""" + return self._expire(self._encode(key), timedelta(milliseconds=milliseconds)) + + def expireat(self, key, when): + """Emulate expireat""" + expire_time = datetime.fromtimestamp(when) + key = self._encode(key) + if key in self.redis: + self.timeouts[key] = expire_time + return True + return False + + def ttl(self, key): + """ + Emulate ttl + + Even though the official redis commands documentation at http://redis.io/commands/ttl + states "Return value: Integer reply: TTL in seconds, -2 when key does not exist or -1 + when key does not have a timeout." the redis-py lib returns None for both these cases. + The lib behavior has been emulated here. + + :param key: key for which ttl is requested. + :returns: the number of seconds till timeout, None if the key does not exist or if the + key has no timeout(as per the redis-py lib behavior). + """ + value = self.pttl(key) + if value is None or value < 0: + return value + return value // 1000 + + def pttl(self, key): + """ + Emulate pttl + + :param key: key for which pttl is requested. + :returns: the number of milliseconds till timeout, None if the key does not exist or if the + key has no timeout(as per the redis-py lib behavior). + """ + """ + Returns time to live in milliseconds if output_ms is True, else returns seconds. + """ + key = self._encode(key) + if key not in self.redis: + # as of redis 2.8, -2 is returned if the key does not exist + return long(-2) if self.strict else None + if key not in self.timeouts: + # as of redis 2.8, -1 is returned if the key is persistent + # redis-py returns None; command docs say -1 + return long(-1) if self.strict else None + + time_to_live = get_total_milliseconds(self.timeouts[key] - self.clock.now()) + return long(max(-1, time_to_live)) + + def do_expire(self): + """ + Expire objects assuming now == time + """ + # Deep copy to avoid RuntimeError: dictionary changed size during iteration + _timeouts = deepcopy(self.timeouts) + for key, value in _timeouts.items(): + if value - self.clock.now() < timedelta(0): + del self.timeouts[key] + # removing the expired key + if key in self.redis: + self.redis.pop(key, None) + + def flushdb(self): + self.redis.clear() + self.pubsub.clear() + self.timeouts.clear() + + def rename(self, old_key, new_key): + return self._rename(old_key, new_key) + + def renamenx(self, old_key, new_key): + return 1 if self._rename(old_key, new_key, True) else 0 + + def _rename(self, old_key, new_key, nx=False): + old_key = self._encode(old_key) + new_key = self._encode(new_key) + if old_key in self.redis and (not nx or new_key not in self.redis): + self.redis[new_key] = self.redis.pop(old_key) + return True + return False + + def dbsize(self): + return len(self.redis.keys()) + + # String Functions # + + def get(self, key): + key = self._encode(key) + return self.redis.get(key) + + def __getitem__(self, name): + """ + Return the value at key ``name``, raises a KeyError if the key + doesn't exist. + """ + value = self.get(name) + if value is not None: + return value + raise KeyError(name) + + def mget(self, keys, *args): + args = self._list_or_args(keys, args) + return [self.get(arg) for arg in args] + + def set(self, key, value, ex=None, px=None, nx=False, xx=False): + """ + Set the ``value`` for the ``key`` in the context of the provided kwargs. + + As per the behavior of the redis-py lib: + If nx and xx are both set, the function does nothing and None is returned. + If px and ex are both set, the preference is given to px. + If the key is not set for some reason, the lib function returns None. + """ + key = self._encode(key) + value = self._encode(value) + + if nx and xx: + return None + mode = "nx" if nx else "xx" if xx else None + if self._should_set(key, mode): + expire = None + if ex is not None: + expire = ex if isinstance(ex, timedelta) else timedelta(seconds=ex) + if px is not None: + expire = px if isinstance(px, timedelta) else timedelta(milliseconds=px) + + if expire is not None and expire.total_seconds() <= 0: + raise ResponseError("invalid expire time in SETEX") + + result = self._set(key, value) + if expire: + self._expire(key, expire) + + return result + __setitem__ = set + + def getset(self, key, value): + old_value = self.get(key) + self.set(key, value) + return old_value + + def _set(self, key, value): + self.redis[key] = self._encode(value) + + # removing the timeout + if key in self.timeouts: + self.timeouts.pop(key, None) + + return True + + def _should_set(self, key, mode): + """ + Determine if it is okay to set a key. + + If the mode is None, returns True, otherwise, returns True of false based on + the value of ``key`` and the ``mode`` (nx | xx). + """ + + if mode is None or mode not in ["nx", "xx"]: + return True + + if mode == "nx": + if key in self.redis: + # nx means set only if key is absent + # false if the key already exists + return False + elif key not in self.redis: + # at this point mode can only be xx + # xx means set only if the key already exists + # false if is absent + return False + # for all other cases, return true + return True + + def setex(self, key, time, value): + """ + Set the value of ``key`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + if not self.strict: + # when not strict mode swap value and time args order + time, value = value, time + return self.set(key, value, ex=time) + + def psetex(self, key, time, value): + """ + Set the value of ``key`` to ``value`` that expires in ``time`` + milliseconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + return self.set(key, value, px=time) + + def setnx(self, key, value): + """Set the value of ``key`` to ``value`` if key doesn't exist""" + return self.set(key, value, nx=True) + + def mset(self, *args, **kwargs): + """ + Sets key/values based on a mapping. Mapping can be supplied as a single + dictionary argument or as kwargs. + """ + if args: + if len(args) != 1 or not isinstance(args[0], dict): + raise RedisError('MSET requires **kwargs or a single dict arg') + mapping = args[0] + else: + mapping = kwargs + for key, value in mapping.items(): + self.set(key, value) + return True + + def msetnx(self, *args, **kwargs): + """ + Sets key/values based on a mapping if none of the keys are already set. + Mapping can be supplied as a single dictionary argument or as kwargs. + Returns a boolean indicating if the operation was successful. + """ + if args: + if len(args) != 1 or not isinstance(args[0], dict): + raise RedisError('MSETNX requires **kwargs or a single dict arg') + mapping = args[0] + else: + mapping = kwargs + + for key in mapping.keys(): + if self._encode(key) in self.redis: + return False + for key, value in mapping.items(): + self.set(key, value) + + return True + + def decr(self, key, amount=1): + key = self._encode(key) + previous_value = long(self.redis.get(key, '0')) + self.redis[key] = self._encode(previous_value - amount) + return long(self.redis[key]) + + decrby = decr + + def incr(self, key, amount=1): + """Emulate incr.""" + key = self._encode(key) + previous_value = long(self.redis.get(key, '0')) + self.redis[key] = self._encode(previous_value + amount) + return long(self.redis[key]) + + incrby = incr + + def setbit(self, key, offset, value): + """ + Set the bit at ``offset`` in ``key`` to ``value``. + """ + key = self._encode(key) + index, bits, mask = self._get_bits_and_offset(key, offset) + + if index >= len(bits): + bits.extend(b"\x00" * (index + 1 - len(bits))) + + prev_val = 1 if (bits[index] & mask) else 0 + + if value: + bits[index] |= mask + else: + bits[index] &= ~mask + + self.redis[key] = bytes(bits) + + return prev_val + + def getbit(self, key, offset): + """ + Returns the bit value at ``offset`` in ``key``. + """ + key = self._encode(key) + index, bits, mask = self._get_bits_and_offset(key, offset) + + if index >= len(bits): + return 0 + + return 1 if (bits[index] & mask) else 0 + + def _get_bits_and_offset(self, key, offset): + bits = bytearray(self.redis.get(key, b"")) + index, position = divmod(offset, 8) + mask = 128 >> position + return index, bits, mask + + # Hash Functions # + + def hexists(self, hashkey, attribute): + """Emulate hexists.""" + + redis_hash = self._get_hash(hashkey, 'HEXISTS') + return self._encode(attribute) in redis_hash + + def hget(self, hashkey, attribute): + """Emulate hget.""" + + redis_hash = self._get_hash(hashkey, 'HGET') + return redis_hash.get(self._encode(attribute)) + + def hgetall(self, hashkey): + """Emulate hgetall.""" + + redis_hash = self._get_hash(hashkey, 'HGETALL') + return dict(redis_hash) + + def hdel(self, hashkey, *keys): + """Emulate hdel""" + + redis_hash = self._get_hash(hashkey, 'HDEL') + count = 0 + for key in keys: + attribute = self._encode(key) + if attribute in redis_hash: + count += 1 + del redis_hash[attribute] + if not redis_hash: + self.delete(hashkey) + return count + + def hlen(self, hashkey): + """Emulate hlen.""" + redis_hash = self._get_hash(hashkey, 'HLEN') + return len(redis_hash) + + def hmset(self, hashkey, value): + """Emulate hmset.""" + + redis_hash = self._get_hash(hashkey, 'HMSET', create=True) + for key, value in value.items(): + attribute = self._encode(key) + redis_hash[attribute] = self._encode(value) + return True + + def hmget(self, hashkey, keys, *args): + """Emulate hmget.""" + + redis_hash = self._get_hash(hashkey, 'HMGET') + attributes = self._list_or_args(keys, args) + return [redis_hash.get(self._encode(attribute)) for attribute in attributes] + + def hset(self, hashkey, attribute, value): + """Emulate hset.""" + + redis_hash = self._get_hash(hashkey, 'HSET', create=True) + attribute = self._encode(attribute) + attribute_present = attribute in redis_hash + redis_hash[attribute] = self._encode(value) + return long(0) if attribute_present else long(1) + + def hsetnx(self, hashkey, attribute, value): + """Emulate hsetnx.""" + + redis_hash = self._get_hash(hashkey, 'HSETNX', create=True) + attribute = self._encode(attribute) + if attribute in redis_hash: + return long(0) + else: + redis_hash[attribute] = self._encode(value) + return long(1) + + def hincrby(self, hashkey, attribute, increment=1): + """Emulate hincrby.""" + + return self._hincrby(hashkey, attribute, 'HINCRBY', long, increment) + + def hincrbyfloat(self, hashkey, attribute, increment=1.0): + """Emulate hincrbyfloat.""" + + return self._hincrby(hashkey, attribute, 'HINCRBYFLOAT', float, increment) + + def _hincrby(self, hashkey, attribute, command, type_, increment): + """Shared hincrby and hincrbyfloat routine""" + redis_hash = self._get_hash(hashkey, command, create=True) + attribute = self._encode(attribute) + previous_value = type_(redis_hash.get(attribute, '0')) + redis_hash[attribute] = self._encode(previous_value + increment) + return type_(redis_hash[attribute]) + + def hkeys(self, hashkey): + """Emulate hkeys.""" + + redis_hash = self._get_hash(hashkey, 'HKEYS') + return redis_hash.keys() + + def hvals(self, hashkey): + """Emulate hvals.""" + + redis_hash = self._get_hash(hashkey, 'HVALS') + return redis_hash.values() + + # List Functions # + + def lrange(self, key, start, stop): + """Emulate lrange.""" + redis_list = self._get_list(key, 'LRANGE') + start, stop = self._translate_range(len(redis_list), start, stop) + return redis_list[start:stop + 1] + + def lindex(self, key, index): + """Emulate lindex.""" + + redis_list = self._get_list(key, 'LINDEX') + + if self._encode(key) not in self.redis: + return None + + try: + return redis_list[index] + except (IndexError): + # Redis returns nil if the index doesn't exist + return None + + def llen(self, key): + """Emulate llen.""" + redis_list = self._get_list(key, 'LLEN') + + # Redis returns 0 if list doesn't exist + return len(redis_list) + + def _blocking_pop(self, pop_func, keys, timeout): + """Emulate blocking pop functionality""" + if not isinstance(timeout, (int, long)): + raise RuntimeError('timeout is not an integer or out of range') + + if timeout is None or timeout == 0: + timeout = self.blocking_timeout + + if isinstance(keys, basestring): + keys = [keys] + else: + keys = list(keys) + + elapsed_time = 0 + start = time.time() + while elapsed_time < timeout: + key, val = self._pop_first_available(pop_func, keys) + if val: + return key, val + # small delay to avoid high cpu utilization + time.sleep(self.blocking_sleep_interval) + elapsed_time = time.time() - start + return None + + def _pop_first_available(self, pop_func, keys): + for key in keys: + val = pop_func(key) + if val: + return self._encode(key), val + return None, None + + def blpop(self, keys, timeout=0): + """Emulate blpop""" + return self._blocking_pop(self.lpop, keys, timeout) + + def brpop(self, keys, timeout=0): + """Emulate brpop""" + return self._blocking_pop(self.rpop, keys, timeout) + + def lpop(self, key): + """Emulate lpop.""" + redis_list = self._get_list(key, 'LPOP') + + if self._encode(key) not in self.redis: + return None + + try: + value = redis_list.pop(0) + if len(redis_list) == 0: + self.delete(key) + return value + except (IndexError): + # Redis returns nil if popping from an empty list + return None + + def lpush(self, key, *args): + """Emulate lpush.""" + redis_list = self._get_list(key, 'LPUSH', create=True) + + # Creates the list at this key if it doesn't exist, and appends args to its beginning + args_reversed = [self._encode(arg) for arg in args] + args_reversed.reverse() + updated_list = args_reversed + redis_list + self.redis[self._encode(key)] = updated_list + + # Return the length of the list after the push operation + return len(updated_list) + + def rpop(self, key): + """Emulate lpop.""" + redis_list = self._get_list(key, 'RPOP') + + if self._encode(key) not in self.redis: + return None + + try: + value = redis_list.pop() + if len(redis_list) == 0: + self.delete(key) + return value + except (IndexError): + # Redis returns nil if popping from an empty list + return None + + def rpush(self, key, *args): + """Emulate rpush.""" + redis_list = self._get_list(key, 'RPUSH', create=True) + + # Creates the list at this key if it doesn't exist, and appends args to it + redis_list.extend(map(self._encode, args)) + + # Return the length of the list after the push operation + return len(redis_list) + + def lrem(self, key, value, count=0): + """Emulate lrem.""" + value = self._encode(value) + redis_list = self._get_list(key, 'LREM') + removed_count = 0 + if self._encode(key) in self.redis: + if count == 0: + # Remove all ocurrences + while redis_list.count(value): + redis_list.remove(value) + removed_count += 1 + elif count > 0: + counter = 0 + # remove first 'count' ocurrences + while redis_list.count(value): + redis_list.remove(value) + counter += 1 + removed_count += 1 + if counter >= count: + break + elif count < 0: + # remove last 'count' ocurrences + counter = -count + new_list = [] + for v in reversed(redis_list): + if v == value and counter > 0: + counter -= 1 + removed_count += 1 + else: + new_list.append(v) + redis_list[:] = list(reversed(new_list)) + if removed_count > 0 and len(redis_list) == 0: + self.delete(key) + return removed_count + + def ltrim(self, key, start, stop): + """Emulate ltrim.""" + redis_list = self._get_list(key, 'LTRIM') + if redis_list: + start, stop = self._translate_range(len(redis_list), start, stop) + self.redis[self._encode(key)] = redis_list[start:stop + 1] + return True + + def rpoplpush(self, source, destination): + """Emulate rpoplpush""" + transfer_item = self.rpop(source) + if transfer_item is not None: + self.lpush(destination, transfer_item) + return transfer_item + + def brpoplpush(self, source, destination, timeout=0): + """Emulate brpoplpush""" + transfer_item = self.brpop(source, timeout) + if transfer_item is None: + return None + + key, val = transfer_item + self.lpush(destination, val) + return val + + def lset(self, key, index, value): + """Emulate lset.""" + redis_list = self._get_list(key, 'LSET') + if redis_list is None: + raise ResponseError("no such key") + try: + redis_list[index] = self._encode(value) + except IndexError: + raise ResponseError("index out of range") + + def sort(self, name, + start=None, + num=None, + by=None, + get=None, + desc=False, + alpha=False, + store=None, + groups=False): + # check valid parameter combos + if [start, num] != [None, None] and None in [start, num]: + raise ValueError('start and num must both be specified together') + + # check up-front if there's anything to actually do + items = num != 0 and self.get(name) + if not items: + if store: + return 0 + else: + return [] + + by = self._encode(by) if by is not None else by + by_special = dict(zip(('*', 'nosort', '#'), + [self._encode(x) for x in (b'*', b'nosort', b'#')])) + # always organize the items as tuples of the value from the list and the sort key + if by and by_special['*'] in by: + items = [(i, self.get(by.replace(by_special['*'], self._encode(i)))) for i in items] + elif by in [None, by_special['nosort']]: + items = [(i, i) for i in items] + else: + raise ValueError('invalid value for "by": %s' % by) + + if by != by_special['nosort']: + # if sorting, do alpha sort or float (default) and take desc flag into account + sort_type = alpha and str or float + items.sort(key=lambda x: sort_type(x[1]), reverse=bool(desc)) + + # results is a list of lists to support different styles of get and also groups + results = [] + if get: + if isinstance(get, basestring): + # always deal with get specifiers as a list + get = [get] + for g in map(self._encode, get): + if g == by_special['#']: + results.append([self.get(i) for i in items]) + else: + results.append([self.get(g.replace(by_special['*'], self._encode(i[0]))) for i in items]) + else: + # if not using GET then returning just the item itself + results.append([i[0] for i in items]) + + # results to either list of tuples or list of values + if len(results) > 1: + results = list(zip(*results)) + elif results: + results = results[0] + + # apply the 'start' and 'num' to the results + if not start: + start = 0 + if not num: + if start: + results = results[start:] + else: + end = start + num + results = results[start:end] + + # if more than one GET then flatten if groups not wanted + if get and len(get) > 1: + if not groups: + results = list(chain(*results)) + + # either store value and return length of results or just return results + if store: + self.redis[self._encode(store)] = results + return len(results) + else: + return results + + # SCAN COMMANDS # + + def _common_scan(self, values_function, cursor='0', match=None, count=10, key=None): + """ + Common scanning skeleton. + + :param key: optional function used to identify what 'match' is applied to + """ + if count is None: + count = 10 + cursor = int(cursor) + count = int(count) + if not count: + raise ValueError('if specified, count must be > 0: %s' % count) + + values = values_function() + if cursor + count >= len(values): + # we reached the end, back to zero + result_cursor = 0 + else: + result_cursor = cursor + count + + values = values[cursor:cursor+count] + + if match is not None: + m_special = dict(zip(('^', '*', '.*', '$'), + [self._encode(x) for x in (b'^', b'\\*', b'.*', b'$')])) + regex = re.compile(m_special['^'] + re.escape(self._encode(match)).replace(m_special['*'], + m_special['.*']) + m_special['$']) + if not key: + key = lambda v: v + values = [v for v in values if regex.match(key(v))] + + return [result_cursor, values] + + def scan(self, cursor='0', match=None, count=10): + """Emulate scan.""" + def value_function(): + return sorted(self.redis.keys()) # sorted list for consistent order + return self._common_scan(value_function, cursor=cursor, match=match, count=count) + + def scan_iter(self, match=None, count=10): + """Emulate scan_iter.""" + cursor = '0' + while cursor != 0: + cursor, data = self.scan(cursor=cursor, match=match, count=count) + for item in data: + yield item + + def sscan(self, name, cursor='0', match=None, count=10): + """Emulate sscan.""" + def value_function(): + members = list(self.smembers(name)) + members.sort() # sort for consistent order + return members + return self._common_scan(value_function, cursor=cursor, match=match, count=count) + + def sscan_iter(self, name, match=None, count=10): + """Emulate sscan_iter.""" + cursor = '0' + while cursor != 0: + cursor, data = self.sscan(name, cursor=cursor, + match=match, count=count) + for item in data: + yield item + + def zscan(self, name, cursor='0', match=None, count=10): + """Emulate zscan.""" + def value_function(): + values = self.zrange(name, 0, -1, withscores=True) + values.sort(key=lambda x: x[1]) # sort for consistent order + return values + return self._common_scan(value_function, cursor=cursor, match=match, count=count, key=lambda v: v[0]) # noqa + + def zscan_iter(self, name, match=None, count=10): + """Emulate zscan_iter.""" + cursor = '0' + while cursor != 0: + cursor, data = self.zscan(name, cursor=cursor, match=match, + count=count) + for item in data: + yield item + + def hscan(self, name, cursor='0', match=None, count=10): + """Emulate hscan.""" + def value_function(): + values = self.hgetall(name) + values = list(values.items()) # list of tuples for sorting and matching + values.sort(key=lambda x: x[0]) # sort for consistent order + return values + scanned = self._common_scan(value_function, cursor=cursor, match=match, count=count, key=lambda v: v[0]) # noqa + scanned[1] = dict(scanned[1]) # from list of tuples back to dict + return scanned + + def hscan_iter(self, name, match=None, count=10): + """Emulate hscan_iter.""" + cursor = '0' + while cursor != 0: + cursor, data = self.hscan(name, cursor=cursor, + match=match, count=count) + for item in data.items(): + yield item + + # SET COMMANDS # + + def sadd(self, key, *values): + """Emulate sadd.""" + if len(values) == 0: + raise ResponseError("wrong number of arguments for 'sadd' command") + redis_set = self._get_set(key, 'SADD', create=True) + before_count = len(redis_set) + redis_set.update(map(self._encode, values)) + after_count = len(redis_set) + return after_count - before_count + + def scard(self, key): + """Emulate scard.""" + redis_set = self._get_set(key, 'SADD') + return len(redis_set) + + def sdiff(self, keys, *args): + """Emulate sdiff.""" + func = lambda left, right: left.difference(right) + return self._apply_to_sets(func, "SDIFF", keys, *args) + + def sdiffstore(self, dest, keys, *args): + """Emulate sdiffstore.""" + result = self.sdiff(keys, *args) + self.redis[self._encode(dest)] = result + return len(result) + + def sinter(self, keys, *args): + """Emulate sinter.""" + func = lambda left, right: left.intersection(right) + return self._apply_to_sets(func, "SINTER", keys, *args) + + def sinterstore(self, dest, keys, *args): + """Emulate sinterstore.""" + result = self.sinter(keys, *args) + self.redis[self._encode(dest)] = result + return len(result) + + def sismember(self, name, value): + """Emulate sismember.""" + redis_set = self._get_set(name, 'SISMEMBER') + if not redis_set: + return 0 + + result = self._encode(value) in redis_set + return 1 if result else 0 + + def smembers(self, name): + """Emulate smembers.""" + return self._get_set(name, 'SMEMBERS').copy() + + def smove(self, src, dst, value): + """Emulate smove.""" + src_set = self._get_set(src, 'SMOVE') + dst_set = self._get_set(dst, 'SMOVE') + value = self._encode(value) + + if value not in src_set: + return False + + src_set.discard(value) + dst_set.add(value) + self.redis[self._encode(src)], self.redis[self._encode(dst)] = src_set, dst_set + return True + + def spop(self, name): + """Emulate spop.""" + redis_set = self._get_set(name, 'SPOP') + if not redis_set: + return None + member = choice(list(redis_set)) + redis_set.remove(member) + if len(redis_set) == 0: + self.delete(name) + return member + + def srandmember(self, name, number=None): + """Emulate srandmember.""" + redis_set = self._get_set(name, 'SRANDMEMBER') + if not redis_set: + return None if number is None else [] + if number is None: + return choice(list(redis_set)) + elif number > 0: + return sample(list(redis_set), min(number, len(redis_set))) + else: + return [choice(list(redis_set)) for _ in xrange(abs(number))] + + def srem(self, key, *values): + """Emulate srem.""" + redis_set = self._get_set(key, 'SREM') + if not redis_set: + return 0 + before_count = len(redis_set) + for value in values: + redis_set.discard(self._encode(value)) + after_count = len(redis_set) + if before_count > 0 and len(redis_set) == 0: + self.delete(key) + return before_count - after_count + + def sunion(self, keys, *args): + """Emulate sunion.""" + func = lambda left, right: left.union(right) + return self._apply_to_sets(func, "SUNION", keys, *args) + + def sunionstore(self, dest, keys, *args): + """Emulate sunionstore.""" + result = self.sunion(keys, *args) + self.redis[self._encode(dest)] = result + return len(result) + + # SORTED SET COMMANDS # + + def zadd(self, name, *args, **kwargs): + zset = self._get_zset(name, "ZADD", create=True) + + pieces = [] + + # args + if len(args) % 2 != 0: + raise RedisError("ZADD requires an equal number of " + "values and scores") + for i in xrange(len(args) // 2): + # interpretation of args order depends on whether Redis + # or StrictRedis is used + score = args[2 * i + (0 if self.strict else 1)] + member = args[2 * i + (1 if self.strict else 0)] + pieces.append((member, score)) + + # kwargs + pieces.extend(kwargs.items()) + + insert_count = lambda member, score: 1 if zset.insert(self._encode(member), float(score)) else 0 # noqa + return sum((insert_count(member, score) for member, score in pieces)) + + def zcard(self, name): + zset = self._get_zset(name, "ZCARD") + + return len(zset) if zset is not None else 0 + + def zcount(self, name, min, max): + zset = self._get_zset(name, "ZCOUNT") + + if not zset: + return 0 + + return len(zset.scorerange(float(min), float(max))) + + def zincrby(self, name, value, amount=1): + zset = self._get_zset(name, "ZINCRBY", create=True) + + value = self._encode(value) + score = zset.score(value) or 0.0 + score += float(amount) + zset[value] = score + return score + + def zinterstore(self, dest, keys, aggregate=None): + aggregate_func = self._aggregate_func(aggregate) + + members = {} + + for key in keys: + zset = self._get_zset(key, "ZINTERSTORE") + if not zset: + return 0 + + for score, member in zset: + members.setdefault(member, []).append(score) + + intersection = SortedSet() + for member, scores in members.items(): + if len(scores) != len(keys): + continue + intersection[member] = reduce(aggregate_func, scores) + + # always override existing keys + self.redis[self._encode(dest)] = intersection + return len(intersection) + + def zrange(self, name, start, end, desc=False, withscores=False, + score_cast_func=float): + zset = self._get_zset(name, "ZRANGE") + + if not zset: + return [] + + start, end = self._translate_range(len(zset), start, end) + + func = self._range_func(withscores, score_cast_func) + return [func(item) for item in zset.range(start, end, desc)] + + def zrangebyscore(self, name, min, max, start=None, num=None, + withscores=False, score_cast_func=float): + if (start is None) ^ (num is None): + raise RedisError('`start` and `num` must both be specified') + + zset = self._get_zset(name, "ZRANGEBYSCORE") + + if not zset: + return [] + + func = self._range_func(withscores, score_cast_func) + include_start, min = self._score_inclusive(min) + include_end, max = self._score_inclusive(max) + scorerange = zset.scorerange(min, max, start_inclusive=include_start, end_inclusive=include_end) # noqa + if start is not None and num is not None: + start, num = self._translate_limit(len(scorerange), int(start), int(num)) + scorerange = scorerange[start:start + num] + return [func(item) for item in scorerange] + + def zrank(self, name, value): + zset = self._get_zset(name, "ZRANK") + + return zset.rank(self._encode(value)) if zset else None + + def zrem(self, name, *values): + zset = self._get_zset(name, "ZREM") + + if not zset: + return 0 + + count_removals = lambda value: 1 if zset.remove(self._encode(value)) else 0 + removal_count = sum((count_removals(value) for value in values)) + if removal_count > 0 and len(zset) == 0: + self.delete(name) + return removal_count + + def zremrangebyrank(self, name, start, end): + zset = self._get_zset(name, "ZREMRANGEBYRANK") + + if not zset: + return 0 + + start, end = self._translate_range(len(zset), start, end) + count_removals = lambda score, member: 1 if zset.remove(member) else 0 + removal_count = sum((count_removals(score, member) for score, member in zset.range(start, end))) # noqa + if removal_count > 0 and len(zset) == 0: + self.delete(name) + return removal_count + + def zremrangebyscore(self, name, min, max): + zset = self._get_zset(name, "ZREMRANGEBYSCORE") + + if not zset: + return 0 + + count_removals = lambda score, member: 1 if zset.remove(member) else 0 + include_start, min = self._score_inclusive(min) + include_end, max = self._score_inclusive(max) + + removal_count = sum((count_removals(score, member) + for score, member in zset.scorerange(min, max, + start_inclusive=include_start, + end_inclusive=include_end))) + if removal_count > 0 and len(zset) == 0: + self.delete(name) + return removal_count + + def zrevrange(self, name, start, end, withscores=False, + score_cast_func=float): + return self.zrange(name, start, end, + desc=True, withscores=withscores, score_cast_func=score_cast_func) + + def zrevrangebyscore(self, name, max, min, start=None, num=None, + withscores=False, score_cast_func=float): + + if (start is None) ^ (num is None): + raise RedisError('`start` and `num` must both be specified') + + zset = self._get_zset(name, "ZREVRANGEBYSCORE") + if not zset: + return [] + + func = self._range_func(withscores, score_cast_func) + include_start, min = self._score_inclusive(min) + include_end, max = self._score_inclusive(max) + + scorerange = [x for x in reversed(zset.scorerange(float(min), float(max), + start_inclusive=include_start, + end_inclusive=include_end))] + if start is not None and num is not None: + start, num = self._translate_limit(len(scorerange), int(start), int(num)) + scorerange = scorerange[start:start + num] + return [func(item) for item in scorerange] + + def zrevrank(self, name, value): + zset = self._get_zset(name, "ZREVRANK") + + if zset is None: + return None + + rank = zset.rank(self._encode(value)) + if rank is None: + return None + + return len(zset) - rank - 1 + + def zscore(self, name, value): + zset = self._get_zset(name, "ZSCORE") + + return zset.score(self._encode(value)) if zset is not None else None + + def zunionstore(self, dest, keys, aggregate=None): + union = SortedSet() + aggregate_func = self._aggregate_func(aggregate) + + for key in keys: + zset = self._get_zset(key, "ZUNIONSTORE") + if not zset: + continue + + for score, member in zset: + if member in union: + union[member] = aggregate_func(union[member], score) + else: + union[member] = score + + # always override existing keys + self.redis[self._encode(dest)] = union + return len(union) + + # Script Commands # + + def eval(self, script, numkeys, *keys_and_args): + """Emulate eval""" + sha = self.script_load(script) + return self.evalsha(sha, numkeys, *keys_and_args) + + def evalsha(self, sha, numkeys, *keys_and_args): + """Emulates evalsha""" + if not self.script_exists(sha)[0]: + raise RedisError("Sha not registered") + script_callable = Script(self, self.shas[sha], self.load_lua_dependencies) + numkeys = max(numkeys, 0) + keys = keys_and_args[:numkeys] + args = keys_and_args[numkeys:] + return script_callable(keys, args) + + def script_exists(self, *args): + """Emulates script_exists""" + return [arg in self.shas for arg in args] + + def script_flush(self): + """Emulate script_flush""" + self.shas.clear() + + def script_kill(self): + """Emulate script_kill""" + """XXX: To be implemented, should not be called before that.""" + raise NotImplementedError("Not yet implemented.") + + def script_load(self, script): + """Emulate script_load""" + sha_digest = sha1(script.encode("utf-8")).hexdigest() + self.shas[sha_digest] = script + return sha_digest + + def register_script(self, script): + """Emulate register_script""" + return Script(self, script, self.load_lua_dependencies) + + def call(self, command, *args): + """ + Sends call to the function, whose name is specified by command. + + Used by Script invocations and normalizes calls using standard + Redis arguments to use the expected redis-py arguments. + """ + command = self._normalize_command_name(command) + args = self._normalize_command_args(command, *args) + + redis_function = getattr(self, command) + value = redis_function(*args) + return self._normalize_command_response(command, value) + + def _normalize_command_name(self, command): + """ + Modifies the command string to match the redis client method name. + """ + command = command.lower() + + if command == 'del': + return 'delete' + + return command + + def _normalize_command_args(self, command, *args): + """ + Modifies the command arguments to match the + strictness of the redis client. + """ + if command == 'zadd' and not self.strict and len(args) >= 3: + # Reorder score and name + zadd_args = [x for tup in zip(args[2::2], args[1::2]) for x in tup] + return [args[0]] + zadd_args + + if command in ('zrangebyscore', 'zrevrangebyscore'): + # expected format is: name min max start num with_scores score_cast_func + if len(args) <= 3: + # just plain min/max + return args + + start, num = None, None + withscores = False + + for i, arg in enumerate(args[3:], 3): + # keywords are case-insensitive + lower_arg = self._encode(arg).lower() + + # handle "limit" + if lower_arg == b"limit" and i + 2 < len(args): + start, num = args[i + 1], args[i + 2] + + # handle "withscores" + if lower_arg == b"withscores": + withscores = True + + # do not expect to set score_cast_func + + return args[:3] + (start, num, withscores) + + return args + + def _normalize_command_response(self, command, response): + if command in ('zrange', 'zrevrange', 'zrangebyscore', 'zrevrangebyscore'): + if response and isinstance(response[0], tuple): + return [value for tpl in response for value in tpl] + + return response + + # Config Set/Get commands # + + def config_set(self, name, value): + """ + Set a configuration parameter. + """ + self.redis_config[name] = value + + def config_get(self, pattern='*'): + """ + Get one or more configuration parameters. + """ + result = {} + for name, value in self.redis_config.items(): + if fnmatch.fnmatch(name, pattern): + try: + result[name] = int(value) + except ValueError: + result[name] = value + return result + + # PubSub commands # + + def publish(self, channel, message): + self.pubsub[channel].append(message) + + # Internal # + + def _get_list(self, key, operation, create=False): + """ + Get (and maybe create) a list by name. + """ + return self._get_by_type(key, operation, create, b'list', []) + + def _get_set(self, key, operation, create=False): + """ + Get (and maybe create) a set by name. + """ + return self._get_by_type(key, operation, create, b'set', set()) + + def _get_hash(self, name, operation, create=False): + """ + Get (and maybe create) a hash by name. + """ + return self._get_by_type(name, operation, create, b'hash', {}) + + def _get_zset(self, name, operation, create=False): + """ + Get (and maybe create) a sorted set by name. + """ + return self._get_by_type(name, operation, create, b'zset', SortedSet(), return_default=False) # noqa + + def _get_by_type(self, key, operation, create, type_, default, return_default=True): + """ + Get (and maybe create) a redis data structure by name and type. + """ + key = self._encode(key) + keys = [self._encode(x) for x in (type_, b'none')] + if self.type(key) in keys: + if create: + return self.redis.setdefault(key, default) + else: + return self.redis.get(key, default if return_default else None) + + raise TypeError("{} requires a {}".format(operation, type_)) + + def _translate_range(self, len_, start, end): + """ + Translate range to valid bounds. + """ + if start < 0: + start += len_ + start = max(0, min(start, len_)) + if end < 0: + end += len_ + end = max(-1, min(end, len_ - 1)) + return start, end + + def _translate_limit(self, len_, start, num): + """ + Translate limit to valid bounds. + """ + if start > len_ or num <= 0: + return 0, 0 + return min(start, len_), num + + def _range_func(self, withscores, score_cast_func): + """ + Return a suitable function from (score, member) + """ + if withscores: + return lambda score_member: (score_member[1], score_cast_func(self._encode(score_member[0]))) # noqa + else: + return lambda score_member: score_member[1] + + def _aggregate_func(self, aggregate): + """ + Return a suitable aggregate score function. + """ + funcs = {"sum": add, "min": min, "max": max} + func_name = aggregate.lower() if aggregate else 'sum' + try: + return funcs[func_name] + except KeyError: + raise TypeError("Unsupported aggregate: {}".format(aggregate)) + + def _apply_to_sets(self, func, operation, keys, *args): + """Helper function for sdiff, sinter, and sunion""" + keys = self._list_or_args(keys, args) + if not keys: + raise TypeError("{} takes at least two arguments".format(operation.lower())) + left = self._get_set(keys[0], operation) or set() + for key in keys[1:]: + right = self._get_set(key, operation) or set() + left = func(left, right) + return left + + def _list_or_args(self, keys, args): + """ + Shamelessly copied from redis-py. + """ + # returns a single list combining keys and args + try: + iter(keys) + # a string can be iterated, but indicates + # keys wasn't passed as a list + if isinstance(keys, basestring): + keys = [keys] + except TypeError: + keys = [keys] + if args: + keys.extend(args) + return keys + + def _score_inclusive(self, score): + if isinstance(score, basestring) and score[0] == '(': + return False, float(score[1:]) + return True, float(score) + + def _encode(self, value): + "Return a bytestring representation of the value. Taken from redis-py connection.py" + if isinstance(value, bytes): + value = value + elif isinstance(value, (int, long)): + value = str(value).encode('utf-8') + elif isinstance(value, float): + value = repr(value).encode('utf-8') + elif not isinstance(value, basestring): + value = str(value).encode('utf-8') + else: + value = value.encode('utf-8', 'strict') + + if self.decode_responses: + return value.decode('utf-8') + return value + + +def get_total_milliseconds(td): + return int((td.days * 24 * 60 * 60 + td.seconds) * 1000 + td.microseconds / 1000.0) + + +def mock_redis_client(**kwargs): + """ + Mock common.util.redis_client so we + can return a MockRedis object + instead of a Redis object. + """ + return MockRedis(**kwargs) + +mock_redis_client.from_url = mock_redis_client + + +def mock_strict_redis_client(**kwargs): + """ + Mock common.util.redis_client so we + can return a MockRedis object + instead of a StrictRedis object. + """ + return MockRedis(strict=True, **kwargs) + +mock_strict_redis_client.from_url = mock_strict_redis_client diff --git a/src/tests/mockredis/clock.py b/src/tests/mockredis/clock.py new file mode 100755 index 0000000..070a948 --- /dev/null +++ b/src/tests/mockredis/clock.py @@ -0,0 +1,24 @@ +""" +Simple clock abstraction. +""" +from abc import ABCMeta, abstractmethod +from datetime import datetime + + +class Clock(object): + """ + A clock knows the current time. + + Clock can be subclassed for testing scenarios that need to control for time. + """ + __metaclass__ = ABCMeta + + @abstractmethod + def now(self): + pass + + +class SystemClock(Clock): + + def now(self): + return datetime.now() diff --git a/src/tests/mockredis/exceptions.py b/src/tests/mockredis/exceptions.py new file mode 100755 index 0000000..ddda94c --- /dev/null +++ b/src/tests/mockredis/exceptions.py @@ -0,0 +1,18 @@ +""" +Emulates exceptions raised by the Redis client, if necessary. +""" + +try: + # Prefer actual exceptions to defining our own, so code that swaps + # in implementations does not have to swap in different exception + # classes. + from redis.exceptions import RedisError, ResponseError, WatchError +except ImportError: + class RedisError(Exception): + pass + + class ResponseError(RedisError): + pass + + class WatchError(RedisError): + pass diff --git a/src/tests/mockredis/lock.py b/src/tests/mockredis/lock.py new file mode 100755 index 0000000..ddef603 --- /dev/null +++ b/src/tests/mockredis/lock.py @@ -0,0 +1,30 @@ +class MockRedisLock(object): + """ + Poorly imitate a Redis lock object from redis-py + to allow testing without a real redis server. + """ + + def __init__(self, redis, name, timeout=None, sleep=0.1): + """Initialize the object.""" + + self.redis = redis + self.name = name + self.acquired_until = None + self.timeout = timeout + self.sleep = sleep + + def acquire(self, blocking=True): # pylint: disable=R0201,W0613 + """Emulate acquire.""" + + return True + + def release(self): # pylint: disable=R0201 + """Emulate release.""" + + return + + def __enter__(self): + return self.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self.release() diff --git a/src/tests/mockredis/noseplugin.py b/src/tests/mockredis/noseplugin.py new file mode 100755 index 0000000..009c614 --- /dev/null +++ b/src/tests/mockredis/noseplugin.py @@ -0,0 +1,79 @@ +""" +This module includes a nose plugin that allows unit tests to be run with a real +redis-server instance, as long as redis-py is installed. + +This provides a simple way to verify that mockredis tests are accurate (at least +for a particular version of redis-server and redis-py). + +Usage: + + nosetests --use-redis [--redis-host ] [--redis-database ] [args] + +For this plugin to work, several things need to be true: + + 1. Nose and setuptools need to be used to invoke tests (so the plugin will work). + + Note that the setuptools "entry_point" for "nose.plugins.0.10" must be activated. + + 2. A version of redis-py must be installed in the virtualenv under test. + + 3. A redis-server instance must be running locally. + + 4. The redis-server must have a database that can be flushed between tests. + + YOU WILL LOSE DATA OTHERWISE. + + By default, database 15 is used. + + 5. Tests must be written without any references to internal mockredis state. Essentially, + that means testing GET and SET together instead of separately and not looking at the contents + of `self.redis.redis` (because this won't exist for redis-py). +""" +from functools import partial +import os + +from nose.plugins import Plugin + +from mockredis import MockRedis + + +class WithRedis(Plugin): + """ + Nose plugin to allow selection of redis-server. + """ + def options(self, parser, env=os.environ): + parser.add_option("--use-redis", + dest="use_redis", + action="store_true", + default=False, + help="Use a local redis instance to validate tests.") + parser.add_option("--redis-host", + dest="redis_host", + default="localhost", + help="Run tests against redis database on another host") + parser.add_option("--redis-database", + dest="redis_database", + default=15, + help="Run tests against local redis database") + + def configure(self, options, conf): + if options.use_redis: + from redis import Redis, RedisError, ResponseError, StrictRedis, WatchError + + WithRedis.Redis = partial(Redis, + db=options.redis_database, + host=options.redis_host) + WithRedis.StrictRedis = partial(StrictRedis, + db=options.redis_database, + host=options.redis_host) + WithRedis.ResponseError = ResponseError + WithRedis.RedisError = RedisError + WithRedis.WatchError = WatchError + else: + from mockredis.exceptions import RedisError, ResponseError, WatchError + + WithRedis.Redis = MockRedis + WithRedis.StrictRedis = partial(MockRedis, strict=True) + WithRedis.ResponseError = ResponseError + WithRedis.RedisError = RedisError + WithRedis.WatchError = WatchError diff --git a/src/tests/mockredis/pipeline.py b/src/tests/mockredis/pipeline.py new file mode 100755 index 0000000..f178570 --- /dev/null +++ b/src/tests/mockredis/pipeline.py @@ -0,0 +1,80 @@ +from copy import deepcopy + +from mockredis.exceptions import RedisError, WatchError + + +class MockRedisPipeline(object): + """ + Simulates a redis-python pipeline object. + """ + + def __init__(self, mock_redis, transaction=True, shard_hint=None): + self.mock_redis = mock_redis + self._reset() + + def __getattr__(self, name): + """ + Handle all unfound attributes by adding a deferred function call that + delegates to the underlying mock redis instance. + """ + command = getattr(self.mock_redis, name) + if not callable(command): + raise AttributeError(name) + + def wrapper(*args, **kwargs): + if self.watching and not self.explicit_transaction: + # execute the command immediately + return command(*args, **kwargs) + else: + self.commands.append(lambda: command(*args, **kwargs)) + return self + return wrapper + + def watch(self, *keys): + """ + Put the pipeline into immediate execution mode. + Does not actually watch any keys. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + self.watching = True + for key in keys: + self._watched_keys[key] = deepcopy(self.mock_redis.redis.get(self.mock_redis._encode(key))) # noqa + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.commands: + raise RedisError("Commands without an initial WATCH have already been issued") + self.explicit_transaction = True + + def execute(self): + """ + Execute all of the saved commands and return results. + """ + try: + for key, value in self._watched_keys.items(): + if self.mock_redis.redis.get(self.mock_redis._encode(key)) != value: + raise WatchError("Watched variable changed.") + return [command() for command in self.commands] + finally: + self._reset() + + def _reset(self): + """ + Reset instance variables. + """ + self.commands = [] + self.watching = False + self._watched_keys = {} + self.explicit_transaction = False + + def __exit__(self, *argv, **kwargs): + pass + + def __enter__(self, *argv, **kwargs): + return self diff --git a/src/tests/mockredis/script.py b/src/tests/mockredis/script.py new file mode 100755 index 0000000..c0a3d48 --- /dev/null +++ b/src/tests/mockredis/script.py @@ -0,0 +1,179 @@ +import sys +import threading +from mockredis.exceptions import ResponseError + +LuaLock = threading.Lock() + + +class Script(object): + """ + An executable Lua script object returned by ``MockRedis.register_script``. + """ + + def __init__(self, registered_client, script, load_dependencies=True): + self.registered_client = registered_client + self.script = script + self.load_dependencies = load_dependencies + self.sha = registered_client.script_load(script) + + def __call__(self, keys=[], args=[], client=None): + """Execute the script, passing any required ``args``""" + with LuaLock: + client = client or self.registered_client + + if not client.script_exists(self.sha)[0]: + self.sha = client.script_load(self.script) + + return self._execute_lua(keys, args, client) + + def _execute_lua(self, keys, args, client): + """ + Sets KEYS and ARGV alongwith redis.call() function in lua globals + and executes the lua redis script + """ + lua, lua_globals = Script._import_lua(self.load_dependencies) + lua_globals.KEYS = self._python_to_lua(keys) + lua_globals.ARGV = self._python_to_lua(args) + + def _call(*call_args): + # redis-py and native redis commands are mostly compatible argument + # wise, but some exceptions need to be handled here: + if str(call_args[0]).lower() == 'lrem': + response = client.call( + call_args[0], call_args[1], + call_args[3], # "count", default is 0 + call_args[2]) + else: + response = client.call(*call_args) + return self._python_to_lua(response) + + lua_globals.redis = {"call": _call} + return self._lua_to_python(lua.execute(self.script), return_status=True) + + @staticmethod + def _import_lua(load_dependencies=True): + """ + Import lua and dependencies. + + :param load_dependencies: should Lua library dependencies be loaded? + :raises: RuntimeError if Lua is not available + """ + try: + import lua + except ImportError: + raise RuntimeError("Lua not installed") + + lua_globals = lua.globals() + if load_dependencies: + Script._import_lua_dependencies(lua, lua_globals) + return lua, lua_globals + + @staticmethod + def _import_lua_dependencies(lua, lua_globals): + """ + Imports lua dependencies that are supported by redis lua scripts. + + The current implementation is fragile to the target platform and lua version + and may be disabled if these imports are not needed. + + Included: + - cjson lib. + Pending: + - base lib. + - table lib. + - string lib. + - math lib. + - debug lib. + - cmsgpack lib. + """ + if sys.platform not in ('darwin', 'windows'): + import ctypes + ctypes.CDLL('liblua5.2.so', mode=ctypes.RTLD_GLOBAL) + + try: + lua_globals.cjson = lua.eval('require "cjson"') + except RuntimeError: + raise RuntimeError("cjson not installed") + + @staticmethod + def _lua_to_python(lval, return_status=False): + """ + Convert Lua object(s) into Python object(s), as at times Lua object(s) + are not compatible with Python functions + """ + import lua + lua_globals = lua.globals() + if lval is None: + # Lua None --> Python None + return None + if lua_globals.type(lval) == "table": + # Lua table --> Python list + pval = [] + for i in lval: + if return_status: + if i == 'ok': + return lval[i] + if i == 'err': + raise ResponseError(lval[i]) + pval.append(Script._lua_to_python(lval[i])) + return pval + elif isinstance(lval, long): + # Lua number --> Python long + return long(lval) + elif isinstance(lval, float): + # Lua number --> Python float + return float(lval) + elif lua_globals.type(lval) == "userdata": + # Lua userdata --> Python string + return str(lval) + elif lua_globals.type(lval) == "string": + # Lua string --> Python string + return lval + elif lua_globals.type(lval) == "boolean": + # Lua boolean --> Python bool + return bool(lval) + raise RuntimeError("Invalid Lua type: " + str(lua_globals.type(lval))) + + @staticmethod + def _python_to_lua(pval): + """ + Convert Python object(s) into Lua object(s), as at times Python object(s) + are not compatible with Lua functions + """ + import lua + if pval is None: + # Python None --> Lua None + return lua.eval("") + if isinstance(pval, (list, tuple, set)): + # Python list --> Lua table + # e.g.: in lrange + # in Python returns: [v1, v2, v3] + # in Lua returns: {v1, v2, v3} + lua_list = lua.eval("{}") + lua_table = lua.eval("table") + for item in pval: + lua_table.insert(lua_list, Script._python_to_lua(item)) + return lua_list + elif isinstance(pval, dict): + # Python dict --> Lua dict + # e.g.: in hgetall + # in Python returns: {k1:v1, k2:v2, k3:v3} + # in Lua returns: {k1, v1, k2, v2, k3, v3} + lua_dict = lua.eval("{}") + lua_table = lua.eval("table") + for k, v in pval.iteritems(): + lua_table.insert(lua_dict, Script._python_to_lua(k)) + lua_table.insert(lua_dict, Script._python_to_lua(v)) + return lua_dict + elif isinstance(pval, str): + # Python string --> Lua userdata + return pval + elif isinstance(pval, bool): + # Python bool--> Lua boolean + return lua.eval(str(pval).lower()) + elif isinstance(pval, (int, long, float)): + # Python int --> Lua number + lua_globals = lua.globals() + return lua_globals.tonumber(str(pval)) + + raise RuntimeError("Invalid Python type: " + str(type(pval))) diff --git a/src/tests/mockredis/sortedset.py b/src/tests/mockredis/sortedset.py new file mode 100755 index 0000000..397a59a --- /dev/null +++ b/src/tests/mockredis/sortedset.py @@ -0,0 +1,153 @@ +from bisect import bisect_left, bisect_right + + +class SortedSet(object): + """ + Redis-style SortedSet implementation. + + Maintains two internal data structures: + + 1. A multimap from score to member + 2. A dictionary from member to score. + + The multimap is implemented using a sorted list of (score, member) pairs. The bisect + operations used to maintain the multimap are O(log N), but insertion into and removal + from a list are O(N), so insertion and removal O(N). It should be possible to swap in + an indexable skip list to get the expected O(log N) behavior. + """ + def __init__(self): + """ + Create an empty sorted set. + """ + # sorted list of (score, member) + self._scores = [] + # dictionary from member to score + self._members = {} + + def clear(self): + """ + Remove all members and scores from the sorted set. + """ + self.__init__() + + def __len__(self): + return len(self._members) + + def __contains__(self, member): + return member in self._members + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return "SortedSet({})".format(self._scores) + + def __eq__(self, other): + return self._scores == other._scores and self._members == other._members + + def __ne__(self, other): + return not self == other + + def __setitem__(self, member, score): + """ + Insert member with score. If member is already present in the + set, update its score. + """ + self.insert(member, score) + + def __delitem__(self, member): + """ + Remove member from the set. + """ + self.remove(member) + + def __getitem__(self, member): + """ + Get the score for a member. + """ + if isinstance(member, slice): + raise TypeError("Slicing not supported") + return self._members[member] + + def __iter__(self): + return self._scores.__iter__() + + def __reversed__(self): + return self._scores.__reversed__() + + def insert(self, member, score): + """ + Identical to __setitem__, but returns whether a member was + inserted (True) or updated (False) + """ + found = self.remove(member) + index = bisect_left(self._scores, (score, member)) + self._scores.insert(index, (score, member)) + self._members[member] = score + return not found + + def remove(self, member): + """ + Identical to __delitem__, but returns whether a member was removed. + """ + if member not in self: + return False + score = self._members[member] + score_index = bisect_left(self._scores, (score, member)) + del self._scores[score_index] + del self._members[member] + return True + + def score(self, member): + """ + Identical to __getitem__, but returns None instead of raising + KeyError if member is not found. + """ + return self._members.get(member) + + def rank(self, member): + """ + Get the rank (index of a member). + """ + score = self._members.get(member) + if score is None: + return None + return bisect_left(self._scores, (score, member)) + + def range(self, start, end, desc=False): + """ + Return (score, member) pairs between min and max ranks. + """ + if not self: + return [] + + if desc: + return reversed(self._scores[len(self) - end - 1:len(self) - start]) + else: + return self._scores[start:end + 1] + + def scorerange(self, start, end, start_inclusive=True, end_inclusive=True): + """ + Return (score, member) pairs between min and max scores. + """ + if not self: + return [] + + left = bisect_left(self._scores, (start,)) + right = bisect_right(self._scores, (end,)) + + if end_inclusive: + # end is inclusive + while right < len(self) and self._scores[right][0] == end: + right += 1 + + if not start_inclusive: + while left < right and self._scores[left][0] == start: + left += 1 + return self._scores[left:right] + + def min_score(self): + return self._scores[0][0] + + def max_score(self): + return self._scores[-1][0] diff --git a/src/tests/mocks.py b/src/tests/mocks.py new file mode 100644 index 0000000..decc145 --- /dev/null +++ b/src/tests/mocks.py @@ -0,0 +1,13 @@ +from collections import namedtuple + +from mockredis import mock_strict_redis_client + +redis_db = mock_strict_redis_client( + host="0.0.0.0", + port=6379, + db=0, + decode_responses=True, +) + +# A simple mock we can populate with a textual representation of the user's profile HTML +ProfileMock = namedtuple("ProfileMock", "text") diff --git a/src/tests/test_helpers.py b/src/tests/test_helpers.py new file mode 100644 index 0000000..442c54d --- /dev/null +++ b/src/tests/test_helpers.py @@ -0,0 +1,80 @@ +import unittest +import re +import json +from unittest.mock import MagicMock +from io import StringIO + +import falcon + +from src.helpers import get_hash, get_json, get_username + + +class GetHashTestCase(unittest.TestCase): + def test_returns_str(self): + returned = get_hash() + self.assertEqual(type(returned), str) + + def test_returns_thirty_two_characters(self): + returned = get_hash() + self.assertEqual(len(returned), 32) + + def test_contains_only_numbers_and_letters(self): + returned = get_hash() + regex_alpha_num = re.compile("^[a-zA-Z0-9]*$") + self.assertTrue(regex_alpha_num.match(returned) is not None) + + +class GetJsonTestCase(unittest.TestCase): + req = MagicMock(spec=falcon.Request) + stream = MagicMock(spec=StringIO) + + def setUp(self): + # Fake a data stream Falcon prepares from an HTTP request + self.stream.decode = MagicMock(return_value=json.dumps({"foo": "bar"})) + # Prepare a Falcon-request-like mock value + self.req.stream.read = MagicMock(return_value=self.stream) + + def test_returns_dict(self): + returned = get_json(self.req) + self.assertEqual(type(returned), dict) + + def test_returns_json_as_dict(self): + returned = get_json(self.req) + self.assertDictEqual({"foo": "bar"}, returned) + + def test_returns_400_on_empty_body(self): + self.req.stream.read = MagicMock(return_value=None) + with self.assertRaises(falcon.HTTPBadRequest): + get_json(self.req) + + def test_returns_400_with_reason(self): + self.stream.decode = MagicMock(side_effect=ValueError) + with self.assertRaisesRegex( + expected_regex="JSON-encoded body", + expected_exception=falcon.HTTPBadRequest, + ): + get_json(self.req) + + +class GetUsernameTestCase(unittest.TestCase): + def test_returns_str(self): + returned = get_username({"username": "foobar"}) + self.assertEqual(type(returned), str) + + def test_returns_username_with_encoded_spaces(self): + returned = get_username({"username": "foo bar"}) + self.assertEqual(returned, "foo%20bar") + + def test_raises_400_on_missing_username(self): + with self.assertRaisesRegex( + expected_regex="username", + expected_exception=falcon.HTTPMissingParam, + ): + get_username({}) + + def test_raises_400_on_empty_username(self): + with self.assertRaisesRegex( + expected_regex="Username cannot be blank", + expected_exception=falcon.HTTPInvalidParam, + ): + get_username({"username": ""}) diff --git a/src/tests/test_server.py b/src/tests/test_server.py new file mode 100644 index 0000000..2298ff1 --- /dev/null +++ b/src/tests/test_server.py @@ -0,0 +1,160 @@ +import json +from unittest.mock import patch + +from falcon import testing +import requests + +from src.tests import mocks + +from src import server + +req_params = { + "headers": { + "Content-Type": "application/json" + }, +} + + +class ServerTestCase(testing.TestCase): + @patch("src.server.redis_db", mocks.redis_db) + def setUp(self): + super(ServerTestCase, self).setUp() + # Initialize the server we"re testing + self.api = server.app + + +class RequireJSONMiddlewareTestCase(ServerTestCase): + def setUp(self): + super(RequireJSONMiddlewareTestCase, self).setUp() + self.url = "/v1/generate_hash/" + + def test_require_json_requests(self): + resp = self.simulate_post( + self.url, + headers={ + "Content-Type": "multipart/form-data" + }, + body="", + ) + self.assertEqual(resp.status_code, 415) + self.assertEqual(resp.json["title"], "Unsupported media type") + self.assertEqual(resp.json["description"], "This API only supports JSON-encoded requests") + + +@patch("src.server.redis_db", mocks.redis_db) +class GenerateHashTestCase(ServerTestCase): + def setUp(self): + super(GenerateHashTestCase, self).setUp() + self.url = "/v1/generate_hash/" + + def test_generate_hash_require_username(self): + resp = self.simulate_post( + self.url, + body=json.dumps({}), + **req_params, + ) + self.assertEqual(resp.status_code, 400) + self.assertEqual(resp.json["title"], "Missing parameter") + self.assertEqual(resp.json["description"], "The \"username\" parameter is required.") + + def test_return_hash_for_username(self): + resp = self.simulate_post( + self.url, + body=json.dumps({"username": "foobar"}), + **req_params, + ) + self.assertEqual(resp.status_code, 200) + self.assertIsNotNone(resp.json["hash"]) + + def test_returns_same_hash_for_same_username(self): + username = "foobar" + resp1 = self.simulate_post( + self.url, + body=json.dumps({"username": username}), + **req_params, + ) + + resp2 = self.simulate_post( + self.url, + body=json.dumps({"username": username}), + **req_params, + ) + + self.assertEqual(resp1.json["hash"], resp2.json["hash"]) + + +@patch("src.server.redis_db", mocks.redis_db) +class ValidateUserResourceTestCase(ServerTestCase): + def setUp(self): + super(ValidateUserResourceTestCase, self).setUp() + self.url = "/v1/validate_user/" + mocks.redis_db.flushdb() + + def test_require_username(self): + resp = self.simulate_post( + self.url, + body=json.dumps({}), + **req_params, + ) + self.assertEqual(resp.status_code, 400) + self.assertEqual(resp.json["title"], "Missing parameter") + self.assertEqual(resp.json["description"], "The \"username\" parameter is required.") + + def test_prompt_user_to_generate_hash_when_none_found(self): + resp = self.simulate_post( + self.url, + body=json.dumps({"username": "foobar"}), + **req_params, + ) + self.assertEqual(resp.status_code, 400) + self.assertEqual(resp.json["title"], "Hash Missing") + self.assertEqual( + resp.json["description"], + "A hash does not exist for this username. Run /generate_hash/ first", + ) + + @patch.object(requests.Session, "get") + def test_validate_hash_is_in_user_profile(self, mock_get): + username = "foobar" + # Generate a hash for the user + resp1 = self.simulate_post( + "/v1/generate_hash/", + body=json.dumps({"username": username}), + **req_params, + ) + + # The user has put their hash in their profile + mock_get.return_value = mocks.ProfileMock(text=resp1.json["hash"]) + + # Validate the existence of the hash in the profile + resp2 = self.simulate_post( + self.url, + body=json.dumps({"username": username}), + **req_params, + ) + + self.assertEqual(resp2.status_code, 200) + self.assertEqual(resp2.json["validated"], True) + + @patch.object(requests.Session, "get") + def test_validate_hash_is_not_in_user_profile(self, mock_get): + username = "foobar" + # Generate a hash for the user + self.simulate_post( + "/v1/generate_hash/", + body=json.dumps({"username": username}), + **req_params, + ) + + # The user has put their hash in their profile + mock_get.return_value = mocks.ProfileMock(text="hash_is_not_here") + + # Won"t be able to find hash in the user"s profile + resp = self.simulate_post( + self.url, + body=json.dumps({"username": username}), + **req_params, + ) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json["validated"], False)