From d05ddf57a886a5b609da59c2e165ddf4b5302198 Mon Sep 17 00:00:00 2001 From: Marcus Fredriksson Date: Tue, 4 Feb 2025 17:55:40 +0100 Subject: [PATCH] Added new schema class for defining a structure of types and flatten/unflatten dictionary functions --- tomlval/__init__.py | 1 + tomlval/errors/__init__.py | 1 + tomlval/errors/toml_schema_error.py | 18 +++ tomlval/toml_schema.py | 163 ++++++++++++++++++++++++++++ tomlval/toml_validator.py | 81 ++++++++------ tomlval/types/__init__.py | 1 + tomlval/types/validated_schema.py | 6 + tomlval/utils/__init__.py | 5 + tomlval/utils/flatten.py | 34 ++++++ tomlval/utils/unflatten.py | 66 +++++++++++ 10 files changed, 345 insertions(+), 31 deletions(-) create mode 100644 tomlval/errors/toml_schema_error.py create mode 100644 tomlval/toml_schema.py create mode 100644 tomlval/types/validated_schema.py create mode 100644 tomlval/utils/flatten.py create mode 100644 tomlval/utils/unflatten.py diff --git a/tomlval/__init__.py b/tomlval/__init__.py index 9f0ee3a..ca4a913 100644 --- a/tomlval/__init__.py +++ b/tomlval/__init__.py @@ -1,3 +1,4 @@ """ toml_parser package """ +from .toml_schema import TOMLSchemaError from .toml_validator import TOMLValidator diff --git a/tomlval/errors/__init__.py b/tomlval/errors/__init__.py index 1c613c3..dea8277 100644 --- a/tomlval/errors/__init__.py +++ b/tomlval/errors/__init__.py @@ -1,3 +1,4 @@ """ Errors specific to the 'toml_parser' package. """ from .toml_handler_error import TOMLHandlerError +from .toml_schema_error import TOMLSchemaError diff --git a/tomlval/errors/toml_schema_error.py b/tomlval/errors/toml_schema_error.py new file mode 100644 index 0000000..33bc72c --- /dev/null +++ b/tomlval/errors/toml_schema_error.py @@ -0,0 +1,18 @@ +""" Custom error for invalid schemas. """ + + +class TOMLSchemaError(Exception): + """Custom error for invalid schemas.""" + + def __init__(self, message: str = "Invalid TOML schema."): + """ + Initialize the TOMLSchemaError. + + Args: + message: str - The error message. + Returns: + None + Raises: + None + """ + super().__init__(message) diff --git a/tomlval/toml_schema.py b/tomlval/toml_schema.py new file mode 100644 index 0000000..22bcf60 --- /dev/null +++ b/tomlval/toml_schema.py @@ -0,0 +1,163 @@ +""" A module for defining a TOML schema structure. """ + +import json +import re +from collections import defaultdict +from typing import List, Tuple, Union + +from tomlval.errors import TOMLSchemaError +from tomlval.utils import flatten, key_pattern + +index_pattern = re.compile(r"\.\[\d+\]$") + + +class JSONEncoder(json.JSONEncoder): + """A JSON encoder that can handle sets.""" + + def default(self, o): + if isinstance(o, type): + return o.__name__ + return super().default(o) + + +class TOMLSchema: + """A class for defining a TOML schema structure.""" + + def __init__(self, schema: dict): + """ + Initialize a new TOML schema. + + A schema is a dictionary with keys as strings and values as types. + This is used to define an outline of how the validator should interpret + the data and handle certain errors. + + Example: + { + "string": str, + "number": (int, float), + "boolean": bool, + "string_list": [str], + "number_list": [int, float], + "mixed_list": [str, int, float], + "nested": { + "key": str, + "value": int + } + } + + Args: + schema: dict - The TOML schema. + Returns: + None + Raises: + tomlval.errors.TOMLSchemaError - If the schema is invalid. + """ + + self._validate(schema) + self._nested_schema = schema + self._flat_schema = self._flatten(schema) + + def _validate(self, schema: dict) -> None: + """Validate a TOML schema.""" + if not isinstance(schema, dict): + raise TOMLSchemaError("Schema must be a dictionary.") + + def _check_schema(schema: dict) -> bool: + """Check the schema recursively.""" + for k, v in schema.items(): + # Keys + if not isinstance(k, str): + raise TOMLSchemaError( + f"Invalid key type '{str(k)}' in schema." + ) + elif not key_pattern.match(k): + raise TOMLSchemaError(f"Invalid key '{k}' in schema.") + + # Values + if isinstance(v, dict): + return _check_schema(v) + + ## Tuple/List + if isinstance(v, (tuple, list)): + for t in v: + if not isinstance(t, type): + raise TOMLSchemaError( + " ".join( + [ + "Invalid type", + f"'{type(t).__name__}'", + "found in schema.", + ] + ) + ) + + ## Simple type + elif not isinstance(v, type): + raise TOMLSchemaError( + f"Invalid type '{type(v).__name__}' found in schema." + ) + + return None + + _check_schema(schema) + + def _flatten(self, schema: dict) -> dict: + """A custom version of the flatten function to combine lists.""" + + pattern = re.compile(r"^(.*)\.\[(\d+)\]$") + result = {} + temp = defaultdict(list) + + for key, value in flatten(schema).items(): + match = pattern.match(key) + + if match: + base_key, index = match.groups() + index = int(index) + temp[base_key].append((index, value)) + else: + result[key] = value + + for base_key, items in temp.items(): + sorted_values = [ + val for _, val in sorted(items, key=lambda x: x[0]) + ] + result[base_key] = sorted_values + + return result + + def __str__(self) -> str: + return json.dumps(self._nested_schema, cls=JSONEncoder, indent=2) + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return len(self.keys()) + + def __getitem__(self, key: str) -> Union[type, Tuple[type]]: + """Get an item from a TOML schema.""" + return self._flat_schema[key] + + def __contains__(self, key: str) -> bool: + """Check if a key is in a TOML schema.""" + return key in self._flat_schema + + def __iter__(self): + return iter(self._flat_schema) + + def get(self, key: str, default=None) -> Union[type, Tuple[type]]: + """Get an item from a TOML schema.""" + return self._flat_schema.get(key, default) + + def keys(self) -> list[str]: + """Get the keys from a TOML schema.""" + return sorted(self._flat_schema.keys()) + + def values(self) -> List[Union[type, Tuple[type]]]: + """Get the values from a TOML schema.""" + return list(self._flat_schema.values()) + + def items(self) -> List[Tuple[str, Union[type, Tuple[type]]]]: + """Get the items from a TOML schema.""" + return list(self._flat_schema.items()) diff --git a/tomlval/toml_validator.py b/tomlval/toml_validator.py index 1bfad1c..0c27b5d 100644 --- a/tomlval/toml_validator.py +++ b/tomlval/toml_validator.py @@ -4,8 +4,9 @@ import re from typing import Any, Callable -from tomlval.errors import TOMLHandlerError -from tomlval.types import Handler +from tomlval.errors import TOMLHandlerError, TOMLSchemaError +from tomlval.types import Handler, ValidatedSchema +from tomlval.utils import flatten from tomlval.utils.regex import key_pattern @@ -22,7 +23,8 @@ def __init__(self, data: dict, schema: dict = None): Returns: None Raises: - TypeError - If data or schema is not a dictionary. + TypeError - If data is not a dictionary. + tomlval.errors.TOMLSchemaError - If the schema is invalid. """ # Data @@ -30,37 +32,14 @@ def __init__(self, data: dict, schema: dict = None): raise TypeError("Data must be a dictionary.") # Schema - if schema is not None and not isinstance(schema, dict): - raise TypeError("Schema must be a dictionary.") + if schema is not None: + if not self._validate_schema(schema): + raise TOMLSchemaError() self._data = data self._schema = schema self._handlers = {} - def _map_keys(self) -> dict[str, Any]: - """A method to map keys in dot notation to their values.""" - - def _flatten(data: dict, parent_key: str = "") -> dict[str, Any]: - """A recursive function to flatten a dictionary.""" - - _data = {} - for key, value in data.items(): - full_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, dict): - _data.update(_flatten(value, full_key)) - elif isinstance(value, list): - for idx, item in enumerate(value): - list_key = f"{full_key}.[{idx}]" - if isinstance(item, (dict, list)): - _data.update(_flatten(item, list_key)) - else: - _data[list_key] = item - else: - _data[full_key] = value - return _data - - return _flatten(self._data) - def _map_handlers(self) -> dict[str, Handler]: """A method to map each key to a handler.""" @@ -90,8 +69,7 @@ def _match_key(key: str) -> Handler: return matched_handler - keys = self._map_keys() - return {k: _match_key(k) for k in keys} + return {k: _match_key(k) for k in flatten(self._data)} def _inspect_function(self, fn: Callable) -> list[str]: """ @@ -109,6 +87,27 @@ def _inspect_function(self, fn: Callable) -> list[str]: return list(inspect.signature(fn).parameters.keys()) + def _validate_schema(self, schema: dict = None) -> bool: + """Method to validate a schema.""" + schema = schema or self._schema + + if not isinstance(schema, dict): + return False + + def _check_schema(schema: dict) -> bool: + for k, v in schema.items(): + if isinstance(v, dict): + return _check_schema(v) + if not isinstance(v, type): + return False + return True + + return _check_schema(schema) + + def _get_missing_keys(self) -> list[str]: ... + + def _get_invalid_types(self) -> list[tuple[str, Any]]: ... + def add_handler(self, key: str, handler: Handler): """ Adds a new handler for a specific (e.g. 'my', 'my.key') or global key @@ -181,3 +180,23 @@ def add_handler(self, key: str, handler: Handler): ## Too many arguments else: raise TOMLHandlerError("Handler must accept 0, 1, or 2 arguments.") + + def validate(self) -> ValidatedSchema: + """""" + + +if __name__ == "__main__": + import pathlib + import tomllib + + data_path = pathlib.Path("examples/full_spec.toml") + + with data_path.open("rb") as file: + toml_data = tomllib.load(file) + + validator = TOMLValidator(toml_data, schema={"name": str, "age": "nice"}) + + validator.add_handler("string*c", str) + + # for k, v in validator.validate().items(): + # print(f"{k}: {v} ({type(v)})") diff --git a/tomlval/types/__init__.py b/tomlval/types/__init__.py index a7df5a0..5a8ab2e 100644 --- a/tomlval/types/__init__.py +++ b/tomlval/types/__init__.py @@ -2,3 +2,4 @@ from .handler import Handler from .path_or_str import PathOrStr +from .validated_schema import ValidatedSchema diff --git a/tomlval/types/validated_schema.py b/tomlval/types/validated_schema.py new file mode 100644 index 0000000..c73416f --- /dev/null +++ b/tomlval/types/validated_schema.py @@ -0,0 +1,6 @@ +""" A type for a validated schema. """ + +from typing import Any, Tuple, Union + +# {"key": ("message", value)} +ValidatedSchema = dict[str, Union[Tuple[str, Any], "ValidatedSchema"]] diff --git a/tomlval/utils/__init__.py b/tomlval/utils/__init__.py index e69de29..bcc7eca 100644 --- a/tomlval/utils/__init__.py +++ b/tomlval/utils/__init__.py @@ -0,0 +1,5 @@ +""" 'tomlval.utils' module containing utilities used throughout the project. """ + +from .flatten import flatten +from .regex import key_pattern +from .unflatten import unflatten diff --git a/tomlval/utils/flatten.py b/tomlval/utils/flatten.py new file mode 100644 index 0000000..da71cf4 --- /dev/null +++ b/tomlval/utils/flatten.py @@ -0,0 +1,34 @@ +""" A function to flatten a dictionary into a single level dictionary. """ + + +def flatten(dictionary: dict) -> dict: + """ + A function to flatten a dictionary into a single level dictionary. + + Args: + dictionary: dict - The dictionary to flatten. + Returns: + dict - The flattened dictionary + Raises: + None + """ + + def _flatten(data: dict, parent_key: str = "") -> dict: + """A recursive function to flatten a dictionary.""" + _data = {} + for key, value in data.items(): + full_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + _data.update(_flatten(value, full_key)) + elif isinstance(value, list): + for idx, item in enumerate(value): + list_key = f"{full_key}.[{idx}]" + if isinstance(item, (dict, list)): + _data.update(_flatten(item, list_key)) + else: + _data[list_key] = item + else: + _data[full_key] = value + return _data + + return _flatten(dictionary) diff --git a/tomlval/utils/unflatten.py b/tomlval/utils/unflatten.py new file mode 100644 index 0000000..1e57616 --- /dev/null +++ b/tomlval/utils/unflatten.py @@ -0,0 +1,66 @@ +""" A module to unflatten a single level dictionary into a nested dictionary. """ + + +def unflatten(dictionary: dict) -> dict: + """ + A function to unflatten a single level dictionary into a nested dictionary. + + Args: + dictionary: dict - The single level dictionary to unflatten. + Returns: + dict - The nested dictionary. + Raises: + ValueError - If the dictionary is not a single level dictionary. + """ + + def is_list_index(segment: str) -> bool: + return ( + segment.startswith("[") + and segment.endswith("]") + and segment[1:-1].isdigit() + ) + + result = {} + for flat_key, value in dictionary.items(): + segments = flat_key.split(".") + current = result + for i, segment in enumerate(segments): + is_last = i == len(segments) - 1 + if is_list_index(segment): + index = int(segment[1:-1]) + if not isinstance(current, list): + raise ValueError( + " ".join( + [ + "Expected list at segment", + f"'{segment}' in key '{flat_key}'", + ] + ) + ) + while len(current) <= index: + current.append(None) + if is_last: + current[index] = value + else: + if current[index] is None: + next_seg = segments[i + 1] + current[index] = [] if is_list_index(next_seg) else {} + current = current[index] + else: + if not isinstance(current, dict): + raise ValueError( + " ".join( + [ + "Expected dict at segment", + f"'{segment}' in key '{flat_key}'", + ] + ) + ) + if is_last: + current[segment] = value + else: + if segment not in current: + next_seg = segments[i + 1] + current[segment] = [] if is_list_index(next_seg) else {} + current = current[segment] + return result