Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 58 additions & 44 deletions pkg/distribution/internal/bundle/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
"github.com/docker/model-runner/pkg/distribution/types"
)

// errFoundSafetensors is a sentinel error used to stop filepath.Walk early
// after finding the first safetensors file.
var errFoundSafetensors = fmt.Errorf("found safetensors file")
// errFoundModelFile is a sentinel error used to stop filepath.Walk early after
// finding the first matching model file.
var errFoundModelFile = fmt.Errorf("found model file")

// Parse returns the Bundle at the given rootDir
func Parse(rootDir string) (*Bundle, error) {
Expand All @@ -37,10 +37,14 @@ func Parse(rootDir string) (*Bundle, error) {
if err != nil {
return nil, err
}
ddufPath, err := findDDUFFile(modelDir)
if err != nil {
return nil, err
}

// Ensure at least one model weight format is present
if ggufPath == "" && safetensorsPath == "" {
return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)")
if ggufPath == "" && safetensorsPath == "" && ddufPath == "" {
return nil, fmt.Errorf("no supported model weights found (neither GGUF, safetensors, nor DDUF)")
}

mmprojPath, err := findMultiModalProjectorFile(modelDir)
Expand All @@ -62,6 +66,7 @@ func Parse(rootDir string) (*Bundle, error) {
mmprojPath: mmprojPath,
ggufFile: ggufPath,
safetensorsFile: safetensorsPath,
ddufFile: ddufPath,
runtimeConfig: cfg,
chatTemplatePath: templatePath,
}, nil
Expand Down Expand Up @@ -92,60 +97,69 @@ func parseRuntimeConfig(rootDir string) (types.ModelConfig, error) {
return &cfg, nil
}

func findGGUFFile(modelDir string) (string, error) {
ggufs, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.gguf"))
// findModelFile finds a supported model file by extension. It prefers a
// top-level match in modelDir and falls back to a recursive search when needed.
// Hidden files are ignored.
func findModelFile(modelDir, ext string) (string, error) {
pattern := filepath.Join(modelDir, "[^.]*"+ext)
paths, err := filepath.Glob(pattern)
if err != nil {
return "", fmt.Errorf("find gguf files: %w", err)
return "", fmt.Errorf("find %s files: %w", ext, err)
}
if len(ggufs) == 0 {
// GGUF files are optional - safetensors models won't have them
return "", nil
if len(paths) > 0 {
return filepath.Base(paths[0]), nil
}
return filepath.Base(ggufs[0]), nil
}

func findSafetensorsFile(modelDir string) (string, error) {
// First check top-level directory (most common case)
safetensors, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.safetensors"))
if err != nil {
return "", fmt.Errorf("find safetensors files: %w", err)
}
if len(safetensors) > 0 {
return filepath.Base(safetensors[0]), nil
}

// Search recursively for V0.2 models with nested directory structure
// (e.g., text_encoder/model.safetensors)
var firstFound string
walkErr := filepath.Walk(modelDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
// Propagate filesystem errors so callers can distinguish them from
// the case where no safetensors files are present.
return err
}
if info.IsDir() {
return nil
}
if filepath.Ext(path) == ".safetensors" && !strings.HasPrefix(info.Name(), ".") {
walkErr := filepath.Walk(
modelDir,
func(path string, info os.FileInfo, err error) error {
if err != nil {
// Propagate filesystem errors so callers can distinguish them
// from the case where no matching files are present.
return err
}
if info.IsDir() {
return nil
}
if filepath.Ext(path) != ext ||
strings.HasPrefix(info.Name(), ".") {
return nil
}

rel, relErr := filepath.Rel(modelDir, path)
if relErr != nil {
// Treat a bad relative path as a real error instead of silently
// ignoring it, so malformed bundles surface to the caller.
// Treat a bad relative path as a real error instead of
// silently ignoring it, so malformed bundles surface to the
// caller.
return relErr
}
firstFound = rel
return errFoundSafetensors // found one, stop walking
}
return nil
})
if walkErr != nil && !errors.Is(walkErr, errFoundSafetensors) {
return "", fmt.Errorf("walk for safetensors files: %w", walkErr)
return errFoundModelFile
},
)
if walkErr != nil && !errors.Is(walkErr, errFoundModelFile) {
return "", fmt.Errorf("walk for %s files: %w", ext, walkErr)
}

// Safetensors files are optional - GGUF models won't have them
return firstFound, nil
}

func findGGUFFile(modelDir string) (string, error) {
// GGUF files are optional.
return findModelFile(modelDir, ".gguf")
}

func findSafetensorsFile(modelDir string) (string, error) {
// Safetensors files are optional.
return findModelFile(modelDir, ".safetensors")
}

func findDDUFFile(modelDir string) (string, error) {
// DDUF files are optional.
return findModelFile(modelDir, ".dduf")
}

func findMultiModalProjectorFile(modelDir string) (string, error) {
mmprojPaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.mmproj"))
if err != nil {
Expand Down
166 changes: 165 additions & 1 deletion pkg/distribution/internal/bundle/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestParse_NoModelWeights(t *testing.T) {
t.Fatal("Expected error when parsing bundle without model weights, got nil")
}

expectedErrMsg := "no supported model weights found (neither GGUF nor safetensors)"
expectedErrMsg := "no supported model weights found (neither GGUF, safetensors, nor DDUF)"
if !strings.Contains(err.Error(), expectedErrMsg) {
t.Errorf("Expected error message to contain %q, got: %v", expectedErrMsg, err)
}
Expand Down Expand Up @@ -93,6 +93,63 @@ func TestParse_WithGGUF(t *testing.T) {
}
}

func TestParse_WithNestedGGUF(t *testing.T) {
// Create a temporary directory for the test bundle.
tempDir := t.TempDir()

// Create model subdirectory.
modelDir := filepath.Join(tempDir, ModelSubdir)
if err := os.MkdirAll(modelDir, 0755); err != nil {
t.Fatalf("Failed to create model directory: %v", err)
}

// Create nested directory structure.
weightsDir := filepath.Join(modelDir, "nested", "weights")
if err := os.MkdirAll(weightsDir, 0755); err != nil {
t.Fatalf("Failed to create nested weights directory: %v", err)
}

// Create a GGUF file in the nested directory.
nestedGGUFPath := filepath.Join(weightsDir, "model.gguf")
if err := os.WriteFile(nestedGGUFPath, []byte("dummy nested gguf"), 0644); err != nil {
t.Fatalf("Failed to create nested GGUF file: %v", err)
}

// Create a valid config.json at bundle root.
cfg := types.Config{
Format: types.FormatGGUF,
}
configPath := filepath.Join(tempDir, "config.json")
f, err := os.Create(configPath)
if err != nil {
t.Fatalf("Failed to create config.json: %v", err)
}
if err := json.NewEncoder(f).Encode(cfg); err != nil {
f.Close()
t.Fatalf("Failed to encode config: %v", err)
}
f.Close()

// Parse the bundle and ensure GGUF discovery falls back to recursion.
bundle, err := Parse(tempDir)
if err != nil {
t.Fatalf("Expected successful parse with nested GGUF, got: %v", err)
}

expectedPath := filepath.Join("nested", "weights", "model.gguf")
if bundle.ggufFile != expectedPath {
t.Errorf("Expected ggufFile to be %q, got: %s", expectedPath, bundle.ggufFile)
}

fullPath := bundle.GGUFPath()
if fullPath == "" {
t.Error("Expected GGUFPath() to return a non-empty path")
}
if !strings.HasSuffix(fullPath, expectedPath) {
t.Errorf("Expected GGUFPath() to end with %q, got: %s", expectedPath, fullPath)
}
}

func TestParse_WithSafetensors(t *testing.T) {
// Create a temporary directory for the test bundle
tempDir := t.TempDir()
Expand Down Expand Up @@ -139,6 +196,56 @@ func TestParse_WithSafetensors(t *testing.T) {
}
}

func TestParse_WithDDUF(t *testing.T) {
// Create a temporary directory for the test bundle.
tempDir := t.TempDir()

// Create model subdirectory.
modelDir := filepath.Join(tempDir, ModelSubdir)
if err := os.MkdirAll(modelDir, 0755); err != nil {
t.Fatalf("Failed to create model directory: %v", err)
}

// Create a dummy DDUF file.
ddufPath := filepath.Join(modelDir, "model.dduf")
if err := os.WriteFile(ddufPath, []byte("dummy dduf content"), 0644); err != nil {
t.Fatalf("Failed to create DDUF file: %v", err)
}

// Create a valid config.json at bundle root.
cfg := types.Config{
Format: types.FormatDDUF,
}
configPath := filepath.Join(tempDir, "config.json")
f, err := os.Create(configPath)
if err != nil {
t.Fatalf("Failed to create config.json: %v", err)
}
if err := json.NewEncoder(f).Encode(cfg); err != nil {
f.Close()
t.Fatalf("Failed to encode config: %v", err)
}
f.Close()

// Parse the bundle and ensure DDUF-only bundles are accepted.
bundle, err := Parse(tempDir)
if err != nil {
t.Fatalf("Expected successful parse with DDUF file, got: %v", err)
}

if bundle.ddufFile != "model.dduf" {
t.Errorf("Expected ddufFile to be %q, got: %s", "model.dduf", bundle.ddufFile)
}

fullPath := bundle.DDUFPath()
if fullPath == "" {
t.Error("Expected DDUFPath() to return a non-empty path")
}
if !strings.HasSuffix(fullPath, "model.dduf") {
t.Errorf("Expected DDUFPath() to end with %q, got: %s", "model.dduf", fullPath)
}
}

func TestParse_WithNestedSafetensors(t *testing.T) {
// Create a temporary directory for the test bundle
tempDir := t.TempDir()
Expand Down Expand Up @@ -198,6 +305,63 @@ func TestParse_WithNestedSafetensors(t *testing.T) {
}
}

func TestParse_WithNestedDDUF(t *testing.T) {
// Create a temporary directory for the test bundle.
tempDir := t.TempDir()

// Create model subdirectory.
modelDir := filepath.Join(tempDir, ModelSubdir)
if err := os.MkdirAll(modelDir, 0755); err != nil {
t.Fatalf("Failed to create model directory: %v", err)
}

// Create nested directory structure.
diffusersDir := filepath.Join(modelDir, "sanitized", "diffusers")
if err := os.MkdirAll(diffusersDir, 0755); err != nil {
t.Fatalf("Failed to create nested diffusers directory: %v", err)
}

// Create a DDUF file in the nested directory.
nestedDDUFPath := filepath.Join(diffusersDir, "model.dduf")
if err := os.WriteFile(nestedDDUFPath, []byte("dummy nested dduf"), 0644); err != nil {
t.Fatalf("Failed to create nested DDUF file: %v", err)
}

// Create a valid config.json at bundle root.
cfg := types.Config{
Format: types.FormatDDUF,
}
configPath := filepath.Join(tempDir, "config.json")
f, err := os.Create(configPath)
if err != nil {
t.Fatalf("Failed to create config.json: %v", err)
}
if err := json.NewEncoder(f).Encode(cfg); err != nil {
f.Close()
t.Fatalf("Failed to encode config: %v", err)
}
f.Close()

// Parse the bundle and ensure DDUF discovery falls back to recursion.
bundle, err := Parse(tempDir)
if err != nil {
t.Fatalf("Expected successful parse with nested DDUF, got: %v", err)
}

expectedPath := filepath.Join("sanitized", "diffusers", "model.dduf")
if bundle.ddufFile != expectedPath {
t.Errorf("Expected ddufFile to be %q, got: %s", expectedPath, bundle.ddufFile)
}

fullPath := bundle.DDUFPath()
if fullPath == "" {
t.Error("Expected DDUFPath() to return a non-empty path")
}
if !strings.HasSuffix(fullPath, expectedPath) {
t.Errorf("Expected DDUFPath() to end with %q, got: %s", expectedPath, fullPath)
}
}

func TestParse_WithBothFormats(t *testing.T) {
// Create a temporary directory for the test bundle
tempDir := t.TempDir()
Expand Down
Loading
Loading