Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic model inference #137

Merged
merged 5 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Models

## Automatic model inference

Models of sheets can now be automatically inferred if no explicit model is provided.

This is done exclusively by parsing the header row of a sheet. Headers can be annotated with types (basic types and list; dict and existing models are currently not supported). If no annotation is present, the column is assumed to be a string.

Examples of what the data in a column can represent:
- `field`: `field` is inferred to be a string
- `field:int`: `field` is inferred to be a int
- `field:list`: `field` is inferred to be a list
- `field:List[int]`: `field` is inferred to be a list of integers
- `field.1`: `field` is inferred to be a list, and this column contains its first entry
- `field.1:int`: `field` is inferred to be a list of integers, and this column contains its first entry
- `field.subfield`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield
- `field.subfield:int`: `field` is inferred to be another model with one or multiple subfields, and this column contains values for the `subfield` subfield which is inferred to be an integer
- `field.1.subfield`: `field` is inferred to be a list of another model with one or multiple subfields, and this column contains values for the `subfield` subfield of the first list entry

Intermediate models like in the last three examples are created automatically.
117 changes: 117 additions & 0 deletions src/rpft/parsers/common/model_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections import defaultdict
from typing import List, ForwardRef, _eval_type
from pydoc import locate
from pydantic import create_model

from rpft.parsers.common.rowparser import (
ParserModel,
RowParser,
RowParserError,
get_field_name,
is_list_type,
is_parser_model_type,
str_to_bool,
)


def type_from_string(string):
if not string:
# By default, assume str
return str
basic_type = locate(string)
if basic_type:
return basic_type
try:
inferred_type = _eval_type(ForwardRef(string), globals(), globals())
except NameError as e:
raise RowParserError(f'Error while parsing type "{string}": {str(e)}')
return inferred_type


def get_value_for_type(type, value=None):
if is_list_type(type):
# We do not support default values for lists.
return []
if is_parser_model_type(type):
# We do not support default values for ParserModel.
return type()
if value is not None:
if type is bool:
return str_to_bool(value)
return type(value)
return type()


def infer_type(string):
if RowParser.TYPE_ANNOTATION_SEPARATOR not in string:
return type_from_string("")
# Take the stuff between colon and equal sign
prefix, suffix = string.split(RowParser.TYPE_ANNOTATION_SEPARATOR, 1)
return type_from_string(suffix.split(RowParser.DEFAULT_VALUE_SEPARATOR)[0].strip())


def infer_default_value(type, string):
if RowParser.DEFAULT_VALUE_SEPARATOR not in string:
# Return the default value for the given type
return get_value_for_type(type)
prefix, suffix = string.split(RowParser.DEFAULT_VALUE_SEPARATOR, 1)
return get_value_for_type(type, suffix.strip())


def parse_header_annotations(string):
inferred_type = infer_type(string)
return inferred_type, infer_default_value(inferred_type, string)


def represents_integer(string):
try:
_ = int(string)
return True
except ValueError:
return False


def dict_to_list(dict):
out = [None] * (max(dict.keys()) + 1)
for k, v in dict.items():
out[k] = v
return out


def model_from_headers(name, headers):
return model_from_headers_rec(name, headers)[0]


def model_from_headers_rec(name, headers):
# Returns a model and a default value
fields = {}
complex_fields = defaultdict(list)
for header in headers:
if RowParser.HEADER_FIELD_SEPARATOR in header:
field, subheader = header.split(RowParser.HEADER_FIELD_SEPARATOR, 1)
complex_fields[field].append(subheader)
else:
field = get_field_name(header)
field_type, default_value = parse_header_annotations(header)
fields[field] = (field_type, default_value)
for field, subheaders in complex_fields.items():
# Assign model and default value
fields[field] = model_from_headers_rec(name.title() + field.title(), subheaders)

# In case the model that we're creating is a list,
# all its fields are numbers (indices).
list_model = None
list_default_values = {}
for field, value in fields.items():
if represents_integer(field):
# We do not check whether the models for each list entry match.
# We just take one of them.
list_model = value[0]
# Index shift: because in the headers, we count from 1
list_default_values[int(field) - 1] = value[1]
if list_model is not None:
return List[list_model], dict_to_list(list_default_values)

# If the model we're creating is not a list, it's a class
model = create_model(name.title(), __base__=ParserModel, **fields)
return model, get_value_for_type(model)
86 changes: 53 additions & 33 deletions src/rpft/parsers/common/rowparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,34 @@ def header_name_to_field_name_with_context(header, row):
return header


def get_list_child_model(model):
if is_basic_list_type(model):
# If not specified, list elements may be anything.
# Without additional information, we assume strings.
child_model = str
else:
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]
return child_model


def is_list_type(model):
# Determine whether model is a list type,
# such as list, List, List[str], ...
# issubclass only works for Python <=3.6
# model.__dict__.get('__origin__') returns different things in different Python
# version.
# This function tries to accommodate both 3.6 and 3.8 (at least)
return model == List or model.__dict__.get("__origin__") in [list, List]
return (
is_basic_list_type(model)
or model is List
or model.__dict__.get("__origin__") in [list, List]
)


def is_basic_list_type(model):
return model == list
return model is list


def is_list_instance(value):
Expand Down Expand Up @@ -81,12 +97,30 @@ def is_default_value(model_instance, field, field_value):
return True


def str_to_bool(string):
# in this case, the default value takes effect.
if string.lower() == "false":
return False
else:
return True


def get_field_name(string):
return (
string.split(RowParser.TYPE_ANNOTATION_SEPARATOR)[0]
.split(RowParser.DEFAULT_VALUE_SEPARATOR)[0]
.strip()
)


class RowParser:
# Takes a dictionary of cell entries, whose keys are the column names
# and the values are the cell content converted into nested lists.
# Turns this into an instance of the provided model.

HEADER_FIELD_SEPARATOR = "."
TYPE_ANNOTATION_SEPARATOR = ":"
DEFAULT_VALUE_SEPARATOR = "="

def __init__(self, model, cell_parser):
self.model = model
Expand All @@ -96,7 +130,7 @@ def __init__(self, model, cell_parser):
def try_assign_as_kwarg(self, field, key, value, model):
# If value can be interpreted as a (field, field_value) pair for a field of
# model, assign value to field[key] (which represents the field in the model)
if type(value) is list and len(value) == 2 and type(value[0]) is str:
if is_list_instance(value) and len(value) == 2 and type(value[0]) is str:
first_entry_as_key = model.header_name_to_field_name(value[0])
if first_entry_as_key in model.__fields__:
self.assign_value(
Expand Down Expand Up @@ -156,17 +190,20 @@ def assign_value(self, field, key, value, model):
entry,
model.__fields__[entry_key].outer_type_,
)

elif is_basic_list_type(model):
# We cannot iterate deeper if we don't know what to expect.
if is_iterable_instance(value):
field[key] = list(value)
else:
field[key] = [value]
elif is_list_type(model):
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]
child_model = get_list_child_model(model)
# The created entry should be a list. Value should also be a list
field[key] = []
# Note: This makes a decision on how to resolve an ambiguity when the target
# field is a list of lists, but the cell value is a 1-dimensional list.
# 1;2 → [[1],[2]] rather than [[1,2]]
if type(value) is not list:
if not is_list_instance(value):
# It could be that a list is specified via a single element.
if value == "":
# Interpret an empty cell as [] rather than ['']
Expand All @@ -178,27 +215,16 @@ def assign_value(self, field, key, value, model):
# recursively
field[key].append(None)
self.assign_value(field[key], -1, entry, child_model)

elif is_basic_list_type(model):
if is_iterable_instance(value):
field[key] = list(value)
else:
field[key] = [value]
else:
assert is_basic_type(model)
# The value should be a basic type
# TODO: Ensure the types match. E.g. we don't want value to be a list
if model == bool:
if type(value) is str:
if value.strip():
# Special case: empty string is not assigned at all;
# in this case, the default value takes effect.
if value.strip().lower() == "false":
field[key] = False
else:
# This is consistent with python: Anything that's
# not '' or explicitly False is True.
field[key] = True
stripped = value.strip()
# Special case: empty string is not assigned at all.
if stripped:
field[key] = str_to_bool(stripped)
else:
field[key] = bool(value)
else:
Expand All @@ -219,10 +245,7 @@ def find_entry(self, model, output_field, field_path):
# It'd be nicer to already have a template.
field_name = field_path[0]
if is_list_type(model):
# Get the type that's inside the list
assert len(model.__args__) == 1
child_model = model.__args__[0]

child_model = get_list_child_model(model)
index = int(field_name) - 1
if len(output_field) <= index:
# Create a new list entry for this, if necessary
Expand All @@ -248,7 +271,7 @@ def find_entry(self, model, output_field, field_path):
output_field[key] = None

if len(field_path) == 1:
# We're reach the end of the field_path
# We've reached the end of the field_path
# Therefore we've found where we need to assign
return output_field, key, child_model
else:
Expand All @@ -266,6 +289,7 @@ def parse_entry(
):
# This creates/populates a field in self.output
# The field is determined by column_name, its value by value
column_name = get_field_name(column_name)
field_path = column_name.split(RowParser.HEADER_FIELD_SEPARATOR)
# Find the destination subfield in self.output that corresponds to field_path
field, key, model = self.find_entry(self.model, self.output, field_path)
Expand All @@ -277,11 +301,7 @@ def parse_entry(
# The model of field[key] is model, and thus value should also be interpreted
# as being of type model.
if not value_is_parsed:
if (
is_basic_list_type(model)
or is_list_type(model)
or is_parser_model_type(model)
):
if is_list_type(model) or is_parser_model_type(model):
# If the expected type of the value is list/object,
# parse the cell content as such.
# Otherwise leave it as a string
Expand Down
30 changes: 15 additions & 15 deletions src/rpft/parsers/creation/contentindexparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from rpft.logger.logger import get_logger, logging_context
from rpft.parsers.common.cellparser import CellParser
from rpft.parsers.common.model_inference import model_from_headers
from rpft.parsers.common.rowparser import RowParser
from rpft.parsers.common.sheetparser import SheetParser
from rpft.parsers.creation.campaigneventrowmodel import CampaignEventRowModel
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
self.campaign_parsers = {} # name-indexed dict of CampaignParser
self.trigger_parsers = []

self.user_models_module = None
if user_data_model_module_name:
self.user_models_module = importlib.import_module(
user_data_model_module_name
Expand Down Expand Up @@ -186,11 +188,6 @@ def _get_sheet_or_die(self, sheet_name):
return active

def _process_data_sheet(self, row):
if not hasattr(self, "user_models_module"):
LOGGER.critical(
"If there are data sheets, a user_data_model_module_name "
"has to be provided"
)
sheet_names = row.sheet_name
if row.operation.type in ["filter", "sort"] and len(sheet_names) > 1:
LOGGER.warning(
Expand Down Expand Up @@ -236,19 +233,22 @@ def _get_data_sheet(self, sheet_name, data_model_name):
else:
return self._get_new_data_sheet(sheet_name, data_model_name)

def _get_new_data_sheet(self, sheet_name, data_model_name):
if not data_model_name:
LOGGER.critical("No data_model_name provided for data sheet.")
try:
user_model = getattr(self.user_models_module, data_model_name)
except AttributeError:
LOGGER.critical(
f'Undefined data_model_name "{data_model_name}" '
f"in {self.user_models_module}."
)
def _get_new_data_sheet(self, sheet_name, data_model_name=None):
user_model = None
if self.user_models_module and data_model_name:
try:
user_model = getattr(self.user_models_module, data_model_name)
except AttributeError:
LOGGER.critical(
f'Undefined data_model_name "{data_model_name}" '
f"in {self.user_models_module}."
)
data_table = self._get_sheet_or_die(sheet_name)
with logging_context(sheet_name):
data_table = self._get_sheet_or_die(sheet_name).table
if not user_model:
LOGGER.info("Inferring RowModel automatically")
user_model = model_from_headers(sheet_name, data_table.headers)
row_parser = RowParser(user_model, CellParser())
sheet_parser = SheetParser(row_parser, data_table)
data_rows = sheet_parser.parse_all()
Expand Down
Loading
Loading