11import unittest
2- from enum import Enum
32from typing import Dict , List , Optional , Union
43
54import pytest
6- from pydantic import BaseModel , Field , ValidationError
5+ from pydantic import BaseModel , ConfigDict , Field , ValidationError
76
87from datahub .ingestion .source .airbyte .models import (
98 AirbyteConnection ,
109 AirbyteDataSource ,
1110 AirbyteDestination ,
1211 AirbyteWorkspace ,
1312)
13+ from datahub .utilities .str_enum import StrEnum
1414
1515
1616# Mock classes for testing purposes that don't exist in the models.py
17- class AirbyteFieldType (str , Enum ):
17+ class AirbyteFieldType (StrEnum ):
1818 """Mock enum for field types in tests."""
1919
2020 STRING = "string"
@@ -79,7 +79,7 @@ def to_schema_fields(
7979 return fields
8080
8181
82- class AirbyteSyncMode (str , Enum ):
82+ class AirbyteSyncMode (StrEnum ):
8383 """Mock enum for sync modes in tests."""
8484
8585 FULL_REFRESH = "full_refresh"
@@ -88,15 +88,15 @@ class AirbyteSyncMode(str, Enum):
8888 OVERWRITE = "overwrite"
8989
9090
91- class AirbyteStatus (str , Enum ):
91+ class AirbyteStatus (StrEnum ):
9292 """Mock enum for connection status in tests."""
9393
9494 ACTIVE = "active"
9595 INACTIVE = "inactive"
9696 DEPRECATED = "deprecated"
9797
9898
99- class AirbyteConnectionScheduleType (str , Enum ):
99+ class AirbyteConnectionScheduleType (StrEnum ):
100100 """Mock enum for connection schedule types in tests."""
101101
102102 MANUAL = "manual"
@@ -107,27 +107,21 @@ class AirbyteConnectionScheduleType(str, Enum):
107107class AirbyteConnectionStream (BaseModel ):
108108 """Mock class for connection stream in tests."""
109109
110+ model_config = ConfigDict (
111+ populate_by_name = True ,
112+ )
113+
110114 name : str
111115 namespace : Optional [str ] = None
112116 sync_mode : AirbyteSyncMode = AirbyteSyncMode .FULL_REFRESH
113- cursor_field : List [str ] = Field (default_factory = list )
114- primary_key : List [List [str ]] = Field (default_factory = list )
115- destination_sync_mode : AirbyteSyncMode = AirbyteSyncMode .OVERWRITE
117+ cursor_field : List [str ] = Field (default_factory = list , alias = "cursorField" )
118+ primary_key : List [List [str ]] = Field (default_factory = list , alias = "primaryKey" )
119+ destination_sync_mode : AirbyteSyncMode = Field (
120+ default = AirbyteSyncMode .OVERWRITE , alias = "destinationSyncMode"
121+ )
116122 selected : bool = True
117123 stream_schema : Optional [AirbyteStreamSchema ] = Field (None , alias = "schema" )
118124
119- class Config :
120- """Pydantic configuration."""
121-
122- fields = {
123- "name" : {"alias" : "name" },
124- "namespace" : {"alias" : "namespace" },
125- "sync_mode" : {"alias" : "syncMode" },
126- "destination_sync_mode" : {"alias" : "destinationSyncMode" },
127- "cursor_field" : {"alias" : "cursorField" },
128- "primary_key" : {"alias" : "primaryKey" },
129- }
130-
131125
132126class FieldSelection (BaseModel ):
133127 """Mock class for field selection in tests."""
@@ -158,7 +152,7 @@ def test_from_dict(self):
158152 "securityUpdates" : True ,
159153 "displaySetupWizard" : False ,
160154 }
161- workspace = AirbyteWorkspace .parse_obj (data )
155+ workspace = AirbyteWorkspace .model_validate (data )
162156 assert workspace .workspace_id == "test-workspace-id"
163157 assert workspace .name == "Test Workspace"
164158 assert not hasattr (workspace , "slug" )
@@ -170,7 +164,7 @@ def test_from_dict(self):
170164
171165 def test_missing_required_fields (self ):
172166 with pytest .raises (ValidationError ):
173- AirbyteWorkspace .parse_obj ({})
167+ AirbyteWorkspace .model_validate ({})
174168
175169 def test_extra_fields_ignored (self ):
176170 data = {
@@ -179,7 +173,7 @@ def test_extra_fields_ignored(self):
179173 "slug" : "test-workspace" ,
180174 "extra_field" : "extra value" ,
181175 }
182- workspace = AirbyteWorkspace .parse_obj (data )
176+ workspace = AirbyteWorkspace .model_validate (data )
183177 assert not hasattr (workspace , "extra_field" )
184178
185179
@@ -211,7 +205,7 @@ def test_from_dict(self):
211205 "workspaceId" : "test-workspace-id" ,
212206 "configuration" : {"host" : "localhost" , "port" : 5432 },
213207 }
214- source = AirbyteDataSource .parse_obj (data )
208+ source = AirbyteDataSource .model_validate (data )
215209 assert source .source_id == "test-source-id"
216210 assert source .name == "Test Source"
217211 assert source .source_type == "postgres"
@@ -222,7 +216,7 @@ def test_from_dict(self):
222216
223217 def test_missing_required_fields (self ):
224218 with pytest .raises (ValidationError ):
225- AirbyteDataSource .parse_obj ({})
219+ AirbyteDataSource .model_validate ({})
226220
227221
228222class TestAirbyteDestination :
@@ -253,7 +247,7 @@ def test_from_dict(self):
253247 "workspaceId" : "test-workspace-id" ,
254248 "configuration" : {"host" : "localhost" , "port" : 5432 },
255249 }
256- destination = AirbyteDestination .parse_obj (data )
250+ destination = AirbyteDestination .model_validate (data )
257251 assert destination .destination_id == "test-destination-id"
258252 assert destination .name == "Test Destination"
259253 assert destination .destination_type == "postgres"
@@ -264,7 +258,7 @@ def test_from_dict(self):
264258
265259 def test_missing_required_fields (self ):
266260 with pytest .raises (ValidationError ):
267- AirbyteDestination .parse_obj ({})
261+ AirbyteDestination .model_validate ({})
268262
269263
270264class TestAirbyteConnection :
@@ -367,7 +361,7 @@ def test_from_dict(self):
367361 "memory_limit" : "" ,
368362 },
369363 }
370- connection = AirbyteConnection .parse_obj (data )
364+ connection = AirbyteConnection .model_validate (data )
371365 assert connection .connection_id == "test-connection-id"
372366 assert connection .name == "Test Connection"
373367 assert connection .source_id == "test-source-id"
@@ -378,7 +372,7 @@ def test_from_dict(self):
378372
379373 def test_missing_required_fields (self ):
380374 with pytest .raises (ValidationError ):
381- AirbyteConnection .parse_obj ({})
375+ AirbyteConnection .model_validate ({})
382376
383377
384378if __name__ == "__main__" :
0 commit comments