Skip to content

Commit 833d48b

Browse files
committed
basic uow test, uow and conftest.py changes
1 parent 53ad798 commit 833d48b

File tree

4 files changed

+101
-9
lines changed

4 files changed

+101
-9
lines changed

src/allocation/service_layer/services.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,16 @@ def is_valid_sku(sku, batches):
1616

1717

1818
def add_batch(
19-
ref: str,
20-
sku: str,
21-
qty: int,
22-
eta: Optional[date],
23-
repo: AbstractRepository,
24-
session,
19+
ref: str, sku: str, qty: int, eta: Optional[date],
20+
repo: AbstractRepository, session,
2521
) -> None:
2622
repo.add(model.Batch(ref, sku, qty, eta))
2723
session.commit()
2824

2925

3026
def allocate(
31-
orderid: str, sku: str, qty: int, repo: AbstractRepository, session
27+
orderid: str, sku: str, qty: int,
28+
repo: AbstractRepository, session,
3229
) -> str:
3330
line = OrderLine(orderid, sku, qty)
3431
batches = repo.list()

src/allocation/unit_of_work.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# pylint: disable=attribute-defined-outside-init
2+
from __future__ import annotations
3+
import abc
4+
from sqlalchemy import create_engine
5+
from sqlalchemy.orm import sessionmaker
6+
from sqlalchemy.orm.session import Session
7+
8+
from allocation import config
9+
from allocation import repository
10+
11+
12+
class AbstractUnitOfWork(abc.ABC):
13+
batches: repository.AbstractRepository
14+
15+
def __enter__(self) -> AbstractUnitOfWork:
16+
return self
17+
18+
def __exit__(self, *args):
19+
self.rollback()
20+
21+
@abc.abstractmethod
22+
def commit(self):
23+
raise NotImplementedError
24+
25+
@abc.abstractmethod
26+
def rollback(self):
27+
raise NotImplementedError
28+
29+
30+
DEFAULT_SESSION_FACTORY = sessionmaker(bind=create_engine(config.get_postgres_uri(),))
31+
32+
33+
class SqlAlchemyUnitOfWork(AbstractUnitOfWork):
34+
def __init__(self, session_factory=DEFAULT_SESSION_FACTORY):
35+
self.session_factory = session_factory
36+
37+
def __enter__(self):
38+
self.session = self.session_factory() # type: Session
39+
self.batches = repository.SqlAlchemyRepository(self.session)
40+
return super().__enter__()
41+
42+
def __exit__(self, *args):
43+
super().__exit__(*args)
44+
self.session.close()
45+
46+
def commit(self):
47+
self.session.commit()
48+
49+
def rollback(self):
50+
self.session.rollback()

tests/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@ def in_memory_db():
2121

2222

2323
@pytest.fixture
24-
def session(in_memory_db):
24+
def session_factory(in_memory_db):
2525
start_mappers()
26-
yield sessionmaker(bind=in_memory_db)()
26+
yield sessionmaker(bind=in_memory_db)
2727
clear_mappers()
2828

2929

30+
@pytest.fixture
31+
def session(session_factory):
32+
return session_factory()
33+
34+
3035
def wait_for_postgres_to_come_up(engine):
3136
deadline = time.time() + 10
3237
while time.time() < deadline:

tests/integration/test_uow.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
from allocation import model
3+
from allocation import unit_of_work
4+
5+
6+
def insert_batch(session, ref, sku, qty, eta):
7+
session.execute(
8+
"INSERT INTO batches (reference, sku, _purchased_quantity, eta)"
9+
" VALUES (:ref, :sku, :qty, :eta)",
10+
dict(ref=ref, sku=sku, qty=qty, eta=eta),
11+
)
12+
13+
14+
def get_allocated_batch_ref(session, orderid, sku):
15+
[[orderlineid]] = session.execute(
16+
"SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku",
17+
dict(orderid=orderid, sku=sku),
18+
)
19+
[[batchref]] = session.execute(
20+
"SELECT b.reference FROM allocations JOIN batches AS b ON batch_id = b.id"
21+
" WHERE orderline_id=:orderlineid",
22+
dict(orderlineid=orderlineid),
23+
)
24+
return batchref
25+
26+
27+
def test_uow_can_retrieve_a_batch_and_allocate_to_it(session_factory):
28+
session = session_factory()
29+
insert_batch(session, "batch1", "HIPSTER-WORKBENCH", 100, None)
30+
session.commit()
31+
32+
uow = unit_of_work.SqlAlchemyUnitOfWork(session_factory)
33+
with uow:
34+
batch = uow.batches.get(reference="batch1")
35+
line = model.OrderLine("o1", "HIPSTER-WORKBENCH", 10)
36+
batch.allocate(line)
37+
uow.commit()
38+
39+
batchref = get_allocated_batch_ref(session, "o1", "HIPSTER-WORKBENCH")
40+
assert batchref == "batch1"

0 commit comments

Comments
 (0)