Skip to content

Commit

Permalink
Added support for optional values in schema
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusfrdk committed Feb 4, 2025
1 parent 026e702 commit 30439e6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
21 changes: 20 additions & 1 deletion tomlval/toml_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def __init__(self, schema: dict):
}
}
Optional values can be added by suffixing the key with a question mark.
Example:
{
"string?": str,
"number": (int, float)
}
In this case, string is an optional value, while number is required.
Args:
schema: dict - The TOML schema.
Returns:
Expand Down Expand Up @@ -119,6 +129,8 @@ def __len__(self) -> int:

def __getitem__(self, key: str) -> Union[type, Tuple[type]]:
"""Get an item from a TOML schema."""
if self.get(f"{key}?") is not None:
key = f"{key}?"
return self._flat_schema[key]

def __contains__(self, key: str) -> bool:
Expand All @@ -130,7 +142,9 @@ def __iter__(self):

def get(self, key: str, default=None) -> Union[type, Tuple[type]]:
"""Get an item from a TOML schema."""
return self._flat_schema.get(key, default)
if (value := self._flat_schema.get(key)) is None:
value = self._flat_schema.get(f"{key}?")
return value

def keys(self) -> list[str]:
"""Get the keys from a TOML schema."""
Expand All @@ -143,3 +157,8 @@ def values(self) -> List[Union[type, Tuple[type]]]:
def items(self) -> List[Tuple[str, Union[type, Tuple[type]]]]:
"""Get the items from a TOML schema."""
return list(self._flat_schema.items())


if __name__ == "__main__":
s = TOMLSchema({"string?": str, "number": (int, float)})
print(s.get("string"))
29 changes: 21 additions & 8 deletions tomlval/toml_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def _inspect_function(self, fn: Callable) -> list[str]:

def _get_missing_keys(self) -> list[str]:
"""Get a list of keys missing in the data."""
return [k for k in self._schema if k not in self._data]
# return [k for k in self._schema if k not in self._data]
return [
k
for k in self._schema
if k not in self._data and not k.endswith("?")
]

def _get_invalid_types(self) -> List[Tuple[str, Tuple[type, Any]]]:
"""Get a list of keys with invalid types."""
Expand All @@ -101,15 +106,20 @@ def _get_invalid_types(self) -> List[Tuple[str, Tuple[type, Any]]]:
if key in self._schema:
# List of types
if isinstance(self._schema[key], list):
invalid_list_types = set()

for t in value:
if type(t) not in self._schema[key]:
invalid_list_types.add(type(t))
# Check if any of the types are valid
if isinstance(value, list):
invalid_list_types = set()
for t in value:
if type(t) not in self._schema[key]:
invalid_list_types.add(type(t))
invalid_list_types = list(invalid_list_types)
else:
invalid_list_types = type(value)

if invalid_list_types:
invalid_types.append(
(key, (self._schema[key], list(invalid_list_types)))
(key, (self._schema[key], invalid_list_types))
)

# Single type
Expand All @@ -123,6 +133,9 @@ def _get_invalid_types(self) -> List[Tuple[str, Tuple[type, Any]]]:

return invalid_types

def _get_handler_results(self) -> dict[str, Any]:
"""Runs the handlers and gets the results."""

def add_handler(self, key: str, handler: Handler):
"""
Adds a new handler for a specific (e.g. 'my', 'my.key') or global key
Expand Down Expand Up @@ -210,12 +223,12 @@ def validate(self) -> ValidatedSchema:
toml_data = tomllib.load(file)

# schema = TOMLSchema({"string_basic": (int, float)})
_schema = TOMLSchema({"int_positive": [str, list]})
_schema = TOMLSchema({"int_non_existing": int})

validator = TOMLValidator(toml_data, _schema)

# print(validator._get_missing_keys())
print(validator._get_invalid_types())
print(validator._get_missing_keys())

# validator.add_handler("string*c", str)

Expand Down
2 changes: 1 addition & 1 deletion tomlval/utils/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import re

key_pattern = re.compile(r"^(?:\*|\w+)(?:\.(?:\*|\w+))*$")
key_pattern = re.compile(r"^(?:\*|\w+)(?:\.(?:\*|\w+))*\??$")

0 comments on commit 30439e6

Please sign in to comment.