Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ class TreeAggregatorClassifier : public TreeAggregatorSum<InputType, ThresholdTy
private:
const std::vector<int64_t>& class_labels_;
bool binary_case_;
bool weights_are_all_positive_;
int64_t positive_label_;
int64_t negative_label_;

Expand All @@ -491,13 +490,11 @@ class TreeAggregatorClassifier : public TreeAggregatorSum<InputType, ThresholdTy
const std::vector<ThresholdType>& base_values,
const std::vector<int64_t>& class_labels,
bool binary_case,
bool weights_are_all_positive,
int64_t positive_label = 1,
int64_t negative_label = 0) : TreeAggregatorSum<InputType, ThresholdType, OutputType>(n_trees, n_targets_or_classes,
post_transform, base_values),
class_labels_(class_labels),
binary_case_(binary_case),
weights_are_all_positive_(weights_are_all_positive),
positive_label_(positive_label),
negative_label_(negative_label) {}

Expand Down Expand Up @@ -526,22 +523,12 @@ class TreeAggregatorClassifier : public TreeAggregatorSum<InputType, ThresholdTy
ThresholdType score1, unsigned char has_score1) const {
ThresholdType pos_weight = has_score1 ? score1 : (has_score0 ? score0 : 0); // only 1 class
if (binary_case_) {
if (weights_are_all_positive_) {
if (pos_weight > 0.5) {
write_additional_scores = 0;
return class_labels_[1]; // positive label
} else {
write_additional_scores = 1;
return class_labels_[0]; // negative label
}
if (pos_weight > 0) {
write_additional_scores = 2;
return class_labels_[1]; // positive label
} else {
if (pos_weight > 0) {
write_additional_scores = 2;
return class_labels_[1]; // positive label
} else {
write_additional_scores = 3;
return class_labels_[0]; // negative label
}
write_additional_scores = 3;
return class_labels_[0]; // negative label
}
}
return (pos_weight > 0)
Expand Down
15 changes: 3 additions & 12 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,6 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
template <typename InputType, typename ThresholdType, typename OutputType>
class TreeEnsembleCommonClassifier : public TreeEnsembleCommon<InputType, ThresholdType, OutputType> {
private:
bool weights_are_all_positive_;
bool binary_case_;
std::vector<std::string> classlabels_strings_;
std::vector<int64_t> classlabels_int64s_;
Expand Down Expand Up @@ -974,13 +973,7 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::Init(

InlinedHashSet<int64_t> weights_classes;
weights_classes.reserve(attributes.target_class_ids.size());
weights_are_all_positive_ = true;
for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) {
weights_classes.insert(attributes.target_class_ids[i]);
if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i]
: attributes.target_class_weights_as_tensor[i]) < 0)
weights_are_all_positive_ = false;
}
weights_classes.insert(attributes.target_class_ids.begin(), attributes.target_class_ids.end());
binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1;
if (!classlabels_strings_.empty()) {
class_labels_.reserve(classlabels_strings_.size());
Expand All @@ -1001,8 +994,7 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::compu
TreeAggregatorClassifier<InputType, ThresholdType, OutputType>(
this->roots_.size(), this->n_targets_or_classes_,
this->post_transform_, this->base_values_,
classlabels_int64s_, binary_case_,
weights_are_all_positive_));
classlabels_int64s_, binary_case_));
} else {
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
AllocatorPtr alloc;
Expand All @@ -1013,8 +1005,7 @@ Status TreeEnsembleCommonClassifier<InputType, ThresholdType, OutputType>::compu
TreeAggregatorClassifier<InputType, ThresholdType, OutputType>(
this->roots_.size(), this->n_targets_or_classes_,
this->post_transform_, this->base_values_,
class_labels_, binary_case_,
weights_are_all_positive_));
class_labels_, binary_case_));
const int64_t* plabel = label_int64.Data<int64_t>();
std::string* labels = label->MutableData<std::string>();
for (size_t i = 0; i < (size_t)N; ++i)
Expand Down
108 changes: 108 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,114 @@ def test_get_graph_provider_assignment_info_not_enabled(self):
str(context.exception),
)

def test_tree_ensemble_logistic(self):
try:
import onnx # noqa: PLC0415
except ImportError:
# onnx is not installed on ARM build.
self.skipTest("onnx is not installed")
# issue https://github.com/microsoft/onnxruntime/issues/27533
x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, 3])
label_out = onnx.helper.make_tensor_value_info("label", onnx.TensorProto.INT64, [None])
prob_out = onnx.helper.make_tensor_value_info("probs", onnx.TensorProto.FLOAT, [None, 2])

def make_model(
nodes_modes,
nodes_values,
nodes_truenodeids,
nodes_falsenodeids,
class_treeids,
class_nodeids,
class_weights,
**node_kwargs,
):
"""Build a minimal TreeEnsembleClassifier ONNX model."""
n_nodes = len(nodes_modes)
if "base_values" not in node_kwargs:
node_kwargs["base_values"] = [-0.405] # logit(0.4)
node = onnx.helper.make_node(
"TreeEnsembleClassifier",
inputs=["X"],
outputs=["label", "probs"],
domain="ai.onnx.ml",
nodes_treeids=[0] * n_nodes,
nodes_nodeids=list(range(n_nodes)),
nodes_featureids=[0] * n_nodes,
nodes_values=nodes_values,
nodes_modes=nodes_modes,
nodes_truenodeids=nodes_truenodeids,
nodes_falsenodeids=nodes_falsenodeids,
nodes_missing_value_tracks_true=[0] * n_nodes,
nodes_hitrates=[1.0] * n_nodes,
class_treeids=class_treeids,
class_nodeids=class_nodeids,
class_ids=[0] * len(class_weights),
class_weights=class_weights,
classlabels_int64s=[0, 1],
post_transform="LOGISTIC",
**node_kwargs,
)
graph = onnx.helper.make_graph([node], "test", [x], [label_out, prob_out])
return onnx.helper.make_model(
graph,
opset_imports=[
onnx.helper.make_opsetid("", 15),
onnx.helper.make_opsetid("ai.onnx.ml", 3),
],
)

test_input = {"X": np.array([[0.1, 0.0, 0.0]], dtype=np.float32)}

# Case 1: Tree with a real split (root splits on feature 0 at 0.5)
model_split = make_model(
nodes_modes=["BRANCH_LT", "LEAF", "LEAF"],
nodes_values=[0.5, 0.0, 0.0],
nodes_truenodeids=[1, 0, 0],
nodes_falsenodeids=[2, 0, 0],
class_treeids=[0, 0],
class_nodeids=[1, 2],
class_weights=[0.3, -0.3], # mixed positive/negative
)
sess_split = onnxrt.InferenceSession(model_split.SerializeToString())
result_split = sess_split.run(None, test_input)
# x[0]=0.1 < 0.5, so left leaf (weight=0.3), aggregate = -0.405 + 0.3 = -0.105
expected_p1 = 1 / (1 + np.exp(0.105)) # sigmoid(-0.105)
with self.subTest(case="Case 1: Tree with a real split"):
np.testing.assert_allclose(result_split[1][0][1], expected_p1, atol=1e-5)

# Case 2: Leaf-only tree (single LEAF node, no splits)
model_leaf = make_model(
nodes_modes=["LEAF"],
nodes_values=[0.0],
nodes_truenodeids=[0],
nodes_falsenodeids=[0],
class_treeids=[0],
class_nodeids=[0],
class_weights=[0.0], # non-negative weight
)
sess_leaf = onnxrt.InferenceSession(model_leaf.SerializeToString())
result_leaf = sess_leaf.run(None, test_input)
# aggregate = -0.405 + 0 = -0.405
expected_p1_leaf = 1 / (1 + np.exp(0.405)) # sigmoid(-0.405) ≈ 0.400
with self.subTest(case="Case 2: Leaf-only tree (single LEAF node, no splits)"):
np.testing.assert_allclose(result_leaf[1][0][1], expected_p1_leaf, atol=1e-5)

# Case 3: Same leaf-only tree but with a negative weight (workaround)
model_leaf_neg = make_model(
nodes_modes=["LEAF"],
nodes_values=[0.0],
nodes_truenodeids=[0],
nodes_falsenodeids=[0],
class_treeids=[0],
class_nodeids=[0],
class_weights=[-0.405], # negative weight (move base_values into weight)
base_values=[0.0], # zero base
)
sess_leaf_neg = onnxrt.InferenceSession(model_leaf_neg.SerializeToString())
result_leaf_neg = sess_leaf_neg.run(None, test_input)
with self.subTest(case="Case 3: Same leaf-only tree but with a negative weight"):
np.testing.assert_allclose(result_leaf_neg[1][0][1], expected_p1_leaf, atol=1e-5)

Comment thread
xadupre marked this conversation as resolved.

if __name__ == "__main__":
unittest.main(verbosity=1)
Loading