From 62e1c46ed763de813c1df287f57c45159acb61f0 Mon Sep 17 00:00:00 2001 From: Ian Stride Date: Thu, 6 Mar 2025 17:24:14 +0000 Subject: [PATCH] Migrate to Pydantic v2 (#174) --- src/rpft/parsers/common/model_inference.py | 2 +- src/rpft/parsers/common/rowparser.py | 21 +++++++------------ .../parsers/creation/campaigneventrowmodel.py | 12 +++++++---- .../parsers/creation/contentindexparser.py | 2 +- src/rpft/parsers/creation/flowparser.py | 10 +++++---- src/rpft/parsers/creation/flowrowmodel.py | 6 ++++++ src/rpft/parsers/creation/triggerrowmodel.py | 14 +++++++------ src/rpft/rapidpro/models/exceptions.py | 2 ++ tests/test_contentindexparser.py | 6 +++--- tests/test_differentways.py | 4 +--- tests/test_flowparser_reverse.py | 4 ++-- tests/test_model_inference.py | 4 ++-- tests/test_rowparser.py | 4 ++-- tests/test_to_row_model.py | 4 ++-- tests/test_unparse.py | 2 +- 15 files changed, 53 insertions(+), 44 deletions(-) diff --git a/src/rpft/parsers/common/model_inference.py b/src/rpft/parsers/common/model_inference.py index 3d6aa36..ae836a9 100644 --- a/src/rpft/parsers/common/model_inference.py +++ b/src/rpft/parsers/common/model_inference.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import List, ForwardRef, _eval_type from pydoc import locate -from pydantic.v1 import create_model +from pydantic import create_model from rpft.parsers.common.rowparser import ( ParserModel, diff --git a/src/rpft/parsers/common/rowparser.py b/src/rpft/parsers/common/rowparser.py index 18d80bb..3971a69 100644 --- a/src/rpft/parsers/common/rowparser.py +++ b/src/rpft/parsers/common/rowparser.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from typing import List -from pydantic.v1 import BaseModel +from pydantic import BaseModel from rpft.parsers.common.cellparser import CellParser @@ -97,9 +97,7 @@ def is_basic_instance(value): def is_default_value(model_instance, field, field_value): - # Note: In pydantic V2, __fields__ will become model_fields - if field_value == type(model_instance).__fields__[field].get_default(): - return True + return field_value == type(model_instance).model_fields[field].default def str_to_bool(string): @@ -137,12 +135,12 @@ def try_assign_as_kwarg(self, field, key, value, model): # model, assign value to field[key] (which represents the field in the model) if is_list_instance(value) and len(value) == 2 and type(value[0]) is str: first_entry_as_key = model.header_name_to_field_name(value[0]) - if first_entry_as_key in model.__fields__: + if first_entry_as_key in model.model_fields: self.assign_value( field[key], first_entry_as_key, value[1], - model.__fields__[first_entry_as_key].outer_type_, + model.model_fields[first_entry_as_key].annotation, ) return True return False @@ -164,7 +162,7 @@ def assign_value(self, field, key, value, model): # Get the list of keys that are available for the target model # Note: The fields have a well defined ordering. # See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering - model_fields = list(model.__fields__.keys()) + model_fields = list(model.model_fields.keys()) if type(value) is not list: # It could be that an object is specified via a single element. @@ -193,7 +191,7 @@ def assign_value(self, field, key, value, model): field[key], entry_key, entry, - model.__fields__[entry_key].outer_type_, + model.model_fields[entry_key].annotation, ) elif is_basic_dict_type(model): field[key] = {} @@ -281,12 +279,9 @@ def find_entry(self, model, output_field, field_path): else: assert is_parser_model_type(model) key = model.header_name_to_field_name(field_name) - if key not in model.__fields__: + if key not in model.model_fields: raise ValueError(f"Field {key} doesn't exist in target type {model}.") - child_model = model.__fields__[key].outer_type_ - # TODO: how does ModelField.outer_type_ and ModelField.type_ - # deal with nested lists, e.g. List[List[str]]? - # Write test cases and fix code. + child_model = model.model_fields[key].annotation if key not in output_field: # Create a new entry for this, if necessary diff --git a/src/rpft/parsers/creation/campaigneventrowmodel.py b/src/rpft/parsers/creation/campaigneventrowmodel.py index e48c955..216caec 100644 --- a/src/rpft/parsers/creation/campaigneventrowmodel.py +++ b/src/rpft/parsers/creation/campaigneventrowmodel.py @@ -1,5 +1,6 @@ +from pydantic import field_validator + from rpft.parsers.common.rowparser import ParserModel -from pydantic.v1 import validator class CampaignEventRowModel(ParserModel): @@ -14,13 +15,15 @@ class CampaignEventRowModel(ParserModel): flow: str = "" base_language: str = "" - @validator("unit") + @field_validator("unit") + @classmethod def validate_unit(cls, v): if v not in ["M", "H", "D", "W"]: raise ValueError("unit must be M (minute), H (hour), D (day) or W (week)") return v - @validator("start_mode") + @field_validator("start_mode") + @classmethod def validate_start_mode(cls, v): if v not in ["I", "S", "P"]: raise ValueError( @@ -29,7 +32,8 @@ def validate_start_mode(cls, v): ) return v - @validator("event_type") + @field_validator("event_type") + @classmethod def validate_event_type(cls, v): if v not in ["M", "F"]: raise ValueError("event_type must be F (flow) or M (message)") diff --git a/src/rpft/parsers/creation/contentindexparser.py b/src/rpft/parsers/creation/contentindexparser.py index efcc418..c4440b9 100644 --- a/src/rpft/parsers/creation/contentindexparser.py +++ b/src/rpft/parsers/creation/contentindexparser.py @@ -36,7 +36,7 @@ def __init__(self, rows, row_model): def to_dict(self): return { "model": self.row_model.__name__, - "rows": [content.dict() for content in self.rows.values()], + "rows": [content.model_dump() for content in self.rows.values()], } diff --git a/src/rpft/parsers/creation/flowparser.py b/src/rpft/parsers/creation/flowparser.py index 970e64a..9265c16 100644 --- a/src/rpft/parsers/creation/flowparser.py +++ b/src/rpft/parsers/creation/flowparser.py @@ -514,7 +514,7 @@ def _get_row_action(self, row): ) elif row.type == "remove_from_group": if not row.mainarg_groups: - LOGGER.warning(f"Removing contact from ALL groups.") + LOGGER.warning("Removing contact from ALL groups.") return RemoveContactGroupAction(groups=[], all_groups=True) elif row.mainarg_groups[0] == "ALL": return RemoveContactGroupAction(groups=[], all_groups=True) @@ -559,7 +559,8 @@ def _get_or_create_group(self, name, uuid=None): def _get_row_node(self, row): if ( row.type in ["add_to_group", "remove_from_group", "split_by_group"] - and row.obj_id and row.mainarg_groups + and row.obj_id + and row.mainarg_groups ): self.rapidpro_container.record_group_uuid(row.mainarg_groups[0], row.obj_id) @@ -804,8 +805,9 @@ def _compile_flow(self): to fill in these missing UUIDs in a consistent way. """ - # Caveat/TODO: Need to ensure starting node comes first. - flow_container = FlowContainer(flow_name=self.flow_name, uuid=self.flow_uuid, type=self.flow_type) + flow_container = FlowContainer( + flow_name=self.flow_name, uuid=self.flow_uuid, type=self.flow_type + ) if not len(self.node_group_stack) == 1: raise Exception("Unexpected end of flow. Did you forget end_for/end_block?") self.current_node_group().add_nodes_to_flow(flow_container) diff --git a/src/rpft/parsers/creation/flowrowmodel.py b/src/rpft/parsers/creation/flowrowmodel.py index 38f66aa..baaaadc 100644 --- a/src/rpft/parsers/creation/flowrowmodel.py +++ b/src/rpft/parsers/creation/flowrowmodel.py @@ -1,3 +1,5 @@ +from pydantic import ConfigDict + from rpft.parsers.common.rowparser import ParserModel from rpft.parsers.creation.models import Condition @@ -45,6 +47,8 @@ class WhatsAppTemplating(ParserModel): class Edge(ParserModel): + model_config = ConfigDict(coerce_numbers_to_str=True) + from_: str = "" condition: Condition = Condition() @@ -65,6 +69,8 @@ def header_name_to_field_name_with_context(header, row): class FlowRowModel(ParserModel): + model_config = ConfigDict(coerce_numbers_to_str=True) + row_id: str = "" type: str edges: list[Edge] diff --git a/src/rpft/parsers/creation/triggerrowmodel.py b/src/rpft/parsers/creation/triggerrowmodel.py index f8a2ed9..0ff6b83 100644 --- a/src/rpft/parsers/creation/triggerrowmodel.py +++ b/src/rpft/parsers/creation/triggerrowmodel.py @@ -1,4 +1,4 @@ -from pydantic.v1 import validator +from pydantic import field_validator, model_validator from rpft.parsers.common.rowparser import ParserModel @@ -12,7 +12,8 @@ class TriggerRowModel(ParserModel): channel: str = "" match_type: str = "" - @validator("type") + @field_validator("type") + @classmethod def validate_type(cls, v): if v not in ["K", "C", "M", "T"]: raise ValueError( @@ -21,10 +22,11 @@ def validate_type(cls, v): ) return v - @validator("match_type") - def validate_match_type(cls, v, values): - if values["type"] == "K" and v not in ["F", "O", ""]: + @model_validator(mode="after") + def validate_match_type(self): + if self.type == "K" and self.match_type not in ["F", "O", ""]: raise ValueError( 'match_type must be "F" (starts with) or "O" (only) if type is "K".' ) - return v + + return self diff --git a/src/rpft/rapidpro/models/exceptions.py b/src/rpft/rapidpro/models/exceptions.py index abbacff..c859c34 100644 --- a/src/rpft/rapidpro/models/exceptions.py +++ b/src/rpft/rapidpro/models/exceptions.py @@ -1,8 +1,10 @@ class RapidProActionError(Exception): "Raised if some parameter of a RapidProAction is invalid." + pass class RapidProRouterError(Exception): "Raised if some parameter of a RapidProRouter is invalid." + pass diff --git a/tests/test_contentindexparser.py b/tests/test_contentindexparser.py index 908d7a5..e2bb72e 100644 --- a/tests/test_contentindexparser.py +++ b/tests/test_contentindexparser.py @@ -137,9 +137,9 @@ def test_basic_user_model(self): def test_flow_type(self): ci_sheet = ( - "type,sheet_name,data_sheet,data_row_id,new_name,data_model,options\n" - "create_flow,my_basic_flow,,,,,\n" - "create_flow,my_basic_flow,,,my_other_flow,,flow_type;messaging_background\n" + "type,sheet_name,new_name,data_model,options\n" + "create_flow,my_basic_flow,,,\n" + "create_flow,my_basic_flow,my_other_flow,,flow_type;messaging_background\n" ) my_basic_flow = csv_join( "row_id,type,from,message_text", diff --git a/tests/test_differentways.py b/tests/test_differentways.py index b65eccd..40fa579 100644 --- a/tests/test_differentways.py +++ b/tests/test_differentways.py @@ -124,11 +124,9 @@ def test_different_ways(self): for inp in inputs: out = self.parser.parse_row(inp) # We get an instance of the model outputs.append(out) - # Note: we can also serialize via out.json(indent=4) for printing - # or out.dict() for out in outputs: - self.assertEqual(out.dict(), output_instance) + self.assertEqual(out.model_dump(), output_instance) def test_single_kwarg(self): output_single_kwarg = self.parser.parse_row(input_single_kwarg) diff --git a/tests/test_flowparser_reverse.py b/tests/test_flowparser_reverse.py index 8a42f3d..ba04533 100644 --- a/tests/test_flowparser_reverse.py +++ b/tests/test_flowparser_reverse.py @@ -10,8 +10,8 @@ def setUp(self) -> None: pass def compare_models_leniently(self, model, model_exp): - model_dict = model.dict() - model_exp_dict = model_exp.dict() + model_dict = model.model_dump() + model_exp_dict = model_exp.model_dump() for edge, edge_exp in zip(model_dict["edges"], model_exp_dict["edges"]): if ( edge["condition"]["variable"] == "@input.text" diff --git a/tests/test_model_inference.py b/tests/test_model_inference.py index d677583..fa7a85b 100644 --- a/tests/test_model_inference.py +++ b/tests/test_model_inference.py @@ -1,6 +1,6 @@ from typing import List import unittest -from pydantic.v1 import create_model, BaseModel +from pydantic import create_model, BaseModel from rpft.parsers.common.model_inference import ( get_value_for_type, @@ -60,7 +60,7 @@ def test_parse_header_annotations(self): self.assertEqual(parse_header_annotations("field:int=5"), (int, 5)) def compare_models(self, model1, model2, **kwargs): - self.assertEqual(model1(**kwargs).dict(), model2(**kwargs).dict()) + self.assertEqual(model1(**kwargs).model_dump(), model2(**kwargs).model_dump()) def test_model_from_headers(self): self.compare_models( diff --git a/tests/test_rowparser.py b/tests/test_rowparser.py index ce0b8e8..1c35596 100644 --- a/tests/test_rowparser.py +++ b/tests/test_rowparser.py @@ -173,7 +173,7 @@ def test_convert_empty(self): self.assertEqual(out, self.emptyModel) def test_convert_single_element(self): - self.oneModel = DictModel(**{"dict_field": {"K" : "V"}}) + self.oneModel = DictModel(**{"dict_field": {"K": "V"}}) inputs = [ {"dict_field": ["K", "V"]}, {"dict_field": [["K", "V"]]}, @@ -184,7 +184,7 @@ def test_convert_single_element(self): self.assertEqual(out, self.oneModel) def test_convert_two_element(self): - self.onetwoModel = DictModel(**{"dict_field": {"K1" : "V1", "K2" : "V2"}}) + self.onetwoModel = DictModel(**{"dict_field": {"K1": "V1", "K2": "V2"}}) inputs = [ {"dict_field": [["K1", "V1"], ["K2", "V2"]]}, {"dict_field.K1": "V1", "dict_field.K2": "V2"}, diff --git a/tests/test_to_row_model.py b/tests/test_to_row_model.py index e0b2c5d..bdd020e 100644 --- a/tests/test_to_row_model.py +++ b/tests/test_to_row_model.py @@ -44,8 +44,8 @@ def compare_row_models_without_uuid(self, row_models1, row_models2): self.maxDiff = None self.assertEqual(len(row_models1), len(row_models2)) for model1, model2 in zip(row_models1, row_models2): - data1 = model1.dict() - data2 = model2.dict() + data1 = model1.model_dump() + data2 = model2.model_dump() if not data1["node_uuid"] or not data2["node_uuid"]: # If one of them is blank, skip the comparison data1.pop("node_uuid") diff --git a/tests/test_unparse.py b/tests/test_unparse.py index ad1b605..8291749 100644 --- a/tests/test_unparse.py +++ b/tests/test_unparse.py @@ -15,7 +15,7 @@ class ModelWithStuff(ParserModel): class MainModel(ParserModel): str_field: str = "" - model_optional: Optional[ModelWithStuff] + model_optional: Optional[ModelWithStuff] = None model_default: ModelWithStuff = ModelWithStuff() model_list: List[ModelWithStuff] = []