Skip to content

Commit d6fda15

Browse files
fix: resolve CI failures in multi-provider
- Add _is_internal_hook_provider class marker to avoid Mock false positives with runtime_checkable Protocol isinstance checks - Fix mypy no-redef errors by hoisting evaluations declaration before branch - Fix mypy no-any-return by assigning to typed local before returning - Fix mypy attr-defined by using _as_internal_hook_provider narrowing helper - Apply ruff formatting fixes Signed-off-by: Jonathan Norris <jonathan.norris@dynatrace.com>
1 parent 521d9bb commit d6fda15

File tree

6 files changed

+103
-45
lines changed

6 files changed

+103
-45
lines changed

openfeature/client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,35 +470,44 @@ def _establish_hooks_and_provider(
470470
merged_eval_context,
471471
)
472472

473+
def _as_internal_hook_provider(
474+
self, provider: FeatureProvider
475+
) -> InternalHookProvider | None:
476+
"""Return the provider as InternalHookProvider if it opts in, else None."""
477+
if getattr(provider, "_is_internal_hook_provider", False) and isinstance(
478+
provider, InternalHookProvider
479+
):
480+
return provider
481+
return None
482+
473483
def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool:
474-
return (
475-
isinstance(provider, InternalHookProvider)
476-
and provider.uses_internal_provider_hooks()
477-
)
484+
ihp = self._as_internal_hook_provider(provider)
485+
return ihp is not None and ihp.uses_internal_provider_hooks()
478486

479487
def _set_internal_provider_hook_runtime(
480488
self,
481489
provider: FeatureProvider,
482490
flag_type: FlagType,
483491
hook_hints: HookHints,
484492
) -> object | None:
485-
if not isinstance(provider, InternalHookProvider):
486-
return None
487-
if not provider.uses_internal_provider_hooks():
493+
ihp = self._as_internal_hook_provider(provider)
494+
if ihp is None or not ihp.uses_internal_provider_hooks():
488495
return None
489-
return provider.set_internal_provider_hook_runtime(
496+
result: object | None = ihp.set_internal_provider_hook_runtime(
490497
flag_type=flag_type,
491498
client_metadata=self.get_metadata(),
492499
hook_hints=hook_hints,
493500
)
501+
return result
494502

495503
def _reset_internal_provider_hook_runtime(
496504
self, provider: FeatureProvider, runtime_token: object | None
497505
) -> None:
498506
if runtime_token is None:
499507
return
500-
if isinstance(provider, InternalHookProvider):
501-
provider.reset_internal_provider_hook_runtime(runtime_token)
508+
ihp = self._as_internal_hook_provider(provider)
509+
if ihp is not None:
510+
ihp.reset_internal_provider_hook_runtime(runtime_token)
502511

503512
def _assert_provider_status(
504513
self,

openfeature/provider/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,14 @@ class InternalHookProvider(typing.Protocol):
139139
140140
The registry will also use get_status() from this protocol instead of its
141141
own internal status tracking for providers that implement it.
142+
143+
Implementations must set ``_is_internal_hook_provider = True`` as a class
144+
attribute. This marker is checked alongside ``isinstance`` to avoid false
145+
positives from duck-typed objects (e.g. ``Mock``).
142146
"""
143147

148+
_is_internal_hook_provider: typing.ClassVar[bool]
149+
144150
def uses_internal_provider_hooks(self) -> bool: ...
145151

146152
def set_internal_provider_hook_runtime(

openfeature/provider/_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,12 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None:
124124
def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus:
125125
# Only InternalHookProvider implementations (e.g. MultiProvider) manage
126126
# their own status. For all other providers, use the registry's tracking.
127-
if isinstance(provider, InternalHookProvider):
127+
# We check _is_internal_hook_provider (a concrete class attribute) in
128+
# addition to isinstance, because runtime_checkable Protocols match any
129+
# object that has the right method names — including Mock objects.
130+
if getattr(provider, "_is_internal_hook_provider", False) and isinstance(
131+
provider, InternalHookProvider
132+
):
128133
return provider.get_status()
129134
return self._provider_status.get(provider, ProviderStatus.NOT_READY)
130135

openfeature/provider/multi_provider.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def _validate_run_mode(run_mode: RunMode) -> RunMode:
109109
def _format_result_error(
110110
provider_name: str, result: FlagResolutionDetails[FlagValueType]
111111
) -> str:
112-
error_code = result.error_code.value if result.error_code else ErrorCode.GENERAL.value
112+
error_code = (
113+
result.error_code.value if result.error_code else ErrorCode.GENERAL.value
114+
)
113115
error_message = result.error_message or "Unknown error"
114116
return f"{provider_name}: {error_code} ({error_message})"
115117

@@ -280,7 +282,9 @@ def determine_final_result(
280282
evaluations: list[_ProviderEvaluation[FlagValueType]],
281283
) -> FlagResolutionDetails[FlagValueType]:
282284
failed_evaluations = [
283-
evaluation for evaluation in evaluations if not _is_success(evaluation.result)
285+
evaluation
286+
for evaluation in evaluations
287+
if not _is_success(evaluation.result)
284288
]
285289
if failed_evaluations:
286290
return _build_aggregated_error(
@@ -338,6 +342,8 @@ class MultiProvider(AbstractProvider):
338342
ProviderStatus.READY,
339343
)
340344

345+
_is_internal_hook_provider: typing.ClassVar[bool] = True
346+
341347
def __init__(
342348
self,
343349
providers: list[ProviderEntry],
@@ -365,7 +371,9 @@ def __init__(
365371
provider_name: ProviderStatus.NOT_READY
366372
for provider_name, _ in self._registered_providers
367373
}
368-
validate_provider_names = getattr(self.strategy, "validate_provider_names", None)
374+
validate_provider_names = getattr(
375+
self.strategy, "validate_provider_names", None
376+
)
369377
if callable(validate_provider_names):
370378
validate_provider_names(
371379
[provider_name for provider_name, _ in self._registered_providers]
@@ -455,8 +463,12 @@ def initialize_provider(
455463
except Exception as err:
456464
return provider_name, err
457465

458-
with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor:
459-
init_results = list(executor.map(initialize_provider, self._registered_providers))
466+
with ThreadPoolExecutor(
467+
max_workers=len(self._registered_providers)
468+
) as executor:
469+
init_results = list(
470+
executor.map(initialize_provider, self._registered_providers)
471+
)
460472

461473
error_messages: list[str] = []
462474
event_details = ProviderEventDetails()
@@ -475,7 +487,9 @@ def initialize_provider(
475487
self._refresh_aggregate_status(event_details, force=True)
476488

477489
if error_messages:
478-
raise GeneralError(f"Multi-provider initialization failed: {'; '.join(error_messages)}")
490+
raise GeneralError(
491+
f"Multi-provider initialization failed: {'; '.join(error_messages)}"
492+
)
479493

480494
def shutdown(self) -> None:
481495
for _, provider in self._registered_providers:
@@ -488,7 +502,9 @@ def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None:
488502
except Exception:
489503
logger.exception("Provider '%s' shutdown failed", provider_name)
490504

491-
with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor:
505+
with ThreadPoolExecutor(
506+
max_workers=len(self._registered_providers)
507+
) as executor:
492508
list(executor.map(shutdown_provider, self._registered_providers))
493509

494510
with self._status_lock:
@@ -522,7 +538,9 @@ def _handle_provider_event(
522538
provider_name,
523539
self._status_from_event_details(details),
524540
)
525-
self._refresh_aggregate_status(self._with_provider_metadata(details, provider_name))
541+
self._refresh_aggregate_status(
542+
self._with_provider_metadata(details, provider_name)
543+
)
526544

527545
def _set_provider_status(
528546
self, provider_name: str, provider_status: ProviderStatus
@@ -547,7 +565,9 @@ def _should_evaluate_provider(self, provider_name: str) -> bool:
547565
if not self._initialized:
548566
return True
549567
with self._status_lock:
550-
status = self._provider_statuses.get(provider_name, ProviderStatus.NOT_READY)
568+
status = self._provider_statuses.get(
569+
provider_name, ProviderStatus.NOT_READY
570+
)
551571
return status not in (ProviderStatus.NOT_READY, ProviderStatus.FATAL)
552572

553573
def _calculate_aggregate_status(self) -> ProviderStatus:
@@ -577,7 +597,9 @@ def _refresh_aggregate_status(
577597
if event_to_emit is not None:
578598
self.emit(event_to_emit, event_details)
579599

580-
def _event_from_status(self, provider_status: ProviderStatus) -> ProviderEvent | None:
600+
def _event_from_status(
601+
self, provider_status: ProviderStatus
602+
) -> ProviderEvent | None:
581603
if provider_status == ProviderStatus.READY:
582604
return ProviderEvent.PROVIDER_READY
583605
if provider_status == ProviderStatus.STALE:
@@ -632,9 +654,7 @@ def _details_from_exception(
632654
self, err: Exception, provider_name: str
633655
) -> ProviderEventDetails:
634656
error_code = (
635-
err.error_code
636-
if isinstance(err, OpenFeatureError)
637-
else ErrorCode.GENERAL
657+
err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL
638658
)
639659
error_message = self._error_message_from_exception(err)
640660
return ProviderEventDetails(
@@ -652,9 +672,7 @@ def _resolution_from_exception(
652672
self, default_value: T, err: Exception
653673
) -> FlagResolutionDetails[T]:
654674
error_code = (
655-
err.error_code
656-
if isinstance(err, OpenFeatureError)
657-
else ErrorCode.GENERAL
675+
err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL
658676
)
659677
error_message = self._error_message_from_exception(err)
660678
return FlagResolutionDetails(
@@ -709,7 +727,9 @@ def _evaluate_provider_sync( # noqa: PLR0913
709727
return _ProviderEvaluation(
710728
provider_name=provider_name,
711729
provider=provider,
712-
result=resolve_fn(provider, flag_key, default_value, evaluation_context),
730+
result=resolve_fn(
731+
provider, flag_key, default_value, evaluation_context
732+
),
713733
)
714734
except Exception as err:
715735
return _ProviderEvaluation(
@@ -821,7 +841,9 @@ async def _evaluate_provider_async( # noqa: PLR0913
821841
try:
822842
before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints)
823843
resolved_context = provider_context.merge(before_context)
824-
resolution = await resolve_fn(provider, flag_key, default_value, resolved_context)
844+
resolution = await resolve_fn(
845+
provider, flag_key, default_value, resolved_context
846+
)
825847
flag_evaluation = resolution.to_flag_evaluation_details(flag_key)
826848
if err := flag_evaluation.get_exception():
827849
error_hooks(
@@ -883,12 +905,16 @@ def _evaluate_with_providers(
883905
if self._should_evaluate_provider(name)
884906
]
885907

908+
evaluations: list[_ProviderEvaluation[T]] = []
909+
886910
if self.strategy.run_mode == "parallel":
887911
# Each worker thread gets its own copy of the current context so
888912
# that ContextVars (e.g. _hook_runtime) are propagated correctly.
889913
# ThreadPoolExecutor does not automatically copy context on
890914
# Python < 3.12, and a single Context.run() is not reentrant.
891-
with ThreadPoolExecutor(max_workers=len(eligible_providers) or 1) as executor:
915+
with ThreadPoolExecutor(
916+
max_workers=len(eligible_providers) or 1
917+
) as executor:
892918
futures = [
893919
executor.submit(
894920
contextvars.copy_context().run,
@@ -916,7 +942,6 @@ def _evaluate_with_providers(
916942
),
917943
)
918944

919-
evaluations: list[_ProviderEvaluation[T]] = []
920945
for provider_name, provider in eligible_providers:
921946
evaluation = self._evaluate_provider_sync(
922947
provider_name,
@@ -970,6 +995,8 @@ async def _evaluate_with_providers_async(
970995
if self._should_evaluate_provider(name)
971996
]
972997

998+
evaluations: list[_ProviderEvaluation[T]] = []
999+
9731000
if self.strategy.run_mode == "parallel":
9741001
tasks = [
9751002
asyncio.create_task(
@@ -985,20 +1012,19 @@ async def _evaluate_with_providers_async(
9851012
)
9861013
for provider_name, provider in eligible_providers
9871014
]
988-
evaluations = await asyncio.gather(*tasks)
1015+
evaluations = list(await asyncio.gather(*tasks))
9891016
return typing.cast(
9901017
FlagResolutionDetails[T],
9911018
self.strategy.determine_final_result(
9921019
flag_key,
9931020
default_value,
9941021
typing.cast(
9951022
list[_ProviderEvaluation[FlagValueType]],
996-
list(evaluations),
1023+
evaluations,
9971024
),
9981025
),
9991026
)
10001027

1001-
evaluations: list[_ProviderEvaluation[T]] = []
10021028
for provider_name, provider in eligible_providers:
10031029
evaluation = await self._evaluate_provider_async(
10041030
provider_name,

tests/test_multi_provider.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def resolve_boolean_details(
6060
) -> FlagResolutionDetails[bool]:
6161
del flag_key
6262
self.resolveCount += 1
63-
self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes))
63+
self.seenContexts.append(
64+
dict((evaluation_context or EvaluationContext()).attributes)
65+
)
6466
if self.sync_blocker is not None:
6567
self.sync_blocker.wait()
6668
if self.booleanException is not None:
@@ -77,7 +79,9 @@ async def resolve_boolean_details_async(
7779
) -> FlagResolutionDetails[bool]:
7880
del flag_key
7981
self.resolveCount += 1
80-
self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes))
82+
self.seenContexts.append(
83+
dict((evaluation_context or EvaluationContext()).attributes)
84+
)
8185
if self.async_blocker is not None:
8286
await self.async_blocker.wait()
8387
if self.booleanException is not None:
@@ -225,7 +229,9 @@ def test_comparison_strategy_rejects_unknown_fallback_provider():
225229
first_provider = BooleanProvider("first")
226230
second_provider = BooleanProvider("second")
227231

228-
with pytest.raises(ValueError, match="Fallback provider 'missing' is not registered"):
232+
with pytest.raises(
233+
ValueError, match="Fallback provider 'missing' is not registered"
234+
):
229235
MultiProvider(
230236
[
231237
ProviderEntry(first_provider, name="first"),
@@ -311,7 +317,9 @@ def test_first_successful_skips_general_errors():
311317

312318
def test_first_successful_aggregates_errors_when_all_providers_fail():
313319
first_provider = BooleanProvider("first", boolean_exception=GeneralError("first"))
314-
second_provider = BooleanProvider("second", boolean_exception=GeneralError("second"))
320+
second_provider = BooleanProvider(
321+
"second", boolean_exception=GeneralError("second")
322+
)
315323
multi_provider = MultiProvider(
316324
[
317325
ProviderEntry(first_provider, name="first"),
@@ -583,8 +591,12 @@ def test_multi_provider_forwards_configuration_changed_events():
583591
spy.provider_configuration_changed,
584592
)
585593

586-
first_provider.emit_provider_configuration_changed(ProviderEventDetails(message="one"))
587-
second_provider.emit_provider_configuration_changed(ProviderEventDetails(message="two"))
594+
first_provider.emit_provider_configuration_changed(
595+
ProviderEventDetails(message="one")
596+
)
597+
second_provider.emit_provider_configuration_changed(
598+
ProviderEventDetails(message="two")
599+
)
588600

589601
assert spy.provider_configuration_changed.call_count == 2
590602

uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)