From 4dc8ce54a967c045eafe4dca29535eb11a6ae968 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 24 Dec 2025 13:46:37 +0000 Subject: [PATCH] Add ability to pull vllm-compatible hf models This commit introduces native HuggingFace model support by adding a new HuggingFace client implementation that can download safetensors files directly from HuggingFace Hub repositories. The changes include: A new HuggingFace client with authentication, file listing, and download capabilities. The client handles LFS files, error responses, and rate limiting appropriately. A downloader component that manages parallel file downloads with progress reporting and temporary file storage. It includes progress tracking and concurrent download limiting. Model building functionality that downloads files from HuggingFace repositories and constructs OCI model artifacts using the existing builder framework. Repository utilities for file classification, filtering, and size calculations to identify safetensors and config files needed for model construction. Integration with the existing pull mechanism to detect HuggingFace references and attempt native pulling when no OCI manifest is found. This preserves existing OCI functionality while adding fallback support for raw HuggingFace repositories. Signed-off-by: Eric Curtin --- pkg/distribution/distribution/client.go | 129 +++++++++- .../distribution/normalize_test.go | 116 ++++++++- pkg/distribution/huggingface/client.go | 192 +++++++++++++++ pkg/distribution/huggingface/client_test.go | 157 +++++++++++++ pkg/distribution/huggingface/downloader.go | 221 ++++++++++++++++++ pkg/distribution/huggingface/model.go | 159 +++++++++++++ pkg/distribution/huggingface/repository.go | 119 ++++++++++ .../huggingface/repository_test.go | 138 +++++++++++ pkg/distribution/packaging/safetensors.go | 12 +- 9 files changed, 1235 insertions(+), 8 deletions(-) create mode 100644 pkg/distribution/huggingface/client.go create mode 100644 pkg/distribution/huggingface/client_test.go create mode 100644 pkg/distribution/huggingface/downloader.go create mode 100644 pkg/distribution/huggingface/model.go create mode 100644 pkg/distribution/huggingface/repository.go create mode 100644 pkg/distribution/huggingface/repository_test.go diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 9382e87ac..4181fee54 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -6,9 +6,11 @@ import ( "fmt" "io" "net/http" + "os" "slices" "strings" + "github.com/docker/model-runner/pkg/distribution/huggingface" "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/store" "github.com/docker/model-runner/pkg/distribution/registry" @@ -162,10 +164,11 @@ func (c *Client) normalizeModelName(model string) string { return model } - // Normalize HuggingFace model names (lowercase path) + // Normalize HuggingFace model names if strings.HasPrefix(model, "hf.co/") { // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. - model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) + // Note: We preserve case since HuggingFace's native API is case-sensitive + model = "huggingface.co" + strings.TrimPrefix(model, "hf.co") } // Check if model contains a registry (domain with dot before first slash) @@ -267,15 +270,22 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // Use the client's registry, or create a temporary one if bearer token is provided registryClient := c.registry + var token string if len(bearerToken) > 0 && bearerToken[0] != "" { + token = bearerToken[0] // Create a temporary registry client with bearer token authentication - auth := &authn.Bearer{Token: bearerToken[0]} + auth := &authn.Bearer{Token: token} registryClient = registry.FromClient(c.registry, registry.WithAuth(auth)) } // First, fetch the remote model to get the manifest remoteModel, err := registryClient.Model(ctx, reference) if err != nil { + // Check if this is a HuggingFace reference and the error indicates no OCI manifest + if isHuggingFaceReference(reference) && isNotOCIError(err) { + c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull") + return c.pullNativeHuggingFace(ctx, reference, progressWriter, token) + } return fmt.Errorf("reading model from registry: %w", err) } @@ -637,3 +647,116 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, return nil } + +// isHuggingFaceReference checks if a reference is a HuggingFace model reference +func isHuggingFaceReference(reference string) bool { + return strings.HasPrefix(reference, "huggingface.co/") +} + +// isNotOCIError checks if the error indicates the model is not OCI-formatted +// This happens when the HuggingFace repository doesn't have an OCI manifest +func isNotOCIError(err error) bool { + if err == nil { + return false + } + + // Check for registry errors indicating no manifest + var regErr *registry.Error + if errors.As(err, ®Err) { + if regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN" { + return true + } + } + + // Check for invalid reference error (e.g., uppercase letters not allowed in OCI) + // This happens with HuggingFace model names like "Qwen/Qwen3-0.6B" + if errors.Is(err, registry.ErrInvalidReference) { + return true + } + + // Also check error message for common patterns + errStr := err.Error() + return strings.Contains(errStr, "MANIFEST_UNKNOWN") || + strings.Contains(errStr, "NAME_UNKNOWN") || + strings.Contains(errStr, "manifest unknown") || + // HuggingFace returns this error for non-GGUF repositories + strings.Contains(errStr, "Repository is not GGUF") || + strings.Contains(errStr, "not compatible with llama.cpp") +} + +// parseHFReference extracts repo and revision from a normalized HF reference +// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision") +// e.g., "huggingface.co/org/model:latest" -> ("org/model", "main") +func parseHFReference(reference string) (repo, revision string) { + // Remove registry prefix + ref := strings.TrimPrefix(reference, "huggingface.co/") + + // Split by colon to get tag + parts := strings.SplitN(ref, ":", 2) + repo = parts[0] + + revision = "main" + if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" { + revision = parts[1] + } + + return repo, revision +} + +// pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format) +// This is used when the model is stored as raw files (safetensors) on HuggingFace Hub +func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error { + repo, revision := parseHFReference(reference) + c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision)) + + // Create HuggingFace client + hfOpts := []huggingface.ClientOption{ + huggingface.WithUserAgent(registry.DefaultUserAgent), + } + if token != "" { + hfOpts = append(hfOpts, huggingface.WithToken(token)) + } + hfClient := huggingface.NewClient(hfOpts...) + + // Create temp directory for downloads + tempDir, err := os.MkdirTemp("", "hf-model-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + // Build model from HuggingFace repository + model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter) + if err != nil { + // Convert HuggingFace errors to registry errors for consistent handling + var authErr *huggingface.AuthError + var notFoundErr *huggingface.NotFoundError + if errors.As(err, &authErr) { + return registry.ErrUnauthorized + } + if errors.As(err, ¬FoundErr) { + return registry.ErrModelNotFound + } + if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { + c.log.Warnf("Failed to write error message: %v", writeErr) + } + return fmt.Errorf("build model from HuggingFace: %w", err) + } + + // Write model to store + // Lowercase the reference for storage since OCI tags don't allow uppercase + storageTag := strings.ToLower(reference) + c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag)) + if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil { + if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { + c.log.Warnf("Failed to write error message: %v", writeErr) + } + return fmt.Errorf("writing model to store: %w", err) + } + + if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { + c.log.Warnf("Failed to write success message: %v", err) + } + + return nil +} diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index b970bcf41..e644f5cf4 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -2,12 +2,14 @@ package distribution import ( "context" + "errors" "io" "path/filepath" "strings" "testing" "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/sirupsen/logrus" ) @@ -66,7 +68,7 @@ func TestNormalizeModelName(t *testing.T) { expected: "registry.example.com/myorg/model:v1", }, - // HuggingFace cases + // HuggingFace cases (case is preserved for native HF API compatibility) { name: "huggingface short form lowercase", input: "hf.co/model", @@ -75,12 +77,12 @@ func TestNormalizeModelName(t *testing.T) { { name: "huggingface short form uppercase", input: "hf.co/Model", - expected: "huggingface.co/model:latest", + expected: "huggingface.co/Model:latest", }, { name: "huggingface short form with org", input: "hf.co/MyOrg/MyModel", - expected: "huggingface.co/myorg/mymodel:latest", + expected: "huggingface.co/MyOrg/MyModel:latest", }, { name: "huggingface with tag", @@ -355,6 +357,114 @@ func createTestClient(t *testing.T) (*Client, func()) { return client, cleanup } +func TestIsHuggingFaceReference(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"huggingface.co prefix", "huggingface.co/org/model:latest", true}, + {"huggingface.co without tag", "huggingface.co/org/model", true}, + {"not huggingface", "registry.example.com/model:latest", false}, + {"docker hub", "ai/gemma3:latest", false}, + {"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHuggingFaceReference(tt.input) + if result != tt.expected { + t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestParseHFReference(t *testing.T) { + tests := []struct { + name string + input string + expectedRepo string + expectedRev string + }{ + { + name: "basic with latest tag", + input: "huggingface.co/org/model:latest", + expectedRepo: "org/model", + expectedRev: "main", // latest maps to main + }, + { + name: "with explicit revision", + input: "huggingface.co/org/model:v1.0", + expectedRepo: "org/model", + expectedRev: "v1.0", + }, + { + name: "without tag", + input: "huggingface.co/org/model", + expectedRepo: "org/model", + expectedRev: "main", + }, + { + name: "with commit hash as tag", + input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123", + expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct", + expectedRev: "abc123", + }, + { + name: "single name (no org)", + input: "huggingface.co/model:latest", + expectedRepo: "model", + expectedRev: "main", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, rev := parseHFReference(tt.input) + if repo != tt.expectedRepo { + t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo) + } + if rev != tt.expectedRev { + t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev) + } + }) + } +} + +func TestIsNotOCIError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {"nil error", nil, false}, + {"generic error", errors.New("some error"), false}, + {"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true}, + {"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true}, + {"manifest unknown lowercase", errors.New("manifest unknown"), true}, + {"unrelated error", errors.New("network timeout"), false}, + {"HuggingFace not GGUF error", errors.New("Repository is not GGUF or is not compatible with llama.cpp"), true}, + {"HuggingFace llama.cpp incompatible", errors.New("not compatible with llama.cpp"), true}, + // registry.Error typed error cases + {"registry error MANIFEST_UNKNOWN", ®istry.Error{Code: "MANIFEST_UNKNOWN"}, true}, + {"registry error NAME_UNKNOWN", ®istry.Error{Code: "NAME_UNKNOWN"}, true}, + {"registry error other code", ®istry.Error{Code: "UNAUTHORIZED"}, false}, + // ErrInvalidReference case + {"invalid reference error", registry.ErrInvalidReference, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNotOCIError(tt.err) + if result != tt.expected { + t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +} + // Helper function to load a test model and return its ID func loadTestModel(t *testing.T, client *Client, ggufPath string) string { t.Helper() diff --git a/pkg/distribution/huggingface/client.go b/pkg/distribution/huggingface/client.go new file mode 100644 index 000000000..9dc5a64e9 --- /dev/null +++ b/pkg/distribution/huggingface/client.go @@ -0,0 +1,192 @@ +package huggingface + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +const ( + defaultBaseURL = "https://huggingface.co" + defaultUserAgent = "model-distribution" +) + +// Client handles HuggingFace Hub API interactions +type Client struct { + httpClient *http.Client + userAgent string + token string + baseURL string +} + +// ClientOption configures a Client +type ClientOption func(*Client) + +// WithToken sets the HuggingFace API token for authentication +func WithToken(token string) ClientOption { + return func(c *Client) { + if token != "" { + c.token = token + } + } +} + +// WithTransport sets the HTTP transport for the client +func WithTransport(transport http.RoundTripper) ClientOption { + return func(c *Client) { + if transport != nil { + c.httpClient.Transport = transport + } + } +} + +// WithUserAgent sets the User-Agent header for requests +func WithUserAgent(userAgent string) ClientOption { + return func(c *Client) { + if userAgent != "" { + c.userAgent = userAgent + } + } +} + +// WithBaseURL sets a custom base URL (useful for testing) +func WithBaseURL(baseURL string) ClientOption { + return func(c *Client) { + if baseURL != "" { + c.baseURL = strings.TrimSuffix(baseURL, "/") + } + } +} + +// NewClient creates a new HuggingFace Hub API client +func NewClient(opts ...ClientOption) *Client { + c := &Client{ + httpClient: &http.Client{}, + userAgent: defaultUserAgent, + baseURL: defaultBaseURL, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// ListFiles returns all files in a repository at a given revision +func (c *Client) ListFiles(ctx context.Context, repo, revision string) ([]RepoFile, error) { + if revision == "" { + revision = "main" + } + + // HuggingFace API endpoint for listing files + url := fmt.Sprintf("%s/api/models/%s/tree/%s", c.baseURL, repo, revision) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + c.setHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("list files: %w", err) + } + defer resp.Body.Close() + + if err := c.checkResponse(resp, repo); err != nil { + return nil, err + } + + var files []RepoFile + if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return files, nil +} + +// DownloadFile streams a file from the repository +// Returns the reader, content length (-1 if unknown), and any error +func (c *Client) DownloadFile(ctx context.Context, repo, revision, filename string) (io.ReadCloser, int64, error) { + if revision == "" { + revision = "main" + } + + // HuggingFace file download endpoint (handles LFS redirects automatically) + url := fmt.Sprintf("%s/%s/resolve/%s/%s", c.baseURL, repo, revision, filename) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, 0, fmt.Errorf("create request: %w", err) + } + + c.setHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("download file: %w", err) + } + + if err := c.checkResponse(resp, repo); err != nil { + resp.Body.Close() + return nil, 0, err + } + + return resp.Body, resp.ContentLength, nil +} + +// setHeaders sets common headers for HuggingFace API requests +func (c *Client) setHeaders(req *http.Request) { + req.Header.Set("User-Agent", c.userAgent) + if c.token != "" { + req.Header.Set("Authorization", "Bearer "+c.token) + } +} + +// checkResponse checks the HTTP response for errors +func (c *Client) checkResponse(resp *http.Response, repo string) error { + switch resp.StatusCode { + case http.StatusOK: + return nil + case http.StatusUnauthorized, http.StatusForbidden: + return &AuthError{Repo: repo, StatusCode: resp.StatusCode} + case http.StatusNotFound: + return &NotFoundError{Repo: repo} + case http.StatusTooManyRequests: + return &RateLimitError{Repo: repo} + default: + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } +} + +// AuthError indicates authentication failure +type AuthError struct { + Repo string + StatusCode int +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("authentication required for repository %q (status %d)", e.Repo, e.StatusCode) +} + +// NotFoundError indicates the repository or file was not found +type NotFoundError struct { + Repo string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("repository %q not found", e.Repo) +} + +// RateLimitError indicates rate limiting +type RateLimitError struct { + Repo string +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limited while accessing repository %q", e.Repo) +} diff --git a/pkg/distribution/huggingface/client_test.go b/pkg/distribution/huggingface/client_test.go new file mode 100644 index 000000000..e33dff7a0 --- /dev/null +++ b/pkg/distribution/huggingface/client_test.go @@ -0,0 +1,157 @@ +package huggingface + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestClientListFiles(t *testing.T) { + // Mock HuggingFace API response + mockFiles := []RepoFile{ + {Type: "file", Path: "model.safetensors", Size: 1000}, + {Type: "file", Path: "config.json", Size: 100}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/models/test-org/test-model/tree/main" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockFiles) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + files, err := client.ListFiles(context.Background(), "test-org/test-model", "main") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } + + if len(files) != 2 { + t.Errorf("Expected 2 files, got %d", len(files)) + } +} + +func TestClientListFilesDefaultRevision(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the default revision is "main" + if !strings.Contains(r.URL.Path, "/tree/main") { + t.Errorf("Expected /tree/main in path, got %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]RepoFile{}) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + _, err := client.ListFiles(context.Background(), "test/model", "") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } +} + +func TestClientDownloadFile(t *testing.T) { + expectedContent := "test file content" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test-org/test-model/resolve/main/test.txt" { + w.Header().Set("Content-Length", "17") + w.Write([]byte(expectedContent)) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + reader, size, err := client.DownloadFile(context.Background(), "test-org/test-model", "main", "test.txt") + if err != nil { + t.Fatalf("DownloadFile failed: %v", err) + } + defer reader.Close() + + if size != 17 { + t.Errorf("Expected size 17, got %d", size) + } + + content, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) + } + + if string(content) != expectedContent { + t.Errorf("Expected content %q, got %q", expectedContent, string(content)) + } +} + +func TestClientAuthError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + _, err := client.ListFiles(context.Background(), "private/model", "main") + if err == nil { + t.Fatal("Expected error, got nil") + } + + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Errorf("Expected AuthError, got %T", err) + } +} + +func TestClientNotFoundError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + _, err := client.ListFiles(context.Background(), "nonexistent/model", "main") + if err == nil { + t.Fatal("Expected error, got nil") + } + + var notFoundErr *NotFoundError + if !errors.As(err, ¬FoundErr) { + t.Errorf("Expected NotFoundError, got %T", err) + } +} + +func TestClientWithToken(t *testing.T) { + var receivedToken string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]RepoFile{}) + })) + defer server.Close() + + client := NewClient( + WithBaseURL(server.URL), + WithToken("test-token"), + ) + + _, err := client.ListFiles(context.Background(), "test/model", "main") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } + + if receivedToken != "Bearer test-token" { + t.Errorf("Expected 'Bearer test-token', got %q", receivedToken) + } +} diff --git a/pkg/distribution/huggingface/downloader.go b/pkg/distribution/huggingface/downloader.go new file mode 100644 index 000000000..0bb537319 --- /dev/null +++ b/pkg/distribution/huggingface/downloader.go @@ -0,0 +1,221 @@ +package huggingface + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "sync" + + "github.com/docker/model-runner/pkg/distribution/internal/progress" +) + +// Downloader manages file downloads from HuggingFace repositories +type Downloader struct { + client *Client + repo string + revision string + tempDir string +} + +// NewDownloader creates a new downloader for a HuggingFace repository +func NewDownloader(client *Client, repo, revision, tempDir string) *Downloader { + if revision == "" { + revision = "main" + } + return &Downloader{ + client: client, + repo: repo, + revision: revision, + tempDir: tempDir, + } +} + +// DownloadResult contains the result of downloading files +type DownloadResult struct { + // LocalPaths maps original repo paths to local file paths + LocalPaths map[string]string + // TotalBytes is the total number of bytes downloaded + TotalBytes int64 +} + +// syncWriter wraps an io.Writer with a mutex for thread-safe concurrent writes +type syncWriter struct { + mu sync.Mutex + w io.Writer +} + +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + +// fileIDFromPath generates a unique ID for a file based on its path +// Returns a sha256: prefixed hash to match layer ID format +func fileIDFromPath(path string) string { + hash := sha256.Sum256([]byte(path)) + return "sha256:" + hex.EncodeToString(hash[:]) +} + +// DownloadAll downloads all specified files with progress reporting +// Files are downloaded in parallel with per-file progress updates written to progressWriter +func (d *Downloader) DownloadAll(ctx context.Context, files []RepoFile, progressWriter io.Writer) (*DownloadResult, error) { + if len(files) == 0 { + return &DownloadResult{LocalPaths: make(map[string]string)}, nil + } + + totalSize := TotalSize(files) + + // Create result map (thread-safe access) + var mu sync.Mutex + localPaths := make(map[string]string, len(files)) + + // Create thread-safe writer for concurrent progress reporting + var safeWriter io.Writer + if progressWriter != nil { + safeWriter = &syncWriter{w: progressWriter} + } + + // Download files in parallel (limit concurrency to avoid overwhelming) + const maxConcurrent = 4 + sem := make(chan struct{}, maxConcurrent) + var wg sync.WaitGroup + errChan := make(chan error, len(files)) + + for _, file := range files { + wg.Add(1) + go func(f RepoFile) { + defer wg.Done() + + // Acquire semaphore + select { + case sem <- struct{}{}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + defer func() { <-sem }() + + localPath, err := d.downloadFileWithProgress(ctx, f, uint64(totalSize), safeWriter) + if err != nil { + errChan <- fmt.Errorf("download %s: %w", f.Path, err) + return + } + + mu.Lock() + localPaths[f.Path] = localPath + mu.Unlock() + }(file) + } + + // Wait for all downloads to complete + wg.Wait() + close(errChan) + + // Collect any errors + var errs []error + for err := range errChan { + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return nil, fmt.Errorf("download errors: %v", errs) + } + + // Calculate total downloaded + var totalDownloaded int64 + for _, f := range files { + totalDownloaded += f.ActualSize() + } + + return &DownloadResult{ + LocalPaths: localPaths, + TotalBytes: totalDownloaded, + }, nil +} + +// downloadFileWithProgress downloads a single file with progress reporting +func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile, totalImageSize uint64, progressWriter io.Writer) (string, error) { + // Create local file path (preserve directory structure) + localPath := filepath.Join(d.tempDir, file.Path) + + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil { + return "", fmt.Errorf("create directory: %w", err) + } + + // Download from HuggingFace + reader, _, err := d.client.DownloadFile(ctx, d.repo, d.revision, file.Path) + if err != nil { + return "", err + } + defer reader.Close() + + // Create local file + f, err := os.Create(localPath) + if err != nil { + return "", fmt.Errorf("create file: %w", err) + } + defer f.Close() + + // Generate unique ID for this file (for progress tracking) + fileID := fileIDFromPath(file.Path) + fileSize := uint64(file.ActualSize()) + + // Copy with progress tracking + pr := &progressReader{ + reader: reader, + progressWriter: progressWriter, + totalImageSize: totalImageSize, + fileSize: fileSize, + fileID: fileID, + } + + if _, err := io.Copy(f, pr); err != nil { + os.Remove(localPath) // Clean up on error + return "", fmt.Errorf("write file: %w", err) + } + + // Write final progress for this file (100% complete) + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID) + } + + return localPath, nil +} + +// progressReader wraps a reader and reports per-file progress +type progressReader struct { + reader io.Reader + progressWriter io.Writer + totalImageSize uint64 + fileSize uint64 + fileID string + bytesRead uint64 + lastReported uint64 +} + +func (pr *progressReader) Read(p []byte) (n int, err error) { + n, err = pr.reader.Read(p) + if n > 0 { + pr.bytesRead += uint64(n) + + // Report progress periodically (every 1MB or when complete) + if pr.progressWriter != nil && (pr.bytesRead-pr.lastReported >= progress.MinBytesForUpdate || pr.bytesRead == pr.fileSize) { + _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID) + pr.lastReported = pr.bytesRead + } + } + return n, err +} + +// DownloadSingleFile downloads a single file and returns its local path +func (d *Downloader) DownloadSingleFile(ctx context.Context, file RepoFile) (string, error) { + return d.downloadFileWithProgress(ctx, file, uint64(file.ActualSize()), nil) +} diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go new file mode 100644 index 000000000..fd072b836 --- /dev/null +++ b/pkg/distribution/huggingface/model.go @@ -0,0 +1,159 @@ +package huggingface + +import ( + "context" + "fmt" + "io" + "log" + "path/filepath" + "sort" + "strings" + + "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/internal/progress" + "github.com/docker/model-runner/pkg/distribution/packaging" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// BuildModel downloads files from a HuggingFace repository and constructs an OCI model artifact +// This is the main entry point for pulling native HuggingFace models +func BuildModel(ctx context.Context, client *Client, repo, revision string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { + // Step 1: List files in the repository + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "") + } + + files, err := client.ListFiles(ctx, repo, revision) + if err != nil { + return nil, fmt.Errorf("list files: %w", err) + } + + // Step 2: Filter to model files (safetensors + configs) + safetensorsFiles, configFiles := FilterModelFiles(files) + + if len(safetensorsFiles) == 0 { + return nil, fmt.Errorf("no safetensors files found in repository %s", repo) + } + + // Combine all files to download + allFiles := append(safetensorsFiles, configFiles...) + + if progressWriter != nil { + totalSize := TotalSize(allFiles) + msg := fmt.Sprintf("Found %d files (%.2f MB total)", + len(allFiles), float64(totalSize)/1024/1024) + _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "") + } + + // Step 3: Download all files + downloader := NewDownloader(client, repo, revision, tempDir) + result, err := downloader.DownloadAll(ctx, allFiles, progressWriter) + if err != nil { + return nil, fmt.Errorf("download files: %w", err) + } + + // Step 4: Build the model artifact + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") + } + + model, err := buildModelFromFiles(result.LocalPaths, safetensorsFiles, configFiles, tempDir) + if err != nil { + return nil, fmt.Errorf("build model: %w", err) + } + + return model, nil +} + +// buildModelFromFiles constructs an OCI model artifact from downloaded files +func buildModelFromFiles(localPaths map[string]string, safetensorsFiles, configFiles []RepoFile, tempDir string) (types.ModelArtifact, error) { + // Collect safetensors paths (sorted for reproducibility) + var safetensorsPaths []string + for _, f := range safetensorsFiles { + localPath, ok := localPaths[f.Path] + if !ok { + return nil, fmt.Errorf("missing local path for %s", f.Path) + } + safetensorsPaths = append(safetensorsPaths, localPath) + } + sort.Strings(safetensorsPaths) + + // Create builder from safetensors files + b, err := builder.FromSafetensors(safetensorsPaths) + if err != nil { + return nil, fmt.Errorf("create builder: %w", err) + } + + // Create config archive if we have config files + if len(configFiles) > 0 { + configArchive, err := createConfigArchive(localPaths, configFiles, tempDir) + if err != nil { + return nil, fmt.Errorf("create config archive: %w", err) + } + // Note: configArchive is cleaned up by the caller's tempDir cleanup + + if configArchive != "" { + b, err = b.WithConfigArchive(configArchive) + if err != nil { + return nil, fmt.Errorf("add config archive: %w", err) + } + } + } + + // Check for chat template and add it + for _, f := range configFiles { + if isChatTemplate(f.Path) { + localPath := localPaths[f.Path] + b, err = b.WithChatTemplateFile(localPath) + if err != nil { + // Non-fatal: log warning but continue to try other potential templates + log.Printf("Warning: failed to add chat template from %s: %v", f.Path, err) + continue + } + break // Only add one chat template + } + } + + return b.Model(), nil +} + +// createConfigArchive creates a tar archive of config files in the specified tempDir +func createConfigArchive(localPaths map[string]string, configFiles []RepoFile, tempDir string) (string, error) { + // Collect config file paths (excluding chat templates which are added separately) + var configPaths []string + for _, f := range configFiles { + if isChatTemplate(f.Path) { + continue // Chat templates are added as separate layers + } + localPath, ok := localPaths[f.Path] + if !ok { + return "", fmt.Errorf("internal error: missing local path for downloaded config file %s", f.Path) + } + configPaths = append(configPaths, localPath) + } + + if len(configPaths) == 0 { + // No config files to archive + return "", nil + } + + // Sort for reproducibility + sort.Strings(configPaths) + + // Create the archive in our tempDir so it gets cleaned up with everything else + archivePath, err := packaging.CreateConfigArchiveInDir(configPaths, tempDir) + if err != nil { + return "", fmt.Errorf("create config archive: %w", err) + } + + return archivePath, nil +} + +// isChatTemplate checks if a file is a chat template +func isChatTemplate(path string) bool { + filename := filepath.Base(path) + lower := strings.ToLower(filename) + return strings.HasSuffix(lower, ".jinja") || + strings.Contains(lower, "chat_template") || + filename == "tokenizer_config.json" // Often contains chat_template +} diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go new file mode 100644 index 000000000..deb41c1b3 --- /dev/null +++ b/pkg/distribution/huggingface/repository.go @@ -0,0 +1,119 @@ +package huggingface + +import ( + "path" + "strings" +) + +// RepoFile represents a file in a HuggingFace repository +type RepoFile struct { + Type string `json:"type"` // "file" or "directory" + Path string `json:"path"` // Relative path in repo + Size int64 `json:"size"` // File size in bytes (0 for directories) + OID string `json:"oid"` // Git blob ID + LFS *LFSInfo `json:"lfs"` // Present if LFS file +} + +// LFSInfo contains LFS-specific file information +type LFSInfo struct { + OID string `json:"oid"` // LFS object ID (sha256) + Size int64 `json:"size"` // Actual file size + PointerSize int64 `json:"pointer_size"` // Size of pointer file +} + +// ActualSize returns the actual file size, accounting for LFS +func (f *RepoFile) ActualSize() int64 { + if f.LFS != nil { + return f.LFS.Size + } + return f.Size +} + +// Filename returns the base filename without directory path +func (f *RepoFile) Filename() string { + return path.Base(f.Path) +} + +// configExtensions defines file extensions treated as config files +// This matches the existing packaging/safetensors.go logic +var configExtensions = []string{".md", ".txt", ".json", ".vocab", ".jinja"} + +// specialConfigFiles are specific filenames treated as config files +var specialConfigFiles = []string{"tokenizer.model"} + +// FileType represents the type of file for model packaging +type FileType int + +const ( + // FileTypeUnknown is an unrecognized file type + FileTypeUnknown FileType = iota + // FileTypeSafetensors is a safetensors model weight file + FileTypeSafetensors + // FileTypeConfig is a configuration file (json, txt, etc.) + FileTypeConfig +) + +// ClassifyFile determines the file type based on filename +func ClassifyFile(filename string) FileType { + lower := strings.ToLower(filename) + + // Check for safetensors files + if strings.HasSuffix(lower, ".safetensors") { + return FileTypeSafetensors + } + + // Check for config file extensions + for _, ext := range configExtensions { + if strings.HasSuffix(lower, ext) { + return FileTypeConfig + } + } + + // Check for special config files + for _, special := range specialConfigFiles { + if strings.EqualFold(filename, special) { + return FileTypeConfig + } + } + + return FileTypeUnknown +} + +// FilterModelFiles filters repository files to only include files needed for model-runner +// Returns safetensors files and config files separately +func FilterModelFiles(files []RepoFile) (safetensors []RepoFile, configs []RepoFile) { + for _, f := range files { + if f.Type != "file" { + continue + } + + switch ClassifyFile(f.Filename()) { + case FileTypeSafetensors: + safetensors = append(safetensors, f) + case FileTypeConfig: + configs = append(configs, f) + case FileTypeUnknown: + // Skip unknown file types + } + } + return safetensors, configs +} + +// TotalSize calculates the total size of files +func TotalSize(files []RepoFile) int64 { + var total int64 + for _, f := range files { + total += f.ActualSize() + } + return total +} + +// IsSafetensorsModel checks if the files contain at least one safetensors file +func IsSafetensorsModel(files []RepoFile) bool { + for _, f := range files { + if f.Type == "file" && ClassifyFile(f.Filename()) == FileTypeSafetensors { + return true + } + } + return false +} diff --git a/pkg/distribution/huggingface/repository_test.go b/pkg/distribution/huggingface/repository_test.go new file mode 100644 index 000000000..63918b710 --- /dev/null +++ b/pkg/distribution/huggingface/repository_test.go @@ -0,0 +1,138 @@ +package huggingface + +import ( + "testing" +) + +func TestClassifyFile(t *testing.T) { + tests := []struct { + name string + filename string + want FileType + }{ + {"safetensors file", "model.safetensors", FileTypeSafetensors}, + {"safetensors uppercase", "model.SAFETENSORS", FileTypeSafetensors}, + {"safetensors mixed case", "Model.SafeTensors", FileTypeSafetensors}, + {"sharded safetensors", "model-00001-of-00003.safetensors", FileTypeSafetensors}, + + {"json config", "config.json", FileTypeConfig}, + {"tokenizer json", "tokenizer.json", FileTypeConfig}, + {"tokenizer config", "tokenizer_config.json", FileTypeConfig}, + {"txt file", "README.txt", FileTypeConfig}, + {"markdown file", "README.md", FileTypeConfig}, + {"vocab file", "vocab.vocab", FileTypeConfig}, + {"jinja template", "chat_template.jinja", FileTypeConfig}, + {"tokenizer model", "tokenizer.model", FileTypeConfig}, + + {"unknown extension", "model.bin", FileTypeUnknown}, + {"python file", "model.py", FileTypeUnknown}, + {"pytorch model", "pytorch_model.bin", FileTypeUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ClassifyFile(tt.filename); got != tt.want { + t.Errorf("ClassifyFile(%q) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} + +func TestFilterModelFiles(t *testing.T) { + files := []RepoFile{ + {Type: "file", Path: "model.safetensors", Size: 1000}, + {Type: "file", Path: "config.json", Size: 100}, + {Type: "file", Path: "tokenizer.json", Size: 200}, + {Type: "file", Path: "README.md", Size: 50}, + {Type: "file", Path: "model.py", Size: 500}, + {Type: "directory", Path: "subdir", Size: 0}, + {Type: "file", Path: "model-00001-of-00002.safetensors", Size: 2000}, + {Type: "file", Path: "model-00002-of-00002.safetensors", Size: 2000}, + } + + safetensors, configs := FilterModelFiles(files) + + if len(safetensors) != 3 { + t.Errorf("Expected 3 safetensors files, got %d", len(safetensors)) + } + if len(configs) != 3 { + t.Errorf("Expected 3 config files, got %d", len(configs)) + } +} + +func TestTotalSize(t *testing.T) { + files := []RepoFile{ + {Type: "file", Path: "a.safetensors", Size: 1000}, + {Type: "file", Path: "b.safetensors", Size: 2000, LFS: &LFSInfo{Size: 5000}}, + } + + total := TotalSize(files) + if total != 6000 { // 1000 + 5000 (LFS size takes precedence) + t.Errorf("TotalSize() = %d, want 6000", total) + } +} + +func TestRepoFileActualSize(t *testing.T) { + tests := []struct { + name string + file RepoFile + want int64 + }{ + { + name: "regular file", + file: RepoFile{Size: 1000}, + want: 1000, + }, + { + name: "LFS file", + file: RepoFile{Size: 100, LFS: &LFSInfo{Size: 5000}}, + want: 5000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.file.ActualSize(); got != tt.want { + t.Errorf("ActualSize() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestIsSafetensorsModel(t *testing.T) { + tests := []struct { + name string + files []RepoFile + want bool + }{ + { + name: "has safetensors", + files: []RepoFile{ + {Type: "file", Path: "model.safetensors"}, + {Type: "file", Path: "config.json"}, + }, + want: true, + }, + { + name: "no safetensors", + files: []RepoFile{ + {Type: "file", Path: "config.json"}, + {Type: "file", Path: "README.md"}, + }, + want: false, + }, + { + name: "empty", + files: []RepoFile{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsSafetensorsModel(tt.files); got != tt.want { + t.Errorf("IsSafetensorsModel() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go index 74a4694c5..0b2159fcd 100644 --- a/pkg/distribution/packaging/safetensors.go +++ b/pkg/distribution/packaging/safetensors.go @@ -71,8 +71,16 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig // It returns the path to the temporary tar file and any error encountered. // The caller is responsible for removing the temporary file when done. func CreateTempConfigArchive(configFiles []string) (string, error) { - // Create temp file - tmpFile, err := os.CreateTemp("", "vllm-config-*.tar") + return CreateConfigArchiveInDir(configFiles, "") +} + +// CreateConfigArchiveInDir creates a tar archive containing the specified config files in the given directory. +// If dir is empty, the system temp directory is used. +// It returns the path to the tar file and any error encountered. +// The caller is responsible for removing the file when done. +func CreateConfigArchiveInDir(configFiles []string, dir string) (string, error) { + // Create temp file in specified directory + tmpFile, err := os.CreateTemp(dir, "vllm-config-*.tar") if err != nil { return "", fmt.Errorf("create temp file: %w", err) }