@@ -151,61 +151,50 @@ public function defaultModelName(): string
151151 };
152152 }
153153
154- public function autoModel (
155- string $ modelNameOrPath ,
156- bool $ quantized = true ,
157- ?array $ config = null ,
158- ?string $ cacheDir = null ,
159- string $ revision = 'main ' ,
160- ?string $ modelFilename = null ,
161- ?callable $ onProgress = null
162- ): PretrainedModel
154+ /**
155+ * @return class-string<PretrainedModel>|array<class-string<PretrainedModel>>
156+ */
157+ public function autoModelClass (): string |array
163158 {
164159 return match ($ this ) {
165160 self ::SentimentAnalysis,
166161 self ::TextClassification,
167- self ::ZeroShotClassification => AutoModelForSequenceClassification::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
162+ self ::ZeroShotClassification => AutoModelForSequenceClassification::class ,
168163
169- self ::FillMask => AutoModelForMaskedLM::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
164+ self ::FillMask => AutoModelForMaskedLM::class ,
170165
171- self ::QuestionAnswering => AutoModelForQuestionAnswering::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
166+ self ::QuestionAnswering => AutoModelForQuestionAnswering::class ,
172167
173168 self ::FeatureExtraction,
174- self ::Embeddings => AutoModel::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
169+ self ::Embeddings => AutoModel::class ,
175170
176171 self ::Text2TextGeneration,
177172 self ::Translation,
178- self ::Summarization => AutoModelForSeq2SeqLM::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
173+ self ::Summarization => AutoModelForSeq2SeqLM::class ,
179174
180- self ::TextGeneration => AutoModelForCausalLM::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
175+ self ::TextGeneration => AutoModelForCausalLM::class ,
181176
182177 self ::TokenClassification,
183- self ::Ner => AutoModelForTokenClassification::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
178+ self ::Ner => AutoModelForTokenClassification::class ,
184179
185- self ::ImageToText => AutoModelForVision2Seq::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
180+ self ::ImageToText => AutoModelForVision2Seq::class ,
186181
187- self ::ImageClassification => AutoModelForImageClassification::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
182+ self ::ImageClassification => AutoModelForImageClassification::class ,
188183
189- self ::ImageFeatureExtraction => AutoModelForImageFeatureExtraction::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
184+ self ::ImageFeatureExtraction => [ AutoModelForImageFeatureExtraction::class, AutoModel::class] ,
190185
191- self ::ZeroShotImageClassification => AutoModel::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
186+ self ::ZeroShotImageClassification => AutoModel::class ,
192187
193- self ::ImageToImage => AutoModelForImageToImage::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
188+ self ::ImageToImage => AutoModelForImageToImage::class ,
194189
195- self ::ObjectDetection => AutoModelForObjectDetection::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
190+ self ::ObjectDetection => AutoModelForObjectDetection::class ,
196191
197- self ::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
192+ self ::ZeroShotObjectDetection => AutoModelForZeroShotObjectDetection::class ,
198193
199- self ::AudioClassification => AutoModelForAudioClassification::fromPretrained ( $ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) ,
194+ self ::AudioClassification => AutoModelForAudioClassification::class ,
200195
201196 self ::ASR ,
202- self ::AutomaticSpeechRecognition => (function () use ($ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress ) {
203- try {
204- return AutoModelForSpeechSeq2Seq::fromPretrained ($ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress );
205- } catch (UnsupportedModelTypeException ) {
206- return AutoModelForCTC::fromPretrained ($ modelNameOrPath , $ quantized , $ config , $ cacheDir , $ revision , $ modelFilename , $ onProgress );
207- }
208- })(),
197+ self ::AutomaticSpeechRecognition => [AutoModelForSpeechSeq2Seq::class, AutoModelForCTC::class],
209198 };
210199 }
211200
@@ -214,8 +203,7 @@ public function autoTokenizer(
214203 ?string $ cacheDir = null ,
215204 string $ revision = 'main ' ,
216205 ?callable $ onProgress = null
217- ): ?PreTrainedTokenizer
218- {
206+ ): ?PreTrainedTokenizer {
219207 return match ($ this ) {
220208
221209 self ::ImageClassification,
@@ -252,8 +240,7 @@ public function autoProcessor(
252240 ?string $ cacheDir = null ,
253241 string $ revision = 'main ' ,
254242 ?callable $ onProgress = null
255- ): ?Processor
256- {
243+ ): ?Processor {
257244 return match ($ this ) {
258245
259246 self ::ImageToText,
0 commit comments