diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91d9a37..06f377e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,13 +40,23 @@ jobs: - name: Install test dependencies run: | - pip install -r test-requirements.txt + pip install -r dev-requirements.txt - name: Run Pylama run: | - pylama redis_dict.py -i E501,E231 + python -m pylama -i E501,E231 src + + - name: Run mypy strict + run: | + mypy + + - name: Doctype Check + run: | + darglint src/redis_dict/ - name: Run Unit Tests + env: + PYTHONPATH: src run: | coverage run -m unittest discover -p "*tests.py" diff --git a/.gitignore b/.gitignore index 07e1cac..0b1e0d8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,13 @@ build dist venv .venv +.venv_* +dev_venv .hypothesis/ - -.coverage +.coverage* htmlcov + + +.idea/ diff --git a/README.md b/README.md index 863a7d5..91c17a2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # Redis-dict +[![PyPI](https://img.shields.io/pypi/v/redis-dict.svg)](https://pypi.org/project/redis-dict/) [![CI](https://github.com/Attumm/redis-dict/actions/workflows/ci.yml/badge.svg)](https://github.com/Attumm/redis-dict/actions/workflows/ci.yml) [![codecov](https://codecov.io/gh/Attumm/redis-dict/graph/badge.svg?token=Lqs7McQGEs)](https://codecov.io/gh/Attumm/redis-dict) [![Downloads](https://static.pepy.tech/badge/redis-dict/month)](https://pepy.tech/project/redis-dict) @@ -52,7 +53,6 @@ In Redis our example looks like this. ### Namespaces Acting as an identifier for your dictionary across different systems, RedisDict employs namespaces for organized data management. When a namespace isn't specified, "main" becomes the default. Thus allowing for data organization across systems and projects with the same redis instance. - This approach also minimizes the risk of key collisions between different applications, preventing hard-to-debug issues. By leveraging namespaces, RedisDict ensures a cleaner and more maintainable data management experience for developers working on multiple projects. ## Advanced Features @@ -101,7 +101,6 @@ dic['gone'] = 'gone in 5 seconds' Efficiently batch your requests using the Pipeline feature, which can be easily utilized with a context manager. ```python -from redis_dict import RedisDict dic = RedisDict(namespace="example") # one round trip to redis @@ -229,14 +228,11 @@ This approach optimizes Redis database performance and efficiency by ensuring th Following types are supported: `str, int, float, bool, NoneType, list, dict, tuple, set, datetime, date, time, timedelta, Decimal, complex, bytes, UUID, OrderedDict, defaultdict, frozenset` ```python -from redis_dict import RedisDict - from uuid import UUID from decimal import Decimal from collections import OrderedDict, defaultdict from datetime import datetime, date, time, timedelta - dic = RedisDict() dic["string"] = "Hello World" @@ -265,6 +261,32 @@ dic["default"] = defaultdict(int, {'a': 1, 'b': 2}) dic["frozen"] = frozenset([1, 2, 3]) ``` + + +### Nested types +Nested Types +RedisDict supports nested structures with mixed types through JSON serialization. The feature works by utilizing JSON encoding and decoding under the hood. While this represents an upgrade in functionality, the feature is not fully implemented and should be used with caution. For optimal performance, using shallow dictionaries is recommended. +```python +from datetime import datetime, timedelta + +dic["mixed"] = [1, "foobar", 3.14, [1, 2, 3], datetime.now()] + +dic['dic'] = {"elapsed_time": timedelta(hours=60)} +``` + +### JSON Encoding - Decoding +The nested type support in RedisDict is implemented using custom JSON encoders and decoders. These JSON encoders and decoders are built on top of RedisDict's own encoding and decoding functionality, extending it for JSON compatibility. Since JSON serialization was a frequently requested feature, these enhanced encoders and decoders are available for use in other projects: +```python +import json +from datetime import datetime +from redis_dict import RedisDictJSONDecoder, RedisDictJSONEncoder + +data = [1, "foobar", 3.14, [1, 2, 3], datetime.now()] +encoded = json.dumps(data, cls=RedisDictJSONEncoder) +result = json.loads(encoded, cls=RedisDictJSONDecoder) +``` + + ### Extending RedisDict with Custom Types RedisDict supports custom type serialization. Here's how to add a new type: @@ -272,7 +294,6 @@ RedisDict supports custom type serialization. Here's how to add a new type: ```python import json -from redis_dict import RedisDict class Person: def __init__(self, name, age): @@ -301,23 +322,13 @@ assert result.name == person.name assert result.age == person.age ``` -```python ->>> from datetime import datetime ->>> redis_dict.extends_type(datetime, datetime.isoformat, datetime.fromisoformat) ->>> redis_dict["now"] = datetime.now() ->>> redis_dict -{'now': datetime.datetime(2024, 10, 14, 18, 41, 53, 493775)} ->>> redis_dict["now"] -datetime.datetime(2024, 10, 14, 18, 41, 53, 493775) -``` - -For more information on [extending types](https://github.com/Attumm/redis-dict/blob/main/extend_types_tests.py). +For more information on [extending types](https://github.com/Attumm/redis-dict/blob/main/tests/unit/extend_types_tests.py). ### Redis Encryption Setup guide for configuring and utilizing encrypted Redis TLS for redis-dict. -[Setup guide](https://github.com/Attumm/redis-dict/blob/main/encrypted_redis.MD) +[Setup guide](https://github.com/Attumm/redis-dict/blob/main/docs/tutorials/encrypted_redis.MD) ### Redis Storage Encryption -For storing encrypted data values, it's possible to use extended types. Take a look at this [encrypted test](https://github.com/Attumm/redis-dict/blob/main/encrypt_tests.py). +For storing encrypted data values, it's possible to use extended types. Take a look at this [encrypted test](https://github.com/Attumm/redis-dict/blob/main/tests/unit/encrypt_tests.py). ### Tests The RedisDict library includes a comprehensive suite of tests that ensure its correctness and resilience. The test suite covers various data types, edge cases, and error handling scenarios. It also employs the Hypothesis library for property-based testing, which provides fuzz testing to evaluate the implementation @@ -325,19 +336,16 @@ The RedisDict library includes a comprehensive suite of tests that ensure its co ### Redis config To configure RedisDict using your Redis config. -Configure both the host and port. +Configure both the host and port. Or configuration with a setting dictionary. ```python dic = RedisDict(host='127.0.0.1', port=6380) -``` -Configuration with a dictionary. -```python redis_config = { 'host': '127.0.0.1', 'port': 6380, } -dic = RedisDict(**redis_config) +confid_dic = RedisDict(**redis_config) ``` ## Installation @@ -348,4 +356,3 @@ pip install redis-dict ### Note * Please be aware that this project is currently being utilized by various organizations in their production environments. If you have any questions or concerns, feel free to raise issues * This project only uses redis as dependency - diff --git a/test-requirements.txt b/dev-requirements.txt similarity index 52% rename from test-requirements.txt rename to dev-requirements.txt index d6027fe..eaff09b 100644 --- a/test-requirements.txt +++ b/dev-requirements.txt @@ -1,23 +1,32 @@ +astroid==3.2.4 attrs==22.2.0 cffi==1.15.1 coverage==5.5 cryptography==43.0.1 +darglint==1.8.1 +dill==0.3.9 exceptiongroup==1.1.1 future==0.18.3 hypothesis==6.70.1 +isort==5.13.2 mccabe==0.7.0 -mypy==1.1.1 +mypy==1.13.0 mypy-extensions==1.0.0 +platformdirs==4.3.6 pycodestyle==2.10.0 pycparser==2.21 pydocstyle==6.3.0 pyflakes==3.0.1 pylama==8.4.1 pylint==3.2.7 -redis==4.5.4 +redis==5.2.0 +setuptools==75.3.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 tomli==2.0.1 -types-pyOpenSSL==23.1.0.0 -types-redis==4.5.3.0 -typing_extensions==4.5.0 +tomlkit==0.13.2 +types-cffi==1.16.0.20240331 +types-pyOpenSSL==24.1.0.20240722 +types-redis==4.6.0.20241004 +types-setuptools==75.2.0.20241025 +typing_extensions==4.12.2 diff --git a/encrypted_redis.MD b/docs/tutorials/encrypted_redis.MD similarity index 100% rename from encrypted_redis.MD rename to docs/tutorials/encrypted_redis.MD diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8af7fa6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,133 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "redis-dict" +version = "3.0.0" +description = "Dictionary with Redis as storage backend" +authors = [ + {name = "Melvin Bijman", email = "bijman.m.m@gmail.com"}, +] +readme = "README.md" + +requires-python = ">=3.8" +license = {text = "MIT"} +dependencies = [ + "redis>=4.0.0", +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "Topic :: Internet", + "Topic :: Scientific/Engineering", + "Topic :: Database", + "Topic :: System :: Distributed Computing", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Object Brokering", + "Topic :: Database :: Database Engines/Servers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Typing :: Typed", +] + +keywords = [ + "redis", "python", "dictionary", "dict", "key-value", + "database", "caching", "distributed-computing", + "dictionary-interface", "large-datasets", + "scientific-computing", "data-persistence", + "high-performance", "scalable", "pipelining", + "batching", "big-data", "data-types", + "distributed-algorithms", "encryption", + "data-management", +] + +[project.optional-dependencies] +dev = [ + "coverage==5.5", + "hypothesis==6.70.1", + + "mypy>=1.8.0", + "mypy-extensions>=1.0.0", + "types-pyOpenSSL>=24.0.0.0", + "types-redis>=4.6.0", + "typing_extensions>=4.5.0", + + "pylama>=8.4.1", + "pycodestyle==2.10.0", + "pydocstyle==6.3.0", + "pyflakes==3.0.1", + "pylint==3.2.7", + "mccabe==0.7.0", + + "attrs==22.2.0", + "cffi==1.15.1", + "cryptography==43.0.1", + "exceptiongroup==1.1.1", + "future==0.18.3", + "pycparser==2.21", + "snowballstemmer==2.2.0", + "sortedcontainers==2.4.0", + "tomli==2.0.1", + "setuptools>=68.0.0", + "darglint", + "pydocstyle", +] + +docs = [ + "sphinx", + "sphinx-rtd-theme", + "sphinx-autodoc-typehints", + "tomli", + "myst-parser", +] + + +[tool.setuptools] +package-dir = {"" = "src"} +packages = ["redis_dict"] + +[tool.setuptools.package-data] +redis_dict = ["py.typed"] + +[tool.coverage.run] +source = ["redis_dict"] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if __name__ == .__main__.:", + "raise NotImplementedError", + "if TYPE_CHECKING:", +] +show_missing = true + +[tool.mypy] +python_version = "3.8" +strict = true +mypy_path = "src" +files = ["src"] +namespace_packages = true +explicit_package_bases = true + +[tool.pylama] +ignore = "E501,E231" +skip = "*/.tox/*,*/.env/*,build/*" +linters = "pycodestyle,pyflakes,mccabe" +max_line_length = 120 +paths = ["src/redis_dict"] + +[project.urls] +Homepage = "https://github.com/Attumm/redisdict" +Documentation = "https://github.com/Attumm/redisdict#readme" +Repository = "https://github.com/Attumm/redisdict.git" +Changelog = "https://github.com/Attumm/redisdict/releases" diff --git a/scripts/build_dev_checks.sh b/scripts/build_dev_checks.sh new file mode 100644 index 0000000..f6e385d --- /dev/null +++ b/scripts/build_dev_checks.sh @@ -0,0 +1,28 @@ +#!/bin/bash +set -e + +rm -rf dev_venv +python3 -m venv .venv_dev +source .venv_dev/bin/activate + +pip install --upgrade pip +pip install -e ".[dev]" + +pip freeze > dev-requirements.txt + +# Type Check +python -m mypy + +# Doctype Check +darglint src/redis_dict/ + +# Multiple linters +python -m pylama -i E501,E231 src + +# Unit tests +python -m unittest discover -s tests + +# Docstring Check +# pydocstyle src/redis_dict/ + +deactivate diff --git a/scripts/build_docs.sh b/scripts/build_docs.sh new file mode 100644 index 0000000..0972153 --- /dev/null +++ b/scripts/build_docs.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e + +rm -rf docs/Makefile docs/build/* docs/source/* + +#python3 -m venv .venv_docs + +source .venv_docs/bin/activate +pip install --upgrade pip +pip install -e ".[docs]" + +pip freeze + +python3 scripts/generate_sphinx_config.py + +sphinx-apidoc -o docs/source src/redis_dict + +cd docs +make html + +echo "Documentation built successfully in docs/build/html/" diff --git a/scripts/build_local_pkg.sh b/scripts/build_local_pkg.sh new file mode 100644 index 0000000..fd7a4d1 --- /dev/null +++ b/scripts/build_local_pkg.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir test_install_0.1.0 +cd test_install_0.1.0 +python3 -m venv venv +./venv/bin/pip install -e .. +./venv/bin/python ../tests/misc/simple_test.py diff --git a/scripts/generate_sphinx_config.py b/scripts/generate_sphinx_config.py new file mode 100644 index 0000000..a56d06c --- /dev/null +++ b/scripts/generate_sphinx_config.py @@ -0,0 +1,142 @@ +import tomli +import os +from pathlib import Path + + +def generate_configs(): + """Generate Sphinx configuration files from pyproject.toml.""" + print("Current working directory:", os.getcwd()) + + root_dir = Path(os.getcwd()) + package_dir = root_dir / 'src' / 'redis_dict' + docs_dir = root_dir / 'docs' + + print(f"Package directory: {package_dir}") + print(f"Docs directory: {docs_dir}") + + with open('pyproject.toml', 'rb') as f: + config = tomli.load(f) + + project_info = config['project'] + + docs_path = Path('docs') + docs_path.mkdir(exist_ok=True) + source_path = docs_path / 'source' + source_path.mkdir(exist_ok=True) + + docs_path = Path('docs') + docs_path.mkdir(exist_ok=True) + source_path = docs_path / 'source' + source_path.mkdir(exist_ok=True) + + tutorials_source = docs_path / 'tutorials' + tutorials_source.mkdir(exist_ok=True) + + tutorials_build = source_path / 'tutorials' + tutorials_build.mkdir(exist_ok=True) + + conf_content = f""" +import os +import sys + +# Add the package directory to Python path +package_path = os.path.abspath('{package_dir}') +src_path = os.path.dirname(package_path) +print(f"Adding to path: {{src_path}}") +print(f"Package path: {{package_path}}") +sys.path.insert(0, src_path) + +project = "{project_info['name']}" +copyright = "2024, {project_info['authors'][0]['name']}" +author = "{project_info['authors'][0]['name']}" +version = "{project_info['version']}" + +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_autodoc_typehints', + 'myst_parser', +] + +myst_update_mathjax = False +myst_enable_extensions = [ + "colon_fence", + "deflist", +] +myst_heading_anchors = 3 + +html_extra_path = ['../tutorials'] + +def setup(app): + print(f"Python path: {{sys.path}}") + +html_theme = 'sphinx_rtd_theme' +""" + + index_content = """Redis Dict Documentation +======================== + +.. include:: ../../README.md + :parser: myst_parser.sphinx_ + +.. toctree:: + :maxdepth: 2 + :caption: API Reference + + modules + redis_dict + +Indices and tables +================== + +* :ref:`genindex` +""" + + index_content1 = """ +Redis Dict Documentation +===================== + +.. include:: ../../README.md + :parser: myst_parser.sphinx_ + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + modules + +Indices and Tables +================ + +* :ref:`genindex` +""" + + makefile_content = """ +# Minimal makefile for Sphinx documentation +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +""" + + with open(source_path / 'conf.py', 'w') as f: + f.write(conf_content) + + with open(source_path / 'index.rst', 'w') as f: + f.write(index_content) + + with open(docs_path / 'Makefile', 'w') as f: + f.write(makefile_content) + + +if __name__ == '__main__': + generate_configs() \ No newline at end of file diff --git a/scripts/view_docs.sh b/scripts/view_docs.sh new file mode 100644 index 0000000..4197530 --- /dev/null +++ b/scripts/view_docs.sh @@ -0,0 +1,2 @@ +open docs/build/html/index.html # On macOS +#xdg-open docs/build/html/index.html # On Linux diff --git a/setup.py b/setup.py deleted file mode 100644 index d074fd0..0000000 --- a/setup.py +++ /dev/null @@ -1,56 +0,0 @@ -from os import path -from setuptools import setup -import io - -current = path.abspath(path.dirname(__file__)) - -with io.open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -setup( - name='redis dict', - author='Melvin Bijman', - author_email='bijman.m.m@gmail.com', - - description='Dictionary with Redis as storage backend', - long_description=long_description, - long_description_content_type='text/markdown', - - version='2.7.0', - py_modules=['redis_dict'], - install_requires=['redis',], - license='MIT', - - platforms=['any'], - - url='https://github.com/Attumm/redisdict', - - classifiers=[ - 'Development Status :: 5 - Production/Stable', - - 'Intended Audience :: Developers', - 'Intended Audience :: Information Technology', - 'Intended Audience :: Science/Research', - - 'Topic :: Internet', - 'Topic :: Scientific/Engineering', - 'Topic :: Database', - 'Topic :: System :: Distributed Computing', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: Software Development :: Object Brokering', - 'Topic :: Database :: Database Engines/Servers', - - 'License :: OSI Approved :: MIT License', - - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - ], - keywords='redis python dictionary dict key-value key:value database caching distributed-computing dictionary-interface large-datasets scientific-computing data-persistence high-performance scalable pipelining batching big-data data-types distributed-algorithms encryption data-management', -) diff --git a/src/redis_dict/__init__.py b/src/redis_dict/__init__.py new file mode 100644 index 0000000..5b3a3b0 --- /dev/null +++ b/src/redis_dict/__init__.py @@ -0,0 +1,17 @@ +"""__init__ module for redis dict.""" +from importlib.metadata import version, PackageNotFoundError + +from .core import RedisDict +from .type_management import decoding_registry, encoding_registry, RedisDictJSONEncoder, RedisDictJSONDecoder + +__all__ = [ + 'RedisDict', + 'decoding_registry', + 'encoding_registry', + 'RedisDictJSONEncoder', + 'RedisDictJSONDecoder', +] +try: + __version__ = version("redis-dict") +except PackageNotFoundError: + __version__ = "0.0.0" diff --git a/redis_dict.py b/src/redis_dict/core.py similarity index 66% rename from redis_dict.py rename to src/redis_dict/core.py index ca85561..8aaec87 100644 --- a/redis_dict.py +++ b/src/redis_dict/core.py @@ -1,174 +1,23 @@ -""" -redis_dict.py - -RedisDict is a Python library that provides a convenient and familiar interface for -interacting with Redis as if it were a Python dictionary. The simple yet powerful library -enables you to manage key-value pairs in Redis using native Python syntax of dictionary. It supports -various data types, including strings, integers, floats, booleans, lists, and dictionaries, -and includes additional utility functions for more complex use cases. - -By leveraging Redis for efficient key-value storage, RedisDict allows for high-performance -data management and is particularly useful for handling large datasets that may exceed local -memory capacity. - -## Features - -* **Dictionary-like interface**: Use familiar Python dictionary syntax to interact with Redis. -* **Data Type Support**: Comprehensive support for various data types, including strings, - integers, floats, booleans, lists, dictionaries, sets, and tuples. -* **Pipelining support**: Use pipelines for batch operations to improve performance. -* **Expiration Support**: Enables the setting of expiration times either globally or individually - per key, through the use of context managers. -* **Efficiency and Scalability**: RedisDict is designed for use with large datasets and is - optimized for efficiency. It retrieves only the data needed for a particular operation, - ensuring efficient memory usage and fast performance. -* **Namespace Management**: Provides simple and efficient namespace handling to help organize - and manage data in Redis, streamlining data access and manipulation. -* **Distributed Computing**: With its ability to seamlessly connect to other instances or - servers with access to the same Redis instance, RedisDict enables easy distributed computing. -* **Custom data types**: Add custom types and transformations to suit your specific needs. - -New feature - -Custom extendable Validity checks on keys, and values.to support redis-dict base exceptions with messages from -enabling detailed reporting on the reasons for specific validation failures. This refactor would allow users -to configure which validity checks to execute, integrate custom validation functions, and specify whether -to raise an error on validation failures or to drop the operation and log a warning. - -For example, in a caching scenario, data should only be cached if it falls within defined minimum and -maximum size constraints. This approach enables straightforward dictionary set operations while ensuring -that only meaningful data is cached: values greater than 10 MB and less than 100 MB should be cached; -otherwise, they will be dropped. - ->>> def my_custom_validity_check(value: str) -> None: - \""" - Validates the size of the input. - - Args: - value (str): string to validate. - - Raises: - RedisDictValidityException: If the length of the input is not within the allowed range. - \""" - min_size = 10 * 1024: # Minimum size: 10 KB - max_size = 10 * 1024 * 1024: # Maximum size: 10 MB - if len(value) < min_size - raise RedisDictValidityException(f"value is too small: {len(value)} bytes") - if len(value) > max_size - raise RedisDictValidityException(f"value is too large: {len(value)} bytes") - ->>> cache = RedisDict(namespace='cache_valid_results_for_1_minute', -... expire=60, -... custom_valid_values_checks=[my_custom_validity_check], # extend with new valid check -... validity_exception_suppress=True) # when value is invalid, don't store, and don't raise an exception. ->>> too_small = "too small to cache" ->>> cache["1234"] = too_small # Since the value is below 10kb, thus there is no reason to cache the value. ->>> cache.get("1234") is None ->>> True -""" -# Types imports -import json -from datetime import datetime, time, timedelta, date -from decimal import Decimal -from uuid import UUID -from collections import OrderedDict, defaultdict -import base64 - -from typing import Any, Callable, Dict, Iterator, Set, List, Tuple, Union, Optional +"""Redis Dict module.""" +from typing import Any, Dict, Iterator, List, Tuple, Union, Optional, Type + +from datetime import timedelta from contextlib import contextmanager +from collections.abc import Mapping from redis import StrictRedis -SENTINEL = object() - -EncodeFuncType = Callable[[Any], str] -DecodeFuncType = Callable[[str], Any] - -EncodeType = Dict[str, EncodeFuncType] -DecodeType = Dict[str, DecodeFuncType] - - -def _create_default_encode(custom_encode_method: str) -> EncodeFuncType: - def default_encode(obj: Any) -> str: - return getattr(obj, custom_encode_method)() # type: ignore[no-any-return] - return default_encode - - -def _create_default_decode(cls: type, custom_decode_method: str) -> DecodeFuncType: - def default_decode(encoded_str: str) -> Any: - return getattr(cls, custom_decode_method)(encoded_str) - return default_decode - - -def _decode_tuple(val: str) -> Tuple[Any, ...]: - """ - Deserialize a JSON-formatted string to a tuple. - - This function takes a JSON-formatted string, deserializes it to a list, and - then converts the list to a tuple. - - Args: - val (str): A JSON-formatted string representing a list. - - Returns: - Tuple[Any, ...]: A tuple with the deserialized values from the input string. - """ - return tuple(json.loads(val)) - - -def _encode_tuple(val: Tuple[Any, ...]) -> str: - """ - Serialize a tuple to a JSON-formatted string. - - This function takes a tuple, converts it to a list, and then serializes - the list to a JSON-formatted string. - - Args: - val (Tuple[Any, ...]): A tuple with values to be serialized. - - Returns: - str: A JSON-formatted string representing the input tuple. - """ - return json.dumps(list(val)) - - -def _decode_set(val: str) -> Set[Any]: - """ - Deserialize a JSON-formatted string to a set. - - This function takes a JSON-formatted string, deserializes it to a list, and - then converts the list to a set. - - Args: - val (str): A JSON-formatted string representing a list. - - Returns: - set[Any]: A set with the deserialized values from the input string. - """ - return set(json.loads(val)) - - -def _encode_set(val: Set[Any]) -> str: - """ - Serialize a set to a JSON-formatted string. - - This function takes a set, converts it to a list, and then serializes the - list to a JSON-formatted string. - - Args: - val (set[Any]): A set with values to be serialized. - - Returns: - str: A JSON-formatted string representing the input set. - """ - return json.dumps(list(val)) +from redis_dict.type_management import SENTINEL, EncodeFuncType, DecodeFuncType, EncodeType, DecodeType +from redis_dict.type_management import _create_default_encode, _create_default_decode, _default_decoder +from redis_dict.type_management import encoding_registry as enc_reg +from redis_dict.type_management import decoding_registry as dec_reg # pylint: disable=R0902, R0904 class RedisDict: - """ - A Redis-backed dictionary-like data structure with support for advanced features, such as - custom data types, pipelining, and key expiration. + """Python dictionary with Redis as backend. + + With support for advanced features, such as custom data types, pipelining, and key expiration. This class provides a dictionary-like interface that interacts with a Redis database, allowing for efficient storage and retrieval of key-value pairs. It supports various data types, including @@ -194,70 +43,35 @@ class RedisDict: expire (Union[int, None]): An optional expiration time for keys, in seconds. """ - decoding_registry: DecodeType = { - type('').__name__: str, - type(1).__name__: int, - type(0.1).__name__: float, - type(True).__name__: lambda x: x == "True", - type(None).__name__: lambda x: None, - - "list": json.loads, - "dict": json.loads, - "tuple": _decode_tuple, - type(set()).__name__: _decode_set, - - datetime.__name__: datetime.fromisoformat, - date.__name__: date.fromisoformat, - time.__name__: time.fromisoformat, - timedelta.__name__: lambda x: timedelta(seconds=float(x)), - - Decimal.__name__: Decimal, - complex.__name__: lambda x: complex(*map(float, x.split(','))), - bytes.__name__: base64.b64decode, - - UUID.__name__: UUID, - OrderedDict.__name__: lambda x: OrderedDict(json.loads(x)), - defaultdict.__name__: lambda x: defaultdict(type(None), json.loads(x)), - frozenset.__name__: lambda x: frozenset(json.loads(x)), - } - - encoding_registry: EncodeType = { - "list": json.dumps, - "dict": json.dumps, - "tuple": _encode_tuple, - type(set()).__name__: _encode_set, - - datetime.__name__: datetime.isoformat, - date.__name__: date.isoformat, - time.__name__: time.isoformat, - timedelta.__name__: lambda x: str(x.total_seconds()), - - complex.__name__: lambda x: f"{x.real},{x.imag}", - bytes.__name__: lambda x: base64.b64encode(x).decode('ascii'), - OrderedDict.__name__: lambda x: json.dumps(list(x.items())), - defaultdict.__name__: lambda x: json.dumps(dict(x)), - frozenset.__name__: lambda x: json.dumps(list(x)), - } + + encoding_registry: EncodeType = enc_reg + decoding_registry: DecodeType = dec_reg def __init__(self, namespace: str = 'main', expire: Union[int, timedelta, None] = None, preserve_expiration: Optional[bool] = False, + redis: "Optional[StrictRedis[Any]]" = None, **redis_kwargs: Any) -> None: - """ - Initialize a RedisDict instance. + """Initialize a RedisDict instance. + + Init the RedisDict instance. Args: - namespace (str, optional): A prefix for keys stored in Redis. - expire (int, timedelta, optional): Expiration time for keys in seconds. - preserve_expiration (bool, optional): Preserve the expiration count when the key is updated. - **redis_kwargs: Additional keyword arguments passed to StrictRedis. + namespace (str): A prefix for keys stored in Redis. + expire (Union[int, timedelta, None], optional): Expiration time for keys. + preserve_expiration (Optional[bool], optional): Preserve expiration on key updates. + redis (Optional[StrictRedis[Any]], optional): A Redis connection instance. + **redis_kwargs (Any): Additional kwargs for Redis connection if not provided. """ self.namespace: str = namespace self.expire: Union[int, timedelta, None] = expire self.preserve_expiration: Optional[bool] = preserve_expiration - self.redis: StrictRedis[Any] = StrictRedis(decode_responses=True, **redis_kwargs) + if redis: + redis.connection_pool.connection_kwargs["decode_responses"] = True + + self.redis: StrictRedis[Any] = redis or StrictRedis(decode_responses=True, **redis_kwargs) self.get_redis: StrictRedis[Any] = self.redis self.custom_encode_method = "encode" @@ -288,7 +102,7 @@ def _valid_input(self, val: Any, val_type: str) -> bool: length does not exceed the maximum allowed size (500 MB). Args: - val (Union[str, int, float, bool]): The input value to be validated. + val (Any): The input value to be validated. val_type (str): The type of the input value ("str", "int", "float", or "bool"). Returns: @@ -298,7 +112,19 @@ def _valid_input(self, val: Any, val_type: str) -> bool: return len(val) < self._max_string_size return True - def _format_value(self, key: str, value: Any) -> str: + def _format_value(self, key: str, value: Any) -> str: + """Format a valid value with the type and encoded representation of the value. + + Args: + key (str): The key of the value to be formatted. + value (Any): The value to be encoded and formatted. + + Raises: + ValueError: If the value or key fail validation. + + Returns: + str: The formatted value with the type and encoded representation of the value. + """ store_type, key = type(value).__name__, str(key) if not self._valid_input(value, store_type) or not self._valid_input(key, "str"): raise ValueError("Invalid input value or key size exceeded the maximum limit.") @@ -355,7 +181,7 @@ def _transform(self, result: str) -> Any: Any: The transformed Python object. """ type_, value = result.split(':', 1) - return self.decoding_registry.get(type_, lambda x: x)(value) + return self.decoding_registry.get(type_, _default_decoder)(value) def new_type_compliance( self, @@ -363,8 +189,7 @@ def new_type_compliance( encode_method_name: Optional[str] = None, decode_method_name: Optional[str] = None, ) -> None: - """ - Checks if a class complies with the required encoding and decoding methods. + """Check if a class complies with the required encoding and decoding methods. Args: class_type (type): The class to check for compliance. @@ -386,6 +211,7 @@ def new_type_compliance( raise NotImplementedError( f"Class {class_type.__name__} does not implement the required {decode_method_name} class method.") + # pylint: disable=too-many-arguments def extends_type( self, class_type: type, @@ -394,31 +220,21 @@ def extends_type( encoding_method_name: Optional[str] = None, decoding_method_name: Optional[str] = None, ) -> None: - """ - Extends RedisDict to support a custom type in the encode/decode mapping. + """Extend RedisDict to support a custom type in the encode/decode mapping. This method enables serialization of instances based on their type, allowing for custom types, specialized storage formats, and more. There are three ways to add custom types: - 1. Have a class with an `encode` instance method and a `decode` class method. - 2. Have a class and pass encoding and decoding functions, where - `encode` converts the class instance to a string, and - `decode` takes the string and recreates the class instance. - 3. Have a class that already has serialization methods, that satisfies the: - EncodeFuncType = Callable[[Any], str] - DecodeFuncType = Callable[[str], Any] + 1. Have a class with an `encode` instance method and a `decode` class method. + 2. Have a class and pass encoding and decoding functions, where + `encode` converts the class instance to a string, and + `decode` takes the string and recreates the class instance. + 3. Have a class that already has serialization methods, that satisfies the: + EncodeFuncType = Callable[[Any], str] + DecodeFuncType = Callable[[str], Any] - `custom_encode_method` - `custom_decode_method` attributes. - - Args: - class_type (Type[type]): The class `__name__` will become the key for the encoding and decoding functions. - encode (Optional[EncodeFuncType]): function that encodes an object into a storable string format. - This function should take an instance of `class_type` as input and return a string. - decode (Optional[DecodeFuncType]): function that decodes a string back into an object of `class_type`. - This function should take a string as input and return an instance of `class_type`. - encoding_method_name (str, optional): Name of encoding method of the class for redis-dict custom types. - decoding_method_name (str, optional): Name of decoding method of the class for redis-dict custom types. + `custom_encode_method` + `custom_decode_method` If no encoding or decoding function is provided, default to use the `encode` and `decode` methods of the class. @@ -445,11 +261,21 @@ def decode(cls, encoded_str: str) -> 'Person': redis_dict.extends_type(Person) + Args: + class_type (type): The class `__name__` will become the key for the encoding and decoding functions. + encode (Optional[EncodeFuncType]): function that encodes an object into a storable string format. + decode (Optional[DecodeFuncType]): function that decodes a string back into an object of `class_type`. + encoding_method_name (str, optional): Name of encoding method of the class for redis-dict custom types. + decoding_method_name (str, optional): Name of decoding method of the class for redis-dict custom types. + + Raises: + NotImplementedError + Note: - You can check for compliance of a class separately using the `new_type_compliance` method: + You can check for compliance of a class separately using the `new_type_compliance` method: - This method raises a NotImplementedError if either `encode` or `decode` is `None` - and the class does not implement the corresponding method. + This method raises a NotImplementedError if either `encode` or `decode` is `None` + and the class does not implement the corresponding method. """ if encode is None or decode is None: @@ -460,7 +286,7 @@ def decode(cls, encoded_str: str) -> 'Person': if decode is None: decode_method_name = decoding_method_name or self.custom_decode_method - self.new_type_compliance(class_type, decode_method_name=decode_method_name) + self.new_type_compliance(class_type, decode_method_name=decode_method_name) decode = _create_default_decode(class_type, decode_method_name) type_name = class_type.__name__ @@ -479,7 +305,7 @@ def __eq__(self, other: Any) -> bool: """ if len(self) != len(other): return False - for key, value in self.iteritems(): + for key, value in self.items(): if value != other.get(key, SENTINEL): return False return True @@ -561,7 +387,7 @@ def __iter__(self) -> Iterator[str]: Returns: Iterator[str]: An iterator over the keys of the RedisDict. """ - self._iter = self.iterkeys() + self._iter = self.keys() return self def __repr__(self) -> str: @@ -582,15 +408,102 @@ def __str__(self) -> str: """ return str(self.to_dict()) + def __or__(self, other: Dict[str, Any]) -> Dict[str, Any]: + """ + Implements the | operator (dict union). + Returns a new dictionary with items from both dictionaries. + + Args: + other (Dict[str, Any]): The dictionary to merge with. + + Raises: + TypeError: If other does not adhere to Mapping. + + Returns: + Dict[str, Any]: A new dictionary containing items from both dictionaries. + """ + if not isinstance(other, Mapping): + raise TypeError(f"unsupported operand type(s) for |: '{type(other).__name__}' and 'RedisDict'") + + result = {} + result.update(self.to_dict()) + result.update(other) + return result + + def __ror__(self, other: Dict[str, Any]) -> Dict[str, Any]: + """ + Implements the reverse | operator. + Called when RedisDict is on the right side of |. + + Args: + other (Dict[str, Any]): The dictionary to merge with. + + Raises: + TypeError: If other does not adhere to Mapping. + + Returns: + Dict[str, Any]: A new dictionary containing items from both dictionaries. + """ + if not isinstance(other, Mapping): + raise TypeError(f"unsupported operand type(s) for |: 'RedisDict' and '{type(other).__name__}'") + + result = {} + result.update(other) + result.update(self.to_dict()) + return result + + def __ior__(self, other: Dict[str, Any]) -> 'RedisDict': + """ + Implements the |= operator (in-place union). + Modifies the current dictionary by adding items from other. + + Args: + other (Dict[str, Any]): The dictionary to merge with. + + Raises: + TypeError: If other does not adhere to Mapping. + + Returns: + RedisDict: The modified RedisDict instance. + """ + if not isinstance(other, Mapping): + raise TypeError(f"unsupported operand type(s) for |: '{type(other).__name__}' and 'RedisDict'") + + self.update(other) + return self + + @classmethod + def __class_getitem__(cls: Type['RedisDict'], key: Any) -> Type['RedisDict']: + """ + Enables type hinting support like RedisDict[str, Any]. + + Args: + key (Any): The type parameter(s) used in the type hint. + + Returns: + Type[RedisDict]: The class itself, enabling type hint usage. + """ + return cls + + def __reversed__(self) -> Iterator[str]: + """ + Implements reversed() built-in: + Returns an iterator over dictionary keys in reverse insertion order. + + Warning: + RedisDict Currently does not support 'insertion order' as property thus also not reversed. + + Returns: + Iterator[str]: An iterator yielding the dictionary keys in reverse order. + """ + return reversed(list(self.keys())) + def __next__(self) -> str: """ Get the next item in the iterator. Returns: str: The next item in the iterator. - - Raises: - StopIteration: If there are no more items. """ return next(self._iter) @@ -601,8 +514,6 @@ def next(self) -> str: Returns: str: The next item in the iterator. - Raises: - StopIteration: If there are no more items. """ return next(self) @@ -633,7 +544,7 @@ def _scan_keys(self, search_term: str = '') -> Iterator[str]: Scan for Redis keys matching the given search term. Args: - search_term (str, optional): A search term to filter keys. Defaults to ''. + search_term (str): A search term to filter keys. Defaults to ''. Returns: Iterator[str]: An iterator of matching Redis keys. @@ -642,8 +553,8 @@ def _scan_keys(self, search_term: str = '') -> Iterator[str]: return self.get_redis.scan_iter(match=search_query) def get(self, key: str, default: Optional[Any] = None) -> Any: - """ - Return the value for the given key if it exists, otherwise return the default value. + """Return the value for the given key if it exists, otherwise return the default value. + Analogous to a dictionary's get method. Args: @@ -651,23 +562,30 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: default (Optional[Any], optional): The value to return if the key is not found. Returns: - Optional[Any]: The value associated with the key or the default value. + Any: The value associated with the key or the default value. """ found, item = self._load(key) if not found: return default return item - def iterkeys(self) -> Iterator[str]: - """ - Note: for python2 str is needed + def keys(self) -> Iterator[str]: + """Return an Iterator of keys in the RedisDict, analogous to a dictionary's keys method. + + Returns: + Iterator[str]: A list of keys in the RedisDict. """ to_rm = len(self.namespace) + 1 return (str(item[to_rm:]) for item in self._scan_keys()) def key(self, search_term: str = '') -> Optional[str]: - """ - Note: for python2 str is needed + """Return the first value for search_term if it exists, otherwise return None. + + Args: + search_term (str): A search term to filter keys. Defaults to ''. + + Returns: + str: The first key associated with the given search term. """ to_rm = len(self.namespace) + 1 search_query = self._create_iter_query(search_term) @@ -677,18 +595,11 @@ def key(self, search_term: str = '') -> Optional[str]: return None - def keys(self) -> List[str]: - """ - Return a list of keys in the RedisDict, analogous to a dictionary's keys method. + def items(self) -> Iterator[Tuple[str, Any]]: + """Return a list of key-value pairs (tuples) in the RedisDict, analogous to a dictionary's items method. - Returns: - List[str]: A list of keys in the RedisDict. - """ - return list(self.iterkeys()) - - def iteritems(self) -> Iterator[Tuple[str, Any]]: - """ - Note: for python2 str is needed + Yields: + Iterator[Tuple[str, Any]]: A list of key-value pairs in the RedisDict. """ to_rm = len(self.namespace) + 1 for item in self._scan_keys(): @@ -697,31 +608,14 @@ def iteritems(self) -> Iterator[Tuple[str, Any]]: except KeyError: pass - def items(self) -> List[Tuple[str, Any]]: - """ - Return a list of key-value pairs (tuples) in the RedisDict, analogous to a dictionary's items method. + def values(self) -> Iterator[Any]: + """Analogous to a dictionary's values method. - Returns: - List[Tuple[str, Any]]: A list of key-value pairs in the RedisDict. - """ - return list(self.iteritems()) + Return a list of values in the RedisDict, - def values(self) -> List[Any]: - """ - Return a list of values in the RedisDict, analogous to a dictionary's values method. - - Returns: + Yields: List[Any]: A list of values in the RedisDict. """ - return list(self.itervalues()) - - def itervalues(self) -> Iterator[Any]: - """ - Iterate over the values in the RedisDict. - - Returns: - Iterator[Any]: An iterator of values in the RedisDict. - """ to_rm = len(self.namespace) + 1 for item in self._scan_keys(): try: @@ -730,8 +624,7 @@ def itervalues(self) -> Iterator[Any]: pass def to_dict(self) -> Dict[str, Any]: - """ - Convert the RedisDict to a Python dictionary. + """Convert the RedisDict to a Python dictionary. Returns: Dict[str, Any]: A dictionary with the same key-value pairs as the RedisDict. @@ -739,8 +632,7 @@ def to_dict(self) -> Dict[str, Any]: return dict(self.items()) def clear(self) -> None: - """ - Remove all key-value pairs from the RedisDict in one batch operation using pipelining. + """Remove all key-value pairs from the RedisDict in one batch operation using pipelining. This method mimics the behavior of the `clear` method from a standard Python dictionary. Redis pipelining is employed to group multiple commands into a single request, minimizing @@ -754,33 +646,33 @@ def clear(self) -> None: del self[key] def pop(self, key: str, default: Union[Any, object] = SENTINEL) -> Any: - """ + """Analogous to a dictionary's pop method. + Remove the value associated with the given key and return it, or return the default value - if the key is not found. Analogous to a dictionary's pop method. + if the key is not found. Args: key (str): The key to remove the value. default (Optional[Any], optional): The value to return if the key is not found. Returns: - Optional[Any]: The value associated with the key or the default value. + Any: The value associated with the key or the default value. Raises: KeyError: If the key is not found and no default value is provided. """ - try: - value = self[key] - except KeyError: + formatted_key = self._format_key(key) + value = self.get_redis.execute_command("GETDEL", formatted_key) + if value is None: if default is not SENTINEL: return default - raise + raise KeyError(formatted_key) - del self[key] - return value + return self._transform(value) def popitem(self) -> Tuple[str, Any]: - """ - Remove and return a random (key, value) pair from the RedisDict as a tuple. + """Remove and return a random (key, value) pair from the RedisDict as a tuple. + This method is analogous to the `popitem` method of a standard Python dictionary. Returns: @@ -799,7 +691,8 @@ def popitem(self) -> Tuple[str, Any]: continue def setdefault(self, key: str, default_value: Optional[Any] = None) -> Any: - """ + """Get value under key, and if not present set default value. + Return the value associated with the given key if it exists, otherwise set the value to the default value and return it. Analogous to a dictionary's setdefault method. @@ -810,15 +703,28 @@ def setdefault(self, key: str, default_value: Optional[Any] = None) -> Any: Returns: Any: The value associated with the key or the default value. """ - found, value = self._load(key) - if not found: - self[key] = default_value + formatted_key = self._format_key(key) + formatted_value = self._format_value(key, default_value) + + # Setting {"get": True} enables parsing of the redis result as "GET", instead of "SET" command + options = {"get": True} + args = ["SET", formatted_key, formatted_value, "NX", "GET"] + if self.preserve_expiration: + args.append("KEEPTTL") + elif self.expire is not None: + expire_val = int(self.expire.total_seconds()) if isinstance(self.expire, timedelta) else self.expire + expire_str = str(1) if expire_val <= 1 else str(expire_val) + args.extend(["EX", expire_str]) + + result = self.get_redis.execute_command(*args, **options) + if result is None: return default_value - return value + + return self._transform(result) def copy(self) -> Dict[str, Any]: - """ - Create a shallow copy of the RedisDict and return it as a standard Python dictionary. + """Create a shallow copy of the RedisDict and return it as a standard Python dictionary. + This method is analogous to the `copy` method of a standard Python dictionary Returns: @@ -841,7 +747,8 @@ def update(self, dic: Dict[str, Any]) -> None: self[key] = value def fromkeys(self, iterable: List[str], value: Optional[Any] = None) -> 'RedisDict': - """ + """Create a new RedisDict from an iterable of key-value pairs. + Create a new RedisDict with keys from the provided iterable and values set to the given value. This method is analogous to the `fromkeys` method of a standard Python dictionary, populating the RedisDict with the keys from the iterable and setting their corresponding values to the @@ -861,10 +768,11 @@ def fromkeys(self, iterable: List[str], value: Optional[Any] = None) -> 'RedisDi return self def __sizeof__(self) -> int: - """ - Return the approximate size of the RedisDict in memory, in bytes. + """Return the approximate size of the RedisDict in memory, in bytes. + This method is analogous to the `__sizeof__` method of a standard Python dictionary, estimating the memory consumption of the RedisDict based on the serialized in-memory representation. + Should be changed to redis view of the size. Returns: int: The approximate size of the RedisDict in memory, in bytes. @@ -906,13 +814,12 @@ def chain_del(self, iterable: List[str]) -> None: # compatibility with Python 3.9 typing @contextmanager def expire_at(self, sec_epoch: Union[int, timedelta]) -> Iterator[None]: - """ - Context manager to set the expiration time for keys in the RedisDict. + """Context manager to set the expiration time for keys in the RedisDict. Args: sec_epoch (int, timedelta): The expiration duration is set using either an integer or a timedelta. - Returns: + Yields: ContextManager: A context manager during which the expiration time is the time set. """ self.expire, temp = sec_epoch, self.expire @@ -924,7 +831,7 @@ def pipeline(self) -> Iterator[None]: """ Context manager to create a Redis pipeline for batch operations. - Returns: + Yields: ContextManager: A context manager to create a Redis pipeline batching all operations within the context. """ top_level = False @@ -1006,7 +913,8 @@ def get_redis_info(self) -> Dict[str, Any]: return dict(self.redis.info()) def get_ttl(self, key: str) -> Optional[int]: - """ + """Get the Time To Live from Redis. + Get the Time To Live (TTL) in seconds for a given key. If the key does not exist or does not have an associated `expire`, return None. diff --git a/src/redis_dict/py.typed b/src/redis_dict/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/redis_dict/type_management.py b/src/redis_dict/type_management.py new file mode 100644 index 0000000..9920747 --- /dev/null +++ b/src/redis_dict/type_management.py @@ -0,0 +1,273 @@ +"""Type management module.""" + +import json +import base64 +from collections import OrderedDict, defaultdict +from datetime import datetime, date, time, timedelta + +from typing import Callable, Any, Dict, Tuple, Set + +from uuid import UUID +from decimal import Decimal + + +SENTINEL = object() + +EncodeFuncType = Callable[[Any], str] +DecodeFuncType = Callable[[str], Any] +EncodeType = Dict[str, EncodeFuncType] +DecodeType = Dict[str, DecodeFuncType] + + +def _create_default_encode(custom_encode_method: str) -> EncodeFuncType: + def default_encode(obj: Any) -> str: + return getattr(obj, custom_encode_method)() # type: ignore[no-any-return] + return default_encode + + +def _create_default_decode(cls: type, custom_decode_method: str) -> DecodeFuncType: + def default_decode(encoded_str: str) -> Any: + return getattr(cls, custom_decode_method)(encoded_str) + return default_decode + + +def _decode_tuple(val: str) -> Tuple[Any, ...]: + """ + Deserialize a JSON-formatted string to a tuple. + + This function takes a JSON-formatted string, deserializes it to a list, and + then converts the list to a tuple. + + Args: + val (str): A JSON-formatted string representing a list. + + Returns: + Tuple[Any, ...]: A tuple with the deserialized values from the input string. + """ + return tuple(json.loads(val)) + + +def _encode_tuple(val: Tuple[Any, ...]) -> str: + """ + Serialize a tuple to a JSON-formatted string. + + This function takes a tuple, converts it to a list, and then serializes + the list to a JSON-formatted string. + + Args: + val (Tuple[Any, ...]): A tuple with values to be serialized. + + Returns: + str: A JSON-formatted string representing the input tuple. + """ + return json.dumps(list(val)) + + +def _decode_set(val: str) -> Set[Any]: + """ + Deserialize a JSON-formatted string to a set. + + This function takes a JSON-formatted string, deserializes it to a list, and + then converts the list to a set. + + Args: + val (str): A JSON-formatted string representing a list. + + Returns: + set[Any]: A set with the deserialized values from the input string. + """ + return set(json.loads(val)) + + +def _encode_set(val: Set[Any]) -> str: + """ + Serialize a set to a JSON-formatted string. + + This function takes a set, converts it to a list, and then serializes the + list to a JSON-formatted string. + + Args: + val (set[Any]): A set with values to be serialized. + + Returns: + str: A JSON-formatted string representing the input set. + """ + return json.dumps(list(val)) + + +decoding_registry: DecodeType = { + type('').__name__: str, + type(1).__name__: int, + type(0.1).__name__: float, + type(True).__name__: lambda x: x == "True", + type(None).__name__: lambda x: None, + + "list": json.loads, + "dict": json.loads, + "tuple": _decode_tuple, + type(set()).__name__: _decode_set, + + datetime.__name__: datetime.fromisoformat, + date.__name__: date.fromisoformat, + time.__name__: time.fromisoformat, + timedelta.__name__: lambda x: timedelta(seconds=float(x)), + + Decimal.__name__: Decimal, + complex.__name__: lambda x: complex(*map(float, x.split(','))), + bytes.__name__: base64.b64decode, + + UUID.__name__: UUID, + OrderedDict.__name__: lambda x: OrderedDict(json.loads(x)), + defaultdict.__name__: lambda x: defaultdict(type(None), json.loads(x)), + frozenset.__name__: lambda x: frozenset(json.loads(x)), +} + + +encoding_registry: EncodeType = { + "list": json.dumps, + "dict": json.dumps, + "tuple": _encode_tuple, + type(set()).__name__: _encode_set, + + datetime.__name__: datetime.isoformat, + date.__name__: date.isoformat, + time.__name__: time.isoformat, + timedelta.__name__: lambda x: str(x.total_seconds()), + + complex.__name__: lambda x: f"{x.real},{x.imag}", + bytes.__name__: lambda x: base64.b64encode(x).decode('ascii'), + OrderedDict.__name__: lambda x: json.dumps(list(x.items())), + defaultdict.__name__: lambda x: json.dumps(dict(x)), + frozenset.__name__: lambda x: json.dumps(list(x)), +} + + +class RedisDictJSONEncoder(json.JSONEncoder): + """Extends JSON encoding capabilities by reusing RedisDict type conversion. + + Uses existing decoding_registry to know which types to handle specially and + encoding_registry (falls back to str) for converting to JSON-compatible formats. + + Example: + The encoded format looks like:: + + { + "__type__": "TypeName", + "value": + } + + Notes: + + Uses decoding_registry (containing all supported types) to check if type + needs special handling. For encoding, defaults to str() if no encoder exists + in encoding_registry. + """ + def default(self, o: Any) -> Any: + """Overwrite default from json encoder. + + Args: + o (Any): Object to be serialized. + + Raises: + TypeError: If the object `o` cannot be serialized. + + Returns: + Any: Serialized value. + """ + type_name = type(o).__name__ + if type_name in decoding_registry: + return { + "__type__": type_name, + "value": encoding_registry.get(type_name, _default_encoder)(o) + } + try: + return json.JSONEncoder.default(self, o) + except TypeError as e: + raise TypeError(f"Object of type {type_name} is not JSON serializable") from e + + +class RedisDictJSONDecoder(json.JSONDecoder): + """JSON decoder leveraging RedisDict existing type conversion system. + + Works with RedisDictJSONEncoder to reconstruct Python objects from JSON using + RedisDict decoding_registry. + + Still needs work but allows for more types than without. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Overwrite the __init__ method from JSON decoder. + + Args: + *args (Any): Positional arguments for initialization. + **kwargs (Any): Keyword arguments for initialization. + + """ + def _object_hook(obj: Dict[Any, Any]) -> Any: + if "__type__" in obj and "value" in obj: + type_name = obj["__type__"] + if type_name in decoding_registry: + return decoding_registry[type_name](obj["value"]) + return obj + + super().__init__(object_hook=_object_hook, *args, **kwargs) + + +def encode_json(obj: Any) -> str: + """ + Encode a Python object to a JSON string using the existing encoding registry. + + Args: + obj (Any): The Python object to be encoded. + + Returns: + str: The JSON-encoded string representation of the object. + """ + return json.dumps(obj, cls=RedisDictJSONEncoder) + + +def decode_json(s: str) -> Any: + """ + Decode a JSON string to a Python object using the existing decoding registry. + + Args: + s (str): The JSON string to be decoded. + + Returns: + Any: The decoded Python object. + """ + return json.loads(s, cls=RedisDictJSONDecoder) + + +def _default_decoder(x: str) -> str: + """ + Pass-through decoder that returns the input string unchanged. + + Args: + x (str): The input string. + + Returns: + str: The same input string. + """ + return x + + +def _default_encoder(x: Any) -> str: + """ + Takes x and returns the result str of the object. + + Args: + x (Any): The input object + + Returns: + str: output of str of the object + """ + return str(x) + + +encoding_registry["dict"] = encode_json +decoding_registry["dict"] = decode_json + + +encoding_registry["list"] = encode_json +decoding_registry["list"] = decode_json diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/load/__init__.py b/tests/load/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/load_test.py b/tests/load/load_test.py similarity index 100% rename from load_test.py rename to tests/load/load_test.py diff --git a/tests/load/load_test_compression.py b/tests/load/load_test_compression.py new file mode 100644 index 0000000..2a74eac --- /dev/null +++ b/tests/load/load_test_compression.py @@ -0,0 +1,223 @@ +import time +import statistics + +from redis_dict import RedisDict +import json + +# Constants +BATCH_SIZE = 1000 + + +import os +import csv +import zipfile +import requests +from typing import Iterator, Dict +from io import TextIOWrapper +import gzip +import base64 + +class GzippedDict: + """ + A class that can encode its attributes to a compressed string and decode from a compressed string, + optimized for the fastest possible gzipping. + + Methods: + encode: Compresses and encodes the object's attributes to a base64 string using the fastest settings. + decode: Creates a new object from a compressed and encoded base64 string. + """ + + def encode(self) -> str: + """ + Encodes the object's attributes to a compressed base64 string using the fastest possible settings. + + Returns: + str: A base64 encoded string of the compressed object attributes. + """ + json_data = json.dumps(self.__dict__, separators=(',', ':')) + compressed_data = gzip.compress(json_data.encode('utf-8'), compresslevel=1) + return base64.b64encode(compressed_data).decode('ascii') + + @classmethod + def decode(cls, encoded_str: str) -> 'GzippedDict': + """ + Creates a new object from a compressed and encoded base64 string. + + Args: + encoded_str (str): A base64 encoded string of compressed object attributes. + + Returns: + GzippedDict: A new instance of the class with decoded attributes. + """ + json_data = gzip.decompress(base64.b64decode(encoded_str)).decode('utf-8') + attributes = json.loads(json_data) + return cls(**attributes) + + +def encode_dict(dic: dict) -> str: + json_data = json.dumps(dic, separators=(',', ':')) + compressed_data = gzip.compress(json_data.encode('utf-8'), compresslevel=1) + return str(base64.b64encode(compressed_data).decode('ascii')) + + +def decode_dict(s) -> dict: + return json.loads(gzip.decompress(base64.b64decode(s)).decode('utf-8')) + +import binascii + +def encode_dict(dic: dict) -> str: + json_data = json.dumps(dic, separators=(',', ':')) + compressed_data = gzip.compress(json_data.encode('utf-8'), compresslevel=1) + return binascii.hexlify(compressed_data).decode('ascii') + +def decode_dict(s: str) -> dict: + compressed_data = binascii.unhexlify(s) + return json.loads(gzip.decompress(compressed_data).decode('utf-8')) + + +import os +import zipfile +import gzip +import csv +from typing import Iterator, Dict +from io import TextIOWrapper +import requests +from urllib.parse import urlparse + +def download_file(url: str, filename: str): + response = requests.get(url) + with open(filename, 'wb') as f: + f.write(response.content) + +def csv_iterator(file) -> Iterator[Dict[str, str]]: + reader = csv.DictReader(file) + for row in reader: + yield row + +def get_filename_from_url(url: str) -> str: + return os.path.basename(urlparse(url).path) + +def create_data_gen(url: str) -> Iterator[Dict[str, str]]: + filename = get_filename_from_url(url) + print(filename) + if not os.path.exists(filename): + download_file(url, filename) + + if filename.endswith('.zip'): + with zipfile.ZipFile(filename, 'r') as zip_ref: + csv_filename = zip_ref.namelist()[0] + with zip_ref.open(csv_filename) as csv_file: + text_file = TextIOWrapper(csv_file, encoding='utf-8') + yield from csv_iterator(text_file) + elif filename.endswith('.gz'): + with gzip.open(filename, 'rt', encoding='utf-8') as gz_file: + yield from csv_iterator(gz_file) + else: + raise ValueError("Unsupported file format. Use .zip or .gz files.") + + +def run_load_test(dataset, times=1, use_compression=False): + redis_dict = RedisDict() + redis_dict.clear() + initial_size = redis_dict.redis.info(section="memory")["used_memory"] + if use_compression: + redis_dict.extends_type(dict, encode_dict, decode_dict) + + + operation_times = [] + start_total = time.time() + + total_operations = 0 + + for _ in range(times): + key = "bla" + for i, value in enumerate(create_data_gen(dataset), 1): + #key = f"key{i}" + #print(value) + start_time = time.time() + redis_dict[key] = value + _ = redis_dict[key] + end_time = time.time() + + operation_times.append(end_time - start_time) + + total_operations += i + + print(f"\nTotal operations completed: {total_operations}") + + end_total = time.time() + total_time = end_total - start_total + + final_size = redis_dict.redis.info(section="memory")["used_memory"] + redis_dict.clear() + + return { + "dataset": dataset, + "compression": use_compression, + "total_operations": total_operations, + "batch_size": BATCH_SIZE, + "mean_time": statistics.mean(operation_times) if operation_times else None, + "min_time": min(operation_times) if operation_times else None, + "max_time": max(operation_times) if operation_times else None, + "std_dev": statistics.stdev(operation_times) if len(operation_times) > 1 else None, + "total_time": total_time, + "initial_size": human_readable_size(initial_size), + "final_size": human_readable_size(final_size), + "size_difference": human_readable_size(final_size - initial_size), + } + +def format_value(value): + if isinstance(value, bool): + return "With" if value else "Without" + elif isinstance(value, float): + return f"{value:.6f}" + return str(value) + + +def human_readable_size(size_in_bytes): + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if size_in_bytes < 1024.0: + return f"{size_in_bytes:.2f} {unit}" + size_in_bytes /= 1024.0 + return f"{size_in_bytes:.2f} PB" + + +def display_results(results, sort_key="mean_time", reverse=False): + if not results: + print("No results to display.") + return + + sorted_results = sorted(results, key=lambda x: x[sort_key], reverse=reverse) + + keys = list(sorted_results[0].keys()) + + headers = [key.replace("_", " ").capitalize() for key in keys] + + col_widths = [max(len(header), max(len(format_value(result[key])) for result in sorted_results)) for header, key in zip(headers, keys)] + + header = " | ".join(header.ljust(width) for header, width in zip(headers, col_widths)) + print(header) + print("-" * len(header)) + + # Print each result row + for result in sorted_results: + row = [format_value(result[key]).ljust(width) for key, width in zip(keys, col_widths)] + print(" | ".join(row)) + + +if __name__ == "__main__": + times = 1 + results = [] + datasets = [ + "https://www.briandunning.com/sample-data/us-500.zip", + #"https://datasets.imdbws.com/name.basics.tsv.gz", + "https://datasets.imdbws.com/title.basics.tsv.gz" + ] + for dataset in datasets: + print("Running load test without compression...") + results.append(run_load_test(times=times, use_compression=False, dataset=dataset)) + print("\nRunning load test with compression...") + results.append(run_load_test(times=times, use_compression=True, dataset=dataset)) + + print("\nPerformance Comparison (sorted by Mean Time):") + display_results(results) \ No newline at end of file diff --git a/tests/misc/__init__.py b/tests/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/assert_test.py b/tests/misc/assert_test.py similarity index 72% rename from assert_test.py rename to tests/misc/assert_test.py index 4b7a0e7..2c605f1 100644 --- a/assert_test.py +++ b/tests/misc/assert_test.py @@ -1,23 +1,30 @@ - import time +from datetime import datetime + from redis_dict import RedisDict -d = RedisDict(namespace='app_name2') -assert 'random' not in d -d['random'] = 4 -assert d['random'] == 4 -assert 'random' in d -del d['random'] -assert 'random' not in d +dic = RedisDict(namespace='assert_test') +assert 'random' not in dic +dic['random'] = 4 +assert dic['random'] == 4 +assert 'random' in dic +del dic['random'] +assert 'random' not in dic + +now = datetime.now() +dic['datetime'] = now +assert dic['datetime'] == now +dic.clear() + deep = ['key', 'key1', 'key2'] deep_val = 'mister' -d.chain_set(deep, deep_val) +dic.chain_set(deep, deep_val) -assert deep_val == d.chain_get(deep) -d.chain_del(deep) +assert deep_val == dic.chain_get(deep) +dic.chain_del(deep) try: - d.chain_get(deep) + dic.chain_get(deep) except KeyError: pass except Exception: @@ -25,21 +32,21 @@ else: print('failed to throw KeyError') -assert 'random' not in d -d['random'] = 4 +assert 'random' not in dic +dic['random'] = 4 dd = RedisDict(namespace='app_name_too') assert len(dd) == 0 dd['random'] = 5 -assert d['random'] == 4 -assert 'random' in d +assert dic['random'] == 4 +assert 'random' in dic assert dd['random'] == 5 assert 'random' in dd -del d['random'] -assert 'random' not in d +del dic['random'] +assert 'random' not in dic assert dd['random'] == 5 assert 'random' in dd diff --git a/tests/misc/simple_test.py b/tests/misc/simple_test.py new file mode 100644 index 0000000..23a69dd --- /dev/null +++ b/tests/misc/simple_test.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from redis_dict import RedisDict + +dic = RedisDict(namespace='assert_test') +assert 'random' not in dic +dic['random'] = 4 +assert dic['random'] == 4 +assert 'random' in dic +del dic['random'] +assert 'random' not in dic + +now = datetime.now() +dic['datetime'] = now +assert dic['datetime'] == now +dic.clear() + +print("passed assert test") diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/encrypt_tests.py b/tests/unit/encrypt_tests.py similarity index 100% rename from encrypt_tests.py rename to tests/unit/encrypt_tests.py diff --git a/tests/unit/extend_json_types_tests.py b/tests/unit/extend_json_types_tests.py new file mode 100644 index 0000000..ec71379 --- /dev/null +++ b/tests/unit/extend_json_types_tests.py @@ -0,0 +1,339 @@ +import sys +import json +import unittest + +from typing import Any + +from collections import Counter, ChainMap +from dataclasses import dataclass +from ipaddress import IPv4Address, IPv6Address +from pathlib import Path +from queue import Queue, PriorityQueue +from typing import NamedTuple +from enum import Enum + +from datetime import datetime, date, time, timedelta +from decimal import Decimal +from collections import OrderedDict, defaultdict +from uuid import UUID + +sys.path.append(str(Path(__file__).parent.parent.parent / "src")) +from redis_dict.type_management import encode_json, decode_json, RedisDictJSONEncoder, RedisDictJSONDecoder + + +class TestJsonEncoding(unittest.TestCase): + def setUp(self): + + # Below are tests that contain types that handled by default json encoding/decoding + self.skip_assert_raise_type_error_test = { + "str", "int", "float", "dict", "list", + "NoneType", "defaultdict", "OrderedDict", + "bool", "str,int,bool in list", "None,float,list in list", + "str,dict,set in list", + } + + def _assert_value_encodes_decodes(self, value: Any) -> None: + """Helper method to assert a value can be encoded and decoded correctly""" + encoded = json.dumps(value, cls=RedisDictJSONEncoder) + result = json.loads(encoded, cls=RedisDictJSONDecoder) + self.assertEqual(value, result) + + def test_happy_path(self): + test_cases = [ + ("Hello World", "str"), + (42, "int"), + (3.14, "float"), + (True, "bool"), + (None, "NoneType"), + + ([1, 2, 3], "list"), + ({"a": 1, "b": 2}, "dict"), + #((1, 2, 3), "tuple"), + ({1, 2, 3}, "set"), + + (datetime(2024, 1, 1, 12, 30, 45), "datetime"), + (date(2024, 1, 1), "date"), + (time(12, 30, 45), "time"), + (timedelta(days=1, hours=2), "timedelta"), + + (Decimal("3.14159"), "Decimal"), + (complex(1, 2), "complex"), + (bytes([72, 101, 108, 108, 111]), "bytes"), + (UUID('12345678-1234-5678-1234-567812345678'), "UUID"), + + (OrderedDict([('a', 1), ('b', 2)]), "OrderedDict"), + (defaultdict(type(None), {'a': 1, 'b': 2}), "defaultdict"), + (frozenset([1, 2, 3]), "frozenset"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_empty_path(self): + test_cases = [ + ("", "str"), + (0, "int"), + (0.0, "float"), + (False, "bool"), + (None, "NoneType"), + + ([], "list"), + ({}, "dict"), + # ((), "tuple"), TODO Handle tuple + (set(), "set"), + + (datetime.min, "datetime"), + (date.min, "date"), + (time.min, "time"), + (timedelta(), "timedelta"), + + (Decimal("0"), "Decimal"), + (complex(0, 0), "complex"), + (bytes(), "bytes"), + (UUID('00000000-0000-0000-0000-000000000000'), "UUID"), + + (OrderedDict(), "OrderedDict"), + (defaultdict(type(None)), "defaultdict"), + (frozenset(), "frozenset"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_happy_nested_dict(self): + test_cases = [ + ({"value": "Hello World"}, "str"), + ({"value": 42}, "int"), + ({"value": 3.14}, "float"), + ({"value": True}, "bool"), + ({"value": None}, "NoneType"), + + ({"value": [1, 2, 3]}, "list"), + ({"value": {"a": 1, "b": 2}}, "dict"), + # ({"value": (1, 2, 3)}, "tuple"), TODO Handle tuple + ({"value": {1, 2, 3}}, "set"), + + ({"value": datetime(2024, 1, 1, 12, 30, 45)}, "datetime"), + ({"value": date(2024, 1, 1)}, "date"), + ({"value": time(12, 30, 45)}, "time"), + ({"value": timedelta(days=1, hours=2)}, "timedelta"), + + ({"value": Decimal("3.14159")}, "Decimal"), + ({"value": complex(1, 2)}, "complex"), + ({"value": bytes([72, 101, 108, 108, 111])}, "bytes"), + ({"value": UUID('12345678-1234-5678-1234-567812345678')}, "UUID"), + + ({"value": OrderedDict([('a', 1), ('b', 2)])}, "OrderedDict"), + ({"value": defaultdict(type(None), {'a': 1, 'b': 2})}, "defaultdict"), + ({"value": frozenset([1, 2, 3])}, "frozenset"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_happy_nested_dict_two_levels(self): + test_cases = [ + ({"level1": {"value": "Hello World"}}, "str"), + ({"level1": {"value": 42}}, "int"), + ({"level1": {"value": 3.14}}, "float"), + ({"level1": {"value": True}}, "bool"), + ({"level1": {"value": None}}, "NoneType"), + + ({"level1": {"value": [1, 2, 3]}}, "list"), + ({"level1": {"value": {"a": 1, "b": 2}}}, "dict"), + # ({"level1": {"value": (1, 2, 3)}}, "tuple"), TODO Handle tuple + ({"level1": {"value": {1, 2, 3}}}, "set"), + + ({"level1": {"value": datetime(2024, 1, 1, 12, 30, 45)}}, "datetime"), + ({"level1": {"value": date(2024, 1, 1)}}, "date"), + ({"level1": {"value": time(12, 30, 45)}}, "time"), + ({"level1": {"value": timedelta(days=1, hours=2)}}, "timedelta"), + + ({"level1": {"value": Decimal("3.14159")}}, "Decimal"), + ({"level1": {"value": complex(1, 2)}}, "complex"), + ({"level1": {"value": bytes([72, 101, 108, 108, 111])}}, "bytes"), + ({"level1": {"value": UUID('12345678-1234-5678-1234-567812345678')}}, "UUID"), + + ({"level1": {"value": OrderedDict([('a', 1), ('b', 2)])}}, "OrderedDict"), + ({"level1": {"value": defaultdict(type(None), {'a': 1, 'b': 2})}}, "defaultdict"), + ({"level1": {"value": frozenset([1, 2, 3])}}, "frozenset"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_happy_list(self): + test_cases = [ + (["Hello World"], "str"), + ([42], "int"), + ([3.14], "float"), + ([True], "bool"), + ([None], "NoneType"), + + ([[1, 2, 3]], "list"), + ([{"a": 1, "b": 2}], "dict"), + # ([(1, 2, 3)], "tuple"), TODO Handle tuple + ([{1, 2, 3}], "set"), + + ([datetime(2024, 1, 1, 12, 30, 45)], "datetime"), + ([date(2024, 1, 1)], "date"), + ([time(12, 30, 45)], "time"), + ([timedelta(days=1, hours=2)], "timedelta"), + + ([Decimal("3.14159")], "Decimal"), + ([complex(1, 2)], "complex"), + ([bytes([72, 101, 108, 108, 111])], "bytes"), + ([UUID('12345678-1234-5678-1234-567812345678')], "UUID"), + + ([OrderedDict([('a', 1), ('b', 2)])], "OrderedDict"), + ([defaultdict(type(None), {'a': 1, 'b': 2})], "defaultdict"), + ([frozenset([1, 2, 3])], "frozenset"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_happy_mixed_list(self): + test_cases = [ + (["Hello World", 42, True], "str,int,bool in list"), + ([None, 3.14, [1, 2, 3]], "None,float,list in list"), + (["test", {"a": 1}, {1, 2, 3}], "str,dict,set in list"), + + ([datetime(2024, 1, 1), date(2024, 1, 1), time(12, 30, 45)], "datetime,date,time in list"), + ([timedelta(days=1), Decimal("3.14159"), complex(1, 2)], "timedelta,Decimal,complex in list"), + ([bytes([72, 101]), UUID('12345678-1234-5678-1234-567812345678'), "test"], "bytes,UUID,str in list"), + + ([OrderedDict([('a', 1)]), defaultdict(type(None), {'b': 2}), frozenset([1, 2])], + "OrderedDict,defaultdict,frozenset in list"), + (["a", 1, Decimal("3.14159")], "str,int,Decimal in list"), + ([True, None, datetime(2024, 1, 1)], "bool,None,datetime in list"), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy mixed list path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + if test_case_title not in self.skip_assert_raise_type_error_test: + with self.assertRaises(TypeError): + json.loads(json.dumps(test_case_input)) + + def test_happy_list_dicts_mixed(self): + test_cases = [ + ( + [ + {"decimal": Decimal("3.14159"), "bytes": bytes([72, 101, 108, 108, 111])}, + {"complex": complex(1, 2), "uuid": UUID('12345678-1234-5678-1234-567812345678')}, + {"date": datetime(2024, 1, 1), "set": {1, 2, 3}} + ], + "list of dicts with decimal/bytes, complex/uuid, datetime/set" + ), ( + [ + {"ordered": OrderedDict([('a', 1)]), "default": defaultdict(type(None), {'x': 1})}, + {"frozen": frozenset([1, 2, 3]), "time": time(12, 30, 45)}, + {"delta": timedelta(days=1), "nested": {"a": 1, "b": 2}} + ], + "list of dicts with OrderedDict/defaultdict, frozenset/time, timedelta/dict" + ), ( + [ + {"bytes": bytes([65, 66, 67]), "decimal": Decimal("10.5"), "date": date(2024, 1, 1)}, + {"uuid": UUID('12345678-1234-5678-1234-567812345678'), "complex": complex(3, 4), "set": {4, 5, 6}}, + {"time": time(1, 2, 3), "delta": timedelta(hours=5), "list": [1, 2, 3]} + ], + "list of dicts with three mixed types each" + ), + ] + + for test_case_input, test_case_title in test_cases: + with self.subTest(f"Testing happy list of dicts mixed path: {test_case_title}"): + self._assert_value_encodes_decodes(test_case_input) + + encoded = encode_json(test_case_input) + result = decode_json(encoded) + + for test_index, expected_test in enumerate(test_case_input): + for index, key in enumerate(expected_test): + expected_value = test_case_input[test_index][key] + result_value = result[test_index][key] + self.assertEqual(expected_value, result_value) + + # Ordered dict, becomes regular dictionary. Since the idea is to extend json types fixing this + # issue is not within the scope of this feature + if test_case_title == "list of dicts with OrderedDict/defaultdict, frozenset/time, timedelta/dict": + continue + self.assertEqual(type(expected_value), type(result_value)) + + def test_potential_candidates(self): + """Test cases for types that could be added encoding/decoding in the future""" + + @dataclass + class DataClassType: + name: str + value: int + + class NamedTupleType(NamedTuple): + name: str + value: int + + class EnumType_(Enum): + ONE = 1 + TWO = 2 + + potential_candidates = [ + # (Counter(['a', 'b', 'a']), "Counter"), Encodes into other type + (ChainMap({'a': 1}, {'b': 2}), "ChainMap"), + (Queue(), "Queue"), + (PriorityQueue(), "PriorityQueue"), + + (IPv4Address('192.168.1.1'), "IPv4Address"), + (IPv6Address('2001:db8::1'), "IPv6Address"), + + (Path('/foo/bar.txt'), "Path"), + + (DataClassType("test", 42), "DataClass"), + #(NamedTupleType("test", 42), "NamedTuple"), Encodes into other type + (EnumType_.ONE, "Enum"), + + #(re_compile(r'\d+'), "Pattern"), + #(memoryview(b'Hello'), "memoryview"), + ] + + for test_case_input, test_case_title in potential_candidates: + with self.subTest(f"Testing potential candidate: {test_case_title}"): + + # fails with json out of the box + with self.assertRaises(TypeError, msg=test_case_title): + result = json.dumps(test_case_input) + print(result) + + # Since these types are not yet added + with self.assertRaises(TypeError, msg=test_case_title): + result = json.dumps(test_case_input, cls=RedisDictJSONEncoder) + print(result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/extend_types_tests.py b/tests/unit/extend_types_tests.py similarity index 57% rename from extend_types_tests.py rename to tests/unit/extend_types_tests.py index fa986f7..7614584 100644 --- a/extend_types_tests.py +++ b/tests/unit/extend_types_tests.py @@ -1,5 +1,8 @@ -import unittest import json +import gzip +import time +import base64 +import unittest from datetime import datetime @@ -33,6 +36,18 @@ def setUp(self): def tearDown(self): self.redis_dict.clear() + new_types = [ + 'Customer', + 'GzippedDict', + 'Person', + 'CompressedString', + 'EncryptedStringClassBased', + 'EncryptedRot13String', + 'EncryptedString', + ] + for extend_type in new_types: + self.redis_dict.encoding_registry.pop(extend_type, None) + self.redis_dict.decoding_registry.pop(extend_type, None) def helper_get_redis_internal_value(self, key): sep = self.redis_dict_seperator @@ -337,6 +352,251 @@ def test_datetime_encoding_and_decoding_with_functions(self): self.assertEqual(str(redis_dict), str({key: expected})) +class GzippedDict: + """ + A class that can encode its attributes to a compressed string and decode from a compressed string, + optimized for the fastest possible gzipping. + + Methods: + encode: Compresses and encodes the object's attributes to a base64 string using the fastest settings. + decode: Creates a new object from a compressed and encoded base64 string. + """ + + def __init__(self, name, age, address): + self.name = name + self.age = age + self.address = address + + def encode(self) -> str: + """ + Encodes the object's attributes to a compressed base64 string using the fastest possible settings. + + Returns: + str: A base64 encoded string of the compressed object attributes. + """ + json_data = json.dumps(self.__dict__, separators=(',', ':')) + compressed_data = gzip.compress(json_data.encode('utf-8'), compresslevel=1) + return base64.b64encode(compressed_data).decode('ascii') + + @classmethod + def decode(cls, encoded_str: str) -> 'GzippedDict': + """ + Creates a new object from a compressed and encoded base64 string. + + Args: + encoded_str (str): A base64 encoded string of compressed object attributes. + + Returns: + GzippedDict: A new instance of the class with decoded attributes. + """ + json_data = gzip.decompress(base64.b64decode(encoded_str)).decode('utf-8') + attributes = json.loads(json_data) + return cls(**attributes) + + +class TestRedisDictExtendTypesGzipped(BaseRedisDictTest): + + def test_gzipped_dict_encoding_and_decoding(self): + """Test adding new type and test if encoding and decoding works.""" + self.redis_dict.extends_type(GzippedDict) + + redis_dict = self.redis_dict + key = "person" + expected = GzippedDict("John Doe", 30, "123 Main St, Anytown, USA 12345") + expected_type = GzippedDict.__name__ + + # Store GzippedDict that should be encoded + redis_dict[key] = expected + + # Assert the stored value is correctly encoded + internal_result_type, internal_result_value = self.helper_get_redis_internal_value(key) + + self.assertNotEqual(internal_result_value, expected.__dict__) + self.assertEqual(internal_result_type, expected_type) + self.assertIsInstance(internal_result_value, str) + + # Assert the result from getting the value is decoding correctly + result = redis_dict[key] + self.assertIsInstance(result, GzippedDict) + self.assertDictEqual(result.__dict__, expected.__dict__) + + def test_encoding_decoding_should_remain_equal(self): + """Test adding new type and test if encoding and decoding results in the same value""" + redis_dict = self.redis_dict + self.redis_dict.extends_type(GzippedDict) + + key = "person1" + key2 = "person2" + expected = GzippedDict("Jane Doe", 28, "456 Elm St, Othertown, USA 67890") + + redis_dict[key] = expected + + # Decodes the value, And stores the value encoded. Seamless usage of new type. + redis_dict[key2] = redis_dict[key] + + result_one = redis_dict[key] + result_two = redis_dict[key2] + + # Assert the single encoded decoded value is the same as double encoding decoded value. + self.assertDictEqual(result_one.__dict__, expected.__dict__) + self.assertDictEqual(result_one.__dict__, result_two.__dict__) + self.assertEqual(result_one.name, expected.name) + + +class CompressedString(str): + """ + A string subclass that provides methods for encoding (compressing) and decoding (decompressing) its content. + + Methods: + encode: Compresses the string content and returns a base64 encoded string. + decode: Creates a new CompressedString instance from a compressed and encoded base64 string. + """ + + def compress(self) -> str: + """ + Compresses the string content and returns a base64 encoded string. + + Returns: + str: A base64 encoded string of the compressed content. + """ + compressed_data = gzip.compress(self.encode('utf-8'), compresslevel=1) + return base64.b64encode(compressed_data).decode('ascii') + + @classmethod + def decompress(cls, compressed_str: str) -> 'CompressedString': + """ + Creates a new CompressedString instance from a compressed and encoded base64 string. + + Args: + compressed_str (str): A base64 encoded string of compressed content. + + Returns: + CompressedString: A new instance of the class with decompressed content. + """ + decompressed_data = gzip.decompress(base64.b64decode(compressed_str)).decode('utf-8') + return cls(decompressed_data) + + +class TestRedisDictExtendTypesCompressed(BaseRedisDictTest): + + def test_compressed_string_encoding_and_decoding(self): + """Test adding new type and test if encoding and decoding works.""" + redis_dict = self.redis_dict + redis_dict.extends_type(CompressedString,encoding_method_name='compress', decoding_method_name='decompress') + key = "message" + expected = CompressedString("This is a test message that will be compressed and stored in Redis.") + expected_type = CompressedString.__name__ + + # Store CompressedString that should be encoded + redis_dict[key] = expected + + # Assert the stored value is correctly encoded + internal_result_type, internal_result_value = self.helper_get_redis_internal_value(key) + + self.assertNotEqual(internal_result_value, expected) + self.assertEqual(internal_result_type, expected_type) + self.assertIsInstance(internal_result_value, str) + + # Assert the result from getting the value is decoding correctly + result = redis_dict[key] + self.assertIsInstance(result, CompressedString) + self.assertEqual(result, expected) + + def test_encoding_decoding_should_remain_equal(self): + """Test adding new type and test if encoding and decoding results in the same value""" + redis_dict = self.redis_dict + redis_dict.extends_type(CompressedString, encoding_method_name='compress', decoding_method_name='decompress') + + key = "message1" + key2 = "message2" + expected = CompressedString("Another test message to ensure consistent encoding and decoding.") + + redis_dict[key] = expected + + # Decodes the value, And stores the value encoded. Seamless usage of new type. + redis_dict[key2] = redis_dict[key] + + result_one = redis_dict[key] + result_two = redis_dict[key2] + + # Assert the single encoded decoded value is the same as double encoding decoded value. + self.assertEqual(result_one, expected) + self.assertEqual(result_one, result_two) + self.assertEqual(result_one[:10], expected[:10]) + + def test_compression_size_reduction(self): + """Test that compression significantly reduces the size of stored data""" + redis_dict = self.redis_dict + redis_dict.extends_type(CompressedString, encoding_method_name='compress', decoding_method_name='decompress') + key = "large_message" + + # Create a large string with some repetitive content to ensure good compression + large_string = "This is a test message. " * 1000 + "Some unique content to mix things up." + expected = CompressedString(large_string) + + # Store the large CompressedString + redis_dict[key] = expected + + # Get the internal (compressed) value + internal_result_type, internal_result_value = self.helper_get_redis_internal_value(key) + + # Calculate sizes + original_size = len(large_string) + compressed_size = len(internal_result_value) + + # Print sizes for information (optional) + print(f"Original size: {original_size} bytes") + print(f"Compressed size: {compressed_size} bytes") + print(f"Compression ratio: {compressed_size / original_size:.2f}") + + # Assert that compression achieved significant size reduction + self.assertLess(compressed_size, original_size * 0.5, "Compression should reduce size by at least 50%") + + # Verify that we can still recover the original string + decoded = redis_dict[key] + self.assertEqual(decoded, expected) + self.assertEqual(len(decoded), original_size) + + + def test_compression_timing_comparison(self): + """Compare timing of operations between compressed and uncompressed strings""" + redis_dict = self.redis_dict # A new instance for regular strings + + key_compressed = "compressed" + key = "regular" + + # Create a large string with some repetitive content + large_string = "This is a test message. " * 1000 + "Some unique content to mix things up." + compressed_string = CompressedString(large_string) + + # Timing for setting compressed string + start_time = time.time() + redis_dict[key_compressed] = compressed_string + compressed_set_time = time.time() - start_time + + # Timing for setting regular string + start_time = time.time() + redis_dict[key] = large_string + regular_set_time = time.time() - start_time + + # Timing for getting compressed string + start_time = time.time() + _ = redis_dict[key_compressed] + compressed_get_time = time.time() - start_time + + # Timing for getting regular string + start_time = time.time() + _ = redis_dict[key] + regular_get_time = time.time() - start_time + + # Print timing results + print(f"Compressed string set time: {compressed_set_time:.6f} seconds") + print(f"Regular string set time: {regular_set_time:.6f} seconds") + print(f"Compressed string get time: {compressed_get_time:.6f} seconds") + print(f"Regular string get time: {regular_get_time:.6f} seconds") + + + class TestNewTypeComplianceFailures(BaseRedisDictTest): def test_missing_encode_method(self): class MissingEncodeMethod: diff --git a/standard_types_tests.py b/tests/unit/standard_types_tests.py similarity index 94% rename from standard_types_tests.py rename to tests/unit/standard_types_tests.py index e141973..c5679ff 100644 --- a/standard_types_tests.py +++ b/tests/unit/standard_types_tests.py @@ -1,13 +1,20 @@ import sys import unittest -from uuid import UUID, uuid4 +from pathlib import Path +from uuid import UUID +from pathlib import Path from decimal import Decimal from datetime import datetime, date, time, timedelta, timezone from collections import OrderedDict, defaultdict from redis_dict import RedisDict +import src.redis_dict.type_management + +sys.path.append(str(Path(__file__).parent.parent.parent / "src")) +from redis_dict.type_management import _default_decoder + class TypeCodecTests(unittest.TestCase): def setUp(self): @@ -20,16 +27,15 @@ def _assert_value_encodes_decodes(self, expected_value): self.assertIsInstance(encoded_value, str) - result = self.dic.decoding_registry.get(expected_type, lambda x: x)(encoded_value) + result = self.dic.decoding_registry.get(expected_type, _default_decoder)(encoded_value) self.assertEqual(type(result).__name__, expected_type) self.assertEqual(expected_value, result) def _ensure_testcases_have_all_types(self, test_cases): """ - Instances are colliding during unit tests, refactor encoding/decoding registeries and turn the test back on + Ensure the testcases tests all the current standard types. """ - return test_types = {i[1] for i in test_cases} registry_types = set(self.dic.decoding_registry.keys()) diff --git a/tests.py b/tests/unit/tests.py similarity index 80% rename from tests.py rename to tests/unit/tests.py index 24f8e37..9c0a71f 100644 --- a/tests.py +++ b/tests/unit/tests.py @@ -1,10 +1,17 @@ +from typing import Any + +import sys import time +import json import unittest -from datetime import timedelta + +from datetime import datetime, timedelta import redis from redis_dict import RedisDict +from redis_dict import RedisDictJSONEncoder, RedisDictJSONDecoder + from hypothesis import given, strategies as st # !! Make sure you don't have keys within redis named like this, they will be deleted. @@ -17,6 +24,28 @@ } +def skip_before_python39(test_item): + """ + Decorator to skip tests for Python versions before 3.9 + where dictionary union operations are not supported. + + Can be used to decorate both test methods and test classes. + + Args: + test_item: The test method or class to be decorated + + Returns: + The decorated test item that will be skipped if Python version < 3.9 + """ + reason = "Dictionary union operators (|, |=) require Python 3.9+" + + if sys.version_info < (3, 9): + if isinstance(test_item, type): + return unittest.skip(reason)(test_item) + return unittest.skip(reason)(test_item) + return test_item + + class TestRedisDictBehaviorDict(unittest.TestCase): @classmethod def setUpClass(cls): @@ -42,16 +71,12 @@ def clear_test_namespace(cls): def setUp(self): self.clear_test_namespace() - @unittest.skip def test_python3_all_methods_from_dictionary_are_implemented(self): - import sys - if sys.version_info[0] == 3: - redis_dic = self.create_redis_dict() - dic = dict() + redis_dic = self.create_redis_dict() + dic = dict() - # reversed is currently not supported - self.assertEqual(set(dir({})) - set(dir(RedisDict)), set()) - self.assertEqual(len(set(dir(dic)) - set(dir(redis_dic))), 0) + self.assertEqual(set(dir({})) - set(dir(RedisDict)), set()) + self.assertEqual(len(set(dir(dic)) - set(dir(redis_dic))), 0) def test_input_items(self): """Calling RedisDict.keys() should return an empty list.""" @@ -178,9 +203,6 @@ def test_iter(self): for key in redis_dic: self.assertTrue(key in input_items) - for key in redis_dic.iterkeys(): - self.assertTrue(key in input_items) - for key in redis_dic.keys(): self.assertTrue(key in input_items) @@ -188,18 +210,14 @@ def test_iter(self): self.assertEqual(input_items[key], value) self.assertEqual(dic[key], value) - for key, value in redis_dic.iteritems(): - self.assertEqual(input_items[key], value) - self.assertEqual(dic[key], value) - input_values = list(input_items.values()) dic_values = list(dic.values()) - result_values = list(redis_dic.itervalues()) + result_values = list(redis_dic.values()) self.assertEqual(sorted(map(str, input_values)), sorted(map(str, result_values))) self.assertEqual(sorted(map(str, dic_values)), sorted(map(str, result_values))) - result_values = list(redis_dic.itervalues()) + result_values = list(redis_dic.values()) self.assertEqual(sorted(map(str, input_values)), sorted(map(str, result_values))) self.assertEqual(sorted(map(str, dic_values)), sorted(map(str, result_values))) @@ -319,6 +337,203 @@ def test_dict_method_popitem(self): with self.assertRaises(KeyError): redis_dic.popitem() + @skip_before_python39 + def test_dict_method_or(self): + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "float": 0.9, + "str": "im a string", + "bool": True, + "None": None, + } + + additional_items = { + "str": "new string", + "new_int": 42, + "new_bool": False, + } + + redis_dic.update(input_items) + dic.update(input_items) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(len(input_items), 5) + + redis_result = redis_dic | additional_items + dict_result = dic | additional_items + + self.assertEqual(len(redis_result), len(dict_result)) + self.assertEqual(dict(redis_result), dict_result) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(dict(redis_dic), dict(dic)) + + with self.assertRaises(TypeError): + dic | [1, 2] + + with self.assertRaises(TypeError): + redis_dic | [1, 2] + + @skip_before_python39 + def test_dict_method_ror(self): + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "float": 0.9, + "str": "im a string", + "bool": True, + "None": None, + } + + additional_items = { + "str": "new string", + "new_int": 42, + "new_bool": False, + } + + redis_dic.update(input_items) + dic.update(input_items) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(len(input_items), 5) + + redis_result = additional_items | redis_dic + dict_result = additional_items | dic + + self.assertEqual(len(redis_result), len(dict_result)) + self.assertEqual(dict(redis_result), dict_result) + + # Verify original dicts weren't modified + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(dict(redis_dic), dict(dic)) + + with self.assertRaises(TypeError): + [1, 2] | dic + + with self.assertRaises(TypeError): + [1, 2] | redis_dic + + @skip_before_python39 + def test_dict_method_ior(self): + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "float": 0.9, + "str": "im a string", + "bool": True, + "None": None, + } + + additional_items = { + "str": "new string", + "new_int": 42, + "new_bool": False, + } + + redis_dic.update(input_items) + dic.update(input_items) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(len(input_items), 5) + + redis_dic |= additional_items + dic |= additional_items + + self.assertEqual(len(redis_dic), len(dic)) + self.assertEqual(dict(redis_dic), dict(dic)) + + with self.assertRaises(TypeError): + dic |= [1, 2] + + with self.assertRaises(TypeError): + redis_dic |= [1, 2] + + + def test_dict_method_reversed_(self): + """ + RedisDict Currently does not support insertion order as property thus also not reversed. + This test only test `reversed` can be called. + """ + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "bool": True, + "None": None, + } + + redis_dic.update(input_items) + dic.update(input_items) + redis_reversed = sorted(reversed(redis_dic)) + dict_reversed = sorted(reversed(dic)) + + self.assertEqual(redis_reversed, dict_reversed) + + @unittest.skip + def test_dict_method_reversed(self): + """ + RedisDict Currently does not support insertion order as property thus also not reversed. + """ + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "float": 0.9, + "str": "im a string", + "bool": True, + "None": None, + } + + redis_dic.update(input_items) + dic.update(input_items) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(len(input_items), 5) + + redis_reversed = list(reversed(redis_dic)) + dict_reversed = list(reversed(dic)) + + self.assertEqual(redis_reversed, dict_reversed) + + def test_dict_method_class_getitem(self): + redis_dic = self.create_redis_dict() + dic = dict() + + input_items = { + "int": 1, + "float": 0.9, + "str": "im a string", + "bool": True, + "None": None, + } + + redis_dic.update(input_items) + dic.update(input_items) + + self.assertEqual(len(redis_dic), 5) + self.assertEqual(len(dic), 5) + self.assertEqual(len(input_items), 5) + + def accepts_redis_dict(d: RedisDict[str, Any]) -> None: + self.assertIsInstance(d, RedisDict) + + accepts_redis_dict(redis_dic) + def test_dict_method_setdefault(self): redis_dic = self.create_redis_dict() dic = dict() @@ -348,6 +563,114 @@ def test_dict_method_setdefault(self): self.assertEqual(len(dic), 2) self.assertEqual(len(redis_dic), 2) + def test_dict_method_setdefault_with_expire(self): + """Test setdefault with expiration setting""" + redis_dic = self.create_redis_dict(expire=3600) + key = "test_expire_key" + expected_value = "expected value" + other_expected_value = "other_default_value" + + # Clear any existing values + redis_dic.clear() + + # First call - should set with expiry + result_one = redis_dic.setdefault( + key, expected_value + ) + self.assertEqual(result_one, expected_value) + # Check TTL + actual_ttl = redis_dic.get_ttl(key) + self.assertAlmostEqual(3600, actual_ttl, delta=2) + + # Second call - should get existing value and maintain TTL + time.sleep(1) + result_two = redis_dic.setdefault( + key, other_expected_value, + ) + self.assertEqual(result_one, expected_value) + self.assertNotEqual(result_two, other_expected_value) + # TTL should be ~1 second less + new_ttl = redis_dic.get_ttl(key) + self.assertAlmostEqual(3600 - 1, new_ttl, delta=2) + + # Value should be unchanged + self.assertEqual(result_one, result_two) + + self.assertEqual(expected_value, redis_dic[key]) + del redis_dic[key] + with redis_dic.expire_at(timedelta(seconds=1)): + result_one_three = redis_dic.setdefault( + key, other_expected_value, + ) + self.assertEqual(other_expected_value, redis_dic[key]) + time.sleep(1.5) + with self.assertRaisesRegex(KeyError, key): + redis_dic[key] + + def test_setdefault_with_preserve_ttl(self): + """Test setdefault with preserve_expiration=True""" + redis_dic = self.create_redis_dict(expire=5, preserve_expiration=True) + key = "test_preserve_key" + expected_value = "expected_value" + default_value = "default" + sleep_time = 2 + + redis_dic[key] = expected_value + initial_ttl = redis_dic.get_ttl(key) + + time.sleep(sleep_time) + # Try setdefault - should keep original TTL + result = redis_dic.setdefault( + key, default_value + ) + self.assertEqual(result, expected_value) + + time.sleep(sleep_time) + # TTL should have been preserved, thus new_ttl+sleep_time should less than initial_ttl since sleep 1 second. + new_ttl = redis_dic.get_ttl(key) + self.assertLess(new_ttl+sleep_time, initial_ttl) + time.sleep(sleep_time) + + # TTL should be expired, thus key and value should be missing, and thus we will set the default value. + with self.assertRaisesRegex(KeyError, key): + redis_dic[key] + + expected_value_two = "expected_value_two" + result_two = redis_dic.setdefault( + key, expected_value_two + ) + self.assertEqual(result_two, expected_value_two) + self.assertEqual(redis_dic[key], expected_value_two) + + def test_setdefault_concurrent_ttl(self): + """Test TTL behavior with concurrent setdefault operations""" + redis_dic = self.create_redis_dict(expire=3600) + other_redis_dic = self.create_redis_dict(expire=1800) # Different TTL + + key = "test_concurrent_key" + default_value = "default" + other_default_value = "other_default" + + redis_dic.clear() + + # First operation sets with 3600s TTL + value1 = redis_dic.setdefault( + key, default_value + ) + + ttl1 = redis_dic.get_ttl(key) + self.assertAlmostEqual(3600, ttl1, delta=2) + + # Competing operation tries with 1800s TTL + value2 = other_redis_dic.setdefault( + key, other_default_value + ) + + # Original TTL should be maintained + ttl2 = other_redis_dic.get_ttl(key) + self.assertAlmostEqual(3600, ttl2, delta=3) + self.assertEqual(value1, value2) # Should get same value + def test_dict_method_get(self): redis_dic = self.create_redis_dict() dic = dict() @@ -672,9 +995,9 @@ def test_sizeof(self): self.assertEqual(expected, result) def test_keys_empty(self): - """Calling RedisDict.keys() should return an empty list.""" + """Calling RedisDict.keys() should return an empty Iterator.""" keys = self.r.keys() - self.assertEqual(keys, []) + self.assertEqual(list(keys), []) def test_set_and_get_foobar(self): """Test setting a key and retrieving it.""" @@ -1117,7 +1440,36 @@ def test_set_get_single_element_set(self): self.r[key] = value self.assertEqual(self.r[key], value) - @unittest.skip # this highlights that sets, and tuples not fully supported + def test_set_get_mixed_type_list(self): + key = "mixed_type_list" + value = [1, "foobar", 3.14, [1, 2, 3]] + self.r[key] = value + self.assertEqual(self.r[key], value) + + def test_set_get_mixed_type_list_readme(self): + key = "mixed_type_list" + now = datetime.now() + value = [1, "foobar", 3.14, [1, 2, 3], now] + self.r[key] = value + self.assertEqual(self.r[key], value) + + def test_set_get_dict_with_timedelta_readme(self): + key = "dic_with_timedelta" + value = {"elapsed_time": timedelta(hours=60)} + self.r[key] = value + self.assertEqual(self.r[key], value) + + def test_json_encoder_decoder_readme(self): + """Test the custom JSON encoder and decoder""" + now = datetime.now() + expected = [1, "foobar", 3.14, [1, 2, 3], now] + + encoded = json.dumps(expected, cls=RedisDictJSONEncoder) + result = json.loads(encoded, cls=RedisDictJSONDecoder) + + self.assertEqual(result, expected) + + @unittest.skip def test_set_get_mixed_type_set(self): key = "mixed_type_set" value = {1, "foobar", 3.14, (1, 2, 3)} @@ -1126,11 +1478,41 @@ def test_set_get_mixed_type_set(self): @unittest.skip # this highlights that sets, and tuples not fully supported def test_set_get_nested_tuple(self): + key = "nested_tuple" + value = (1, (2, 3), (4, 5)) + self.r[key] = value + self.assertEqual(self.r[key], value) + + @unittest.skip # this highlights that sets, and tuples not fully supported + def test_set_get_nested_tuple_triple(self): key = "nested_tuple" value = (1, (2, 3), (4, (5, 6))) self.r[key] = value self.assertEqual(self.r[key], value) + def test_init_redis_dict_with_redis_instance(self): + test_key = "test_key" + expected = "expected value" + test_inputs = { + "config from_url": redis.Redis.from_url("redis://127.0.0.1/0"), + "config from kwargs": redis.Redis(**redis_config), + "config passed as keywords": redis.Redis(host="127.0.0.1", port=6379), + } + for test_name, test_input in test_inputs.items(): + assert_fail_msg = f"test with: {test_name} failed" + + dict_ = RedisDict(redis=test_input) + dict_[test_key] = expected + result = dict_[test_key] + self.assertEqual(result, expected, msg=assert_fail_msg) + + self.assertIs(dict_.redis, test_input) + self.assertTrue( + dict_.redis.get_connection_kwargs().get("decode_responses"), + msg=assert_fail_msg, + ) + + test_input.flushdb() class TestRedisDictSecurity(unittest.TestCase): @classmethod @@ -1405,7 +1787,7 @@ def test_sequential_comparison(self): self.assertTrue(d == d2) self.assertTrue(d == rd) self.assertTrue(d.items() == d2.items()) - self.assertTrue(list(d.items()) == rd.items()) + self.assertTrue(list(d.items()) == list(rd.items())) d["foo1"] = "bar1" @@ -1413,7 +1795,7 @@ def test_sequential_comparison(self): self.assertTrue(d != d2) self.assertTrue(d != rd) self.assertTrue(d.items() != d2.items()) - self.assertTrue(list(d.items()) != rd.items()) + self.assertTrue(list(d.items()) != list(rd.items())) # Modifying 'd2' and 'rd' d2["foo1"] = "bar1" @@ -1423,7 +1805,7 @@ def test_sequential_comparison(self): self.assertTrue(d == d2) self.assertTrue(d == rd) self.assertTrue(d.items() == d2.items()) - self.assertTrue(list(d.items()) == rd.items()) + self.assertTrue(list(d.items()) == list(rd.items())) d.clear() d2.clear() @@ -1437,15 +1819,15 @@ def test_sequential_comparison(self): self.assertTrue(d == d2) self.assertTrue(d == rd) self.assertTrue(d.items() == d2.items()) - self.assertTrue(list(d.items()) == rd.items()) + self.assertTrue(list(d.items()) == list(rd.items())) d.clear() # Testing for inequality after clear - self.assertTrue(d != d2) - self.assertTrue(d != rd) - self.assertTrue(d.items() != d2.items()) - self.assertTrue(list(d.items()) != rd.items()) + self.assertFalse(d == d2) + self.assertFalse(d == rd) + self.assertFalse(d.items() == d2.items()) + self.assertFalse(list(d.items()) == list(rd.items())) d2.clear() rd.clear() @@ -1454,7 +1836,7 @@ def test_sequential_comparison(self): self.assertTrue(d == d2) self.assertTrue(d == rd) self.assertTrue(d.items() == d2.items()) - self.assertTrue(list(d.items()) == rd.items()) + self.assertTrue(list(d.items()) == list(rd.items())) class TestRedisDictPreserveExpire(unittest.TestCase): @@ -1655,6 +2037,5 @@ def test_set_get_set(self, key, value): self.r[key] = value self.assertEqual(self.r[key], value) - if __name__ == '__main__': unittest.main() diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 2c1a95d..0000000 --- a/tox.ini +++ /dev/null @@ -1,5 +0,0 @@ -[tox] -envlist = py27, py34, py35, py36, py37 -[testenv] -deps = -rrequirements.txt -commands=python tests.py