Skip to content

Commit bfc1bdb

Browse files
committed
fix: Mean base class
1 parent 0ec1986 commit bfc1bdb

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

stubs/tensorflow/tensorflow/keras/metrics.pyi

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABCMeta, abstractmethod
22
from collections.abc import Callable, Iterable, Sequence
3+
from enum import Enum
34
from typing import Any, Literal
45
from typing_extensions import Self, TypeAlias
56

@@ -99,9 +100,6 @@ class Accuracy(MeanMetricWrapper):
99100
class CategoricalAccuracy(MeanMetricWrapper):
100101
def __init__(self, name: str | None = "categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
101102

102-
class Mean(MeanMetricWrapper):
103-
def __init__(self, name: str | None = "mean", dtype: DTypeLike | None = None) -> None: ...
104-
105103
class TopKCategoricalAccuracy(MeanMetricWrapper):
106104
def __init__(self, k: int = 5, name: str | None = "top_k_categorical_accuracy", dtype: DTypeLike | None = None) -> None: ...
107105

@@ -110,6 +108,19 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
110108
self, k: int = 5, name: str | None = "sparse_top_k_categorical_accuracy", dtype: DTypeLike | None = None
111109
) -> None: ...
112110

111+
class _Reduction(Enum):
112+
SUM = "sum"
113+
SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
114+
WEIGHTED_MEAN = "weighted_mean"
115+
116+
class Reduce(Metric):
117+
def __init__(self, reduction: _Reduction, name: str | None, dtype: DTypeLike | None = None) -> None: ...
118+
def update_state(self, values: TensorCompatible, sample_weight: TensorCompatible | None = None) -> Operation: ...
119+
def result(self) -> Tensor: ...
120+
121+
class Mean(Reduce):
122+
def __init__(self, name: str | None = "mean", dtype: DTypeLike | None = None) -> None: ...
123+
113124
def serialize(metric: KerasSerializable) -> dict[str, Any]: ...
114125
def binary_crossentropy(
115126
y_true: TensorCompatible, y_pred: TensorCompatible, from_logits: bool = False, label_smoothing: float = 0.0, axis: int = -1

0 commit comments

Comments
 (0)