11from abc import ABCMeta , abstractmethod
22from collections .abc import Callable , Iterable , Sequence
3+ from enum import Enum
34from typing import Any , Literal
45from typing_extensions import Self , TypeAlias
56
@@ -99,9 +100,6 @@ class Accuracy(MeanMetricWrapper):
99100class CategoricalAccuracy (MeanMetricWrapper ):
100101 def __init__ (self , name : str | None = "categorical_accuracy" , dtype : DTypeLike | None = None ) -> None : ...
101102
102- class Mean (MeanMetricWrapper ):
103- def __init__ (self , name : str | None = "mean" , dtype : DTypeLike | None = None ) -> None : ...
104-
105103class TopKCategoricalAccuracy (MeanMetricWrapper ):
106104 def __init__ (self , k : int = 5 , name : str | None = "top_k_categorical_accuracy" , dtype : DTypeLike | None = None ) -> None : ...
107105
@@ -110,6 +108,19 @@ class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
110108 self , k : int = 5 , name : str | None = "sparse_top_k_categorical_accuracy" , dtype : DTypeLike | None = None
111109 ) -> None : ...
112110
111+ class _Reduction (Enum ):
112+ SUM = "sum"
113+ SUM_OVER_BATCH_SIZE = "sum_over_batch_size"
114+ WEIGHTED_MEAN = "weighted_mean"
115+
116+ class Reduce (Metric ):
117+ def __init__ (self , reduction : _Reduction , name : str | None , dtype : DTypeLike | None = None ) -> None : ...
118+ def update_state (self , values : TensorCompatible , sample_weight : TensorCompatible | None = None ) -> Operation : ...
119+ def result (self ) -> Tensor : ...
120+
121+ class Mean (Reduce ):
122+ def __init__ (self , name : str | None = "mean" , dtype : DTypeLike | None = None ) -> None : ...
123+
113124def serialize (metric : KerasSerializable ) -> dict [str , Any ]: ...
114125def binary_crossentropy (
115126 y_true : TensorCompatible , y_pred : TensorCompatible , from_logits : bool = False , label_smoothing : float = 0.0 , axis : int = - 1
0 commit comments