Skip to content

Commit d80696c

Browse files
committed
Expand UnitOfWork and introduce SqlalchemyUnitOfWork, handling transactions; update inspect_bin test to use UnitOfWork
1 parent 81981c8 commit d80696c

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

dor/service_layer/unit_of_work.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,91 @@
1+
from abc import ABC, abstractmethod
2+
3+
from sqlalchemy import create_engine
4+
from sqlalchemy.orm import sessionmaker
5+
6+
from dor.adapters.catalog import MemoryCatalog, SqlalchemyCatalog, _custom_json_serializer
7+
from dor.config import config
18
from dor.domain.events import Event
29
from gateway.repository_gateway import RepositoryGateway
310

411

5-
class UnitOfWork:
12+
class AbstractUnitOfWork(ABC):
13+
14+
@abstractmethod
15+
def __enter__(self):
16+
raise NotImplementedError
17+
18+
@abstractmethod
19+
def __exit__(self, *args):
20+
raise NotImplementedError
21+
22+
@abstractmethod
23+
def commit(self) -> None:
24+
raise NotImplementedError
25+
26+
@abstractmethod
27+
def rollback(self) -> None:
28+
raise NotImplementedError
29+
30+
def add_event(self, event: Event):
31+
raise NotImplementedError
32+
33+
def pop_event(self) -> Event | None:
34+
raise NotImplementedError
35+
36+
37+
class UnitOfWork(AbstractUnitOfWork):
638

739
def __init__(self, gateway: RepositoryGateway) -> None:
840
self.gateway = gateway
941
self.events: list[Event] = []
42+
self.catalog = MemoryCatalog()
43+
44+
def __enter__(self):
45+
pass
46+
47+
def __exit__(self, *args):
48+
self.rollback()
49+
50+
def commit(self):
51+
self.committed = True
52+
53+
def rollback(self):
54+
pass
55+
56+
def add_event(self, event: Event):
57+
self.events.append(event)
58+
59+
def pop_event(self) -> Event | None:
60+
if len(self.events) > 0:
61+
return self.events.pop(0)
62+
return None
63+
64+
65+
DEFAULT_SESSION_FACTORY = sessionmaker(bind=create_engine(
66+
config.get_database_engine_url(), json_serializer=_custom_json_serializer
67+
))
68+
69+
class SqlalchemyUnitOfWork(AbstractUnitOfWork):
70+
71+
def __init__(self, gateway: RepositoryGateway, session_factory=DEFAULT_SESSION_FACTORY) -> None:
72+
self.session_factory = session_factory
73+
self.gateway: RepositoryGateway = gateway
74+
self.events: list[Event] = []
75+
76+
def __enter__(self):
77+
self.session = self.session_factory()
78+
self.catalog = SqlalchemyCatalog(self.session)
79+
80+
def __exit__(self, *args):
81+
self.rollback()
82+
self.session.close()
83+
84+
def commit(self):
85+
self.session.commit()
86+
87+
def rollback(self):
88+
self.session.rollback()
1089

1190
def add_event(self, event: Event):
1291
self.events.append(event)

features/steps/inspect_bin.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from behave import given, then, when
22
import uuid
33
from datetime import datetime, UTC
4+
45
from pydantic_core import to_jsonable_python
6+
from sqlalchemy import create_engine
7+
from sqlalchemy.orm import sessionmaker
58

6-
from dor.adapters.catalog import MemoryCatalog
9+
from dor.adapters.catalog import Base, _custom_json_serializer
10+
from dor.config import config
711
from dor.domain.models import Bin
812
from dor.service_layer import catalog_service
13+
from dor.service_layer.unit_of_work import SqlalchemyUnitOfWork, UnitOfWork
914
from dor.providers.models import (
1015
Agent, AlternateIdentifier, FileMetadata, FileReference, PackageResource,
1116
PreservationEvent, StructMap, StructMapItem, StructMapType
1217
)
18+
from gateway.fake_repository_gateway import FakeRepositoryGateway
1319

1420

1521
@given(u'a preserved monograph with an alternate identifier of "{alt_id}"')
@@ -174,13 +180,24 @@ def step_impl(context, alt_id):
174180
)
175181
]
176182
)
177-
context.catalog = MemoryCatalog()
178-
context.catalog.add(bin)
183+
184+
engine = create_engine(
185+
config.get_test_database_engine_url(), json_serializer=_custom_json_serializer
186+
)
187+
session_factory = sessionmaker(bind=engine)
188+
Base.metadata.drop_all(engine)
189+
Base.metadata.create_all(engine)
190+
191+
context.uow = SqlalchemyUnitOfWork(gateway=FakeRepositoryGateway(), session_factory=session_factory)
192+
with context.uow:
193+
context.uow.catalog.add(bin)
194+
context.uow.commit()
179195

180196
@when(u'the Collection Manager looks up the bin by "{alt_id}"')
181197
def step_impl(context, alt_id):
182198
context.alt_id = alt_id
183-
context.bin = context.catalog.get_by_alternate_identifier(alt_id)
199+
with context.uow:
200+
context.bin = context.uow.catalog.get_by_alternate_identifier(alt_id)
184201
context.summary = catalog_service.summarize(context.bin)
185202

186203
@then(u'the Collection Manager sees the summary of the bin')
@@ -203,7 +220,8 @@ def step_impl(context):
203220

204221
@when(u'the Collection Manager lists the contents of the bin for "{alt_id}"')
205222
def step_impl(context, alt_id):
206-
context.bin = context.catalog.get_by_alternate_identifier(alt_id)
223+
with context.uow:
224+
context.bin = context.uow.catalog.get_by_alternate_identifier(alt_id)
207225
context.file_sets = catalog_service.get_file_sets(context.bin)
208226

209227
@then(u'the Collection Manager sees the file sets.')

tests/test_unit_of_work.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from gateway.fake_repository_gateway import FakeRepositoryGateway
2+
import pytest
3+
from sqlalchemy import create_engine
4+
from sqlalchemy.orm import sessionmaker
5+
6+
from dor.adapters.catalog import SqlalchemyCatalog, Base, _custom_json_serializer
7+
from dor.config import config
8+
from dor.service_layer.unit_of_work import SqlalchemyUnitOfWork
9+
10+
11+
def setup_function() -> None:
12+
engine = create_engine(config.get_test_database_engine_url())
13+
Base.metadata.drop_all(engine)
14+
Base.metadata.create_all(engine)
15+
16+
17+
@pytest.fixture
18+
def session_factory():
19+
return sessionmaker(bind=create_engine(
20+
config.get_test_database_engine_url(), json_serializer=_custom_json_serializer
21+
))
22+
23+
24+
@pytest.mark.usefixtures("sample_bin")
25+
def test_uow_can_add_bin(session_factory, sample_bin):
26+
gateway = FakeRepositoryGateway()
27+
uow = SqlalchemyUnitOfWork(gateway=gateway, session_factory=session_factory)
28+
with uow:
29+
uow.catalog.add(sample_bin)
30+
uow.commit()
31+
32+
session = session_factory()
33+
with session:
34+
catalog = SqlalchemyCatalog(session)
35+
bin = catalog.get(sample_bin.identifier)
36+
assert bin == sample_bin
37+
38+
39+
@pytest.mark.usefixtures("sample_bin")
40+
def test_uow_does_not_add_bin_without_commit(session_factory, sample_bin):
41+
gateway = FakeRepositoryGateway()
42+
uow = SqlalchemyUnitOfWork(gateway=gateway, session_factory=session_factory)
43+
with uow:
44+
uow.catalog.add(sample_bin)
45+
46+
session = session_factory()
47+
with session:
48+
catalog = SqlalchemyCatalog(session)
49+
bin = catalog.get(sample_bin.identifier)
50+
assert bin is None

0 commit comments

Comments
 (0)