Skip to content

Commit 4af9799

Browse files
committed
refactor: flatten test classes to top-level functions per CLAUDE.md
Github-Issue: #1691
1 parent 6870770 commit 4af9799

File tree

1 file changed

+139
-149
lines changed

1 file changed

+139
-149
lines changed

tests/server/test_session_lifecycle.py

Lines changed: 139 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -67,217 +67,207 @@ async def _session_context(
6767
# ---------------------------------------------------------------------------
6868

6969

70-
class TestInitializationStateEnum:
71-
"""Verify the expanded InitializationState enum values."""
70+
def test_enum_all_states_present() -> None:
71+
"""Verify the expanded InitializationState enum has all expected members."""
72+
expected = {"NotInitialized", "Initializing", "Initialized", "Stateless", "Closing", "Closed"}
73+
actual = {s.name for s in InitializationState}
74+
assert actual == expected
7275

73-
def test_all_states_present(self) -> None:
74-
expected = {"NotInitialized", "Initializing", "Initialized", "Stateless", "Closing", "Closed"}
75-
actual = {s.name for s in InitializationState}
76-
assert actual == expected
7776

78-
def test_values_are_distinct(self) -> None:
79-
values = [s.value for s in InitializationState]
80-
assert len(values) == len(set(values))
77+
def test_enum_values_are_distinct() -> None:
78+
values = [s.value for s in InitializationState]
79+
assert len(values) == len(set(values))
8180

8281

8382
# ---------------------------------------------------------------------------
8483
# Transition table tests
8584
# ---------------------------------------------------------------------------
8685

8786

88-
class TestValidTransitions:
89-
"""Verify the _VALID_TRANSITIONS table is complete and correct."""
87+
def test_transitions_all_states_have_entry() -> None:
88+
for state in InitializationState:
89+
assert state in _VALID_TRANSITIONS, f"Missing entry for {state.name}"
9090

91-
def test_all_states_have_entry(self) -> None:
92-
for state in InitializationState:
93-
assert state in _VALID_TRANSITIONS, f"Missing entry for {state.name}"
9491

95-
def test_closed_is_terminal(self) -> None:
96-
assert _VALID_TRANSITIONS[InitializationState.Closed] == set()
92+
def test_transitions_closed_is_terminal() -> None:
93+
assert _VALID_TRANSITIONS[InitializationState.Closed] == set()
9794

9895

9996
# ---------------------------------------------------------------------------
10097
# _transition_state tests
10198
# ---------------------------------------------------------------------------
10299

103100

104-
class TestTransitionState:
105-
"""Unit tests for ServerSession._transition_state."""
101+
async def test_transition_valid_stateful_lifecycle() -> None:
102+
"""NotInitialized -> Initializing -> Initialized -> Closing -> Closed."""
103+
async with _session_context() as session:
104+
assert session.initialization_state == InitializationState.NotInitialized
106105

107-
async def test_valid_stateful_lifecycle(self) -> None:
108-
"""NotInitialized -> Initializing -> Initialized -> Closing -> Closed."""
109-
async with _session_context() as session:
110-
assert session.initialization_state == InitializationState.NotInitialized
106+
session._transition_state(InitializationState.Initializing)
107+
assert session.initialization_state == InitializationState.Initializing
111108

112-
session._transition_state(InitializationState.Initializing)
113-
assert session.initialization_state == InitializationState.Initializing
109+
session._transition_state(InitializationState.Initialized)
110+
assert session.initialization_state == InitializationState.Initialized
114111

115-
session._transition_state(InitializationState.Initialized)
116-
assert session.initialization_state == InitializationState.Initialized
112+
session._transition_state(InitializationState.Closing)
113+
assert session.initialization_state == InitializationState.Closing
117114

118-
session._transition_state(InitializationState.Closing)
119-
assert session.initialization_state == InitializationState.Closing
115+
session._transition_state(InitializationState.Closed)
116+
assert session.initialization_state == InitializationState.Closed
120117

121-
session._transition_state(InitializationState.Closed)
122-
assert session.initialization_state == InitializationState.Closed
123118

124-
async def test_valid_stateless_lifecycle(self) -> None:
125-
"""Stateless -> Closing -> Closed."""
126-
async with _session_context(stateless=True) as session:
127-
assert session.initialization_state == InitializationState.Stateless
119+
async def test_transition_valid_stateless_lifecycle() -> None:
120+
"""Stateless -> Closing -> Closed."""
121+
async with _session_context(stateless=True) as session:
122+
assert session.initialization_state == InitializationState.Stateless
128123

129-
session._transition_state(InitializationState.Closing)
130-
assert session.initialization_state == InitializationState.Closing
124+
session._transition_state(InitializationState.Closing)
125+
assert session.initialization_state == InitializationState.Closing
131126

132-
session._transition_state(InitializationState.Closed)
133-
assert session.initialization_state == InitializationState.Closed
127+
session._transition_state(InitializationState.Closed)
128+
assert session.initialization_state == InitializationState.Closed
134129

135-
async def test_invalid_transition_raises(self) -> None:
136-
"""Attempting an invalid transition raises RuntimeError."""
137-
async with _session_context() as session:
138-
with pytest.raises(RuntimeError, match="Invalid session state transition"):
139-
session._transition_state(InitializationState.Closed)
140130

141-
async def test_closed_to_anything_raises(self) -> None:
142-
"""Closed is terminal — no transitions allowed."""
143-
async with _session_context() as session:
144-
session._transition_state(InitializationState.Closing)
131+
async def test_transition_invalid_raises() -> None:
132+
"""Attempting an invalid transition raises RuntimeError."""
133+
async with _session_context() as session:
134+
with pytest.raises(RuntimeError, match="Invalid session state transition"):
145135
session._transition_state(InitializationState.Closed)
146136

147-
for state in InitializationState:
148-
with pytest.raises(RuntimeError, match="Invalid session state transition"):
149-
session._transition_state(state)
137+
138+
async def test_transition_closed_to_anything_raises() -> None:
139+
"""Closed is terminal — no transitions allowed."""
140+
async with _session_context() as session:
141+
session._transition_state(InitializationState.Closing)
142+
session._transition_state(InitializationState.Closed)
143+
144+
for state in InitializationState:
145+
with pytest.raises(RuntimeError, match="Invalid session state transition"):
146+
session._transition_state(state)
150147

151148

152149
# ---------------------------------------------------------------------------
153150
# is_initialized property tests
154151
# ---------------------------------------------------------------------------
155152

156153

157-
class TestIsInitialized:
158-
"""Tests for the is_initialized property."""
154+
@pytest.mark.parametrize(
155+
("stateless", "expected_state"),
156+
[
157+
(False, InitializationState.NotInitialized),
158+
(True, InitializationState.Stateless),
159+
],
160+
)
161+
async def test_is_initialized_initial_state(stateless: bool, expected_state: InitializationState) -> None:
162+
async with _session_context(stateless=stateless) as session:
163+
assert session.initialization_state == expected_state
159164

160-
@pytest.mark.parametrize(
161-
("stateless", "expected_state"),
162-
[
163-
(False, InitializationState.NotInitialized),
164-
(True, InitializationState.Stateless),
165-
],
166-
)
167-
async def test_initial_state(self, stateless: bool, expected_state: InitializationState) -> None:
168-
async with _session_context(stateless=stateless) as session:
169-
assert session.initialization_state == expected_state
170165

171-
async def test_not_initialized_returns_false(self) -> None:
172-
async with _session_context() as session:
173-
assert not session.is_initialized
166+
async def test_is_initialized_not_initialized_returns_false() -> None:
167+
async with _session_context() as session:
168+
assert not session.is_initialized
174169

175-
async def test_initializing_returns_false(self) -> None:
176-
async with _session_context() as session:
177-
session._transition_state(InitializationState.Initializing)
178-
assert not session.is_initialized
179170

180-
async def test_initialized_returns_true(self) -> None:
181-
async with _session_context() as session:
182-
session._transition_state(InitializationState.Initializing)
183-
session._transition_state(InitializationState.Initialized)
184-
assert session.is_initialized
171+
async def test_is_initialized_initializing_returns_false() -> None:
172+
async with _session_context() as session:
173+
session._transition_state(InitializationState.Initializing)
174+
assert not session.is_initialized
185175

186-
async def test_stateless_returns_true(self) -> None:
187-
async with _session_context(stateless=True) as session:
188-
assert session.is_initialized
176+
177+
async def test_is_initialized_initialized_returns_true() -> None:
178+
async with _session_context() as session:
179+
session._transition_state(InitializationState.Initializing)
180+
session._transition_state(InitializationState.Initialized)
181+
assert session.is_initialized
182+
183+
184+
async def test_is_initialized_stateless_returns_true() -> None:
185+
async with _session_context(stateless=True) as session:
186+
assert session.is_initialized
189187

190188

191189
# ---------------------------------------------------------------------------
192190
# __aexit__ lifecycle tests
193191
# ---------------------------------------------------------------------------
194192

195193

196-
class TestSessionExit:
197-
"""Test that __aexit__ transitions to Closing -> Closed."""
194+
async def test_aexit_transitions_to_closed() -> None:
195+
"""Normal exit transitions through Closing -> Closed."""
196+
async with _session_context() as session:
197+
async with session:
198+
assert session.initialization_state == InitializationState.NotInitialized
198199

199-
async def test_aexit_transitions_to_closed(self) -> None:
200-
"""Normal exit transitions through Closing -> Closed."""
201-
async with _session_context() as session:
202-
async with session:
203-
assert session.initialization_state == InitializationState.NotInitialized
200+
assert session.initialization_state == InitializationState.Closed
204201

205-
assert session.initialization_state == InitializationState.Closed
206202

207-
async def test_aexit_from_initialized(self) -> None:
208-
"""Session transitions to Closed even when initialized."""
209-
async with _session_context() as session:
210-
async with session:
211-
session._transition_state(InitializationState.Initializing)
212-
session._transition_state(InitializationState.Initialized)
213-
assert session.is_initialized
203+
async def test_aexit_from_initialized() -> None:
204+
"""Session transitions to Closed even when initialized."""
205+
async with _session_context() as session:
206+
async with session:
207+
session._transition_state(InitializationState.Initializing)
208+
session._transition_state(InitializationState.Initialized)
209+
assert session.is_initialized
214210

215-
assert session.initialization_state == InitializationState.Closed
211+
assert session.initialization_state == InitializationState.Closed
216212

217-
async def test_aexit_stateless_transitions_to_closed(self) -> None:
218-
"""Stateless sessions also transition to Closed on exit."""
219-
async with _session_context(stateless=True) as session:
220-
async with session:
221-
assert session.initialization_state == InitializationState.Stateless
222213

223-
assert session.initialization_state == InitializationState.Closed
214+
async def test_aexit_stateless_transitions_to_closed() -> None:
215+
"""Stateless sessions also transition to Closed on exit."""
216+
async with _session_context(stateless=True) as session:
217+
async with session:
218+
assert session.initialization_state == InitializationState.Stateless
219+
220+
assert session.initialization_state == InitializationState.Closed
224221

225222

226223
# ---------------------------------------------------------------------------
227224
# Integration: full handshake lifecycle
228225
# ---------------------------------------------------------------------------
229226

230227

231-
class TestFullHandshakeLifecycle:
232-
"""Integration test: client/server handshake uses state transitions correctly."""
233-
234-
async def test_stateful_handshake(self) -> None:
235-
"""Stateful handshake transitions NotInitialized -> Initializing -> Initialized."""
236-
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](
237-
1
238-
)
239-
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](
240-
1
241-
)
242-
243-
received_initialized = False
244-
245-
async def run_server() -> None:
246-
nonlocal received_initialized
247-
async with ServerSession(
248-
client_to_server_receive,
249-
server_to_client_send,
250-
_DEFAULT_INIT_OPTIONS,
251-
) as server_session:
252-
async for message in server_session.incoming_messages: # pragma: no branch
253-
if isinstance(message, Exception): # pragma: no cover
254-
raise message
255-
if isinstance(message, InitializedNotification): # pragma: no branch
256-
assert server_session.is_initialized
257-
assert server_session.initialization_state == InitializationState.Initialized
258-
received_initialized = True
259-
return
260-
261-
async def message_handler( # pragma: no cover
262-
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
263-
) -> None:
264-
if isinstance(message, Exception):
265-
raise message
228+
async def test_stateful_handshake() -> None:
229+
"""Stateful handshake transitions NotInitialized -> Initializing -> Initialized."""
230+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
231+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
266232

267-
try:
268-
async with (
233+
received_initialized = False
234+
235+
async def run_server() -> None:
236+
nonlocal received_initialized
237+
async with ServerSession(
238+
client_to_server_receive,
239+
server_to_client_send,
240+
_DEFAULT_INIT_OPTIONS,
241+
) as server_session:
242+
async for message in server_session.incoming_messages: # pragma: no branch
243+
if isinstance(message, Exception): # pragma: no cover
244+
raise message
245+
if isinstance(message, InitializedNotification): # pragma: no branch
246+
assert server_session.is_initialized
247+
assert server_session.initialization_state == InitializationState.Initialized
248+
received_initialized = True
249+
return
250+
251+
async def message_handler( # pragma: no cover
252+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
253+
) -> None:
254+
if isinstance(message, Exception):
255+
raise message
256+
257+
try:
258+
async with (
259+
server_to_client_receive,
260+
client_to_server_send,
261+
ClientSession(
269262
server_to_client_receive,
270263
client_to_server_send,
271-
ClientSession(
272-
server_to_client_receive,
273-
client_to_server_send,
274-
message_handler=message_handler,
275-
) as client_session,
276-
anyio.create_task_group() as tg,
277-
):
278-
tg.start_soon(run_server)
279-
await client_session.initialize()
280-
except anyio.ClosedResourceError: # pragma: no cover
281-
pass
282-
283-
assert received_initialized
264+
message_handler=message_handler,
265+
) as client_session,
266+
anyio.create_task_group() as tg,
267+
):
268+
tg.start_soon(run_server)
269+
await client_session.initialize()
270+
except anyio.ClosedResourceError: # pragma: no cover
271+
pass
272+
273+
assert received_initialized

0 commit comments

Comments
 (0)