From 342c79c1d83f8a348cbf89860e82aa35c14275c8 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 23 Oct 2025 10:14:02 +0800 Subject: [PATCH] Use lastest softmax API --- image_classification/efficientnet_fp16_nchw.js | 2 +- image_classification/mobilenet_nchw.js | 4 ++-- image_classification/mobilenet_nhwc.js | 2 +- image_classification/mobilenet_uint8_nhwc.js | 2 +- image_classification/resnet50v1_fp16_nchw.js | 2 +- image_classification/resnet50v2_nchw.js | 2 +- image_classification/resnet50v2_nhwc.js | 2 +- image_classification/squeezenet_nchw.js | 2 +- image_classification/squeezenet_nhwc.js | 2 +- lenet/lenet.js | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js index 3bf9d373..c7eb577a 100644 --- a/image_classification/efficientnet_fp16_nchw.js +++ b/image_classification/efficientnet_fp16_nchw.js @@ -161,7 +161,7 @@ export class EfficientNetFP16Nchw { const pool1 = this.builder_.averagePool2d(await conv22); const reshape = this.builder_.reshape(pool1, [1, 1280]); const gemm = this.buildGemm_(reshape, '0'); - const softmax = this.builder_.softmax(await gemm); + const softmax = this.builder_.softmax(await gemm, 1); return this.builder_.cast(softmax, 'float32'); } diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js index 6a3b4879..b14d7a0b 100644 --- a/image_classification/mobilenet_nchw.js +++ b/image_classification/mobilenet_nchw.js @@ -153,13 +153,13 @@ export class MobileNetV2Nchw { const pool = this.builder_.averagePool2d(await conv3); const reshape = this.builder_.reshape(pool, [1, 1280]); const gemm = this.buildGemm_(reshape, '104'); - return this.builder_.softmax(await gemm); + return this.builder_.softmax(await gemm, 1); } else { const conv4 = this.buildConv_(await conv3, '97', false, {groups: 1280, strides: [7, 7]}); const conv5 = this.buildConv_(await conv4, '104', false); const reshape = this.builder_.reshape(await conv5, [1, 1000]); - const softmax = this.builder_.softmax(reshape); + const softmax = this.builder_.softmax(reshape, 1); return this.builder_.cast(softmax, 'float32'); } } diff --git a/image_classification/mobilenet_nhwc.js b/image_classification/mobilenet_nhwc.js index ff59d3f2..733e225a 100644 --- a/image_classification/mobilenet_nhwc.js +++ b/image_classification/mobilenet_nhwc.js @@ -153,7 +153,7 @@ export class MobileNetV2Nhwc { const conv4 = this.buildConv_( averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout}); const reshape = this.builder_.reshape(await conv4, [1, 1001]); - return await this.builder_.softmax(reshape); + return await this.builder_.softmax(reshape, 1); } async build(outputOperand) { diff --git a/image_classification/mobilenet_uint8_nhwc.js b/image_classification/mobilenet_uint8_nhwc.js index 92c2a326..06bd4553 100644 --- a/image_classification/mobilenet_uint8_nhwc.js +++ b/image_classification/mobilenet_uint8_nhwc.js @@ -465,7 +465,7 @@ export class MobileNetV2Uint8Nhwc { {scale: [0.06046031787991524], zero_point: [60], shape: []}, false, {autoPad, filterLayout}); const reshape = this.builder_.reshape(conv4, [1, 1001]); - const softmax = this.builder_.softmax(reshape); + const softmax = this.builder_.softmax(reshape, 1); return softmax; } diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js index db4035cb..7e27e821 100644 --- a/image_classification/resnet50v1_fp16_nchw.js +++ b/image_classification/resnet50v1_fp16_nchw.js @@ -130,7 +130,7 @@ export class ResNet50V1FP16Nchw { const pool2 = this.builder_.averagePool2d(await bottleneck16); const reshape = this.builder_.reshape(pool2, [1, 2048]); const gemm = this.buildGemm_(reshape, '0'); - const softmax = this.builder_.softmax(await gemm); + const softmax = this.builder_.softmax(await gemm, 1); return this.builder_.cast(softmax, 'float32'); } diff --git a/image_classification/resnet50v2_nchw.js b/image_classification/resnet50v2_nchw.js index 17a56104..27a72c4d 100644 --- a/image_classification/resnet50v2_nchw.js +++ b/image_classification/resnet50v2_nchw.js @@ -167,7 +167,7 @@ export class ResNet50V2Nchw { const pool2 = this.builder_.averagePool2d(await bn3); const reshape = this.builder_.reshape(await pool2, [1, 2048]); const gemm = this.buildGemm_(await reshape, '0'); - return this.builder_.softmax(await gemm); + return this.builder_.softmax(await gemm, 1); } async build(outputOperand) { diff --git a/image_classification/resnet50v2_nhwc.js b/image_classification/resnet50v2_nhwc.js index 01af991d..4c110ee6 100644 --- a/image_classification/resnet50v2_nhwc.js +++ b/image_classification/resnet50v2_nhwc.js @@ -201,7 +201,7 @@ export class ResNet50V2Nhwc { const conv2 = this.buildConv_( mean, ['', '', 'logits'], {autoPad}, false); const reshape = this.builder_.reshape(await conv2, [1, 1001]); - return this.builder_.softmax(reshape); + return this.builder_.softmax(reshape, 1); } async build(outputOperand) { diff --git a/image_classification/squeezenet_nchw.js b/image_classification/squeezenet_nchw.js index 9ee5300f..c76330ed 100644 --- a/image_classification/squeezenet_nchw.js +++ b/image_classification/squeezenet_nchw.js @@ -80,7 +80,7 @@ export class SqueezeNetNchw { const pool3 = this.builder_.averagePool2d( await conv25, {windowDimensions: [13, 13], strides: [13, 13]}); const reshape0 = this.builder_.reshape(pool3, [1, 1000]); - return this.builder_.softmax(reshape0); + return this.builder_.softmax(reshape0, 1); } async build(outputOperand) { diff --git a/image_classification/squeezenet_nhwc.js b/image_classification/squeezenet_nhwc.js index 44f86668..debbe9b6 100644 --- a/image_classification/squeezenet_nhwc.js +++ b/image_classification/squeezenet_nhwc.js @@ -96,7 +96,7 @@ export class SqueezeNetNhwc { const averagePool2d = this.builder_.averagePool2d( await conv10, {windowDimensions: [13, 13], layout}); const reshape = this.builder_.reshape(averagePool2d, [1, 1001]); - return this.builder_.softmax(reshape); + return this.builder_.softmax(reshape, 1); } async build(outputOperand) { diff --git a/lenet/lenet.js b/lenet/lenet.js index 7f70470f..985d8532 100644 --- a/lenet/lenet.js +++ b/lenet/lenet.js @@ -174,7 +174,7 @@ export class LeNet { new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add4BiasShape))); const add4 = this.builder_.add(matmul2, add4Bias); - return this.builder_.softmax(add4); + return this.builder_.softmax(add4, 1); } async build(outputOperand) {