Skip to content

Commit

Permalink
Migrate to Pydantic v2 (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
istride authored Mar 6, 2025
1 parent bfdc04d commit 62e1c46
Show file tree
Hide file tree
Showing 15 changed files with 53 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/rpft/parsers/common/model_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import List, ForwardRef, _eval_type
from pydoc import locate
from pydantic.v1 import create_model
from pydantic import create_model

from rpft.parsers.common.rowparser import (
ParserModel,
Expand Down
21 changes: 8 additions & 13 deletions src/rpft/parsers/common/rowparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
from typing import List

from pydantic.v1 import BaseModel
from pydantic import BaseModel

from rpft.parsers.common.cellparser import CellParser

Expand Down Expand Up @@ -97,9 +97,7 @@ def is_basic_instance(value):


def is_default_value(model_instance, field, field_value):
# Note: In pydantic V2, __fields__ will become model_fields
if field_value == type(model_instance).__fields__[field].get_default():
return True
return field_value == type(model_instance).model_fields[field].default


def str_to_bool(string):
Expand Down Expand Up @@ -137,12 +135,12 @@ def try_assign_as_kwarg(self, field, key, value, model):
# model, assign value to field[key] (which represents the field in the model)
if is_list_instance(value) and len(value) == 2 and type(value[0]) is str:
first_entry_as_key = model.header_name_to_field_name(value[0])
if first_entry_as_key in model.__fields__:
if first_entry_as_key in model.model_fields:
self.assign_value(
field[key],
first_entry_as_key,
value[1],
model.__fields__[first_entry_as_key].outer_type_,
model.model_fields[first_entry_as_key].annotation,
)
return True
return False
Expand All @@ -164,7 +162,7 @@ def assign_value(self, field, key, value, model):
# Get the list of keys that are available for the target model
# Note: The fields have a well defined ordering.
# See https://pydantic-docs.helpmanual.io/usage/models/#field-ordering
model_fields = list(model.__fields__.keys())
model_fields = list(model.model_fields.keys())

if type(value) is not list:
# It could be that an object is specified via a single element.
Expand Down Expand Up @@ -193,7 +191,7 @@ def assign_value(self, field, key, value, model):
field[key],
entry_key,
entry,
model.__fields__[entry_key].outer_type_,
model.model_fields[entry_key].annotation,
)
elif is_basic_dict_type(model):
field[key] = {}
Expand Down Expand Up @@ -281,12 +279,9 @@ def find_entry(self, model, output_field, field_path):
else:
assert is_parser_model_type(model)
key = model.header_name_to_field_name(field_name)
if key not in model.__fields__:
if key not in model.model_fields:
raise ValueError(f"Field {key} doesn't exist in target type {model}.")
child_model = model.__fields__[key].outer_type_
# TODO: how does ModelField.outer_type_ and ModelField.type_
# deal with nested lists, e.g. List[List[str]]?
# Write test cases and fix code.
child_model = model.model_fields[key].annotation

if key not in output_field:
# Create a new entry for this, if necessary
Expand Down
12 changes: 8 additions & 4 deletions src/rpft/parsers/creation/campaigneventrowmodel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import field_validator

from rpft.parsers.common.rowparser import ParserModel
from pydantic.v1 import validator


class CampaignEventRowModel(ParserModel):
Expand All @@ -14,13 +15,15 @@ class CampaignEventRowModel(ParserModel):
flow: str = ""
base_language: str = ""

@validator("unit")
@field_validator("unit")
@classmethod
def validate_unit(cls, v):
if v not in ["M", "H", "D", "W"]:
raise ValueError("unit must be M (minute), H (hour), D (day) or W (week)")
return v

@validator("start_mode")
@field_validator("start_mode")
@classmethod
def validate_start_mode(cls, v):
if v not in ["I", "S", "P"]:
raise ValueError(
Expand All @@ -29,7 +32,8 @@ def validate_start_mode(cls, v):
)
return v

@validator("event_type")
@field_validator("event_type")
@classmethod
def validate_event_type(cls, v):
if v not in ["M", "F"]:
raise ValueError("event_type must be F (flow) or M (message)")
Expand Down
2 changes: 1 addition & 1 deletion src/rpft/parsers/creation/contentindexparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, rows, row_model):
def to_dict(self):
return {
"model": self.row_model.__name__,
"rows": [content.dict() for content in self.rows.values()],
"rows": [content.model_dump() for content in self.rows.values()],
}


Expand Down
10 changes: 6 additions & 4 deletions src/rpft/parsers/creation/flowparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _get_row_action(self, row):
)
elif row.type == "remove_from_group":
if not row.mainarg_groups:
LOGGER.warning(f"Removing contact from ALL groups.")
LOGGER.warning("Removing contact from ALL groups.")
return RemoveContactGroupAction(groups=[], all_groups=True)
elif row.mainarg_groups[0] == "ALL":
return RemoveContactGroupAction(groups=[], all_groups=True)
Expand Down Expand Up @@ -559,7 +559,8 @@ def _get_or_create_group(self, name, uuid=None):
def _get_row_node(self, row):
if (
row.type in ["add_to_group", "remove_from_group", "split_by_group"]
and row.obj_id and row.mainarg_groups
and row.obj_id
and row.mainarg_groups
):
self.rapidpro_container.record_group_uuid(row.mainarg_groups[0], row.obj_id)

Expand Down Expand Up @@ -804,8 +805,9 @@ def _compile_flow(self):
to fill in these missing UUIDs in a consistent way.
"""

# Caveat/TODO: Need to ensure starting node comes first.
flow_container = FlowContainer(flow_name=self.flow_name, uuid=self.flow_uuid, type=self.flow_type)
flow_container = FlowContainer(
flow_name=self.flow_name, uuid=self.flow_uuid, type=self.flow_type
)
if not len(self.node_group_stack) == 1:
raise Exception("Unexpected end of flow. Did you forget end_for/end_block?")
self.current_node_group().add_nodes_to_flow(flow_container)
Expand Down
6 changes: 6 additions & 0 deletions src/rpft/parsers/creation/flowrowmodel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pydantic import ConfigDict

from rpft.parsers.common.rowparser import ParserModel
from rpft.parsers.creation.models import Condition

Expand Down Expand Up @@ -45,6 +47,8 @@ class WhatsAppTemplating(ParserModel):


class Edge(ParserModel):
model_config = ConfigDict(coerce_numbers_to_str=True)

from_: str = ""
condition: Condition = Condition()

Expand All @@ -65,6 +69,8 @@ def header_name_to_field_name_with_context(header, row):


class FlowRowModel(ParserModel):
model_config = ConfigDict(coerce_numbers_to_str=True)

row_id: str = ""
type: str
edges: list[Edge]
Expand Down
14 changes: 8 additions & 6 deletions src/rpft/parsers/creation/triggerrowmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic.v1 import validator
from pydantic import field_validator, model_validator

from rpft.parsers.common.rowparser import ParserModel

Expand All @@ -12,7 +12,8 @@ class TriggerRowModel(ParserModel):
channel: str = ""
match_type: str = ""

@validator("type")
@field_validator("type")
@classmethod
def validate_type(cls, v):
if v not in ["K", "C", "M", "T"]:
raise ValueError(
Expand All @@ -21,10 +22,11 @@ def validate_type(cls, v):
)
return v

@validator("match_type")
def validate_match_type(cls, v, values):
if values["type"] == "K" and v not in ["F", "O", ""]:
@model_validator(mode="after")
def validate_match_type(self):
if self.type == "K" and self.match_type not in ["F", "O", ""]:
raise ValueError(
'match_type must be "F" (starts with) or "O" (only) if type is "K".'
)
return v

return self
2 changes: 2 additions & 0 deletions src/rpft/rapidpro/models/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
class RapidProActionError(Exception):
"Raised if some parameter of a RapidProAction is invalid."

pass


class RapidProRouterError(Exception):
"Raised if some parameter of a RapidProRouter is invalid."

pass
6 changes: 3 additions & 3 deletions tests/test_contentindexparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def test_basic_user_model(self):

def test_flow_type(self):
ci_sheet = (
"type,sheet_name,data_sheet,data_row_id,new_name,data_model,options\n"
"create_flow,my_basic_flow,,,,,\n"
"create_flow,my_basic_flow,,,my_other_flow,,flow_type;messaging_background\n"
"type,sheet_name,new_name,data_model,options\n"
"create_flow,my_basic_flow,,,\n"
"create_flow,my_basic_flow,my_other_flow,,flow_type;messaging_background\n"
)
my_basic_flow = csv_join(
"row_id,type,from,message_text",
Expand Down
4 changes: 1 addition & 3 deletions tests/test_differentways.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,9 @@ def test_different_ways(self):
for inp in inputs:
out = self.parser.parse_row(inp) # We get an instance of the model
outputs.append(out)
# Note: we can also serialize via out.json(indent=4) for printing
# or out.dict()

for out in outputs:
self.assertEqual(out.dict(), output_instance)
self.assertEqual(out.model_dump(), output_instance)

def test_single_kwarg(self):
output_single_kwarg = self.parser.parse_row(input_single_kwarg)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_flowparser_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def setUp(self) -> None:
pass

def compare_models_leniently(self, model, model_exp):
model_dict = model.dict()
model_exp_dict = model_exp.dict()
model_dict = model.model_dump()
model_exp_dict = model_exp.model_dump()
for edge, edge_exp in zip(model_dict["edges"], model_exp_dict["edges"]):
if (
edge["condition"]["variable"] == "@input.text"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
import unittest
from pydantic.v1 import create_model, BaseModel
from pydantic import create_model, BaseModel

from rpft.parsers.common.model_inference import (
get_value_for_type,
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_parse_header_annotations(self):
self.assertEqual(parse_header_annotations("field:int=5"), (int, 5))

def compare_models(self, model1, model2, **kwargs):
self.assertEqual(model1(**kwargs).dict(), model2(**kwargs).dict())
self.assertEqual(model1(**kwargs).model_dump(), model2(**kwargs).model_dump())

def test_model_from_headers(self):
self.compare_models(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rowparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_convert_empty(self):
self.assertEqual(out, self.emptyModel)

def test_convert_single_element(self):
self.oneModel = DictModel(**{"dict_field": {"K" : "V"}})
self.oneModel = DictModel(**{"dict_field": {"K": "V"}})
inputs = [
{"dict_field": ["K", "V"]},
{"dict_field": [["K", "V"]]},
Expand All @@ -184,7 +184,7 @@ def test_convert_single_element(self):
self.assertEqual(out, self.oneModel)

def test_convert_two_element(self):
self.onetwoModel = DictModel(**{"dict_field": {"K1" : "V1", "K2" : "V2"}})
self.onetwoModel = DictModel(**{"dict_field": {"K1": "V1", "K2": "V2"}})
inputs = [
{"dict_field": [["K1", "V1"], ["K2", "V2"]]},
{"dict_field.K1": "V1", "dict_field.K2": "V2"},
Expand Down
4 changes: 2 additions & 2 deletions tests/test_to_row_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def compare_row_models_without_uuid(self, row_models1, row_models2):
self.maxDiff = None
self.assertEqual(len(row_models1), len(row_models2))
for model1, model2 in zip(row_models1, row_models2):
data1 = model1.dict()
data2 = model2.dict()
data1 = model1.model_dump()
data2 = model2.model_dump()
if not data1["node_uuid"] or not data2["node_uuid"]:
# If one of them is blank, skip the comparison
data1.pop("node_uuid")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ModelWithStuff(ParserModel):

class MainModel(ParserModel):
str_field: str = ""
model_optional: Optional[ModelWithStuff]
model_optional: Optional[ModelWithStuff] = None
model_default: ModelWithStuff = ModelWithStuff()
model_list: List[ModelWithStuff] = []

Expand Down

0 comments on commit 62e1c46

Please sign in to comment.