Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ export class EfficientNetFP16Nchw {
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
const softmax = this.builder_.softmax(await gemm);
const softmax = this.builder_.softmax(await gemm, 1);

return this.builder_.cast(softmax, 'float32');
}
Expand Down
4 changes: 2 additions & 2 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ export class MobileNetV2Nchw {
const pool = this.builder_.averagePool2d(await conv3);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = this.buildGemm_(reshape, '104');
return this.builder_.softmax(await gemm);
return this.builder_.softmax(await gemm, 1);
} else {
const conv4 = this.buildConv_(await conv3, '97', false,
{groups: 1280, strides: [7, 7]});
const conv5 = this.buildConv_(await conv4, '104', false);
const reshape = this.builder_.reshape(await conv5, [1, 1000]);
const softmax = this.builder_.softmax(reshape);
const softmax = this.builder_.softmax(reshape, 1);
return this.builder_.cast(softmax, 'float32');
}
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ export class MobileNetV2Nhwc {
const conv4 = this.buildConv_(
averagePool2d, '222', 'Logits_Conv2d_1c_1x1_Conv2D', false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(await conv4, [1, 1001]);
return await this.builder_.softmax(reshape);
return await this.builder_.softmax(reshape, 1);
}

async build(outputOperand) {
Expand Down
2 changes: 1 addition & 1 deletion image_classification/mobilenet_uint8_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ export class MobileNetV2Uint8Nhwc {
{scale: [0.06046031787991524], zero_point: [60], shape: []},
false, {autoPad, filterLayout});
const reshape = this.builder_.reshape(conv4, [1, 1001]);
const softmax = this.builder_.softmax(reshape);
const softmax = this.builder_.softmax(reshape, 1);

return softmax;
}
Expand Down
2 changes: 1 addition & 1 deletion image_classification/resnet50v1_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ export class ResNet50V1FP16Nchw {
const pool2 = this.builder_.averagePool2d(await bottleneck16);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = this.buildGemm_(reshape, '0');
const softmax = this.builder_.softmax(await gemm);
const softmax = this.builder_.softmax(await gemm, 1);
return this.builder_.cast(softmax, 'float32');
}

Expand Down
2 changes: 1 addition & 1 deletion image_classification/resnet50v2_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export class ResNet50V2Nchw {
const pool2 = this.builder_.averagePool2d(await bn3);
const reshape = this.builder_.reshape(await pool2, [1, 2048]);
const gemm = this.buildGemm_(await reshape, '0');
return this.builder_.softmax(await gemm);
return this.builder_.softmax(await gemm, 1);
}

async build(outputOperand) {
Expand Down
2 changes: 1 addition & 1 deletion image_classification/resnet50v2_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ export class ResNet50V2Nhwc {
const conv2 = this.buildConv_(
mean, ['', '', 'logits'], {autoPad}, false);
const reshape = this.builder_.reshape(await conv2, [1, 1001]);
return this.builder_.softmax(reshape);
return this.builder_.softmax(reshape, 1);
}

async build(outputOperand) {
Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export class SqueezeNetNchw {
const pool3 = this.builder_.averagePool2d(
await conv25, {windowDimensions: [13, 13], strides: [13, 13]});
const reshape0 = this.builder_.reshape(pool3, [1, 1000]);
return this.builder_.softmax(reshape0);
return this.builder_.softmax(reshape0, 1);
}

async build(outputOperand) {
Expand Down
2 changes: 1 addition & 1 deletion image_classification/squeezenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ export class SqueezeNetNhwc {
const averagePool2d = this.builder_.averagePool2d(
await conv10, {windowDimensions: [13, 13], layout});
const reshape = this.builder_.reshape(averagePool2d, [1, 1001]);
return this.builder_.softmax(reshape);
return this.builder_.softmax(reshape, 1);
}

async build(outputOperand) {
Expand Down
2 changes: 1 addition & 1 deletion lenet/lenet.js
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ export class LeNet {
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add4BiasShape)));
const add4 = this.builder_.add(matmul2, add4Bias);

return this.builder_.softmax(add4);
return this.builder_.softmax(add4, 1);
}

async build(outputOperand) {
Expand Down