Skip to content

Commit 8806311

Browse files
authored
Merge pull request #412 from transform-data/qmalcolm--improve-transformations-workflow
Improve Terminology and Implementation of `TransformModel`
2 parents 24d5310 + 453238d commit 8806311

15 files changed

+142
-116
lines changed

metricflow/model/model_transformer.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33

4-
from typing import Sequence
4+
from typing import Sequence, Tuple
55

66
from metricflow.model.objects.user_configured_model import UserConfiguredModel
77
from metricflow.model.transformations.add_input_metric_measures import AddInputMetricMeasuresRule
@@ -23,12 +23,12 @@ class ModelTransformer:
2323
Generally used to make it more convenient for the user to develop their model.
2424
"""
2525

26-
DEFAULT_PRE_VALIDATION_RULES: Sequence[ModelTransformRule] = (
26+
PRIMARY_RULES: Sequence[ModelTransformRule] = (
2727
LowerCaseNamesRule(),
2828
SetMeasureAggregationTimeDimensionRule(),
2929
)
3030

31-
DEFAULT_POST_VALIDATION_RULES: Sequence[ModelTransformRule] = (
31+
SECONDARY_RULES: Sequence[ModelTransformRule] = (
3232
CreateProxyMeasureRule(),
3333
BooleanMeasureAggregationRule(),
3434
CompositeIdentifierExpressionRule(),
@@ -37,25 +37,49 @@ class ModelTransformer:
3737
AddInputMetricMeasuresRule(),
3838
)
3939

40+
DEFAULT_RULES: Tuple[Sequence[ModelTransformRule], ...] = (
41+
PRIMARY_RULES,
42+
SECONDARY_RULES,
43+
)
44+
4045
@staticmethod
41-
def pre_validation_transform_model(
42-
model: UserConfiguredModel, rules: Sequence[ModelTransformRule] = DEFAULT_PRE_VALIDATION_RULES
46+
def transform(
47+
model: UserConfiguredModel,
48+
ordered_rule_sequences: Tuple[Sequence[ModelTransformRule], ...] = DEFAULT_RULES,
4349
) -> UserConfiguredModel:
44-
"""Transform a model according to configured rules before validations are run."""
50+
"""Copies the passed in model, applies the rules to the new model, and then returns that model
51+
52+
It's important to note that some rules need to happen before or after other rules. Thus rules
53+
are passed in as an ordered tuple of rule sequences. Primary rules are run first, and then
54+
secondary rules. We don't currently have tertiary, quaternary, or etc currently, but this
55+
system easily allows for it.
56+
"""
4557
model_copy = copy.deepcopy(model)
4658

47-
for transform_rule in rules:
48-
model_copy = transform_rule.transform_model(model_copy)
59+
for rule_sequence in ordered_rule_sequences:
60+
for rule in rule_sequence:
61+
model_copy = rule.transform_model(model_copy)
4962

5063
return model_copy
5164

65+
@staticmethod
66+
def pre_validation_transform_model(
67+
model: UserConfiguredModel, rules: Sequence[ModelTransformRule] = PRIMARY_RULES
68+
) -> UserConfiguredModel:
69+
"""Transform a model according to configured rules before validations are run."""
70+
logger.warning(
71+
"DEPRECATION: `ModelTransformer.pre_validation_transform_model` is deprecated. Please use `ModelTransformer.transform` instead."
72+
)
73+
74+
return ModelTransformer.transform(model=model, ordered_rule_sequences=(rules,))
75+
5276
@staticmethod
5377
def post_validation_transform_model(
54-
model: UserConfiguredModel, rules: Sequence[ModelTransformRule] = DEFAULT_POST_VALIDATION_RULES
78+
model: UserConfiguredModel, rules: Sequence[ModelTransformRule] = SECONDARY_RULES
5579
) -> UserConfiguredModel:
5680
"""Transform a model according to configured rules after validations are run."""
57-
model_copy = copy.deepcopy(model)
58-
for transform_rule in rules:
59-
model_copy = transform_rule.transform_model(model_copy)
81+
logger.warning(
82+
"DEPRECATION: `ModelTransformer.post_validation_transform_model` is deprecated. Please use `ModelTransformer.transform` instead."
83+
)
6084

61-
return model_copy
85+
return ModelTransformer.transform(model=model, ordered_rule_sequences=(rules,))

metricflow/model/parsing/dbt_cloud_to_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def get_dbt_cloud_metrics(auth: str, job_id: str) -> list[MetricNode]:
4040
def parse_dbt_cloud_metrics_to_model(dbt_metrics: List[MetricNode]) -> ModelBuildResult:
4141
"""Builds a UserConfiguredModel from a list of dbt cloud MetricNodes"""
4242
build_result = DbtConverter().convert(dbt_metrics=tuple(dbt_metrics))
43-
transformed_model = ModelTransformer.pre_validation_transform_model(model=build_result.model)
44-
transformed_model = ModelTransformer.post_validation_transform_model(model=transformed_model)
43+
transformed_model = ModelTransformer.transform(model=build_result.model)
4544
return ModelBuildResult(model=transformed_model, issues=build_result.issues)
4645

4746

metricflow/model/parsing/dbt_dir_to_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,5 @@ def parse_dbt_project_to_model(
4040
"""Parse dbt model files in the given directory to a UserConfiguredModel."""
4141
manifest = get_dbt_project_manifest(directory=directory, profile=profile, target=target)
4242
build_result = DbtManifestTransformer(manifest=manifest).build_user_configured_model()
43-
transformed_model = ModelTransformer.pre_validation_transform_model(model=build_result.model)
44-
transformed_model = ModelTransformer.post_validation_transform_model(model=transformed_model)
43+
transformed_model = ModelTransformer.transform(model=build_result.model)
4544
return ModelBuildResult(model=transformed_model, issues=build_result.issues)

metricflow/model/parsing/dir_to_model.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def collect_yaml_config_file_paths(directory: str) -> List[str]:
104104
def parse_directory_of_yaml_files_to_model(
105105
directory: str,
106106
template_mapping: Optional[Dict[str, str]] = None,
107-
apply_pre_transformations: Optional[bool] = True,
108-
apply_post_transformations: Optional[bool] = True,
107+
apply_transformations: Optional[bool] = True,
109108
raise_issues_as_exceptions: bool = True,
110109
) -> ModelBuildResult:
111110
"""Parse files in the given directory to a UserConfiguredModel.
@@ -116,17 +115,15 @@ def parse_directory_of_yaml_files_to_model(
116115
return parse_yaml_file_paths_to_model(
117116
file_paths=file_paths,
118117
template_mapping=template_mapping,
119-
apply_pre_transformations=apply_pre_transformations,
120-
apply_post_transformations=apply_post_transformations,
118+
apply_transformations=apply_transformations,
121119
raise_issues_as_exceptions=raise_issues_as_exceptions,
122120
)
123121

124122

125123
def parse_yaml_file_paths_to_model(
126124
file_paths: List[str],
127125
template_mapping: Optional[Dict[str, str]] = None,
128-
apply_pre_transformations: Optional[bool] = True,
129-
apply_post_transformations: Optional[bool] = True,
126+
apply_transformations: Optional[bool] = True,
130127
raise_issues_as_exceptions: bool = True,
131128
) -> ModelBuildResult:
132129
"""Parse files the given list of file paths to a UserConfiguredModel.
@@ -157,16 +154,14 @@ def parse_yaml_file_paths_to_model(
157154

158155
return parse_yaml_files_to_validation_ready_model(
159156
yaml_config_files=yaml_config_files,
160-
apply_pre_transformations=apply_pre_transformations,
161-
apply_post_transformations=apply_post_transformations,
157+
apply_transformations=apply_transformations,
162158
raise_issues_as_exceptions=raise_issues_as_exceptions,
163159
)
164160

165161

166162
def parse_yaml_files_to_validation_ready_model(
167163
yaml_config_files: List[YamlConfigFile],
168-
apply_pre_transformations: Optional[bool] = True,
169-
apply_post_transformations: Optional[bool] = True,
164+
apply_transformations: Optional[bool] = True,
170165
raise_issues_as_exceptions: bool = True,
171166
) -> ModelBuildResult:
172167
"""Parse and transform the given set of in-memory YamlConfigFiles to a UserConfigured model
@@ -182,11 +177,8 @@ def parse_yaml_files_to_validation_ready_model(
182177

183178
build_issues = build_result.issues
184179
try:
185-
if apply_pre_transformations:
186-
model = ModelTransformer.pre_validation_transform_model(model)
187-
188-
if apply_post_transformations:
189-
model = ModelTransformer.post_validation_transform_model(model)
180+
if apply_transformations:
181+
model = ModelTransformer.transform(model)
190182
except Exception as e:
191183
transformation_issue_results = ModelValidationResults(errors=[ValidationError(message=str(e))])
192184
build_issues = ModelValidationResults.merge([build_issues, transformation_issue_results])

metricflow/test/fixtures/model_fixtures.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
1212
from metricflow.dataflow.dataflow_plan import ReadSqlSourceNode, BaseOutput
1313
from metricflow.dataset.convert_data_source import DataSourceToDataSetConverter
14+
from metricflow.model.model_transformer import ModelTransformer
1415
from metricflow.model.model_validator import ModelValidator
1516
from metricflow.model.objects.data_source import DataSource
1617
from metricflow.model.objects.user_configured_model import UserConfiguredModel
@@ -252,16 +253,18 @@ def simple_user_configured_model(template_mapping: Dict[str, str]) -> UserConfig
252253

253254

254255
@pytest.fixture(scope="session")
255-
def simple_model__pre_transforms(template_mapping: Dict[str, str]) -> UserConfiguredModel:
256+
def simple_model__with_primary_transforms(template_mapping: Dict[str, str]) -> UserConfiguredModel:
256257
"""Model used for tests pre-transformations."""
257258

258259
model_build_result = parse_directory_of_yaml_files_to_model(
259260
os.path.join(os.path.dirname(__file__), "model_yamls/simple_model"),
260261
template_mapping=template_mapping,
261-
apply_pre_transformations=True,
262-
apply_post_transformations=False,
262+
apply_transformations=False,
263263
)
264-
return model_build_result.model
264+
transformed_model = ModelTransformer.transform(
265+
model=model_build_result.model, ordered_rule_sequences=(ModelTransformer.PRIMARY_RULES,)
266+
)
267+
return transformed_model
265268

266269

267270
@pytest.fixture(scope="session")

metricflow/test/model/transformations/test_configurable_transform_rules.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@ def transform_model(model: UserConfiguredModel) -> UserConfiguredModel: # noqa:
1616
return model
1717

1818

19-
def test_can_configure_model_transform_rules(simple_model__pre_transforms: UserConfiguredModel) -> None: # noqa: D
20-
pre_model = simple_model__pre_transforms
19+
def test_can_configure_model_transform_rules( # noqa: D
20+
simple_model__with_primary_transforms: UserConfiguredModel,
21+
) -> None:
22+
pre_model = simple_model__with_primary_transforms
2123
assert not all(len(x.name) == 3 for x in pre_model.data_sources)
2224

23-
# Confirms that a custom transformation works for pre-validation transform
24-
pre_model = ModelTransformer.pre_validation_transform_model(pre_model, rules=[SliceNamesRule()])
25-
assert all(len(x.name) == 3 for x in pre_model.data_sources)
26-
27-
post_model = simple_model__pre_transforms
28-
assert not all(len(x.name) == 3 for x in post_model.data_sources)
29-
30-
# Confirms that a custom transformation works for post-validation transform
31-
post_model = ModelTransformer.post_validation_transform_model(post_model, rules=[SliceNamesRule()])
32-
assert all(len(x.name) == 3 for x in post_model.data_sources)
25+
# Confirms that a custom transformation works `for ModelTransformer.transform`
26+
rules = [SliceNamesRule()]
27+
transformed_model = ModelTransformer.transform(pre_model, ordered_rule_sequences=(rules,))
28+
assert all(len(x.name) == 3 for x in transformed_model.data_sources)

metricflow/test/model/validations/test_common_identifiers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313

1414
@pytest.mark.skip("TODO: re-enforce after validations improvements")
15-
def test_lonely_identifier_raises_issue(simple_model__pre_transforms: UserConfiguredModel) -> None: # noqa: D
16-
model = copy.deepcopy(simple_model__pre_transforms)
15+
def test_lonely_identifier_raises_issue(simple_model__with_primary_transforms: UserConfiguredModel) -> None: # noqa: D
16+
model = copy.deepcopy(simple_model__with_primary_transforms)
1717
lonely_identifier_name = "hi_im_lonely"
1818

1919
func: Callable[[DataSource], bool] = lambda data_source: len(data_source.identifiers) > 0

metricflow/test/model/validations/test_configurable_rules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from metricflow.test.test_utils import model_with_materialization
99

1010

11-
def test_can_configure_model_validator_rules(simple_model__pre_transforms: UserConfiguredModel) -> None: # noqa: D
11+
def test_can_configure_model_validator_rules( # noqa: D
12+
simple_model__with_primary_transforms: UserConfiguredModel,
13+
) -> None:
1214
model = model_with_materialization(
13-
simple_model__pre_transforms,
15+
simple_model__with_primary_transforms,
1416
[
1517
materialization_with_guaranteed_meta(
1618
name="foobar",

metricflow/test/model/validations/test_data_warehouse_tasks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,7 @@ def test_validate_metrics( # noqa: D
280280
)
281281
model.data_sources[0].measures = new_measures
282282
model.metrics = []
283-
model = ModelTransformer.pre_validation_transform_model(model)
284-
model = ModelTransformer.post_validation_transform_model(model)
283+
model = ModelTransformer.transform(model)
285284

286285
# Validate new metric created by proxy causes an issue (because the column used doesn't exist)
287286
dw_validator = DataWarehouseModelValidator(

metricflow/test/model/validations/test_element_const.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def _categorical_dimensions(data_source: DataSource) -> Tuple[Dimension, ...]:
1515
return tuple(dim for dim in data_source.dimensions if dim.type == DimensionType.CATEGORICAL)
1616

1717

18-
def test_cross_element_names(simple_model__pre_transforms: UserConfiguredModel) -> None: # noqa:D
19-
model = copy.deepcopy(simple_model__pre_transforms)
18+
def test_cross_element_names(simple_model__with_primary_transforms: UserConfiguredModel) -> None: # noqa:D
19+
model = copy.deepcopy(simple_model__with_primary_transforms)
2020

2121
# ensure we have a usable data source for the test
2222
usable_ds, usable_ds_index = find_data_source_with(

0 commit comments

Comments
 (0)