diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 77aab22e8..f3283fd87 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -14,6 +14,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/store" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" @@ -213,11 +214,22 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter } } - // Check for supported type - if err := checkCompat(remoteModel, c.log, reference, progressWriter); err != nil { + // Check for supported type and convert ModelPack format if needed + remoteModel, err = checkAndConvertCompat(remoteModel, c.log, reference, progressWriter) + if err != nil { return err } + // Update digest after potential conversion (ModelPack conversion changes the manifest) + convertedDigest, err := remoteModel.Digest() + if err != nil { + return fmt.Errorf("getting converted model digest: %w", err) + } + if convertedDigest != remoteDigest { + c.log.Infof("Model converted from ModelPack format, new digest: %s", convertedDigest.String()) + } + remoteDigest = convertedDigest + // Check if model exists in local store localModel, err := c.store.Read(remoteDigest.String()) if err == nil { @@ -474,19 +486,38 @@ func GetSupportedFormats() []types.Format { return []types.Format{types.FormatGGUF} } -func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, progressWriter io.Writer) error { +// checkAndConvertCompat validates model compatibility. +// Both Docker format and CNCF ModelPack format are supported. +// ModelPack format is stored as-is and converted at read time. +func checkAndConvertCompat(image types.ModelArtifact, log *logrus.Entry, reference string, progressWriter io.Writer) (types.ModelArtifact, error) { manifest, err := image.Manifest() if err != nil { - return err + return nil, err + } + + mediaTypeStr := string(manifest.Config.MediaType) + + // 檢查是不是支援的格式(Docker 或 ModelPack) + isDocker := mediaTypeStr == string(types.MediaTypeModelConfigV01) + isModelPack := modelpack.IsModelPackMediaType(mediaTypeStr) + + if !isDocker && !isModelPack { + return nil, fmt.Errorf("config type %q is unsupported: %w", mediaTypeStr, ErrUnsupportedMediaType) } - if manifest.Config.MediaType != types.MediaTypeModelConfigV01 { - return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType) + + // ModelPack 格式會原封不動儲存,讀取時再轉換 + if isModelPack { + log.Infof("Detected ModelPack format for %s (stored as-is)", + utils.SanitizeForLog(reference)) + if err := progress.WriteInfo(progressWriter, "ModelPack format detected"); err != nil { + log.Warnf("Failed to write info message: %v", err) + } } // Check if the model format is supported config, err := image.Config() if err != nil { - return fmt.Errorf("reading model config: %w", err) + return nil, fmt.Errorf("reading model config: %w", err) } if config.Format == "" { @@ -501,5 +532,5 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, // Don't return an error - allow the pull to continue } - return nil + return image, nil } diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index e1d1c44e6..29fef7efb 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -8,6 +8,7 @@ import ( "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/partial" ggcr "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -21,9 +22,15 @@ func ConfigFile(i WithRawConfigFile) (*types.ConfigFile, error) { if err != nil { return nil, fmt.Errorf("get raw config file: %w", err) } + + // 自動偵測 ModelPack 格式,讀取時轉換成 Docker 格式 + if modelpack.IsModelPackConfig(raw) { + return modelpack.ConvertToDockerConfig(raw) + } + var cf types.ConfigFile if err := json.Unmarshal(raw, &cf); err != nil { - return nil, fmt.Errorf("unmarshal : %w", err) + return nil, fmt.Errorf("unmarshal config: %w", err) } return &cf, nil } @@ -127,7 +134,12 @@ func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, er var paths []string for _, l := range layers { mt, err := l.MediaType() - if err != nil || mt != mediaType { + if err != nil { + continue + } + // 把 ModelPack 的媒體類型轉成 Docker 格式再比對 + mappedMT := ggcr.MediaType(modelpack.MapLayerMediaType(string(mt))) + if mappedMT != mediaType { continue } layer, ok := l.(*Layer) diff --git a/pkg/distribution/internal/partial/partial_test.go b/pkg/distribution/internal/partial/partial_test.go index 40337e68d..098f24399 100644 --- a/pkg/distribution/internal/partial/partial_test.go +++ b/pkg/distribution/internal/partial/partial_test.go @@ -8,8 +8,76 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/partial" "github.com/docker/model-runner/pkg/distribution/types" + ggcr "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" ) +// mockConfig 用來測試 ConfigFile 函數 +type mockConfig struct { + raw []byte + err error +} + +func (m *mockConfig) RawConfigFile() ([]byte, error) { + return m.raw, m.err +} + +// TestConfigFile_AutoDetection 測試 ConfigFile 能自動偵測並轉換 ModelPack 格式 +func TestConfigFile_AutoDetection(t *testing.T) { + t.Run("Docker format passes through", func(t *testing.T) { + // Docker 格式的 config + dockerJSON := `{ + "config": {"format": "gguf", "parameters": "8B"}, + "descriptor": {}, + "rootfs": {"type": "layers", "diff_ids": []} + }` + + mock := &mockConfig{raw: []byte(dockerJSON)} + cf, err := partial.ConfigFile(mock) + if err != nil { + t.Fatalf("ConfigFile() error = %v", err) + } + + if cf.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", cf.Config.Format, types.FormatGGUF) + } + if cf.Config.Parameters != "8B" { + t.Errorf("Parameters = %q, want %q", cf.Config.Parameters, "8B") + } + }) + + t.Run("ModelPack format auto-converts", func(t *testing.T) { + // ModelPack 格式的 config(用 paramSize 不是 parameters) + modelPackJSON := `{ + "descriptor": {"createdAt": "2025-01-15T10:30:00Z"}, + "config": {"format": "gguf", "paramSize": "8B"}, + "modelfs": {"type": "layers", "diffIds": []} + }` + + mock := &mockConfig{raw: []byte(modelPackJSON)} + cf, err := partial.ConfigFile(mock) + if err != nil { + t.Fatalf("ConfigFile() error = %v", err) + } + + // 轉換後應該要有 Docker 格式的欄位 + if cf.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", cf.Config.Format, types.FormatGGUF) + } + // paramSize 應該被轉成 parameters + if cf.Config.Parameters != "8B" { + t.Errorf("Parameters = %q, want %q", cf.Config.Parameters, "8B") + } + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + mock := &mockConfig{raw: []byte("not valid json")} + _, err := partial.ConfigFile(mock) + if err == nil { + t.Error("expected error for invalid JSON") + } + }) +} + func TestMMPROJPath(t *testing.T) { // Create a model from GGUF file mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) @@ -122,3 +190,33 @@ func TestLayerPathByMediaType(t *testing.T) { } } + +// TestGGUFPaths_ModelPackMediaType 測試 GGUFPaths 能找到 ModelPack 格式的層 +func TestGGUFPaths_ModelPackMediaType(t *testing.T) { + // 用 ModelPack 的 GGUF 媒體類型建立層 + modelPackGGUFType := ggcr.MediaType("application/vnd.cncf.model.weight.v1.gguf") + + layer, err := partial.NewLayer(filepath.Join("..", "..", "assets", "dummy.gguf"), modelPackGGUFType) + if err != nil { + t.Fatalf("建立 ModelPack 層失敗: %v", err) + } + + // 用 mutate 建立含有這個層的 model + mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) + if err != nil { + t.Fatalf("建立 GGUF model 失敗: %v", err) + } + + mdlWithModelPackLayer := mutate.AppendLayers(mdl, layer) + + // GGUFPaths 應該要能找到 ModelPack 格式的 GGUF 層 + paths, err := partial.GGUFPaths(mdlWithModelPackLayer) + if err != nil { + t.Fatalf("GGUFPaths() error = %v", err) + } + + // 應該找到兩個:原本的 Docker 格式 + 新加的 ModelPack 格式 + if len(paths) != 2 { + t.Errorf("預期找到 2 個 GGUF 路徑, 實際找到 %d 個", len(paths)) + } +} diff --git a/pkg/distribution/internal/progress/reporter.go b/pkg/distribution/internal/progress/reporter.go index db038aa08..bb5db0990 100644 --- a/pkg/distribution/internal/progress/reporter.go +++ b/pkg/distribution/internal/progress/reporter.go @@ -139,6 +139,14 @@ func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64 }) } +// WriteInfo writes an info message +func WriteInfo(w io.Writer, message string) error { + return write(w, Message{ + Type: "info", + Message: message, + }) +} + // WriteSuccess writes a success message func WriteSuccess(w io.Writer, message string) error { return write(w, Message{ diff --git a/pkg/distribution/modelpack/convert.go b/pkg/distribution/modelpack/convert.go new file mode 100644 index 000000000..9eaa0f66b --- /dev/null +++ b/pkg/distribution/modelpack/convert.go @@ -0,0 +1,239 @@ +package modelpack + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/opencontainers/go-digest" + + v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" + + "github.com/docker/model-runner/pkg/distribution/types" +) + +// IsModelPackMediaType checks if the given media type indicates a CNCF ModelPack format. +// It returns true if the media type has the CNCF model prefix. +func IsModelPackMediaType(mediaType string) bool { + return strings.HasPrefix(mediaType, MediaTypePrefix) +} + +// IsModelPackConfig 從 raw config bytes 判斷是不是 ModelPack 格式。 +// 透過解析 JSON 結構來精確判斷,避免字串匹配的 false positive。 +// ModelPack 格式的特徵:config.paramSize 或 descriptor.createdAt +// Docker 格式用的是:config.parameters 和 descriptor.created +func IsModelPackConfig(raw []byte) bool { + if len(raw) == 0 { + return false + } + + // 用 map 解析來檢查實際的 JSON 結構 + var parsed map[string]json.RawMessage + if err := json.Unmarshal(raw, &parsed); err != nil { + return false + } + + // 檢查 config.paramSize(ModelPack 特有) + if configRaw, ok := parsed["config"]; ok { + var config map[string]json.RawMessage + if err := json.Unmarshal(configRaw, &config); err == nil { + if _, hasParamSize := config["paramSize"]; hasParamSize { + return true + } + } + } + + // 檢查 descriptor.createdAt(ModelPack 用 camelCase) + if descRaw, ok := parsed["descriptor"]; ok { + var desc map[string]json.RawMessage + if err := json.Unmarshal(descRaw, &desc); err == nil { + if _, hasCreatedAt := desc["createdAt"]; hasCreatedAt { + return true + } + } + } + + // 檢查 modelfs(ModelPack 特有的欄位名稱) + if _, hasModelFS := parsed["modelfs"]; hasModelFS { + return true + } + + return false +} + +// MapLayerMediaType 把 ModelPack 的層媒體類型轉成 Docker 的格式。 +// 如果不是 ModelPack 類型就直接回傳原本的值。 +func MapLayerMediaType(mediaType string) string { + // 只處理 ModelPack 的 weight 層 + if !strings.HasPrefix(mediaType, MediaTypePrefix) { + return mediaType + } + + // 根據媒體類型裡的格式來決定對應的 Docker 類型 + switch { + case strings.Contains(mediaType, "weight") && strings.Contains(mediaType, "gguf"): + return string(types.MediaTypeGGUF) + case strings.Contains(mediaType, "weight") && strings.Contains(mediaType, "safetensors"): + return string(types.MediaTypeSafetensors) + default: + // 其他的層類型(doc、code 等)維持原樣 + return mediaType + } +} + +// ConvertToDockerConfig converts a raw ModelPack config JSON to Docker model-spec ConfigFile. +// It maps common fields directly and preserves extended ModelPack metadata in the ModelPack map. +func ConvertToDockerConfig(rawConfig []byte) (*types.ConfigFile, error) { + var mp Model + if err := json.Unmarshal(rawConfig, &mp); err != nil { + return nil, fmt.Errorf("unmarshal modelpack config: %w", err) + } + + // Build the Docker format config + dockerConfig := &types.ConfigFile{ + Config: types.Config{ + Format: convertFormat(mp.Config.Format), + Architecture: mp.Config.Architecture, + Quantization: mp.Config.Quantization, + Parameters: mp.Config.ParamSize, + Size: "0", // ModelPack doesn't have an equivalent field + ModelPack: extractExtendedMetadata(&mp), + }, + Descriptor: types.Descriptor{ + Created: mp.Descriptor.CreatedAt, + }, + RootFS: v1.RootFS{ + Type: normalizeRootFSType(mp.ModelFS.Type), + DiffIDs: convertDiffIDs(mp.ModelFS.DiffIDs), + }, + } + + return dockerConfig, nil +} + +// convertFormat maps ModelPack format strings to Docker Format type. +// Format strings are normalized to lowercase for consistent matching. +func convertFormat(mpFormat string) types.Format { + switch strings.ToLower(mpFormat) { + case "gguf": + return types.FormatGGUF + case "safetensors": + return types.FormatSafetensors + default: + // Pass through unknown formats as-is + return types.Format(strings.ToLower(mpFormat)) + } +} + +// normalizeRootFSType ensures the rootfs type is set correctly. +// ModelPack uses "layers" as the type, which maps to Docker's "layers". +func normalizeRootFSType(mpType string) string { + if mpType == "" { + return "layers" + } + return mpType +} + +// convertDiffIDs converts opencontainers digest.Digest slice to go-containerregistry v1.Hash slice. +// Note: Invalid digests are silently skipped here because they will be caught +// during layer validation when the model is actually loaded. This avoids +// failing early for formats we might not fully understand yet. +func convertDiffIDs(digests []digest.Digest) []v1.Hash { + if len(digests) == 0 { + return nil + } + + result := make([]v1.Hash, 0, len(digests)) + for _, d := range digests { + // digest.Digest format is "algorithm:hex", same as v1.Hash + hash, err := v1.NewHash(d.String()) + if err != nil { + // 跳過無效的 digest,後面 layer 驗證時會抓到 + continue + } + result = append(result, hash) + } + return result +} + +// extractExtendedMetadata extracts ModelPack-specific metadata that doesn't have +// a direct mapping to Docker format fields. These are preserved in the ModelPack map. +func extractExtendedMetadata(mp *Model) map[string]string { + metadata := make(map[string]string) + + // Descriptor fields + if len(mp.Descriptor.Authors) > 0 { + metadata["authors"] = strings.Join(mp.Descriptor.Authors, ", ") + } + if mp.Descriptor.Family != "" { + metadata["family"] = mp.Descriptor.Family + } + if mp.Descriptor.Name != "" { + metadata["name"] = mp.Descriptor.Name + } + if mp.Descriptor.DocURL != "" { + metadata["docURL"] = mp.Descriptor.DocURL + } + if mp.Descriptor.SourceURL != "" { + metadata["sourceURL"] = mp.Descriptor.SourceURL + } + if len(mp.Descriptor.DatasetsURL) > 0 { + metadata["datasetsURL"] = strings.Join(mp.Descriptor.DatasetsURL, ", ") + } + if mp.Descriptor.Version != "" { + metadata["version"] = mp.Descriptor.Version + } + if mp.Descriptor.Revision != "" { + metadata["revision"] = mp.Descriptor.Revision + } + if mp.Descriptor.Vendor != "" { + metadata["vendor"] = mp.Descriptor.Vendor + } + if len(mp.Descriptor.Licenses) > 0 { + metadata["licenses"] = strings.Join(mp.Descriptor.Licenses, ", ") + } + if mp.Descriptor.Title != "" { + metadata["title"] = mp.Descriptor.Title + } + if mp.Descriptor.Description != "" { + metadata["description"] = mp.Descriptor.Description + } + + // Config fields not in Docker format + if mp.Config.Precision != "" { + metadata["precision"] = mp.Config.Precision + } + + // Capabilities + if mp.Config.Capabilities != nil { + caps := mp.Config.Capabilities + if len(caps.InputTypes) > 0 { + metadata["capabilities.inputTypes"] = strings.Join(caps.InputTypes, ", ") + } + if len(caps.OutputTypes) > 0 { + metadata["capabilities.outputTypes"] = strings.Join(caps.OutputTypes, ", ") + } + if caps.KnowledgeCutoff != nil { + metadata["capabilities.knowledgeCutoff"] = caps.KnowledgeCutoff.Format("2006-01-02") + } + if caps.Reasoning != nil && *caps.Reasoning { + metadata["capabilities.reasoning"] = "true" + } + if caps.ToolUsage != nil && *caps.ToolUsage { + metadata["capabilities.toolUsage"] = "true" + } + if caps.Reward != nil && *caps.Reward { + metadata["capabilities.reward"] = "true" + } + if len(caps.Languages) > 0 { + metadata["capabilities.languages"] = strings.Join(caps.Languages, ", ") + } + } + + // Return nil if no metadata to avoid empty map in JSON + if len(metadata) == 0 { + return nil + } + + return metadata +} diff --git a/pkg/distribution/modelpack/convert_test.go b/pkg/distribution/modelpack/convert_test.go new file mode 100644 index 000000000..2b50281ec --- /dev/null +++ b/pkg/distribution/modelpack/convert_test.go @@ -0,0 +1,542 @@ +package modelpack + +import ( + "encoding/json" + "testing" + "time" + + "github.com/opencontainers/go-digest" + + "github.com/docker/model-runner/pkg/distribution/types" +) + +func TestIsModelPackMediaType(t *testing.T) { + tests := []struct { + name string + mediaType string + expected bool + }{ + { + name: "CNCF v1 config", + mediaType: "application/vnd.cncf.model.config.v1+json", + expected: true, + }, + { + name: "CNCF future version", + mediaType: "application/vnd.cncf.model.config.v2+json", + expected: true, + }, + { + name: "CNCF weight media type", + mediaType: "application/vnd.cncf.model.weight.v1.raw", + expected: true, + }, + { + name: "Docker format", + mediaType: "application/vnd.docker.ai.model.config.v0.1+json", + expected: false, + }, + { + name: "Generic JSON", + mediaType: "application/json", + expected: false, + }, + { + name: "Empty string", + mediaType: "", + expected: false, + }, + { + name: "OCI image config", + mediaType: "application/vnd.oci.image.config.v1+json", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsModelPackMediaType(tt.mediaType) + if result != tt.expected { + t.Errorf("IsModelPackMediaType(%q) = %v, want %v", tt.mediaType, result, tt.expected) + } + }) + } +} + +func TestConvertToDockerConfig(t *testing.T) { + t.Run("full config conversion", func(t *testing.T) { + created := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + knowledgeCutoff := time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC) + reasoning := true + toolUsage := true + + mpConfig := Model{ + Descriptor: ModelDescriptor{ + CreatedAt: &created, + Authors: []string{"Author1", "Author2"}, + Family: "llama", + Name: "llama3-8b-instruct", + DocURL: "https://example.com/docs", + SourceURL: "https://example.com/source", + DatasetsURL: []string{"https://example.com/dataset1", "https://example.com/dataset2"}, + Version: "1.0.0", + Revision: "abc123", + Vendor: "TestVendor", + Licenses: []string{"MIT", "Apache-2.0"}, + Title: "Llama 3 8B Instruct", + Description: "A test model for testing", + }, + Config: ModelConfig{ + Architecture: "transformer", + Format: "gguf", + ParamSize: "8B", + Precision: "fp16", + Quantization: "Q4_K_M", + Capabilities: &ModelCapabilities{ + InputTypes: []string{"text"}, + OutputTypes: []string{"text"}, + KnowledgeCutoff: &knowledgeCutoff, + Reasoning: &reasoning, + ToolUsage: &toolUsage, + Languages: []string{"en", "zh"}, + }, + }, + ModelFS: ModelFS{ + Type: "layers", + DiffIDs: []digest.Digest{"sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1"}, + }, + } + + rawConfig, err := json.Marshal(mpConfig) + if err != nil { + t.Fatalf("Failed to marshal test config: %v", err) + } + + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed: %v", err) + } + + // Verify direct field mappings + if dockerConfig.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", dockerConfig.Config.Format, types.FormatGGUF) + } + if dockerConfig.Config.Architecture != "transformer" { + t.Errorf("Architecture = %q, want %q", dockerConfig.Config.Architecture, "transformer") + } + if dockerConfig.Config.Quantization != "Q4_K_M" { + t.Errorf("Quantization = %q, want %q", dockerConfig.Config.Quantization, "Q4_K_M") + } + if dockerConfig.Config.Parameters != "8B" { + t.Errorf("Parameters = %q, want %q", dockerConfig.Config.Parameters, "8B") + } + if dockerConfig.Config.Size != "0" { + t.Errorf("Size = %q, want %q", dockerConfig.Config.Size, "0") + } + + // Verify descriptor + if dockerConfig.Descriptor.Created == nil { + t.Error("Descriptor.Created should not be nil") + } else if !dockerConfig.Descriptor.Created.Equal(created) { + t.Errorf("Descriptor.Created = %v, want %v", dockerConfig.Descriptor.Created, created) + } + + // Verify RootFS + if dockerConfig.RootFS.Type != "layers" { + t.Errorf("RootFS.Type = %q, want %q", dockerConfig.RootFS.Type, "layers") + } + if len(dockerConfig.RootFS.DiffIDs) != 1 { + t.Errorf("RootFS.DiffIDs length = %d, want 1", len(dockerConfig.RootFS.DiffIDs)) + } + + // Verify extended metadata preserved + mp := dockerConfig.Config.ModelPack + if mp == nil { + t.Fatal("ModelPack metadata should not be nil") + } + if mp["vendor"] != "TestVendor" { + t.Errorf("ModelPack[vendor] = %q, want %q", mp["vendor"], "TestVendor") + } + if mp["precision"] != "fp16" { + t.Errorf("ModelPack[precision] = %q, want %q", mp["precision"], "fp16") + } + if mp["family"] != "llama" { + t.Errorf("ModelPack[family] = %q, want %q", mp["family"], "llama") + } + if mp["capabilities.reasoning"] != "true" { + t.Errorf("ModelPack[capabilities.reasoning] = %q, want %q", mp["capabilities.reasoning"], "true") + } + }) + + t.Run("minimal config", func(t *testing.T) { + mpConfig := Model{ + Config: ModelConfig{ + Format: "gguf", + }, + ModelFS: ModelFS{ + Type: "layers", + DiffIDs: []digest.Digest{"sha256:abc123"}, + }, + } + + rawConfig, _ := json.Marshal(mpConfig) + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed for minimal config: %v", err) + } + + if dockerConfig.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", dockerConfig.Config.Format, types.FormatGGUF) + } + if dockerConfig.Config.ModelPack != nil { + t.Errorf("ModelPack should be nil for minimal config, got %v", dockerConfig.Config.ModelPack) + } + }) + + t.Run("empty config", func(t *testing.T) { + mpConfig := Model{} + rawConfig, _ := json.Marshal(mpConfig) + + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed for empty config: %v", err) + } + + if dockerConfig.Config.Format != "" { + t.Errorf("Format should be empty, got %v", dockerConfig.Config.Format) + } + if dockerConfig.RootFS.Type != "layers" { + t.Errorf("RootFS.Type should default to 'layers', got %q", dockerConfig.RootFS.Type) + } + }) + + t.Run("invalid JSON", func(t *testing.T) { + _, err := ConvertToDockerConfig([]byte("invalid json")) + if err == nil { + t.Error("Expected error for invalid JSON, got nil") + } + }) + + t.Run("empty input", func(t *testing.T) { + _, err := ConvertToDockerConfig([]byte("")) + if err == nil { + t.Error("Expected error for empty input, got nil") + } + }) +} + +func TestConvertFormat(t *testing.T) { + tests := []struct { + input string + expected types.Format + }{ + {"gguf", types.FormatGGUF}, + {"GGUF", types.FormatGGUF}, + {"GgUf", types.FormatGGUF}, + {"safetensors", types.FormatSafetensors}, + {"SafeTensors", types.FormatSafetensors}, + {"SAFETENSORS", types.FormatSafetensors}, + {"onnx", types.Format("onnx")}, + {"pytorch", types.Format("pytorch")}, + {"", types.Format("")}, + {"unknown", types.Format("unknown")}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := convertFormat(tt.input) + if result != tt.expected { + t.Errorf("convertFormat(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestConvertDiffIDs(t *testing.T) { + t.Run("valid digests", func(t *testing.T) { + digests := []digest.Digest{ + "sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1", + "sha256:123456789012345678901234567890123456789012345678901234567890abcd", + } + + result := convertDiffIDs(digests) + if len(result) != 2 { + t.Errorf("Expected 2 hashes, got %d", len(result)) + } + }) + + t.Run("empty slice", func(t *testing.T) { + result := convertDiffIDs([]digest.Digest{}) + if result != nil { + t.Errorf("Expected nil for empty slice, got %v", result) + } + }) + + t.Run("nil slice", func(t *testing.T) { + result := convertDiffIDs(nil) + if result != nil { + t.Errorf("Expected nil for nil slice, got %v", result) + } + }) + + t.Run("invalid digest skipped", func(t *testing.T) { + digests := []digest.Digest{ + "sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1", + "invalid-digest-format", // This should be skipped + "sha256:123456789012345678901234567890123456789012345678901234567890abcd", + } + + result := convertDiffIDs(digests) + // Should only have 2 valid hashes, invalid one skipped + if len(result) != 2 { + t.Errorf("Expected 2 valid hashes (invalid skipped), got %d", len(result)) + } + }) +} + +func TestExtractExtendedMetadata(t *testing.T) { + t.Run("all fields", func(t *testing.T) { + knowledgeCutoff := time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC) + reasoning := true + + mp := &Model{ + Descriptor: ModelDescriptor{ + Authors: []string{"A", "B"}, + Family: "llama", + Name: "test", + DocURL: "https://doc", + SourceURL: "https://src", + DatasetsURL: []string{"https://d1", "https://d2"}, + Version: "1.0", + Revision: "rev1", + Vendor: "vendor1", + Licenses: []string{"MIT"}, + Title: "Title", + Description: "Desc", + }, + Config: ModelConfig{ + Precision: "fp16", + Capabilities: &ModelCapabilities{ + InputTypes: []string{"text"}, + OutputTypes: []string{"text", "image"}, + KnowledgeCutoff: &knowledgeCutoff, + Reasoning: &reasoning, + Languages: []string{"en"}, + }, + }, + } + + metadata := extractExtendedMetadata(mp) + + expectedFields := map[string]string{ + "authors": "A, B", + "family": "llama", + "name": "test", + "docURL": "https://doc", + "sourceURL": "https://src", + "datasetsURL": "https://d1, https://d2", + "version": "1.0", + "revision": "rev1", + "vendor": "vendor1", + "licenses": "MIT", + "title": "Title", + "description": "Desc", + "precision": "fp16", + "capabilities.inputTypes": "text", + "capabilities.outputTypes": "text, image", + "capabilities.knowledgeCutoff": "2024-06-01", + "capabilities.reasoning": "true", + "capabilities.languages": "en", + } + + for key, expected := range expectedFields { + if metadata[key] != expected { + t.Errorf("metadata[%q] = %q, want %q", key, metadata[key], expected) + } + } + }) + + t.Run("empty model returns nil", func(t *testing.T) { + mp := &Model{} + metadata := extractExtendedMetadata(mp) + if metadata != nil { + t.Errorf("Expected nil for empty model, got %v", metadata) + } + }) + + t.Run("false booleans not included", func(t *testing.T) { + falseVal := false + mp := &Model{ + Config: ModelConfig{ + Capabilities: &ModelCapabilities{ + Reasoning: &falseVal, + ToolUsage: &falseVal, + }, + }, + } + + metadata := extractExtendedMetadata(mp) + if metadata != nil { + t.Errorf("Expected nil when only false booleans, got %v", metadata) + } + }) +} + +func TestNormalizeRootFSType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"layers", "layers"}, + {"", "layers"}, + {"rootfs", "rootfs"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeRootFSType(tt.input) + if result != tt.expected { + t.Errorf("normalizeRootFSType(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestMapLayerMediaType 測試層媒體類型的轉換 +func TestMapLayerMediaType(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + // ModelPack GGUF 相關的媒體類型 + { + name: "ModelPack weight gguf v1", + input: "application/vnd.cncf.model.weight.v1.gguf", + expected: "application/vnd.docker.ai.gguf.v3", + }, + { + name: "ModelPack weight gguf no version", + input: "application/vnd.cncf.model.weight.gguf", + expected: "application/vnd.docker.ai.gguf.v3", + }, + // ModelPack safetensors 相關 + { + name: "ModelPack weight safetensors", + input: "application/vnd.cncf.model.weight.v1.safetensors", + expected: "application/vnd.docker.ai.safetensors", + }, + // Docker 格式直接 pass through + { + name: "Docker GGUF passthrough", + input: "application/vnd.docker.ai.gguf.v3", + expected: "application/vnd.docker.ai.gguf.v3", + }, + { + name: "Docker safetensors passthrough", + input: "application/vnd.docker.ai.safetensors", + expected: "application/vnd.docker.ai.safetensors", + }, + // 其他類型不轉換 + { + name: "generic octet-stream", + input: "application/octet-stream", + expected: "application/octet-stream", + }, + { + name: "ModelPack doc layer unchanged", + input: "application/vnd.cncf.model.doc.v1", + expected: "application/vnd.cncf.model.doc.v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapLayerMediaType(tt.input) + if got != tt.expected { + t.Errorf("MapLayerMediaType(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +// TestIsModelPackConfig 測試從 raw config bytes 判斷是否為 ModelPack 格式 +func TestIsModelPackConfig(t *testing.T) { + // 準備測試用的 ModelPack 格式 config(有 paramSize 欄位) + modelPackConfig := `{ + "descriptor": {"createdAt": "2025-01-15T10:30:00Z"}, + "config": {"paramSize": "8B", "format": "gguf"} + }` + + // Docker 格式 config(用 parameters 不是 paramSize) + dockerConfig := `{ + "config": {"parameters": "8B", "format": "gguf"}, + "descriptor": {"created": "2025-01-15T10:30:00Z"} + }` + + tests := []struct { + name string + input []byte + expected bool + }{ + { + name: "ModelPack config with paramSize", + input: []byte(modelPackConfig), + expected: true, + }, + { + name: "Docker config with parameters", + input: []byte(dockerConfig), + expected: false, + }, + { + name: "empty JSON object", + input: []byte("{}"), + expected: false, + }, + { + name: "invalid JSON", + input: []byte("not json"), + expected: false, + }, + { + name: "nil input", + input: nil, + expected: false, + }, + { + name: "empty input", + input: []byte(""), + expected: false, + }, + { + name: "config with createdAt field", + input: []byte(`{"descriptor": {"createdAt": "2025-01-01T00:00:00Z"}}`), + expected: true, + }, + { + name: "config with modelfs field", + input: []byte(`{"modelfs": {"type": "layers", "diffIds": []}}`), + expected: true, + }, + { + name: "false positive prevention - paramSize as value", + input: []byte(`{"config": {"description": "paramSize is 8B"}}`), + expected: false, + }, + { + name: "false positive prevention - createdAt as value", + input: []byte(`{"descriptor": {"note": "createdAt was yesterday"}}`), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsModelPackConfig(tt.input) + if got != tt.expected { + t.Errorf("IsModelPackConfig() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/pkg/distribution/modelpack/types.go b/pkg/distribution/modelpack/types.go new file mode 100644 index 000000000..3d60690f3 --- /dev/null +++ b/pkg/distribution/modelpack/types.go @@ -0,0 +1,139 @@ +// Package modelpack provides compatibility support for CNCF ModelPack format models. +// It enables docker/model-runner to pull and use models packaged in the ModelPack format +// by converting them to the Docker model-spec format on-the-fly. +// +// Note: JSON tags in this package use camelCase (e.g., "createdAt", "paramSize") to match +// the CNCF ModelPack spec, which differs from Docker model-spec's snake_case convention +// (e.g., "context_size"). +// +// See: https://github.com/modelpack/model-spec +package modelpack + +import ( + "time" + + "github.com/opencontainers/go-digest" +) + +const ( + // MediaTypePrefix is the prefix for all CNCF model config media types. + MediaTypePrefix = "application/vnd.cncf.model." + + // MediaTypeModelConfigV1 is the CNCF model config v1 media type. + MediaTypeModelConfigV1 = "application/vnd.cncf.model.config.v1+json" +) + +// Model represents the CNCF ModelPack config structure. +// It provides the `application/vnd.cncf.model.config.v1+json` mediatype when marshalled to JSON. +type Model struct { + // Descriptor provides metadata about the model provenance and identity. + Descriptor ModelDescriptor `json:"descriptor"` + + // ModelFS describes the layer content addresses. + ModelFS ModelFS `json:"modelfs"` + + // Config defines the execution parameters for the model. + Config ModelConfig `json:"config,omitempty"` +} + +// ModelDescriptor defines the general information of a model. +type ModelDescriptor struct { + // CreatedAt is the date and time on which the model was built. + CreatedAt *time.Time `json:"createdAt,omitempty"` + + // Authors contains the contact details of the people or organization responsible for the model. + Authors []string `json:"authors,omitempty"` + + // Family is the model family, such as llama3, gpt2, qwen2, etc. + Family string `json:"family,omitempty"` + + // Name is the model name, such as llama3-8b-instruct, gpt2-xl, etc. + Name string `json:"name,omitempty"` + + // DocURL is the URL to get documentation on the model. + DocURL string `json:"docURL,omitempty"` + + // SourceURL is the URL to get source code for building the model. + SourceURL string `json:"sourceURL,omitempty"` + + // DatasetsURL contains URLs referencing datasets that the model was trained upon. + DatasetsURL []string `json:"datasetsURL,omitempty"` + + // Version is the version of the packaged software. + Version string `json:"version,omitempty"` + + // Revision is the source control revision identifier for the packaged software. + Revision string `json:"revision,omitempty"` + + // Vendor is the name of the distributing entity, organization or individual. + Vendor string `json:"vendor,omitempty"` + + // Licenses contains the license(s) under which contained software is distributed + // as an SPDX License Expression. + Licenses []string `json:"licenses,omitempty"` + + // Title is the human-readable title of the model. + Title string `json:"title,omitempty"` + + // Description is the human-readable description of the software packaged in the model. + Description string `json:"description,omitempty"` +} + +// ModelConfig defines the execution parameters which should be used as a base +// when running a model using an inference engine. +type ModelConfig struct { + // Architecture is the model architecture, such as transformer, cnn, rnn, etc. + Architecture string `json:"architecture,omitempty"` + + // Format is the model format, such as gguf, safetensors, onnx, etc. + Format string `json:"format,omitempty"` + + // ParamSize is the size of the model parameters, such as "8b", "16b", "32b", etc. + ParamSize string `json:"paramSize,omitempty"` + + // Precision is the model precision, such as bf16, fp16, int8, mixed etc. + Precision string `json:"precision,omitempty"` + + // Quantization is the model quantization method, such as awq, gptq, etc. + Quantization string `json:"quantization,omitempty"` + + // Capabilities defines special capabilities that the model supports. + Capabilities *ModelCapabilities `json:"capabilities,omitempty"` +} + +// ModelCapabilities defines the special capabilities that the model supports. +type ModelCapabilities struct { + // InputTypes specifies what input modalities the model can process. + // Values can be: "text", "image", "audio", "video", "embedding", "other". + InputTypes []string `json:"inputTypes,omitempty"` + + // OutputTypes specifies what output modalities the model can produce. + // Values can be: "text", "image", "audio", "video", "embedding", "other". + OutputTypes []string `json:"outputTypes,omitempty"` + + // KnowledgeCutoff is the date of the datasets that the model was trained on. + KnowledgeCutoff *time.Time `json:"knowledgeCutoff,omitempty"` + + // Reasoning indicates whether the model can perform reasoning tasks. + Reasoning *bool `json:"reasoning,omitempty"` + + // ToolUsage indicates whether the model can use external tools. + ToolUsage *bool `json:"toolUsage,omitempty"` + + // Reward indicates whether the model is a reward model. + Reward *bool `json:"reward,omitempty"` + + // Languages indicates the languages that the model can speak. + // Encoded as ISO 639 two letter codes. For example, ["en", "fr", "zh"]. + Languages []string `json:"languages,omitempty"` +} + +// ModelFS describes the layer content addresses. +type ModelFS struct { + // Type is the type of the rootfs. MUST be set to "layers". + Type string `json:"type"` + + // DiffIDs is an array of layer content hashes (DiffIDs), + // in order from bottom-most to top-most. + DiffIDs []digest.Digest `json:"diffIds"` +} diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 62e45ebad..e88c956ab 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -68,6 +68,10 @@ type Config struct { GGUF map[string]string `json:"gguf,omitempty"` Safetensors map[string]string `json:"safetensors,omitempty"` ContextSize *int32 `json:"context_size,omitempty"` + // ModelPack stores extended metadata from CNCF ModelPack format models + // that don't have a direct mapping to Docker format fields. + // This includes fields like authors, family, name, docURL, licenses, etc. + ModelPack map[string]string `json:"modelpack,omitempty"` } // Descriptor provides metadata about the provenance of the model.