Skip to content

Commit 5c90090

Browse files
refactor: improve auto-model resolution for task
1 parent f523672 commit 5c90090

File tree

2 files changed

+41
-36
lines changed

2 files changed

+41
-36
lines changed

src/Pipelines/Pipeline.php

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,25 @@ function pipeline(
7373

7474
$modelName ??= $task->defaultModelName();
7575

76-
$model = $task->autoModel($modelName, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress);
76+
$modelClass = $task->autoModelClass();
77+
78+
$modelClass = is_array($modelClass) ? $modelClass : [$modelClass];
79+
80+
$model = null;
81+
$lastException = null;
82+
83+
foreach ($modelClass as $modelClass) {
84+
try {
85+
$model = $modelClass::fromPretrained($modelName, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress);
86+
break;
87+
} catch (\Throwable $e) {
88+
$lastException = $e;
89+
}
90+
}
91+
92+
if ($model === null) {
93+
throw new \RuntimeException('Could not instantiate model for task: ' . $task->value, 0, $lastException);
94+
}
7795

7896
$tokenizer = $task->autoTokenizer($modelName, $cacheDir, $revision, $onProgress);
7997

src/Pipelines/Task.php

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -151,61 +151,50 @@ public function defaultModelName(): string
151151
};
152152
}
153153

154-
public function autoModel(
155-
string $modelNameOrPath,
156-
bool $quantized = true,
157-
?array $config = null,
158-
?string $cacheDir = null,
159-
string $revision = 'main',
160-
?string $modelFilename = null,
161-
?callable $onProgress = null
162-
): PretrainedModel
154+
/**
155+
* @return class-string<PretrainedModel>|array<class-string<PretrainedModel>>
156+
*/
157+
public function autoModelClass(): string|array
163158
{
164159
return match ($this) {
165160
self::SentimentAnalysis,
166161
self::TextClassification,
167-
self::ZeroShotClassification => AutoModelForSequenceClassification::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
162+
self::ZeroShotClassification => AutoModelForSequenceClassification::class,
168163

169-
self::FillMask => AutoModelForMaskedLM::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
164+
self::FillMask => AutoModelForMaskedLM::class,
170165

171-
self::QuestionAnswering => AutoModelForQuestionAnswering::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
166+
self::QuestionAnswering => AutoModelForQuestionAnswering::class,
172167

173168
self::FeatureExtraction,
174-
self::Embeddings => AutoModel::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
169+
self::Embeddings => AutoModel::class,
175170

176171
self::Text2TextGeneration,
177172
self::Translation,
178-
self::Summarization => AutoModelForSeq2SeqLM::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
173+
self::Summarization => AutoModelForSeq2SeqLM::class,
179174

180-
self::TextGeneration => AutoModelForCausalLM::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
175+
self::TextGeneration => AutoModelForCausalLM::class,
181176

182177
self::TokenClassification,
183-
self::Ner => AutoModelForTokenClassification::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
178+
self::Ner => AutoModelForTokenClassification::class,
184179

185-
self::ImageToText => AutoModelForVision2Seq::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
180+
self::ImageToText => AutoModelForVision2Seq::class,
186181

187-
self::ImageClassification => AutoModelForImageClassification::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
182+
self::ImageClassification => AutoModelForImageClassification::class,
188183

189-
self::ImageFeatureExtraction => AutoModelForImageFeatureExtraction::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
184+
self::ImageFeatureExtraction => [AutoModelForImageFeatureExtraction::class, AutoModel::class],
190185

191-
self::ZeroShotImageClassification => AutoModel::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
186+
self::ZeroShotImageClassification => AutoModel::class,
192187

193-
self::ImageToImage => AutoModelForImageToImage::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
188+
self::ImageToImage => AutoModelForImageToImage::class,
194189

195-
self::ObjectDetection => AutoModelForObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
190+
self::ObjectDetection => AutoModelForObjectDetection::class,
196191

197-
self::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
192+
self::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::class,
198193

199-
self::AudioClassification => AutoModelForAudioClassification::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress),
194+
self::AudioClassification => AutoModelForAudioClassification::class,
200195

201196
self::ASR,
202-
self::AutomaticSpeechRecognition => (function () use ($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress) {
203-
try {
204-
return AutoModelForSpeechSeq2Seq::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress);
205-
} catch (UnsupportedModelTypeException) {
206-
return AutoModelForCTC::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $revision, $modelFilename, $onProgress);
207-
}
208-
})(),
197+
self::AutomaticSpeechRecognition => [AutoModelForSpeechSeq2Seq::class, AutoModelForCTC::class],
209198
};
210199
}
211200

@@ -214,8 +203,7 @@ public function autoTokenizer(
214203
?string $cacheDir = null,
215204
string $revision = 'main',
216205
?callable $onProgress = null
217-
): ?PreTrainedTokenizer
218-
{
206+
): ?PreTrainedTokenizer {
219207
return match ($this) {
220208

221209
self::ImageClassification,
@@ -252,8 +240,7 @@ public function autoProcessor(
252240
?string $cacheDir = null,
253241
string $revision = 'main',
254242
?callable $onProgress = null
255-
): ?Processor
256-
{
243+
): ?Processor {
257244
return match ($this) {
258245

259246
self::ImageToText,

0 commit comments

Comments
 (0)