diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..adfc0a2 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,20 @@ +# Models + +## Automatic model inference + +Models of sheets can now be automatically inferred if no explicit model is provided. + +This is done exclusively by parsing the header row of a sheet. Headers can be annotated with types (basic types and list; dict and existing models are currently not supported). If no annotation is present, the column is assumed to be a string. + +Examples of what the data in a column can represent: +- `field`: `field` is inferred to be a string +- `field:int`: `field` is inferred to be a int +- `field:list`: `field` is inferred to be a list +- `field:List[int]`: `field` is inferred to be a list of integers +- `field.1`: `field` is inferred to be a list, and this column contains its first entry +- `field.1:int`: `field` is inferred to be a list of integers, and this column contains its first entry +- `field.subfield`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield +- `field.subfield:int`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield which is inferred to be an integer +- `field.1.subfield`: `field` is inferred to be a list of another model with one or multiple subfields, and this column contains values for the `subfield` subfield of the first list entry + +Intermediate models like in the last three examples are created automatically. diff --git a/src/rpft/parsers/common/model_inference.py b/src/rpft/parsers/common/model_inference.py new file mode 100644 index 0000000..ae836a9 --- /dev/null +++ b/src/rpft/parsers/common/model_inference.py @@ -0,0 +1,117 @@ +from collections import defaultdict +from typing import List, ForwardRef, _eval_type +from pydoc import locate +from pydantic import create_model + +from rpft.parsers.common.rowparser import ( + ParserModel, + RowParser, + RowParserError, + get_field_name, + is_list_type, + is_parser_model_type, + str_to_bool, +) + + +def type_from_string(string): + if not string: + # By default, assume str + return str + basic_type = locate(string) + if basic_type: + return basic_type + try: + inferred_type = _eval_type(ForwardRef(string), globals(), globals()) + except NameError as e: + raise RowParserError(f'Error while parsing type "{string}": {str(e)}') + return inferred_type + + +def get_value_for_type(type, value=None): + if is_list_type(type): + # We do not support default values for lists. + return [] + if is_parser_model_type(type): + # We do not support default values for ParserModel. + return type() + if value is not None: + if type is bool: + return str_to_bool(value) + return type(value) + return type() + + +def infer_type(string): + if RowParser.TYPE_ANNOTATION_SEPARATOR not in string: + return type_from_string("") + # Take the stuff between colon and equal sign + prefix, suffix = string.split(RowParser.TYPE_ANNOTATION_SEPARATOR, 1) + return type_from_string(suffix.split(RowParser.DEFAULT_VALUE_SEPARATOR)[0].strip()) + + +def infer_default_value(type, string): + if RowParser.DEFAULT_VALUE_SEPARATOR not in string: + # Return the default value for the given type + return get_value_for_type(type) + prefix, suffix = string.split(RowParser.DEFAULT_VALUE_SEPARATOR, 1) + return get_value_for_type(type, suffix.strip()) + + +def parse_header_annotations(string): + inferred_type = infer_type(string) + return inferred_type, infer_default_value(inferred_type, string) + + +def represents_integer(string): + try: + _ = int(string) + return True + except ValueError: + return False + + +def dict_to_list(dict): + out = [None] * (max(dict.keys()) + 1) + for k, v in dict.items(): + out[k] = v + return out + + +def model_from_headers(name, headers): + return model_from_headers_rec(name, headers)[0] + + +def model_from_headers_rec(name, headers): + # Returns a model and a default value + fields = {} + complex_fields = defaultdict(list) + for header in headers: + if RowParser.HEADER_FIELD_SEPARATOR in header: + field, subheader = header.split(RowParser.HEADER_FIELD_SEPARATOR, 1) + complex_fields[field].append(subheader) + else: + field = get_field_name(header) + field_type, default_value = parse_header_annotations(header) + fields[field] = (field_type, default_value) + for field, subheaders in complex_fields.items(): + # Assign model and default value + fields[field] = model_from_headers_rec(name.title() + field.title(), subheaders) + + # In case the model that we're creating is a list, + # all its fields are numbers (indices). + list_model = None + list_default_values = {} + for field, value in fields.items(): + if represents_integer(field): + # We do not check whether the models for each list entry match. + # We just take one of them. + list_model = value[0] + # Index shift: because in the headers, we count from 1 + list_default_values[int(field) - 1] = value[1] + if list_model is not None: + return List[list_model], dict_to_list(list_default_values) + + # If the model we're creating is not a list, it's a class + model = create_model(name.title(), __base__=ParserModel, **fields) + return model, get_value_for_type(model) diff --git a/src/rpft/parsers/common/rowparser.py b/src/rpft/parsers/common/rowparser.py index f34c3ba..817c31e 100644 --- a/src/rpft/parsers/common/rowparser.py +++ b/src/rpft/parsers/common/rowparser.py @@ -32,6 +32,18 @@ def header_name_to_field_name_with_context(header, row): return header +def get_list_child_model(model): + if is_basic_list_type(model): + # If not specified, list elements may be anything. + # Without additional information, we assume strings. + child_model = str + else: + # Get the type that's inside the list + assert len(model.__args__) == 1 + child_model = model.__args__[0] + return child_model + + def is_list_type(model): # Determine whether model is a list type, # such as list, List, List[str], ... @@ -39,11 +51,15 @@ def is_list_type(model): # model.__dict__.get('__origin__') returns different things in different Python # version. # This function tries to accommodate both 3.6 and 3.8 (at least) - return model == List or model.__dict__.get("__origin__") in [list, List] + return ( + is_basic_list_type(model) + or model is List + or model.__dict__.get("__origin__") in [list, List] + ) def is_basic_list_type(model): - return model == list + return model is list def is_list_instance(value): @@ -81,12 +97,30 @@ def is_default_value(model_instance, field, field_value): return True +def str_to_bool(string): + # in this case, the default value takes effect. + if string.lower() == "false": + return False + else: + return True + + +def get_field_name(string): + return ( + string.split(RowParser.TYPE_ANNOTATION_SEPARATOR)[0] + .split(RowParser.DEFAULT_VALUE_SEPARATOR)[0] + .strip() + ) + + class RowParser: # Takes a dictionary of cell entries, whose keys are the column names # and the values are the cell content converted into nested lists. # Turns this into an instance of the provided model. HEADER_FIELD_SEPARATOR = "." + TYPE_ANNOTATION_SEPARATOR = ":" + DEFAULT_VALUE_SEPARATOR = "=" def __init__(self, model, cell_parser): self.model = model @@ -96,7 +130,7 @@ def __init__(self, model, cell_parser): def try_assign_as_kwarg(self, field, key, value, model): # If value can be interpreted as a (field, field_value) pair for a field of # model, assign value to field[key] (which represents the field in the model) - if type(value) is list and len(value) == 2 and type(value[0]) is str: + if is_list_instance(value) and len(value) == 2 and type(value[0]) is str: first_entry_as_key = model.header_name_to_field_name(value[0]) if first_entry_as_key in model.__fields__: self.assign_value( @@ -156,17 +190,20 @@ def assign_value(self, field, key, value, model): entry, model.__fields__[entry_key].outer_type_, ) - + elif is_basic_list_type(model): + # We cannot iterate deeper if we don't know what to expect. + if is_iterable_instance(value): + field[key] = list(value) + else: + field[key] = [value] elif is_list_type(model): - # Get the type that's inside the list - assert len(model.__args__) == 1 - child_model = model.__args__[0] + child_model = get_list_child_model(model) # The created entry should be a list. Value should also be a list field[key] = [] # Note: This makes a decision on how to resolve an ambiguity when the target # field is a list of lists, but the cell value is a 1-dimensional list. # 1;2 → [[1],[2]] rather than [[1,2]] - if type(value) is not list: + if not is_list_instance(value): # It could be that a list is specified via a single element. if value == "": # Interpret an empty cell as [] rather than [''] @@ -178,27 +215,16 @@ def assign_value(self, field, key, value, model): # recursively field[key].append(None) self.assign_value(field[key], -1, entry, child_model) - - elif is_basic_list_type(model): - if is_iterable_instance(value): - field[key] = list(value) - else: - field[key] = [value] else: assert is_basic_type(model) # The value should be a basic type # TODO: Ensure the types match. E.g. we don't want value to be a list if model == bool: if type(value) is str: - if value.strip(): - # Special case: empty string is not assigned at all; - # in this case, the default value takes effect. - if value.strip().lower() == "false": - field[key] = False - else: - # This is consistent with python: Anything that's - # not '' or explicitly False is True. - field[key] = True + stripped = value.strip() + # Special case: empty string is not assigned at all. + if stripped: + field[key] = str_to_bool(stripped) else: field[key] = bool(value) else: @@ -219,10 +245,7 @@ def find_entry(self, model, output_field, field_path): # It'd be nicer to already have a template. field_name = field_path[0] if is_list_type(model): - # Get the type that's inside the list - assert len(model.__args__) == 1 - child_model = model.__args__[0] - + child_model = get_list_child_model(model) index = int(field_name) - 1 if len(output_field) <= index: # Create a new list entry for this, if necessary @@ -248,7 +271,7 @@ def find_entry(self, model, output_field, field_path): output_field[key] = None if len(field_path) == 1: - # We're reach the end of the field_path + # We've reached the end of the field_path # Therefore we've found where we need to assign return output_field, key, child_model else: @@ -266,6 +289,7 @@ def parse_entry( ): # This creates/populates a field in self.output # The field is determined by column_name, its value by value + column_name = get_field_name(column_name) field_path = column_name.split(RowParser.HEADER_FIELD_SEPARATOR) # Find the destination subfield in self.output that corresponds to field_path field, key, model = self.find_entry(self.model, self.output, field_path) @@ -277,11 +301,7 @@ def parse_entry( # The model of field[key] is model, and thus value should also be interpreted # as being of type model. if not value_is_parsed: - if ( - is_basic_list_type(model) - or is_list_type(model) - or is_parser_model_type(model) - ): + if is_list_type(model) or is_parser_model_type(model): # If the expected type of the value is list/object, # parse the cell content as such. # Otherwise leave it as a string diff --git a/src/rpft/parsers/creation/contentindexparser.py b/src/rpft/parsers/creation/contentindexparser.py index c0eb163..c91fc7b 100644 --- a/src/rpft/parsers/creation/contentindexparser.py +++ b/src/rpft/parsers/creation/contentindexparser.py @@ -3,6 +3,7 @@ from rpft.logger.logger import get_logger, logging_context from rpft.parsers.common.cellparser import CellParser +from rpft.parsers.common.model_inference import model_from_headers from rpft.parsers.common.rowparser import RowParser from rpft.parsers.common.sheetparser import SheetParser from rpft.parsers.creation.campaigneventrowmodel import CampaignEventRowModel @@ -53,6 +54,7 @@ def __init__( self.campaign_parsers = {} # name-indexed dict of CampaignParser self.trigger_parsers = [] + self.user_models_module = None if user_data_model_module_name: self.user_models_module = importlib.import_module( user_data_model_module_name @@ -186,11 +188,6 @@ def _get_sheet_or_die(self, sheet_name): return active def _process_data_sheet(self, row): - if not hasattr(self, "user_models_module"): - LOGGER.critical( - "If there are data sheets, a user_data_model_module_name " - "has to be provided" - ) sheet_names = row.sheet_name if row.operation.type in ["filter", "sort"] and len(sheet_names) > 1: LOGGER.warning( @@ -236,19 +233,22 @@ def _get_data_sheet(self, sheet_name, data_model_name): else: return self._get_new_data_sheet(sheet_name, data_model_name) - def _get_new_data_sheet(self, sheet_name, data_model_name): - if not data_model_name: - LOGGER.critical("No data_model_name provided for data sheet.") - try: - user_model = getattr(self.user_models_module, data_model_name) - except AttributeError: - LOGGER.critical( - f'Undefined data_model_name "{data_model_name}" ' - f"in {self.user_models_module}." - ) + def _get_new_data_sheet(self, sheet_name, data_model_name=None): + user_model = None + if self.user_models_module and data_model_name: + try: + user_model = getattr(self.user_models_module, data_model_name) + except AttributeError: + LOGGER.critical( + f'Undefined data_model_name "{data_model_name}" ' + f"in {self.user_models_module}." + ) data_table = self._get_sheet_or_die(sheet_name) with logging_context(sheet_name): data_table = self._get_sheet_or_die(sheet_name).table + if not user_model: + LOGGER.info("Inferring RowModel automatically") + user_model = model_from_headers(sheet_name, data_table.headers) row_parser = RowParser(user_model, CellParser()) sheet_parser = SheetParser(row_parser, data_table) data_rows = sheet_parser.parse_all() diff --git a/tests/test_contentindexparser.py b/tests/test_contentindexparser.py index 3538338..3b46568 100644 --- a/tests/test_contentindexparser.py +++ b/tests/test_contentindexparser.py @@ -802,6 +802,62 @@ def test_sort_descending(self): self.check_filtersort(ci_sheet, exp_keys) +class TestModelInference(TestTemplate): + def setUp(self): + self.ci_sheet = ( + "type,sheet_name,data_sheet,data_row_id,status\n" # noqa: E501 + "template_definition,my_template,,,\n" + "create_flow,my_template,mydata,,\n" + "data_sheet,mydata,,,\n" + ) + self.my_template = ( + "row_id,type,from,message_text\n" + ",send_message,start,Lst {{lst.0}} {{lst.1}}\n" + ",send_message,,{{custom_field.happy}} and {{custom_field.sad}}\n" + ) + + + def check_example(self, sheet_dict): + sheet_reader = MockSheetReader(self.ci_sheet, sheet_dict) + ci_parser = ContentIndexParser(sheet_reader) + container = ci_parser.parse_all() + render_output = container.render() + self.compare_messages( + render_output, + "my_template - row1", + ["Lst 0 4", "Happy1 and Sad1"], + ) + self.compare_messages( + render_output, + "my_template - row2", + ["Lst 1 5", "Happy2 and Sad2"], + ) + + def test_model_inference(self): + mydata = ( + "ID,lst.1:int,lst.2:int,custom_field.happy,custom_field.sad\n" + "row1,0,4,Happy1,Sad1\n" + "row2,1,5,Happy2,Sad2\n" + ) + sheet_dict = { + "mydata": mydata, + "my_template": self.my_template, + } + self.check_example(sheet_dict) + + def test_model_inference_alt(self): + mydata = ( + "ID,lst:List[int],custom_field.happy,custom_field.sad\n" + "row1,0;4,Happy1,Sad1\n" + "row2,1;5,Happy2,Sad2\n" + ) + sheet_dict = { + "mydata": mydata, + "my_template": self.my_template, + } + self.check_example(sheet_dict) + + class TestParseCampaigns(unittest.TestCase): def test_parse_flow_campaign(self): ci_sheet = ( diff --git a/tests/test_model_inference.py b/tests/test_model_inference.py new file mode 100644 index 0000000..ddb0aa6 --- /dev/null +++ b/tests/test_model_inference.py @@ -0,0 +1,154 @@ +from typing import List +import unittest +from pydantic import create_model, BaseModel + +from rpft.parsers.common.model_inference import ( + get_value_for_type, + infer_default_value, + infer_type, + model_from_headers, + parse_header_annotations, + type_from_string, +) + + +class TestModelInference(unittest.TestCase): + def test_type_from_string(self): + self.assertEqual(type_from_string(""), str) + self.assertEqual(type_from_string("str"), str) + self.assertEqual(type_from_string("int"), int) + self.assertEqual(type_from_string("list"), list) + self.assertEqual(type_from_string("List"), List) + self.assertEqual(type_from_string("List[str]"), List[str]) + self.assertEqual(type_from_string("List[int]"), List[int]) + self.assertEqual(type_from_string("List[List[int]]"), List[List[int]]) + + def test_get_value_for_type(self): + self.assertEqual(get_value_for_type(int), 0) + self.assertEqual(get_value_for_type(str), "") + self.assertEqual(get_value_for_type(bool), False) + self.assertEqual(get_value_for_type(List[str]), []) + self.assertEqual(get_value_for_type(list), []) + + self.assertEqual(get_value_for_type(int, "5"), 5) + self.assertEqual(get_value_for_type(str, "abc"), "abc") + self.assertEqual(get_value_for_type(bool, "True"), True) + self.assertEqual(get_value_for_type(bool, "TRUE"), True) + + def test_infer_type(self): + self.assertEqual(infer_type("field:int"), int) + self.assertEqual(infer_type("field:list"), list) + self.assertEqual(infer_type("field:List[int]"), List[int]) + self.assertEqual(infer_type("field : list"), list) + self.assertEqual(infer_type("field:int=5"), int) + self.assertEqual(infer_type("field:int = 5"), int) + self.assertEqual(infer_type("field : int = 5"), int) + self.assertEqual(infer_type("field=5"), str) + + def test_infer_default_value(self): + self.assertEqual(infer_default_value(int, "field:int"), 0) + self.assertEqual(infer_default_value(list, "field:list"), []) + self.assertEqual(infer_default_value(List[int], "field:List[int]"), []) + self.assertEqual(infer_default_value(list, "field : list"), []) + self.assertEqual(infer_default_value(int, "field:int=5"), 5) + self.assertEqual(infer_default_value(int, "field:int = 5"), 5) + self.assertEqual(infer_default_value(int, "field : int = 5"), 5) + self.assertEqual(infer_default_value(str, "field=5"), "5") + self.assertEqual(infer_default_value(str, "field = 5"), "5") + + def test_parse_header_annotations(self): + self.assertEqual(parse_header_annotations("field:int=5"), (int, 5)) + + def compare_models(self, model1, model2, **kwargs): + self.assertEqual(model1(**kwargs).dict(), model2(**kwargs).dict()) + + def test_model_from_headers(self): + self.compare_models( + model_from_headers("mymodel", ["field1"]), + create_model( + "Mymodel", + field1=(str, ""), + ), + ) + self.compare_models( + model_from_headers("mymodel", ["field1:int=5"]), + create_model( + "Mymodel", + field1=(int, 5), + ), + ) + self.compare_models( + model_from_headers("mymodel", ["field1:list"]), + create_model( + "Mymodel", + field1=(list, []), + ), + ) + self.compare_models( + model_from_headers("mymodel", ["field1:list"]), + create_model( + "Mymodel", + field1=(list, []), + ), + field1=[1, 2, 3, 4], + ) + + class MySubmodel(BaseModel): + sub1: str = "" + sub2: int = 5 + + self.compare_models( + model_from_headers("mymodel", ["field1.sub1", "field1.sub2:int=5"]), + create_model( + "Mymodel", + field1=(MySubmodel, MySubmodel()), + ), + ) + self.compare_models( + model_from_headers("mymodel", ["field1.1", "field1.2"]), + create_model( + "Mymodel", + field1=(list, ["", ""]), + ), + ) + self.compare_models( + model_from_headers("mymodel", ["field1.1:int", "field1.2:int=5"]), + create_model( + "Mymodel", + field1=(list, [0, 5]), + ), + ) + self.compare_models( + model_from_headers( + "mymodel", + [ + "field1.1.1", + "field1.1.2=a", + "field1.2.1=b", + "field1.2.2=c", + ], + ), + create_model( + "Mymodel", + field1=(List[List[str]], [["", "a"], ["b", "c"]]), + ), + ) + self.compare_models( + model_from_headers( + "mymodel", + [ + "field1.1.sub1", + "field1.1.sub2:int=5", + "field1.2.sub1", + "field1.2.sub2:int=5", + ], + ), + create_model( + "Mymodel", + field1=(List[MySubmodel], [MySubmodel(), MySubmodel()]), + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rowparser.py b/tests/test_rowparser.py new file mode 100644 index 0000000..d5af207 --- /dev/null +++ b/tests/test_rowparser.py @@ -0,0 +1,161 @@ +import unittest +from typing import List + +from rpft.parsers.common.rowparser import ParserModel, RowParser, RowParserError +from tests.mocks import MockCellParser + + +class SubModel(ParserModel): + str_field: str = "" + list_field: list = [] + + +class MyModel(ParserModel): + bool_field: int = False + int_field: int = 0 + str_field: str = "" + list_field: list = [] + submodel_field: SubModel = SubModel() + + +class BoolModel(ParserModel): + bool_field: bool = True + + +class TestRowParserBoolean(unittest.TestCase): + def setUp(self): + self.parser = RowParser(BoolModel, MockCellParser()) + self.falseModel = BoolModel(**{"bool_field" : False}) + self.trueModel = BoolModel(**{"bool_field" : True}) + + def test_convert_false(self): + inputs = [ + {"bool_field": "False"}, + {"bool_field": " False "}, + {"bool_field": "FALSE"}, + {"bool_field": "false"}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.falseModel) + + def test_convert_true(self): + inputs = [ + {"bool_field": "True"}, + {"bool_field": " True "}, + {"bool_field": "TRUE"}, + {"bool_field": "true"}, + {"bool_field": "something"}, + {"bool_field": "1"}, + {"bool_field": "0"}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.trueModel) + + def test_convert_default(self): + inp = {} + out = self.parser.parse_row(inp) + self.assertEqual(out, self.trueModel) + + +class IntModel(ParserModel): + int_field: int = 0 + + +class TestRowParserInt(unittest.TestCase): + def setUp(self): + self.parser = RowParser(IntModel, MockCellParser()) + self.twelveModel = IntModel(**{"int_field" : 12}) + self.zeroModel = IntModel(**{"int_field" : 0}) + + def test_convert_int(self): + inputs = [ + {"int_field": "12"}, + {"int_field": " 12 "}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.twelveModel) + inp = {"int_field": "twelve"} + with self.assertRaises(ValueError): + out = self.parser.parse_row(inp) + + def test_convert_default(self): + inp = {} + out = self.parser.parse_row(inp) + self.assertEqual(out, self.zeroModel) + + +class ListStrModel(ParserModel): + list_field: List[str] = [] + + +class TestRowParserListStr(unittest.TestCase): + def setUp(self): + self.parser = RowParser(ListStrModel, MockCellParser()) + self.emptyModel = ListStrModel(**{"list_field" : []}) + self.oneModel = ListStrModel(**{"list_field" : ["1"]}) + self.onetwoModel = ListStrModel(**{"list_field" : ["1", "2"]}) + + def test_convert_empty(self): + inputs = [ + {}, + # {"list_field": ""}, + {"list_field": []}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.emptyModel) + + def test_convert_single_element(self): + inputs = [ + {"list_field": ["1"]}, + # {"list_field": ["1", ""]}, + {"list_field": "1"}, + {"list_field.1": "1"}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.oneModel) + + def test_convert_two_element(self): + inputs = [ + {"list_field": ["1", "2"]}, + {"list_field.1": "1", "list_field.2": "2"}, + ] + for inp in inputs: + out = self.parser.parse_row(inp) + self.assertEqual(out, self.onetwoModel) + + # inp = {"list_field": "1"} + # with self.assertRaises(ValueError): + # out = self.parser.parse_row(inp) + + +class ListIntModel(ParserModel): + list_field: List[int] = [] + + +class TestRowParserListInt(TestRowParserListStr): + def setUp(self): + self.parser = RowParser(ListIntModel, MockCellParser()) + self.emptyModel = ListIntModel(**{"list_field" : []}) + self.oneModel = ListIntModel(**{"list_field" : [1]}) + self.onetwoModel = ListIntModel(**{"list_field" : [1, 2]}) + + +class ListModel(ParserModel): + list_field: list = [] + + +class TestRowParserList(TestRowParserListStr): + def setUp(self): + self.parser = RowParser(ListModel, MockCellParser()) + self.emptyModel = ListModel(**{"list_field" : []}) + self.oneModel = ListModel(**{"list_field" : ["1"]}) + self.onetwoModel = ListModel(**{"list_field" : ["1", "2"]}) + + +if __name__ == "__main__": + unittest.main()