Skip to content

Commit

Permalink
Merge pull request #137 from IDEMSInternational/feat/model-inference
Browse files Browse the repository at this point in the history
Automatic model inference
  • Loading branch information
geoo89 authored Jul 10, 2024
2 parents 15de1b3 + d4c1802 commit 1e8c45f
Show file tree
Hide file tree
Showing 7 changed files with 576 additions and 48 deletions.
20 changes: 20 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Models

## Automatic model inference

Models of sheets can now be automatically inferred if no explicit model is provided.

This is done exclusively by parsing the header row of a sheet. Headers can be annotated with types (basic types and list; dict and existing models are currently not supported). If no annotation is present, the column is assumed to be a string.

Examples of what the data in a column can represent:
- `field`: `field` is inferred to be a string
- `field:int`: `field` is inferred to be a int
- `field:list`: `field` is inferred to be a list
- `field:List[int]`: `field` is inferred to be a list of integers
- `field.1`: `field` is inferred to be a list, and this column contains its first entry
- `field.1:int`: `field` is inferred to be a list of integers, and this column contains its first entry
- `field.subfield`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield
- `field.subfield:int`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield which is inferred to be an integer
- `field.1.subfield`: `field` is inferred to be a list of another model with one or multiple subfields, and this column contains values for the `subfield` subfield of the first list entry

Intermediate models like in the last three examples are created automatically.
117 changes: 117 additions & 0 deletions src/rpft/parsers/common/model_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections import defaultdict
from typing import List, ForwardRef, _eval_type
from pydoc import locate
from pydantic import create_model

from rpft.parsers.common.rowparser import (
ParserModel,
RowParser,
RowParserError,
get_field_name,
is_list_type,
is_parser_model_type,
str_to_bool,
)


def type_from_string(string):
if not string:
# By default, assume str
return str
basic_type = locate(string)
if basic_type:
return basic_type
try:
inferred_type = _eval_type(ForwardRef(string), globals(), globals())
except NameError as e:
raise RowParserError(f'Error while parsing type "{string}": {str(e)}')
return inferred_type


def get_value_for_type(type, value=None):
if is_list_type(type):
# We do not support default values for lists.
return []
if is_parser_model_type(type):
# We do not support default values for ParserModel.
return type()
if value is not None:
if type is bool:
return str_to_bool(value)
return type(value)
return type()


def infer_type(string):
if RowParser.TYPE_ANNOTATION_SEPARATOR not in string:
return type_from_string("")
# Take the stuff between colon and equal sign
prefix, suffix = string.split(RowParser.TYPE_ANNOTATION_SEPARATOR, 1)
return type_from_string(suffix.split(RowParser.DEFAULT_VALUE_SEPARATOR)[0].strip())


def infer_default_value(type, string):
if RowParser.DEFAULT_VALUE_SEPARATOR not in string:
# Return the default value for the given type
return get_value_for_type(type)
prefix, suffix = string.split(RowParser.DEFAULT_VALUE_SEPARATOR, 1)
return get_value_for_type(type, suffix.strip())


def parse_header_annotations(string):
inferred_type = infer_type(string)
return inferred_type, infer_default_value(inferred_type, string)


def represents_integer(string):
try:
_ = int(string)
return True
except ValueError:
return False


def dict_to_list(dict):
out = [None] * (max(dict.keys()) + 1)
for k, v in dict.items():
out[k] = v
return out


def model_from_headers(name, headers):
return model_from_headers_rec(name, headers)[0]


def model_from_headers_rec(name, headers):
# Returns a model and a default value
fields = {}
complex_fields = defaultdict(list)
for header in headers:
if RowParser.HEADER_FIELD_SEPARATOR in header:
field, subheader = header.split(RowParser.HEADER_FIELD_SEPARATOR, 1)
complex_fields[field].append(subheader)
else:
field = get_field_name(header)
field_type, default_value = parse_header_annotations(header)
fields[field] = (field_type, default_value)
for field, subheaders in complex_fields.items():
# Assign model and default value
fields[field] = model_from_headers_rec(name.title() + field.title(), subheaders)

# In case the model that we're creating is a list,
# all its fields are numbers (indices).
list_model = None
list_default_values = {}
for field, value in fields.items():
if represents_integer(field):
# We do not check whether the models for each list entry match.
# We just take one of them.
list_model = value[0]
# Index shift: because in the headers, we count from 1
list_default_values[int(field) - 1] = value[1]
if list_model is not None:
return List[list_model], dict_to_list(list_default_values)

# If the model we're creating is not a list, it's a class
model = create_model(name.title(), __base__=ParserModel, **fields)
return model, get_value_for_type(model)
86 changes: 53 additions & 33 deletions src/rpft/parsers/common/rowparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,34 @@ def header_name_to_field_name_with_context(header, row):
return header


def get_list_child_model(model):
if is_basic_list_type(model):
# If not specified, list elements may be anything.
# Without additional information, we assume strings.
child_model = str
else:
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]
return child_model


def is_list_type(model):
# Determine whether model is a list type,
# such as list, List, List[str], ...
# issubclass only works for Python <=3.6
# model.__dict__.get('__origin__') returns different things in different Python
# version.
# This function tries to accommodate both 3.6 and 3.8 (at least)
return model == List or model.__dict__.get("__origin__") in [list, List]
return (
is_basic_list_type(model)
or model is List
or model.__dict__.get("__origin__") in [list, List]
)


def is_basic_list_type(model):
return model == list
return model is list


def is_list_instance(value):
Expand Down Expand Up @@ -81,12 +97,30 @@ def is_default_value(model_instance, field, field_value):
return True


def str_to_bool(string):
# in this case, the default value takes effect.
if string.lower() == "false":
return False
else:
return True


def get_field_name(string):
return (
string.split(RowParser.TYPE_ANNOTATION_SEPARATOR)[0]
.split(RowParser.DEFAULT_VALUE_SEPARATOR)[0]
.strip()
)


class RowParser:
# Takes a dictionary of cell entries, whose keys are the column names
# and the values are the cell content converted into nested lists.
# Turns this into an instance of the provided model.

HEADER_FIELD_SEPARATOR = "."
TYPE_ANNOTATION_SEPARATOR = ":"
DEFAULT_VALUE_SEPARATOR = "="

def __init__(self, model, cell_parser):
self.model = model
Expand All @@ -96,7 +130,7 @@ def __init__(self, model, cell_parser):
def try_assign_as_kwarg(self, field, key, value, model):
# If value can be interpreted as a (field, field_value) pair for a field of
# model, assign value to field[key] (which represents the field in the model)
if type(value) is list and len(value) == 2 and type(value[0]) is str:
if is_list_instance(value) and len(value) == 2 and type(value[0]) is str:
first_entry_as_key = model.header_name_to_field_name(value[0])
if first_entry_as_key in model.__fields__:
self.assign_value(
Expand Down Expand Up @@ -156,17 +190,20 @@ def assign_value(self, field, key, value, model):
entry,
model.__fields__[entry_key].outer_type_,
)

elif is_basic_list_type(model):
# We cannot iterate deeper if we don't know what to expect.
if is_iterable_instance(value):
field[key] = list(value)
else:
field[key] = [value]
elif is_list_type(model):
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]
child_model = get_list_child_model(model)
# The created entry should be a list. Value should also be a list
field[key] = []
# Note: This makes a decision on how to resolve an ambiguity when the target
# field is a list of lists, but the cell value is a 1-dimensional list.
# 1;2 → [[1],[2]] rather than [[1,2]]
if type(value) is not list:
if not is_list_instance(value):
# It could be that a list is specified via a single element.
if value == "":
# Interpret an empty cell as [] rather than ['']
Expand All @@ -178,27 +215,16 @@ def assign_value(self, field, key, value, model):
# recursively
field[key].append(None)
self.assign_value(field[key], -1, entry, child_model)

elif is_basic_list_type(model):
if is_iterable_instance(value):
field[key] = list(value)
else:
field[key] = [value]
else:
assert is_basic_type(model)
# The value should be a basic type
# TODO: Ensure the types match. E.g. we don't want value to be a list
if model == bool:
if type(value) is str:
if value.strip():
# Special case: empty string is not assigned at all;
# in this case, the default value takes effect.
if value.strip().lower() == "false":
field[key] = False
else:
# This is consistent with python: Anything that's
# not '' or explicitly False is True.
field[key] = True
stripped = value.strip()
# Special case: empty string is not assigned at all.
if stripped:
field[key] = str_to_bool(stripped)
else:
field[key] = bool(value)
else:
Expand All @@ -219,10 +245,7 @@ def find_entry(self, model, output_field, field_path):
# It'd be nicer to already have a template.
field_name = field_path[0]
if is_list_type(model):
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]

child_model = get_list_child_model(model)
index = int(field_name) - 1
if len(output_field) <= index:
# Create a new list entry for this, if necessary
Expand All @@ -248,7 +271,7 @@ def find_entry(self, model, output_field, field_path):
output_field[key] = None

if len(field_path) == 1:
# We're reach the end of the field_path
# We've reached the end of the field_path
# Therefore we've found where we need to assign
return output_field, key, child_model
else:
Expand All @@ -266,6 +289,7 @@ def parse_entry(
):
# This creates/populates a field in self.output
# The field is determined by column_name, its value by value
column_name = get_field_name(column_name)
field_path = column_name.split(RowParser.HEADER_FIELD_SEPARATOR)
# Find the destination subfield in self.output that corresponds to field_path
field, key, model = self.find_entry(self.model, self.output, field_path)
Expand All @@ -277,11 +301,7 @@ def parse_entry(
# The model of field[key] is model, and thus value should also be interpreted
# as being of type model.
if not value_is_parsed:
if (
is_basic_list_type(model)
or is_list_type(model)
or is_parser_model_type(model)
):
if is_list_type(model) or is_parser_model_type(model):
# If the expected type of the value is list/object,
# parse the cell content as such.
# Otherwise leave it as a string
Expand Down
30 changes: 15 additions & 15 deletions src/rpft/parsers/creation/contentindexparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from rpft.logger.logger import get_logger, logging_context
from rpft.parsers.common.cellparser import CellParser
from rpft.parsers.common.model_inference import model_from_headers
from rpft.parsers.common.rowparser import RowParser
from rpft.parsers.common.sheetparser import SheetParser
from rpft.parsers.creation.campaigneventrowmodel import CampaignEventRowModel
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
self.campaign_parsers = {} # name-indexed dict of CampaignParser
self.trigger_parsers = []

self.user_models_module = None
if user_data_model_module_name:
self.user_models_module = importlib.import_module(
user_data_model_module_name
Expand Down Expand Up @@ -186,11 +188,6 @@ def _get_sheet_or_die(self, sheet_name):
return active

def _process_data_sheet(self, row):
if not hasattr(self, "user_models_module"):
LOGGER.critical(
"If there are data sheets, a user_data_model_module_name "
"has to be provided"
)
sheet_names = row.sheet_name
if row.operation.type in ["filter", "sort"] and len(sheet_names) > 1:
LOGGER.warning(
Expand Down Expand Up @@ -236,19 +233,22 @@ def _get_data_sheet(self, sheet_name, data_model_name):
else:
return self._get_new_data_sheet(sheet_name, data_model_name)

def _get_new_data_sheet(self, sheet_name, data_model_name):
if not data_model_name:
LOGGER.critical("No data_model_name provided for data sheet.")
try:
user_model = getattr(self.user_models_module, data_model_name)
except AttributeError:
LOGGER.critical(
f'Undefined data_model_name "{data_model_name}" '
f"in {self.user_models_module}."
)
def _get_new_data_sheet(self, sheet_name, data_model_name=None):
user_model = None
if self.user_models_module and data_model_name:
try:
user_model = getattr(self.user_models_module, data_model_name)
except AttributeError:
LOGGER.critical(
f'Undefined data_model_name "{data_model_name}" '
f"in {self.user_models_module}."
)
data_table = self._get_sheet_or_die(sheet_name)
with logging_context(sheet_name):
data_table = self._get_sheet_or_die(sheet_name).table
if not user_model:
LOGGER.info("Inferring RowModel automatically")
user_model = model_from_headers(sheet_name, data_table.headers)
row_parser = RowParser(user_model, CellParser())
sheet_parser = SheetParser(row_parser, data_table)
data_rows = sheet_parser.parse_all()
Expand Down
Loading

0 comments on commit 1e8c45f

Please sign in to comment.