Skip to content

Commit 9e2ee6b

Browse files
test: add comprehensive tests Inference session
1 parent e89b340 commit 9e2ee6b

File tree

13 files changed

+373
-78
lines changed

13 files changed

+373
-78
lines changed

src/FFI/OnnxRuntime.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ public function GetMapKeyType($mapTypeInfo): CData
571571

572572
public function GetMapValueType($mapTypeInfo): CData
573573
{
574-
$keyType = $this->new('OrtTypeInfo');
574+
$keyType = $this->new('OrtTypeInfo*');
575575

576576
$this->checkStatus((($this->api)->GetMapValueType)($mapTypeInfo, FFI::addr($keyType)));
577577

src/Tensor/Tensor.php

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ public function __construct(
6565
int $dtype = null,
6666
array $shape = null,
6767
int $offset = null,
68-
)
69-
{
68+
) {
7069
if ($array === null && $dtype === null && $shape === null && $offset === null) {
7170
// Empty definition for Unserialize
7271
return;
@@ -174,7 +173,7 @@ protected function isBuffer(mixed $buffer): bool
174173
return $buffer instanceof Buffer;
175174
}
176175

177-
protected function isComplex(int $dtype = null): bool
176+
protected function isComplex(?int $dtype = null): bool
178177
{
179178
$dtype = $dtype ?? $this->dtype;
180179
return $this->cistype($dtype);
@@ -193,11 +192,13 @@ protected function assertShape(array $shape): void
193192
foreach ($shape as $num) {
194193
if (!is_int($num)) {
195194
throw new InvalidArgumentException(
196-
"Invalid shape numbers. It gives ".gettype($num));
195+
"Invalid shape numbers. It gives " . gettype($num)
196+
);
197197
}
198198
if ($num < 0) {
199199
throw new InvalidArgumentException(
200-
"Invalid shape numbers. It gives ".$num);
200+
"Invalid shape numbers. It gives " . $num
201+
);
201202
}
202203
}
203204
}
@@ -282,7 +283,7 @@ public static function service(): Service
282283
{
283284
if (!isset(self::$service)) {
284285
self::$service = new TensorService();
285-
// self::$service = new MatlibPhp();
286+
// self::$service = new MatlibPhp();
286287
}
287288

288289
return self::$service;
@@ -361,8 +362,8 @@ public function reshape(array $shape): static
361362
$this->assertShape($shape);
362363

363364
if ($this->size() != array_product($shape)) {
364-
throw new InvalidArgumentException("Unmatched size to reshape: ".
365-
"[".implode(',', $this->shape())."]=>[".implode(',', $shape)."]");
365+
throw new InvalidArgumentException("Unmatched size to reshape: " .
366+
"[" . implode(',', $this->shape()) . "]=>[" . implode(',', $shape) . "]");
366367
}
367368

368369
return new self($this->buffer(), $this->dtype(), $shape, $this->offset());
@@ -422,7 +423,7 @@ public function toString(): string
422423
*/
423424
public function toBufferArray(): array
424425
{
425-
$fmt = self::$pack[$this->dtype].'*';
426+
$fmt = self::$pack[$this->dtype] . '*';
426427

427428
return array_values(unpack($fmt, $this->buffer->dump()));
428429
}
@@ -565,8 +566,9 @@ public static function concat(array $tensors, int $axis = 0): Tensor
565566
public static function safeIndex(int $index, int $size, ?int $axis = null): int
566567
{
567568
if ($index < -$size || $index >= $size) {
568-
throw new InvalidArgumentException("IndexError: index $index is out of bounds for axis"
569-
.($axis === null ? '' : ' '.$axis)." with size $size"
569+
throw new InvalidArgumentException(
570+
"IndexError: index $index is out of bounds for axis"
571+
. ($axis === null ? '' : ' ' . $axis) . " with size $size"
570572
);
571573
}
572574

@@ -643,7 +645,7 @@ public function sigmoid(): self
643645
{
644646
$mo = self::mo();
645647

646-
$ndArray = $mo->f(fn ($x) => 1 / (1 + exp(-$x)), $this);
648+
$ndArray = $mo->f(fn($x) => 1 / (1 + exp(-$x)), $this);
647649

648650
return new static($ndArray->buffer(), $ndArray->dtype(), $ndArray->shape(), $ndArray->offset());
649651
}
@@ -856,7 +858,7 @@ public function norm(int $ord = 2, ?int $axis = null, bool $keepShape = false):
856858
$mo = self::mo();
857859

858860
if ($axis === null) {
859-
$val = pow(array_reduce($this->toBufferArray(), fn ($carry, $item) => $carry + pow($item, $ord), 0), 1 / $ord);
861+
$val = pow(array_reduce($this->toBufferArray(), fn($carry, $item) => $carry + pow($item, $ord), 0), 1 / $ord);
860862

861863
return new Tensor([$val], $this->dtype(), []);
862864
}
@@ -918,7 +920,7 @@ public function clamp(float|int $min, float|int $max): static
918920
{
919921
$mo = self::mo();
920922

921-
$result = $mo->f(fn ($x) => max($min, min($max, $x)), $this);
923+
$result = $mo->f(fn($x) => max($min, min($max, $x)), $this);
922924

923925
return new static($result->buffer(), $result->dtype(), $result->shape(), $result->offset());
924926
}
@@ -932,7 +934,7 @@ public function round(int $precision = 0): static
932934
{
933935
$mo = self::mo();
934936

935-
$result = $mo->f(fn ($x) => round($x, $precision), $this);
937+
$result = $mo->f(fn($x) => round($x, $precision), $this);
936938

937939
return new static($result->buffer(), $result->dtype(), $result->shape(), $result->offset());
938940
}
@@ -1004,7 +1006,11 @@ public function stdMean(?int $axis = null, int $correction = 1, bool $keepShape
10041006
$std = sqrt(
10051007
$mo->sum(
10061008
$mo->la()->pow(
1007-
$mo->la()->increment($this, -$mean), 2)) / ($this->size() - $correction));
1009+
$mo->la()->increment($this, -$mean),
1010+
2
1011+
)
1012+
) / ($this->size() - $correction)
1013+
);
10081014

10091015
return [$std, $mean];
10101016
}
@@ -1127,14 +1133,12 @@ public function slice(...$slices): Tensor
11271133
// null or undefined means take the whole dimension
11281134
$start[] = 0;
11291135
$size[] = $this->shape()[$sliceIndex];
1130-
11311136
} elseif (is_int($slice)) {
11321137
// An integer means take a single element
11331138
$slice = $this->safeIndex($slice, $this->shape()[$sliceIndex], $sliceIndex);
11341139

11351140
$start[] = $slice;
11361141
$size[] = 1;
1137-
11381142
} elseif (is_array($slice) && count($slice) === 2) {
11391143
[$first, $second] = $slice;
11401144

@@ -1146,14 +1150,13 @@ public function slice(...$slices): Tensor
11461150

11471151
// An array of length 2 means take a range of elements
11481152
if ($first > $second) {
1149-
throw new InvalidArgumentException("Invalid slice: ".json_encode($slice));
1153+
throw new InvalidArgumentException("Invalid slice: " . json_encode($slice));
11501154
}
11511155

11521156
$start[] = $first;
11531157
$size[] = $second - $first;
1154-
11551158
} else {
1156-
throw new InvalidArgumentException("Invalid slice: ".json_encode($slice));
1159+
throw new InvalidArgumentException("Invalid slice: " . json_encode($slice));
11571160
}
11581161
}
11591162

@@ -1330,7 +1333,7 @@ public function topk(int $k = -1, bool $sorted = true): array
13301333

13311334
if ($sorted) {
13321335
// Sort the heap to get the top k elements in descending order
1333-
usort($heap, fn ($a, $b) => $b['value'] <=> $a['value']);
1336+
usort($heap, fn($a, $b) => $b['value'] <=> $a['value']);
13341337
}
13351338

13361339
// Extract top K values and indices from the heap
@@ -1430,13 +1433,15 @@ public function offsetExists($offset): bool
14301433
return false;
14311434

14321435
if (is_array($offset)) {
1433-
if (count($offset) != 2 ||
1436+
if (
1437+
count($offset) != 2 ||
14341438
!array_key_exists(0, $offset) || !array_key_exists(1, $offset) ||
1435-
$offset[0] > $offset[1]) {
1439+
$offset[0] > $offset[1]
1440+
) {
14361441
$det = '';
14371442
if (is_numeric($offset[0]) && is_numeric($offset[1]))
1438-
$det = ':['.implode(',', $offset).']';
1439-
throw new OutOfRangeException("Illegal range specification.".$det);
1443+
$det = ':[' . implode(',', $offset) . ']';
1444+
throw new OutOfRangeException("Illegal range specification." . $det);
14401445
}
14411446
$start = $offset[0];
14421447
$limit = $offset[1];
@@ -1451,8 +1456,8 @@ public function offsetExists($offset): bool
14511456
$limit = $offset->limit();
14521457
$delta = $offset->delta();
14531458
if ($start >= $limit || $delta != 1) {
1454-
$det = ":[$start,$limit".(($delta != 1) ? ",$delta" : "").']';
1455-
throw new OutOfRangeException("Illegal range specification.".$det);
1459+
$det = ":[$start,$limit" . (($delta != 1) ? ",$delta" : "") . ']';
1460+
throw new OutOfRangeException("Illegal range specification." . $det);
14561461
}
14571462
} else {
14581463
throw new OutOfRangeException("Dimension must be integer");
@@ -1499,7 +1504,7 @@ public function offsetGet($offset): mixed
14991504
$start = $offset->start();
15001505
$limit = $offset->limit();
15011506
if ($offset->delta() != 1) {
1502-
throw new OutOfRangeException("Illegal range specification.:delta=".$offset->delta());
1507+
throw new OutOfRangeException("Illegal range specification.:delta=" . $offset->delta());
15031508
}
15041509
}
15051510

@@ -1516,8 +1521,12 @@ public function offsetGet($offset): mixed
15161521

15171522
array_unshift($shape, $rowsCount);
15181523

1519-
return new self($this->buffer, $this->dtype,
1520-
$shape, $this->offset + $start * $itemSize);
1524+
return new self(
1525+
$this->buffer,
1526+
$this->dtype,
1527+
$shape,
1528+
$this->offset + $start * $itemSize
1529+
);
15211530
}
15221531

15231532

@@ -1594,7 +1603,7 @@ public function setPortableSerializeMode(bool $mode): void
15941603

15951604
public function serialize(): ?string
15961605
{
1597-
return static::SERIALIZE_NDARRAY_KEYWORD.serialize($this->__serialize());
1606+
return static::SERIALIZE_NDARRAY_KEYWORD . serialize($this->__serialize());
15981607
}
15991608

16001609
public function __serialize()
@@ -1631,8 +1640,10 @@ public function unserialize($data): void
16311640
$buffer = $data->buffer();
16321641
if (get_class($data->service()) !== get_class(self::service())) {
16331642
$newBuffer = self::service()->buffer()->Buffer($buffer->count(), $buffer->dtype());
1634-
if ($data->service()->serviceLevel() >= Service::LV_ADVANCED &&
1635-
self::service()->serviceLevel() >= Service::LV_ADVANCED) {
1643+
if (
1644+
$data->service()->serviceLevel() >= Service::LV_ADVANCED &&
1645+
self::service()->serviceLevel() >= Service::LV_ADVANCED
1646+
) {
16361647
$newBuffer->load($buffer->dump());
16371648
} else {
16381649
$count = $buffer->count();
@@ -1666,15 +1677,16 @@ public function __unserialize($data)
16661677
$this->buffer[$key] = $value;
16671678
}
16681679
} else {
1669-
throw new RuntimeException('Illegal save mode: '.$mode);
1680+
throw new RuntimeException('Illegal save mode: ' . $mode);
16701681
}
16711682
}
16721683

16731684
public function __clone()
16741685
{
16751686
if (self::service()->serviceLevel() >= Service::LV_ADVANCED) {
16761687
$newBuffer = self::service()->buffer()->Buffer(
1677-
count($this->buffer), $this->buffer->dtype()
1688+
count($this->buffer),
1689+
$this->buffer->dtype()
16781690
);
16791691

16801692
$newBuffer->load($this->buffer->dump());
@@ -1683,8 +1695,7 @@ public function __clone()
16831695
} elseif (self::service()->serviceLevel() >= Service::LV_BASIC) {
16841696
$this->buffer = clone $this->buffer;
16851697
} else {
1686-
throw new RuntimeException('Unknown buffer type is uncloneable:'.get_class($this->buffer));
1698+
throw new RuntimeException('Unknown buffer type is uncloneable:' . get_class($this->buffer));
16871699
}
16881700
}
1689-
16901701
}

0 commit comments

Comments
 (0)