diff --git a/pkg/distribution/internal/bundle/parse.go b/pkg/distribution/internal/bundle/parse.go index cd928937..cf769cea 100644 --- a/pkg/distribution/internal/bundle/parse.go +++ b/pkg/distribution/internal/bundle/parse.go @@ -12,9 +12,9 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" ) -// errFoundSafetensors is a sentinel error used to stop filepath.Walk early -// after finding the first safetensors file. -var errFoundSafetensors = fmt.Errorf("found safetensors file") +// errFoundModelFile is a sentinel error used to stop filepath.Walk early after +// finding the first matching model file. +var errFoundModelFile = fmt.Errorf("found model file") // Parse returns the Bundle at the given rootDir func Parse(rootDir string) (*Bundle, error) { @@ -37,10 +37,14 @@ func Parse(rootDir string) (*Bundle, error) { if err != nil { return nil, err } + ddufPath, err := findDDUFFile(modelDir) + if err != nil { + return nil, err + } // Ensure at least one model weight format is present - if ggufPath == "" && safetensorsPath == "" { - return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)") + if ggufPath == "" && safetensorsPath == "" && ddufPath == "" { + return nil, fmt.Errorf("no supported model weights found (neither GGUF, safetensors, nor DDUF)") } mmprojPath, err := findMultiModalProjectorFile(modelDir) @@ -62,6 +66,7 @@ func Parse(rootDir string) (*Bundle, error) { mmprojPath: mmprojPath, ggufFile: ggufPath, safetensorsFile: safetensorsPath, + ddufFile: ddufPath, runtimeConfig: cfg, chatTemplatePath: templatePath, }, nil @@ -92,60 +97,69 @@ func parseRuntimeConfig(rootDir string) (types.ModelConfig, error) { return &cfg, nil } -func findGGUFFile(modelDir string) (string, error) { - ggufs, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.gguf")) +// findModelFile finds a supported model file by extension. It prefers a +// top-level match in modelDir and falls back to a recursive search when needed. +// Hidden files are ignored. +func findModelFile(modelDir, ext string) (string, error) { + pattern := filepath.Join(modelDir, "[^.]*"+ext) + paths, err := filepath.Glob(pattern) if err != nil { - return "", fmt.Errorf("find gguf files: %w", err) + return "", fmt.Errorf("find %s files: %w", ext, err) } - if len(ggufs) == 0 { - // GGUF files are optional - safetensors models won't have them - return "", nil + if len(paths) > 0 { + return filepath.Base(paths[0]), nil } - return filepath.Base(ggufs[0]), nil -} -func findSafetensorsFile(modelDir string) (string, error) { - // First check top-level directory (most common case) - safetensors, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.safetensors")) - if err != nil { - return "", fmt.Errorf("find safetensors files: %w", err) - } - if len(safetensors) > 0 { - return filepath.Base(safetensors[0]), nil - } - - // Search recursively for V0.2 models with nested directory structure - // (e.g., text_encoder/model.safetensors) var firstFound string - walkErr := filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - // Propagate filesystem errors so callers can distinguish them from - // the case where no safetensors files are present. - return err - } - if info.IsDir() { - return nil - } - if filepath.Ext(path) == ".safetensors" && !strings.HasPrefix(info.Name(), ".") { + walkErr := filepath.Walk( + modelDir, + func(path string, info os.FileInfo, err error) error { + if err != nil { + // Propagate filesystem errors so callers can distinguish them + // from the case where no matching files are present. + return err + } + if info.IsDir() { + return nil + } + if filepath.Ext(path) != ext || + strings.HasPrefix(info.Name(), ".") { + return nil + } + rel, relErr := filepath.Rel(modelDir, path) if relErr != nil { - // Treat a bad relative path as a real error instead of silently - // ignoring it, so malformed bundles surface to the caller. + // Treat a bad relative path as a real error instead of + // silently ignoring it, so malformed bundles surface to the + // caller. return relErr } firstFound = rel - return errFoundSafetensors // found one, stop walking - } - return nil - }) - if walkErr != nil && !errors.Is(walkErr, errFoundSafetensors) { - return "", fmt.Errorf("walk for safetensors files: %w", walkErr) + return errFoundModelFile + }, + ) + if walkErr != nil && !errors.Is(walkErr, errFoundModelFile) { + return "", fmt.Errorf("walk for %s files: %w", ext, walkErr) } - // Safetensors files are optional - GGUF models won't have them return firstFound, nil } +func findGGUFFile(modelDir string) (string, error) { + // GGUF files are optional. + return findModelFile(modelDir, ".gguf") +} + +func findSafetensorsFile(modelDir string) (string, error) { + // Safetensors files are optional. + return findModelFile(modelDir, ".safetensors") +} + +func findDDUFFile(modelDir string) (string, error) { + // DDUF files are optional. + return findModelFile(modelDir, ".dduf") +} + func findMultiModalProjectorFile(modelDir string) (string, error) { mmprojPaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.mmproj")) if err != nil { diff --git a/pkg/distribution/internal/bundle/parse_test.go b/pkg/distribution/internal/bundle/parse_test.go index 625aeba0..eb601e3d 100644 --- a/pkg/distribution/internal/bundle/parse_test.go +++ b/pkg/distribution/internal/bundle/parse_test.go @@ -41,7 +41,7 @@ func TestParse_NoModelWeights(t *testing.T) { t.Fatal("Expected error when parsing bundle without model weights, got nil") } - expectedErrMsg := "no supported model weights found (neither GGUF nor safetensors)" + expectedErrMsg := "no supported model weights found (neither GGUF, safetensors, nor DDUF)" if !strings.Contains(err.Error(), expectedErrMsg) { t.Errorf("Expected error message to contain %q, got: %v", expectedErrMsg, err) } @@ -93,6 +93,63 @@ func TestParse_WithGGUF(t *testing.T) { } } +func TestParse_WithNestedGGUF(t *testing.T) { + // Create a temporary directory for the test bundle. + tempDir := t.TempDir() + + // Create model subdirectory. + modelDir := filepath.Join(tempDir, ModelSubdir) + if err := os.MkdirAll(modelDir, 0755); err != nil { + t.Fatalf("Failed to create model directory: %v", err) + } + + // Create nested directory structure. + weightsDir := filepath.Join(modelDir, "nested", "weights") + if err := os.MkdirAll(weightsDir, 0755); err != nil { + t.Fatalf("Failed to create nested weights directory: %v", err) + } + + // Create a GGUF file in the nested directory. + nestedGGUFPath := filepath.Join(weightsDir, "model.gguf") + if err := os.WriteFile(nestedGGUFPath, []byte("dummy nested gguf"), 0644); err != nil { + t.Fatalf("Failed to create nested GGUF file: %v", err) + } + + // Create a valid config.json at bundle root. + cfg := types.Config{ + Format: types.FormatGGUF, + } + configPath := filepath.Join(tempDir, "config.json") + f, err := os.Create(configPath) + if err != nil { + t.Fatalf("Failed to create config.json: %v", err) + } + if err := json.NewEncoder(f).Encode(cfg); err != nil { + f.Close() + t.Fatalf("Failed to encode config: %v", err) + } + f.Close() + + // Parse the bundle and ensure GGUF discovery falls back to recursion. + bundle, err := Parse(tempDir) + if err != nil { + t.Fatalf("Expected successful parse with nested GGUF, got: %v", err) + } + + expectedPath := filepath.Join("nested", "weights", "model.gguf") + if bundle.ggufFile != expectedPath { + t.Errorf("Expected ggufFile to be %q, got: %s", expectedPath, bundle.ggufFile) + } + + fullPath := bundle.GGUFPath() + if fullPath == "" { + t.Error("Expected GGUFPath() to return a non-empty path") + } + if !strings.HasSuffix(fullPath, expectedPath) { + t.Errorf("Expected GGUFPath() to end with %q, got: %s", expectedPath, fullPath) + } +} + func TestParse_WithSafetensors(t *testing.T) { // Create a temporary directory for the test bundle tempDir := t.TempDir() @@ -139,6 +196,56 @@ func TestParse_WithSafetensors(t *testing.T) { } } +func TestParse_WithDDUF(t *testing.T) { + // Create a temporary directory for the test bundle. + tempDir := t.TempDir() + + // Create model subdirectory. + modelDir := filepath.Join(tempDir, ModelSubdir) + if err := os.MkdirAll(modelDir, 0755); err != nil { + t.Fatalf("Failed to create model directory: %v", err) + } + + // Create a dummy DDUF file. + ddufPath := filepath.Join(modelDir, "model.dduf") + if err := os.WriteFile(ddufPath, []byte("dummy dduf content"), 0644); err != nil { + t.Fatalf("Failed to create DDUF file: %v", err) + } + + // Create a valid config.json at bundle root. + cfg := types.Config{ + Format: types.FormatDDUF, + } + configPath := filepath.Join(tempDir, "config.json") + f, err := os.Create(configPath) + if err != nil { + t.Fatalf("Failed to create config.json: %v", err) + } + if err := json.NewEncoder(f).Encode(cfg); err != nil { + f.Close() + t.Fatalf("Failed to encode config: %v", err) + } + f.Close() + + // Parse the bundle and ensure DDUF-only bundles are accepted. + bundle, err := Parse(tempDir) + if err != nil { + t.Fatalf("Expected successful parse with DDUF file, got: %v", err) + } + + if bundle.ddufFile != "model.dduf" { + t.Errorf("Expected ddufFile to be %q, got: %s", "model.dduf", bundle.ddufFile) + } + + fullPath := bundle.DDUFPath() + if fullPath == "" { + t.Error("Expected DDUFPath() to return a non-empty path") + } + if !strings.HasSuffix(fullPath, "model.dduf") { + t.Errorf("Expected DDUFPath() to end with %q, got: %s", "model.dduf", fullPath) + } +} + func TestParse_WithNestedSafetensors(t *testing.T) { // Create a temporary directory for the test bundle tempDir := t.TempDir() @@ -198,6 +305,63 @@ func TestParse_WithNestedSafetensors(t *testing.T) { } } +func TestParse_WithNestedDDUF(t *testing.T) { + // Create a temporary directory for the test bundle. + tempDir := t.TempDir() + + // Create model subdirectory. + modelDir := filepath.Join(tempDir, ModelSubdir) + if err := os.MkdirAll(modelDir, 0755); err != nil { + t.Fatalf("Failed to create model directory: %v", err) + } + + // Create nested directory structure. + diffusersDir := filepath.Join(modelDir, "sanitized", "diffusers") + if err := os.MkdirAll(diffusersDir, 0755); err != nil { + t.Fatalf("Failed to create nested diffusers directory: %v", err) + } + + // Create a DDUF file in the nested directory. + nestedDDUFPath := filepath.Join(diffusersDir, "model.dduf") + if err := os.WriteFile(nestedDDUFPath, []byte("dummy nested dduf"), 0644); err != nil { + t.Fatalf("Failed to create nested DDUF file: %v", err) + } + + // Create a valid config.json at bundle root. + cfg := types.Config{ + Format: types.FormatDDUF, + } + configPath := filepath.Join(tempDir, "config.json") + f, err := os.Create(configPath) + if err != nil { + t.Fatalf("Failed to create config.json: %v", err) + } + if err := json.NewEncoder(f).Encode(cfg); err != nil { + f.Close() + t.Fatalf("Failed to encode config: %v", err) + } + f.Close() + + // Parse the bundle and ensure DDUF discovery falls back to recursion. + bundle, err := Parse(tempDir) + if err != nil { + t.Fatalf("Expected successful parse with nested DDUF, got: %v", err) + } + + expectedPath := filepath.Join("sanitized", "diffusers", "model.dduf") + if bundle.ddufFile != expectedPath { + t.Errorf("Expected ddufFile to be %q, got: %s", expectedPath, bundle.ddufFile) + } + + fullPath := bundle.DDUFPath() + if fullPath == "" { + t.Error("Expected DDUFPath() to return a non-empty path") + } + if !strings.HasSuffix(fullPath, expectedPath) { + t.Errorf("Expected DDUFPath() to end with %q, got: %s", expectedPath, fullPath) + } +} + func TestParse_WithBothFormats(t *testing.T) { // Create a temporary directory for the test bundle tempDir := t.TempDir() diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 293a979a..f1bec8e6 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -9,6 +9,8 @@ import ( "path/filepath" "strings" + "github.com/docker/model-runner/pkg/distribution/files" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -19,7 +21,7 @@ import ( // - V0.1 (legacy): Uses the original unpacking logic based on GGUFPaths(), SafetensorsPaths(), etc. func Unpack(dir string, model types.Model) (*Bundle, error) { artifact, isArtifact := model.(types.ModelArtifact) - if isArtifact && isV02Model(artifact) { + if isArtifact && (isV02Model(artifact) || isCNCFModel(artifact)) { return UnpackFromLayers(dir, artifact) } @@ -37,6 +39,17 @@ func isV02Model(model types.ModelArtifact) bool { return manifest.Config.MediaType == types.MediaTypeModelConfigV02 } +// isCNCFModel checks if the model was packaged using the CNCF ModelPack format. +// CNCF ModelPack uses a layer-per-file approach with filepath annotations, +// similar to V0.2, so it can be unpacked using UnpackFromLayers. +func isCNCFModel(model types.ModelArtifact) bool { + manifest, err := model.Manifest() + if err != nil { + return false + } + return manifest.Config.MediaType == modelpack.MediaTypeModelConfigV1 +} + // unpackLegacy is the original V0.1 unpacking logic that uses model.GGUFPaths(), model.SafetensorsPaths(), etc. func unpackLegacy(dir string, model types.Model) (*Bundle, error) { bundle := &Bundle{ @@ -511,6 +524,30 @@ func validatePathWithinDirectory(baseDir, targetPath string) error { return nil } +// sanitizeRelativePath strips leading ".." segments from a path while +// preserving the remaining directory structure. For example: +// +// "../../home/user/text_encoder/model.safetensors" → "home/user/text_encoder/model.safetensors" +// "../model.gguf" → "model.gguf" +// +// This is safer than filepath.Base which would flatten the entire path, +// losing nested directory structure and causing silent collisions. +func sanitizeRelativePath(p string) string { + // Normalize to forward slashes and clean + cleaned := filepath.ToSlash(filepath.Clean(p)) + + // Strip leading "../" segments + for strings.HasPrefix(cleaned, "../") { + cleaned = cleaned[len("../"):] + } + // Handle the case where the entire path is ".." + if cleaned == ".." { + return "" + } + + return cleaned +} + func extractTarArchiveFromReader(r io.Reader, destDir string) error { // Get absolute path of destination directory for security checks absDestDir, err := filepath.Abs(destDir) @@ -634,11 +671,28 @@ func UnpackFromLayers(dir string, model types.ModelArtifact) (*Bundle, error) { return nil, fmt.Errorf("get model layers: %w", err) } + // Determine the model format from config for resolving format-agnostic + // CNCF weight types (e.g., MediaTypeWeightRaw). + var modelFormat string + if cfg, err := model.Config(); err == nil { + modelFormat = string(cfg.GetFormat()) + } + // When the config omits the format field, infer it from filepath + // annotations on weight layers (e.g., ".gguf" → FormatGGUF). + if modelFormat == "" { + modelFormat = inferFormatFromLayerAnnotations(layers) + } + // Define the interface for getting descriptor with annotations type descriptorProvider interface { GetDescriptor() oci.Descriptor } + // Track destination paths to detect collisions from sanitization. + // Maps a resolved destination path to the original raw annotation so the + // collision error can point to the conflicting source paths. + destPaths := make(map[string]string) + // Iterate through all layers and unpack using annotations for _, layer := range layers { mediaType, err := layer.MediaType() @@ -657,17 +711,48 @@ func UnpackFromLayers(dir string, model types.ModelArtifact) (*Bundle, error) { if !exists || relPath == "" { continue } + rawAnnotation := relPath - // Validate the path to prevent directory traversal + // Validate the path to prevent directory traversal. + // Some packaging tools (e.g., modctl) may produce annotations with + // ".." components when the model was packaged using an absolute path. + // In that case, strip leading ".." segments while preserving the + // remaining directory structure to avoid flattening nested layouts. if err := validatePathWithinDirectory(modelDir, relPath); err != nil { - return nil, fmt.Errorf("invalid filepath annotation %q: %w", relPath, err) + sanitizedPath := sanitizeRelativePath(relPath) + // Re-validate the sanitized path to ensure it's safe. + if err2 := validatePathWithinDirectory(modelDir, sanitizedPath); err2 != nil { + return nil, fmt.Errorf( + "invalid filepath annotation %q (sanitized as %q): original error: %w, sanitized error: %w", + relPath, + sanitizedPath, + err, + err2, + ) + } + relPath = sanitizedPath } // Convert forward slashes to OS-specific separator relPath = filepath.FromSlash(relPath) destPath := filepath.Join(modelDir, relPath) - // Skip if file already exists + // Detect collisions: two different annotations mapping to the same destination + if origAnnotation, exists := destPaths[destPath]; exists { + if origAnnotation != rawAnnotation { + return nil, fmt.Errorf( + "filepath collision: annotations %q and %q both resolve to %q", + origAnnotation, + rawAnnotation, + destPath, + ) + } + // Same annotation seen twice (duplicate layer); skip silently. + continue + } + destPaths[destPath] = rawAnnotation + + // Skip if file already exists on disk (from a previous unpack) if _, err := os.Stat(destPath); err == nil { continue } @@ -683,7 +768,7 @@ func UnpackFromLayers(dir string, model types.ModelArtifact) (*Bundle, error) { } // Update bundle tracking fields - updateBundleFieldsFromLayer(bundle, mediaType, relPath) + updateBundleFieldsFromLayer(bundle, mediaType, relPath, modelFormat) } // Create the runtime config @@ -709,15 +794,53 @@ func unpackLayerToFile(destPath string, layer oci.Layer) error { return fmt.Errorf("layer is not a path provider") } +// inferFormatFromLayerAnnotations inspects filepath annotations on weight +// layers to determine the model format when config.format is not set. This +// handles CNCF ModelPack models that use MediaTypeWeightRaw but omit the +// format field. +func inferFormatFromLayerAnnotations(layers []oci.Layer) string { + type descriptorProvider interface { + GetDescriptor() oci.Descriptor + } + for _, l := range layers { + mt, err := l.MediaType() + if err != nil || !modelpack.IsModelPackWeightMediaType(string(mt)) { + continue + } + dp, ok := l.(descriptorProvider) + if !ok { + continue + } + fp, exists := dp.GetDescriptor().Annotations[types.AnnotationFilePath] + if !exists || fp == "" { + continue + } + // Use file classification to detect format from extension. + switch files.Classify(fp) { + case files.FileTypeGGUF: + return string(types.FormatGGUF) + case files.FileTypeSafetensors: + return string(types.FormatSafetensors) + case files.FileTypeDDUF: + return string(types.FormatDDUF) + case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate: + // Not a weight file; skip. + } + } + return "" +} + // updateBundleFieldsFromLayer updates the bundle tracking fields based on the unpacked layer. -func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPath string) { +// modelFormat is used to resolve format-agnostic CNCF weight types (e.g., MediaTypeWeightRaw) +// to the correct bundle field. Pass empty string when the model format is unknown. +func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPath string, modelFormat string) { //nolint:exhaustive // only tracking specific model-related media types switch mediaType { - case types.MediaTypeGGUF: + case types.MediaTypeGGUF, modelpack.MediaTypeWeightGGUF: if bundle.ggufFile == "" { bundle.ggufFile = relPath } - case types.MediaTypeSafetensors: + case types.MediaTypeSafetensors, modelpack.MediaTypeWeightSafetensors: if bundle.safetensorsFile == "" { bundle.safetensorsFile = relPath } @@ -733,6 +856,24 @@ func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPat if bundle.chatTemplatePath == "" { bundle.chatTemplatePath = relPath } + default: + // Handle format-agnostic CNCF weight types (e.g., .raw) by checking the model config format. + if modelpack.IsModelPackGenericWeightMediaType(string(mediaType)) { + switch types.Format(modelFormat) { + case types.FormatGGUF: + if bundle.ggufFile == "" { + bundle.ggufFile = relPath + } + case types.FormatSafetensors: + if bundle.safetensorsFile == "" { + bundle.safetensorsFile = relPath + } + case types.FormatDDUF, types.FormatDiffusers: //nolint:staticcheck // FormatDiffusers kept for backward compatibility + if bundle.ddufFile == "" { + bundle.ddufFile = relPath + } + } + } } } diff --git a/pkg/distribution/internal/bundle/unpack_test.go b/pkg/distribution/internal/bundle/unpack_test.go index 9f972eef..61c49d32 100644 --- a/pkg/distribution/internal/bundle/unpack_test.go +++ b/pkg/distribution/internal/bundle/unpack_test.go @@ -1,9 +1,16 @@ package bundle import ( + "io" "os" "path/filepath" + "strings" "testing" + + "github.com/docker/model-runner/pkg/distribution/internal/testutil" + "github.com/docker/model-runner/pkg/distribution/modelpack" + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" ) func TestValidatePathWithinDirectory(t *testing.T) { @@ -105,6 +112,22 @@ func TestValidatePathWithinDirectory(t *testing.T) { expectError: false, description: "Directory path with trailing slash should be valid", }, + + // Edge cases for filepath.Base sanitization (re-validation after fallback) + // filepath.Base("foo/..") returns ".." which must be rejected + { + name: "filepath.Base returns dotdot", + targetPath: "..", + expectError: true, + description: "Double dot (filepath.Base output for 'foo/..') should be blocked", + }, + // filepath.Base("/") returns "/" which must be rejected + { + name: "filepath.Base returns slash", + targetPath: "/", + expectError: true, + description: "Slash (filepath.Base output for '/') should be blocked as absolute path", + }, } for _, tt := range tests { @@ -121,6 +144,443 @@ func TestValidatePathWithinDirectory(t *testing.T) { } } +func TestUpdateBundleFieldsFromLayer_CNCFMediaTypes(t *testing.T) { + tests := []struct { + name string + mediaType oci.MediaType + relPath string + modelFormat string + expectGGUF string + expectSafetensors string + }{ + { + name: "Docker safetensors media type", + mediaType: types.MediaTypeSafetensors, + relPath: "model/model.safetensors", + modelFormat: "", + expectSafetensors: "model/model.safetensors", + }, + { + name: "CNCF format-specific safetensors media type", + mediaType: oci.MediaType(modelpack.MediaTypeWeightSafetensors), + relPath: "model/model.safetensors", + modelFormat: "", + expectSafetensors: "model/model.safetensors", + }, + { + name: "CNCF format-specific GGUF media type", + mediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + relPath: "model/model.gguf", + modelFormat: "", + expectGGUF: "model/model.gguf", + }, + { + name: "CNCF generic weight raw with safetensors format", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "model/model.safetensors", + modelFormat: string(types.FormatSafetensors), + expectSafetensors: "model/model.safetensors", + }, + { + name: "CNCF generic weight raw with GGUF format", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "model/model.gguf", + modelFormat: string(types.FormatGGUF), + expectGGUF: "model/model.gguf", + }, + { + name: "CNCF generic weight raw without format does nothing", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "model/model.safetensors", + modelFormat: "", + expectSafetensors: "", + expectGGUF: "", + }, + { + name: "unknown media type does nothing", + mediaType: "application/vnd.cncf.model.weight.config.v1.raw", + relPath: "model/config.json", + modelFormat: string(types.FormatSafetensors), + expectSafetensors: "", + expectGGUF: "", + }, + { + name: "CNCF generic weight raw with DDUF format", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "model/model.dduf", + modelFormat: string(types.FormatDDUF), + expectGGUF: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bundle := &Bundle{} + updateBundleFieldsFromLayer(bundle, tt.mediaType, tt.relPath, tt.modelFormat) + + if bundle.safetensorsFile != tt.expectSafetensors { + t.Errorf("safetensorsFile = %q, want %q", bundle.safetensorsFile, tt.expectSafetensors) + } + if bundle.ggufFile != tt.expectGGUF { + t.Errorf("ggufFile = %q, want %q", bundle.ggufFile, tt.expectGGUF) + } + if tt.modelFormat == string(types.FormatDDUF) && bundle.ddufFile != tt.relPath { + t.Errorf("ddufFile = %q, want %q", bundle.ddufFile, tt.relPath) + } + }) + } +} + +func TestIsCNCFModel(t *testing.T) { + tests := []struct { + name string + configMediaType oci.MediaType + expected bool + }{ + { + name: "CNCF ModelPack config V1", + configMediaType: modelpack.MediaTypeModelConfigV1, + expected: true, + }, + { + name: "Docker V0.1 config", + configMediaType: types.MediaTypeModelConfigV01, + expected: false, + }, + { + name: "Docker V0.2 config", + configMediaType: types.MediaTypeModelConfigV02, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal artifact with the given config media type + artifact := &testArtifactWithConfigMediaType{ + configMediaType: tt.configMediaType, + } + result := isCNCFModel(artifact) + if result != tt.expected { + t.Errorf("isCNCFModel() = %v, want %v", result, tt.expected) + } + }) + } +} + +// testArtifactWithConfigMediaType is a minimal ModelArtifact for testing isCNCFModel/isV02Model. +type testArtifactWithConfigMediaType struct { + configMediaType oci.MediaType +} + +func (a *testArtifactWithConfigMediaType) Manifest() (*oci.Manifest, error) { + return &oci.Manifest{ + Config: oci.Descriptor{ + MediaType: a.configMediaType, + }, + }, nil +} + +// Stubs to satisfy types.ModelArtifact interface (not used in these tests). +func (a *testArtifactWithConfigMediaType) ID() (string, error) { return "", nil } +func (a *testArtifactWithConfigMediaType) Config() (types.ModelConfig, error) { + return nil, nil +} +func (a *testArtifactWithConfigMediaType) Tags() []string { return nil } +func (a *testArtifactWithConfigMediaType) Descriptor() (types.Descriptor, error) { + return types.Descriptor{}, nil +} +func (a *testArtifactWithConfigMediaType) GGUFPaths() ([]string, error) { return nil, nil } +func (a *testArtifactWithConfigMediaType) SafetensorsPaths() ([]string, error) { + return nil, nil +} +func (a *testArtifactWithConfigMediaType) Layers() ([]oci.Layer, error) { return nil, nil } +func (a *testArtifactWithConfigMediaType) RawConfigFile() ([]byte, error) { return nil, nil } +func (a *testArtifactWithConfigMediaType) RawManifest() ([]byte, error) { return nil, nil } +func (a *testArtifactWithConfigMediaType) MediaType() (oci.MediaType, error) { return "", nil } +func (a *testArtifactWithConfigMediaType) Size() (int64, error) { return 0, nil } +func (a *testArtifactWithConfigMediaType) ConfigName() (oci.Hash, error) { return oci.Hash{}, nil } +func (a *testArtifactWithConfigMediaType) ConfigFile() (*oci.ConfigFile, error) { return nil, nil } +func (a *testArtifactWithConfigMediaType) Digest() (oci.Hash, error) { return oci.Hash{}, nil } +func (a *testArtifactWithConfigMediaType) LayerByDigest(oci.Hash) (oci.Layer, error) { + return nil, nil +} +func (a *testArtifactWithConfigMediaType) LayerByDiffID(oci.Hash) (oci.Layer, error) { + return nil, nil +} + +// testLayerWithAnnotation is a minimal oci.Layer that carries a specific media +// type and a filepath annotation. Used to test inferFormatFromLayerAnnotations. +type testLayerWithAnnotation struct { + mediaType oci.MediaType + annotation string +} + +func (l *testLayerWithAnnotation) MediaType() (oci.MediaType, error) { return l.mediaType, nil } +func (l *testLayerWithAnnotation) GetDescriptor() oci.Descriptor { + annotations := map[string]string{} + if l.annotation != "" { + annotations[types.AnnotationFilePath] = l.annotation + } + return oci.Descriptor{MediaType: l.mediaType, Annotations: annotations} +} + +// Stubs to satisfy oci.Layer interface. +func (l *testLayerWithAnnotation) Digest() (oci.Hash, error) { return oci.Hash{}, nil } +func (l *testLayerWithAnnotation) DiffID() (oci.Hash, error) { return oci.Hash{}, nil } +func (l *testLayerWithAnnotation) Compressed() (io.ReadCloser, error) { return nil, nil } +func (l *testLayerWithAnnotation) Uncompressed() (io.ReadCloser, error) { return nil, nil } +func (l *testLayerWithAnnotation) Size() (int64, error) { return 0, nil } + +func TestInferFormatFromLayerAnnotations(t *testing.T) { + tests := []struct { + name string + layers []oci.Layer + expected string + }{ + { + name: "GGUF via MediaTypeWeightRaw annotation", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + annotation: "model.gguf", + }, + }, + expected: string(types.FormatGGUF), + }, + { + name: "safetensors via MediaTypeWeightRaw annotation", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + annotation: "model.safetensors", + }, + }, + expected: string(types.FormatSafetensors), + }, + { + name: "DDUF via MediaTypeWeightRaw annotation", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + annotation: "model.dduf", + }, + }, + expected: string(types.FormatDDUF), + }, + { + name: "no weight layers returns empty", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: "application/json", + annotation: "config.json", + }, + }, + expected: "", + }, + { + name: "empty layers returns empty", + layers: []oci.Layer{}, + expected: "", + }, + { + name: "weight layer without annotation returns empty", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + annotation: "", + }, + }, + expected: "", + }, + { + name: "GGUF via MediaTypeWeightGGUF annotation", + layers: []oci.Layer{ + &testLayerWithAnnotation{ + mediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + annotation: "model.gguf", + }, + }, + expected: string(types.FormatGGUF), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := inferFormatFromLayerAnnotations(tt.layers) + if got != tt.expected { + t.Errorf("inferFormatFromLayerAnnotations() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestSanitizeRelativePath(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no leading dotdot", + input: "text_encoder/model.safetensors", + expected: "text_encoder/model.safetensors", + }, + { + name: "single leading dotdot", + input: "../model.gguf", + expected: "model.gguf", + }, + { + name: "multiple leading dotdots", + input: "../../home/user/text_encoder/model.safetensors", + expected: "home/user/text_encoder/model.safetensors", + }, + { + name: "deep leading dotdots preserving nested dirs", + input: "../../../a/b/c/model.safetensors", + expected: "a/b/c/model.safetensors", + }, + { + name: "only dotdot", + input: "..", + expected: "", + }, + { + name: "simple filename", + input: "model.gguf", + expected: "model.gguf", + }, + { + name: "dotdot in middle is resolved by Clean", + input: "a/../b/model.safetensors", + expected: "b/model.safetensors", + }, + { + name: "trailing slash stripped by Clean", + input: "../models/", + expected: "models", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeRelativePath(tt.input) + if got != tt.expected { + t.Errorf("sanitizeRelativePath(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestUnpackFromLayers_PathCollisionAfterSanitization(t *testing.T) { + // Build a ModelPack artifact where two distinct raw annotations collapse to + // the same sanitized destination path. + artifact := testutil.NewModelPackArtifact( + t, + modelpack.Model{ + Config: modelpack.ModelConfig{Format: string(types.FormatGGUF)}, + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "foo/model.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "../foo/model.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + }, + ) + + _, err := UnpackFromLayers(t.TempDir(), artifact) + if err == nil { + t.Fatal("Expected unpack collision error, got nil") + } + if !strings.Contains(err.Error(), "filepath collision") { + t.Fatalf("Expected collision error, got: %v", err) + } + if !strings.Contains(err.Error(), "foo/model.gguf") { + t.Fatalf("Expected collision error to mention original annotation, got: %v", err) + } + if !strings.Contains(err.Error(), "../foo/model.gguf") { + t.Fatalf("Expected collision error to mention sanitized annotation, got: %v", err) + } +} + +func TestUnpackFromLayers_DuplicateRawAnnotationAllowed(t *testing.T) { + // Build a ModelPack artifact with the same raw annotation twice. This + // should behave like a duplicate layer, not a collision. + artifact := testutil.NewModelPackArtifact( + t, + modelpack.Model{ + Config: modelpack.ModelConfig{Format: string(types.FormatGGUF)}, + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "foo/model.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "foo/model.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + }, + ) + + bundleRoot := t.TempDir() + bundle, err := UnpackFromLayers(bundleRoot, artifact) + if err != nil { + t.Fatalf("Expected duplicate annotation to be ignored, got: %v", err) + } + if bundle.ggufFile != filepath.Join("foo", "model.gguf") { + t.Errorf("Expected ggufFile to track unpacked path, got: %s", bundle.ggufFile) + } + if _, err := os.Stat(bundle.GGUFPath()); err != nil { + t.Fatalf("Expected GGUF file to exist after unpack, got: %v", err) + } +} + +func TestUnpackFromLayers_PathSanitizationRejectsCollapsedPath(t *testing.T) { + // Build a ModelPack artifact whose annotation collapses entirely during + // sanitization. This must fail before any file is written. + artifact := testutil.NewModelPackArtifact( + t, + modelpack.Model{ + Config: modelpack.ModelConfig{Format: string(types.FormatGGUF)}, + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "../../..", + MediaType: oci.MediaType(modelpack.MediaTypeWeightGGUF), + }, + ) + + bundleRoot := t.TempDir() + _, err := UnpackFromLayers(bundleRoot, artifact) + if err == nil { + t.Fatal("Expected sanitization error, got nil") + } + if !strings.Contains(err.Error(), `invalid filepath annotation "../../.."`) { + t.Fatalf("Expected error to mention original annotation, got: %v", err) + } + if !strings.Contains(err.Error(), `sanitized as ""`) { + t.Fatalf("Expected error to mention sanitized path, got: %v", err) + } + if !strings.Contains(err.Error(), "empty path is not allowed") { + t.Fatalf("Expected error to mention sanitized validation failure, got: %v", err) + } + + modelDir := filepath.Join(bundleRoot, ModelSubdir) + entries, readErr := os.ReadDir(modelDir) + if readErr != nil { + t.Fatalf("Expected model directory to exist, got: %v", readErr) + } + if len(entries) != 0 { + t.Fatalf("Expected no files to be written for rejected annotation, got %d entries", len(entries)) + } +} + func TestValidatePathWithinDirectory_RealFilesystem(t *testing.T) { // Create a temporary directory structure baseDir := t.TempDir() diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index eda798f7..5516719a 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/docker/model-runner/pkg/distribution/files" "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" @@ -125,7 +126,7 @@ func SafetensorsPaths(i WithLayers) ([]string, error) { } func DDUFPaths(i WithLayers) ([]string, error) { - return layerPathsByMediaType(i, types.MediaTypeDDUF, "") + return layerPathsByMediaType(i, types.MediaTypeDDUF, getModelFormat(i)) } func ConfigArchivePath(i WithLayers) (string, error) { @@ -143,15 +144,59 @@ func ConfigArchivePath(i WithLayers) (string, error) { return paths[0], err } -// getModelFormat reads the model config and returns the format string (e.g., "gguf", "safetensors"). -// This is used to resolve format-agnostic ModelPack weight media types (e.g., .raw, .tar) -// to specific model formats. Returns empty string if format cannot be determined. +// getModelFormat reads the model config and returns the format string (e.g., +// "gguf", "safetensors"). Used to resolve format-agnostic ModelPack weight +// media types (e.g., .raw). Falls back to inspecting layer filepath +// annotations when the config omits the format field. Returns empty string if +// format cannot be determined. func getModelFormat(i WithLayers) string { cfg, err := Config(i) + if err == nil { + if f := string(cfg.GetFormat()); f != "" { + return f + } + } + // Config did not specify a format; infer from filepath annotations. + return inferFormatFromAnnotations(i) +} + +// inferFormatFromAnnotations scans weight layers for a filepath annotation and +// uses the file extension to determine the model format. This handles CNCF +// ModelPack models that use MediaTypeWeightRaw but omit config.format. +func inferFormatFromAnnotations(i WithLayers) string { + layers, err := i.Layers() if err != nil { return "" } - return string(cfg.GetFormat()) + type descriptorProvider interface { + GetDescriptor() oci.Descriptor + } + for _, l := range layers { + mt, err := l.MediaType() + if err != nil || !modelpack.IsModelPackWeightMediaType(string(mt)) { + continue + } + dp, ok := l.(descriptorProvider) + if !ok { + continue + } + fp, exists := dp.GetDescriptor().Annotations[types.AnnotationFilePath] + if !exists || fp == "" { + continue + } + // Use file classification to detect format from extension. + switch files.Classify(fp) { + case files.FileTypeGGUF: + return string(types.FormatGGUF) + case files.FileTypeSafetensors: + return string(types.FormatSafetensors) + case files.FileTypeDDUF: + return string(types.FormatDDUF) + case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate: + // Not a weight file; skip. + } + } + return "" } // layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path. @@ -193,7 +238,6 @@ func matchesMediaType(layerMT, targetMT oci.MediaType, modelFormat string) bool } // Native ModelPack support: check format-specific ModelPack types - //nolint:exhaustive // Only GGUF and Safetensors need cross-format matching switch targetMT { case types.MediaTypeGGUF: if layerMT == modelpack.MediaTypeWeightGGUF { @@ -203,6 +247,15 @@ func matchesMediaType(layerMT, targetMT oci.MediaType, modelFormat string) bool if layerMT == modelpack.MediaTypeWeightSafetensors { return true } + case types.MediaTypeDDUF, types.MediaTypeLicense, types.MediaTypeMultimodalProjector, + types.MediaTypeChatTemplate, types.MediaTypeModelFile, types.MediaTypeVLLMConfigArchive, + types.MediaTypeDirTar, types.MediaTypeModelConfigV01, types.MediaTypeModelConfigV02, + oci.OCIManifestSchema1, oci.OCIImageIndex, oci.OCIConfigJSON, + oci.OCILayer, oci.OCILayerGzip, oci.OCILayerZstd, + oci.OCIContentDescriptor, oci.OCIArtifactManifest, oci.OCIEmptyJSON, + oci.DockerManifestSchema2, oci.DockerManifestList, oci.DockerConfigJSON, + oci.DockerLayer, oci.DockerForeignLayer, oci.DockerUncompressedLayer: + // No format-specific ModelPack mapping for these media types } // ModelPack model-spec support: format-agnostic weight types (.raw, .tar, etc.) @@ -211,12 +264,22 @@ func matchesMediaType(layerMT, targetMT oci.MediaType, modelFormat string) bool // in their media type and are handled above; applying this fallback to them would // cause cross-format false positives (e.g., safetensors layer matching as GGUF). if modelFormat != "" && modelpack.IsModelPackGenericWeightMediaType(string(layerMT)) { - //nolint:exhaustive // Only GGUF and Safetensors need cross-format matching switch targetMT { case types.MediaTypeGGUF: return modelFormat == string(types.FormatGGUF) case types.MediaTypeSafetensors: return modelFormat == string(types.FormatSafetensors) + case types.MediaTypeDDUF: + return modelFormat == string(types.FormatDDUF) || modelFormat == string(types.FormatDiffusers) //nolint:staticcheck // FormatDiffusers kept for backward compatibility + case types.MediaTypeLicense, types.MediaTypeMultimodalProjector, + types.MediaTypeChatTemplate, types.MediaTypeModelFile, types.MediaTypeVLLMConfigArchive, + types.MediaTypeDirTar, types.MediaTypeModelConfigV01, types.MediaTypeModelConfigV02, + oci.OCIManifestSchema1, oci.OCIImageIndex, oci.OCIConfigJSON, + oci.OCILayer, oci.OCILayerGzip, oci.OCILayerZstd, + oci.OCIContentDescriptor, oci.OCIArtifactManifest, oci.OCIEmptyJSON, + oci.DockerManifestSchema2, oci.DockerManifestList, oci.DockerConfigJSON, + oci.DockerLayer, oci.DockerForeignLayer, oci.DockerUncompressedLayer: + // No generic weight resolution for these media types } } diff --git a/pkg/distribution/internal/partial/partial_test.go b/pkg/distribution/internal/partial/partial_test.go index ee22c3f4..1a92ab62 100644 --- a/pkg/distribution/internal/partial/partial_test.go +++ b/pkg/distribution/internal/partial/partial_test.go @@ -6,6 +6,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/partial" "github.com/docker/model-runner/pkg/distribution/internal/testutil" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -289,3 +290,103 @@ func TestSafetensorsPaths_NoFalsePositive_GGUFModelPackType(t *testing.T) { t.Errorf("Expected 1 safetensors path (GGUF layer must not match), got %d", len(paths)) } } + +// TestGGUFPaths_ModelPackRawNoConfigFormat tests that GGUFPaths can find raw +// ModelPack weight layers even when the model config omits the format field. +// This exercises the annotation-based format discovery fallback introduced to +// handle CNCF ModelPack models that do not populate config.format. +func TestGGUFPaths_ModelPackRawNoConfigFormat(t *testing.T) { + // Build a CNCF ModelPack artifact with an empty config format and a raw + // weight layer whose filepath annotation ends in ".gguf". + mdl := testutil.NewModelPackArtifact( + t, + modelpack.Model{Config: modelpack.ModelConfig{}}, // format intentionally empty + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "model.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + ) + + paths, err := partial.GGUFPaths(mdl) + if err != nil { + t.Fatalf("GGUFPaths() error = %v", err) + } + + // Should discover one GGUF path via the ".gguf" extension fallback. + if len(paths) != 1 { + t.Errorf("Expected 1 GGUF path via annotation fallback, got %d", len(paths)) + } +} + +// TestSafetensorsPaths_ModelPackRawNoConfigFormat mirrors the GGUF test above +// for the safetensors format. +func TestSafetensorsPaths_ModelPackRawNoConfigFormat(t *testing.T) { + mdl := testutil.NewModelPackArtifact( + t, + modelpack.Model{Config: modelpack.ModelConfig{}}, // format intentionally empty + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "model.safetensors", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + ) + + paths, err := partial.SafetensorsPaths(mdl) + if err != nil { + t.Fatalf("SafetensorsPaths() error = %v", err) + } + + // Should discover one safetensors path via the ".safetensors" extension fallback. + if len(paths) != 1 { + t.Errorf("Expected 1 safetensors path via annotation fallback, got %d", len(paths)) + } +} + +// TestDDUFPaths_ModelPackRawWithDDUFFormat tests that DDUFPaths can find raw +// ModelPack weight layers when the model config specifies format as "dduf". +func TestDDUFPaths_ModelPackRawWithDDUFFormat(t *testing.T) { + mdl := testutil.NewModelPackArtifact( + t, + modelpack.Model{Config: modelpack.ModelConfig{Format: string(types.FormatDDUF)}}, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), // reuse dummy file + RelativePath: "model.dduf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + ) + + paths, err := partial.DDUFPaths(mdl) + if err != nil { + t.Fatalf("DDUFPaths() error = %v", err) + } + + if len(paths) != 1 { + t.Errorf("Expected 1 DDUF path, got %d", len(paths)) + } +} + +// TestDDUFPaths_ModelPackRawNoConfigFormat tests that DDUFPaths can find raw +// ModelPack weight layers even when the config omits the format field, by +// inferring the format from the ".dduf" extension in the filepath annotation. +func TestDDUFPaths_ModelPackRawNoConfigFormat(t *testing.T) { + mdl := testutil.NewModelPackArtifact( + t, + modelpack.Model{Config: modelpack.ModelConfig{}}, // format intentionally empty + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), // reuse dummy file + RelativePath: "model.dduf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + ) + + paths, err := partial.DDUFPaths(mdl) + if err != nil { + t.Fatalf("DDUFPaths() error = %v", err) + } + + // Should discover one DDUF path via the ".dduf" extension fallback. + if len(paths) != 1 { + t.Errorf("Expected 1 DDUF path via annotation fallback, got %d", len(paths)) + } +} diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index b6be7e4e..7e8087ba 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -134,7 +134,15 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B return backend } - switch config.GetFormat() { + format := config.GetFormat() + // If the config does not specify a format, infer it from the model's + // available file paths. This handles CNCF ModelPack models that omit + // the optional config.format field. + if format == "" { + format = inferFormatFromModel(model) + } + + switch format { case types.FormatSafetensors: // Prefer vLLM for safetensors models (handles platform dispatch internally) if s.platformSupport.SupportsVLLM() || s.platformSupport.SupportsVLLMMetal() { @@ -185,6 +193,24 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B return backend } +// inferFormatFromModel detects the model format by checking which file types +// are present in the model's layers. Used as a fallback when the model config +// omits the format field (e.g. some CNCF ModelPack models). Order matches +// detectModelFormat in the distribution bundle package to ensure consistent +// behavior for malformed or mixed artifacts. +func inferFormatFromModel(model types.Model) types.Format { + if paths, err := model.GGUFPaths(); err == nil && len(paths) > 0 { + return types.FormatGGUF + } + if paths, err := model.SafetensorsPaths(); err == nil && len(paths) > 0 { + return types.FormatSafetensors + } + if paths, err := model.DDUFPaths(); err == nil && len(paths) > 0 { + return types.FormatDDUF + } + return "" +} + // ResetInstaller resets the backend installer with a new HTTP client. func (s *Scheduler) ResetInstaller(httpClient *http.Client) { s.installer = newInstaller(s.log, s.backends, httpClient, s.deferredBackends) diff --git a/pkg/inference/scheduling/select_backend_test.go b/pkg/inference/scheduling/select_backend_test.go index 76d17be2..dd1f211d 100644 --- a/pkg/inference/scheduling/select_backend_test.go +++ b/pkg/inference/scheduling/select_backend_test.go @@ -29,13 +29,16 @@ func (m mockPlatformSupport) SupportsDiffusers() bool { return m.diffusers } // mockModel is a minimal Model implementation for testing. type mockModel struct { - config types.ModelConfig + config types.ModelConfig + ggufPaths []string + safetensorsPaths []string + ddufPaths []string } func (m *mockModel) ID() (string, error) { return "test-id", nil } -func (m *mockModel) GGUFPaths() ([]string, error) { return nil, nil } -func (m *mockModel) SafetensorsPaths() ([]string, error) { return nil, nil } -func (m *mockModel) DDUFPaths() ([]string, error) { return nil, nil } +func (m *mockModel) GGUFPaths() ([]string, error) { return m.ggufPaths, nil } +func (m *mockModel) SafetensorsPaths() ([]string, error) { return m.safetensorsPaths, nil } +func (m *mockModel) DDUFPaths() ([]string, error) { return m.ddufPaths, nil } func (m *mockModel) ConfigArchivePath() (string, error) { return "", nil } func (m *mockModel) MMPROJPath() (string, error) { return "", nil } func (m *mockModel) Config() (types.ModelConfig, error) { return m.config, nil } @@ -202,6 +205,75 @@ func TestSelectBackendForModel(t *testing.T) { model: legacyDiffusersModel, expectedBackend: diffusers.Name, }, + // Tests for CNCF ModelPack models that omit config.format: format + // must be inferred from the model's layer paths. + { + name: "ModelPack safetensors without format field selects vLLM", + backends: map[string]inference.Backend{ + "llamacpp": llamacppBackend, + vllm.Name: vllmBackend, + }, + defaultBackend: llamacppBackend, + platform: mockPlatformSupport{vllm: true}, + model: &mockModel{ + config: &types.Config{}, + safetensorsPaths: []string{"model.safetensors"}, + }, + expectedBackend: vllm.Name, + }, + { + name: "ModelPack GGUF without format field selects default backend", + backends: map[string]inference.Backend{ + "llamacpp": llamacppBackend, + vllm.Name: vllmBackend, + }, + defaultBackend: llamacppBackend, + platform: mockPlatformSupport{vllm: true}, + model: &mockModel{ + config: &types.Config{}, + ggufPaths: []string{"model.gguf"}, + }, + expectedBackend: "llamacpp", + }, + { + name: "ModelPack with no format and no paths uses default backend", + backends: map[string]inference.Backend{ + "llamacpp": llamacppBackend, + vllm.Name: vllmBackend, + }, + defaultBackend: llamacppBackend, + platform: mockPlatformSupport{vllm: true}, + model: &mockModel{config: &types.Config{}}, + expectedBackend: "llamacpp", + }, + { + name: "config.format wins over inferred safetensors paths", + backends: map[string]inference.Backend{ + "llamacpp": llamacppBackend, + vllm.Name: vllmBackend, + }, + defaultBackend: llamacppBackend, + platform: mockPlatformSupport{vllm: true}, + model: &mockModel{ + config: &types.Config{Format: types.FormatGGUF}, + safetensorsPaths: []string{"model.safetensors"}, + }, + expectedBackend: "llamacpp", + }, + { + name: "ModelPack DDUF without format field selects diffusers", + backends: map[string]inference.Backend{ + "llamacpp": llamacppBackend, + diffusers.Name: diffusersBackend, + }, + defaultBackend: llamacppBackend, + platform: mockPlatformSupport{diffusers: true}, + model: &mockModel{ + config: &types.Config{}, + ddufPaths: []string{"model.dduf"}, + }, + expectedBackend: diffusers.Name, + }, } for _, tt := range tests {