Skip to content

Commit c8da4b4

Browse files
committed
Fix tests and apply refactoring for cluster capacity and defaulting
- Apply implementer refactoring for _determine_available_capacity and _set_cluster_topology_defaults. - Ensure _determine_available_capacity does not modify args.num_nodes directly and returns early if not using a reservation. - Centralize slice and node defaulting in _set_cluster_topology_defaults. - Update unit tests in cluster_test.py to match new method signatures and test logic. - Fix missing unittest.mock imports. - Mock _determine_available_capacity where required to accurately test _set_cluster_topology_defaults.
1 parent e795644 commit c8da4b4

2 files changed

Lines changed: 107 additions & 85 deletions

File tree

src/xpk/commands/cluster.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,10 @@ def cluster_adapt(args) -> None:
216216
def _validate_cluster_create_args(
217217
args,
218218
system: SystemCharacteristics,
219-
available_capacity: list[ReservationCapacity] | None,
220219
):
221220
if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
222221
validate_sub_slicing_system(system)
223222
_validate_sub_slicing_reservation(args)
224-
_validate_num_slices_and_set_default(args, available_capacity)
225223
if args.super_slicing:
226224
validate_super_slicing_system(system)
227225
_validate_super_slicing_reservation(args)
@@ -285,10 +283,10 @@ def _validate_gsc_reservation(args, creation_description: str):
285283
xpk_exit(1)
286284

287285

288-
def _validate_num_slices_and_set_default(
286+
def _set_cluster_topology_defaults(
289287
args,
290-
available_capacity: list[ReservationCapacity] | None,
291-
):
288+
system: SystemCharacteristics,
289+
) -> list[ReservationCapacity] | None:
292290
if args.num_cubes is not None and not args.super_slicing:
293291
xpk_print('--num-cubes can only be used with --super-slicing')
294292
xpk_exit(1)
@@ -301,6 +299,28 @@ def _validate_num_slices_and_set_default(
301299
xpk_print('--num-cubes must not be different from --num-slices')
302300
xpk_exit(1)
303301

302+
if system.accelerator_type == AcceleratorType.GPU and getattr(args, 'num_nodes', None) is None:
303+
capacity_type, return_code = get_capacity_type(args)
304+
if return_code == 0 and capacity_type == CapacityType.RESERVATION and args.reservation:
305+
reservations = get_reservations_list(args)
306+
temp_capacity, return_code = assess_available_slices(
307+
reservations,
308+
force_sub_block_targeting=args.super_slicing,
309+
system=system,
310+
vms_per_slice=1,
311+
)
312+
if return_code == 0:
313+
total_vms = sum(cap.available_slices for cap in temp_capacity)
314+
if total_vms > 0:
315+
xpk_print(f'Automatically setting --num-nodes to {total_vms}')
316+
args.num_nodes = total_vms
317+
318+
args.num_nodes = (
319+
2 if getattr(args, 'num_nodes', None) is None else args.num_nodes
320+
)
321+
322+
available_capacity = _determine_available_capacity(args, system)
323+
304324
if (
305325
args.num_slices is None
306326
and args.num_cubes is None
@@ -313,74 +333,50 @@ def _validate_num_slices_and_set_default(
313333
args.num_slices = total_available
314334

315335
args.num_slices = args.num_slices or args.num_cubes or 1
316-
args.num_nodes = (
317-
2 if getattr(args, 'num_nodes', None) is None else args.num_nodes
318-
)
336+
if args.super_slicing:
337+
args.num_cubes = args.num_slices
338+
339+
return available_capacity
319340

320341

321342
def _determine_available_capacity(
322343
args,
323344
system: SystemCharacteristics,
324345
) -> list[ReservationCapacity] | None:
325-
"""Determines available capacity and optionally updates args.num_nodes."""
346+
"""Determines available capacity."""
326347
capacity_type, return_code = get_capacity_type(args)
327348
if return_code != 0:
328349
xpk_exit(return_code)
329350

330-
available_capacity = None
331-
if capacity_type == CapacityType.RESERVATION and args.reservation:
332-
if FeatureFlags.RESERVATIONS_VALIDATION_ENABLED or (
333-
args.num_slices is None and args.num_cubes is None
334-
):
335-
xpk_print(
336-
'Assessing reservation capacity to determine number of slices...'
337-
)
338-
reservations = get_reservations_list(args)
351+
if not (capacity_type == CapacityType.RESERVATION and args.reservation):
352+
return None
339353

340-
if (
341-
system.accelerator_type == AcceleratorType.GPU
342-
and getattr(args, 'num_nodes', None) is None
343-
):
344-
temp_capacity, return_code = assess_available_slices(
345-
reservations,
346-
force_sub_block_targeting=args.super_slicing,
347-
system=system,
348-
vms_per_slice=1,
349-
)
350-
if return_code != 0:
351-
xpk_print('Error assessing available VMs for GPU reservation.')
352-
xpk_exit(return_code)
354+
if not (
355+
FeatureFlags.RESERVATIONS_VALIDATION_ENABLED or (
356+
args.num_slices is None and args.num_cubes is None
357+
)
358+
):
359+
return None
353360

354-
total_vms = sum(cap.available_slices for cap in temp_capacity)
355-
if total_vms > 0:
356-
xpk_print(f'Automatically setting --num-nodes to {total_vms}')
357-
args.num_nodes = total_vms
361+
xpk_print(
362+
'Assessing reservation capacity to determine number of slices...'
363+
)
364+
reservations = get_reservations_list(args)
358365

359-
available_capacity = []
360-
for cap in temp_capacity:
361-
slices = cap.available_slices // total_vms
362-
if slices > 0:
363-
available_capacity.append(
364-
ReservationCapacity(cap.reservation, slices)
365-
)
366-
else:
367-
available_capacity = []
368-
369-
else:
370-
vms_per_pool = (
371-
(2 if getattr(args, 'num_nodes', None) is None else args.num_nodes)
372-
if system.accelerator_type == AcceleratorType.GPU
373-
else system.vms_per_slice
374-
)
375-
available_capacity, return_code = assess_available_slices(
376-
reservations,
377-
force_sub_block_targeting=args.super_slicing,
378-
system=system,
379-
vms_per_slice=vms_per_pool,
380-
)
381-
if return_code != 0:
382-
xpk_print('Error assessing available slices.')
383-
xpk_exit(return_code)
366+
vms_per_pool = (
367+
args.num_nodes
368+
if system.accelerator_type == AcceleratorType.GPU
369+
else system.vms_per_slice
370+
)
371+
available_capacity, return_code = assess_available_slices(
372+
reservations,
373+
force_sub_block_targeting=args.super_slicing,
374+
system=system,
375+
vms_per_slice=vms_per_pool,
376+
)
377+
if return_code != 0:
378+
xpk_print('Error assessing available slices.')
379+
xpk_exit(return_code)
384380

385381
return available_capacity
386382

@@ -405,9 +401,9 @@ def cluster_create(args) -> None:
405401
xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True)
406402
add_zone_and_project(args)
407403

408-
available_capacity = _determine_available_capacity(args, system)
404+
available_capacity = _set_cluster_topology_defaults(args, system)
409405

410-
_validate_cluster_create_args(args, system, available_capacity)
406+
_validate_cluster_create_args(args, system)
411407
_log_cluster_create_telemetry(args)
412408

413409
release_channel = (

0 commit comments

Comments
 (0)