diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 9382e87ac..4181fee54 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -6,9 +6,11 @@ import ( "fmt" "io" "net/http" + "os" "slices" "strings" + "github.com/docker/model-runner/pkg/distribution/huggingface" "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/internal/store" "github.com/docker/model-runner/pkg/distribution/registry" @@ -162,10 +164,11 @@ func (c *Client) normalizeModelName(model string) string { return model } - // Normalize HuggingFace model names (lowercase path) + // Normalize HuggingFace model names if strings.HasPrefix(model, "hf.co/") { // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. - model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) + // Note: We preserve case since HuggingFace's native API is case-sensitive + model = "huggingface.co" + strings.TrimPrefix(model, "hf.co") } // Check if model contains a registry (domain with dot before first slash) @@ -267,15 +270,22 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter // Use the client's registry, or create a temporary one if bearer token is provided registryClient := c.registry + var token string if len(bearerToken) > 0 && bearerToken[0] != "" { + token = bearerToken[0] // Create a temporary registry client with bearer token authentication - auth := &authn.Bearer{Token: bearerToken[0]} + auth := &authn.Bearer{Token: token} registryClient = registry.FromClient(c.registry, registry.WithAuth(auth)) } // First, fetch the remote model to get the manifest remoteModel, err := registryClient.Model(ctx, reference) if err != nil { + // Check if this is a HuggingFace reference and the error indicates no OCI manifest + if isHuggingFaceReference(reference) && isNotOCIError(err) { + c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull") + return c.pullNativeHuggingFace(ctx, reference, progressWriter, token) + } return fmt.Errorf("reading model from registry: %w", err) } @@ -637,3 +647,116 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, return nil } + +// isHuggingFaceReference checks if a reference is a HuggingFace model reference +func isHuggingFaceReference(reference string) bool { + return strings.HasPrefix(reference, "huggingface.co/") +} + +// isNotOCIError checks if the error indicates the model is not OCI-formatted +// This happens when the HuggingFace repository doesn't have an OCI manifest +func isNotOCIError(err error) bool { + if err == nil { + return false + } + + // Check for registry errors indicating no manifest + var regErr *registry.Error + if errors.As(err, ®Err) { + if regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN" { + return true + } + } + + // Check for invalid reference error (e.g., uppercase letters not allowed in OCI) + // This happens with HuggingFace model names like "Qwen/Qwen3-0.6B" + if errors.Is(err, registry.ErrInvalidReference) { + return true + } + + // Also check error message for common patterns + errStr := err.Error() + return strings.Contains(errStr, "MANIFEST_UNKNOWN") || + strings.Contains(errStr, "NAME_UNKNOWN") || + strings.Contains(errStr, "manifest unknown") || + // HuggingFace returns this error for non-GGUF repositories + strings.Contains(errStr, "Repository is not GGUF") || + strings.Contains(errStr, "not compatible with llama.cpp") +} + +// parseHFReference extracts repo and revision from a normalized HF reference +// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision") +// e.g., "huggingface.co/org/model:latest" -> ("org/model", "main") +func parseHFReference(reference string) (repo, revision string) { + // Remove registry prefix + ref := strings.TrimPrefix(reference, "huggingface.co/") + + // Split by colon to get tag + parts := strings.SplitN(ref, ":", 2) + repo = parts[0] + + revision = "main" + if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" { + revision = parts[1] + } + + return repo, revision +} + +// pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format) +// This is used when the model is stored as raw files (safetensors) on HuggingFace Hub +func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error { + repo, revision := parseHFReference(reference) + c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision)) + + // Create HuggingFace client + hfOpts := []huggingface.ClientOption{ + huggingface.WithUserAgent(registry.DefaultUserAgent), + } + if token != "" { + hfOpts = append(hfOpts, huggingface.WithToken(token)) + } + hfClient := huggingface.NewClient(hfOpts...) + + // Create temp directory for downloads + tempDir, err := os.MkdirTemp("", "hf-model-*") + if err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + // Build model from HuggingFace repository + model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter) + if err != nil { + // Convert HuggingFace errors to registry errors for consistent handling + var authErr *huggingface.AuthError + var notFoundErr *huggingface.NotFoundError + if errors.As(err, &authErr) { + return registry.ErrUnauthorized + } + if errors.As(err, ¬FoundErr) { + return registry.ErrModelNotFound + } + if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { + c.log.Warnf("Failed to write error message: %v", writeErr) + } + return fmt.Errorf("build model from HuggingFace: %w", err) + } + + // Write model to store + // Lowercase the reference for storage since OCI tags don't allow uppercase + storageTag := strings.ToLower(reference) + c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag)) + if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil { + if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { + c.log.Warnf("Failed to write error message: %v", writeErr) + } + return fmt.Errorf("writing model to store: %w", err) + } + + if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil { + c.log.Warnf("Failed to write success message: %v", err) + } + + return nil +} diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index b970bcf41..e644f5cf4 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -2,12 +2,14 @@ package distribution import ( "context" + "errors" "io" "path/filepath" "strings" "testing" "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/sirupsen/logrus" ) @@ -66,7 +68,7 @@ func TestNormalizeModelName(t *testing.T) { expected: "registry.example.com/myorg/model:v1", }, - // HuggingFace cases + // HuggingFace cases (case is preserved for native HF API compatibility) { name: "huggingface short form lowercase", input: "hf.co/model", @@ -75,12 +77,12 @@ func TestNormalizeModelName(t *testing.T) { { name: "huggingface short form uppercase", input: "hf.co/Model", - expected: "huggingface.co/model:latest", + expected: "huggingface.co/Model:latest", }, { name: "huggingface short form with org", input: "hf.co/MyOrg/MyModel", - expected: "huggingface.co/myorg/mymodel:latest", + expected: "huggingface.co/MyOrg/MyModel:latest", }, { name: "huggingface with tag", @@ -355,6 +357,114 @@ func createTestClient(t *testing.T) (*Client, func()) { return client, cleanup } +func TestIsHuggingFaceReference(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"huggingface.co prefix", "huggingface.co/org/model:latest", true}, + {"huggingface.co without tag", "huggingface.co/org/model", true}, + {"not huggingface", "registry.example.com/model:latest", false}, + {"docker hub", "ai/gemma3:latest", false}, + {"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHuggingFaceReference(tt.input) + if result != tt.expected { + t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestParseHFReference(t *testing.T) { + tests := []struct { + name string + input string + expectedRepo string + expectedRev string + }{ + { + name: "basic with latest tag", + input: "huggingface.co/org/model:latest", + expectedRepo: "org/model", + expectedRev: "main", // latest maps to main + }, + { + name: "with explicit revision", + input: "huggingface.co/org/model:v1.0", + expectedRepo: "org/model", + expectedRev: "v1.0", + }, + { + name: "without tag", + input: "huggingface.co/org/model", + expectedRepo: "org/model", + expectedRev: "main", + }, + { + name: "with commit hash as tag", + input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123", + expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct", + expectedRev: "abc123", + }, + { + name: "single name (no org)", + input: "huggingface.co/model:latest", + expectedRepo: "model", + expectedRev: "main", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, rev := parseHFReference(tt.input) + if repo != tt.expectedRepo { + t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo) + } + if rev != tt.expectedRev { + t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev) + } + }) + } +} + +func TestIsNotOCIError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + {"nil error", nil, false}, + {"generic error", errors.New("some error"), false}, + {"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true}, + {"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true}, + {"manifest unknown lowercase", errors.New("manifest unknown"), true}, + {"unrelated error", errors.New("network timeout"), false}, + {"HuggingFace not GGUF error", errors.New("Repository is not GGUF or is not compatible with llama.cpp"), true}, + {"HuggingFace llama.cpp incompatible", errors.New("not compatible with llama.cpp"), true}, + // registry.Error typed error cases + {"registry error MANIFEST_UNKNOWN", ®istry.Error{Code: "MANIFEST_UNKNOWN"}, true}, + {"registry error NAME_UNKNOWN", ®istry.Error{Code: "NAME_UNKNOWN"}, true}, + {"registry error other code", ®istry.Error{Code: "UNAUTHORIZED"}, false}, + // ErrInvalidReference case + {"invalid reference error", registry.ErrInvalidReference, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNotOCIError(tt.err) + if result != tt.expected { + t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected) + } + }) + } +} + // Helper function to load a test model and return its ID func loadTestModel(t *testing.T, client *Client, ggufPath string) string { t.Helper() diff --git a/pkg/distribution/huggingface/client.go b/pkg/distribution/huggingface/client.go new file mode 100644 index 000000000..9dc5a64e9 --- /dev/null +++ b/pkg/distribution/huggingface/client.go @@ -0,0 +1,192 @@ +package huggingface + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +const ( + defaultBaseURL = "https://huggingface.co" + defaultUserAgent = "model-distribution" +) + +// Client handles HuggingFace Hub API interactions +type Client struct { + httpClient *http.Client + userAgent string + token string + baseURL string +} + +// ClientOption configures a Client +type ClientOption func(*Client) + +// WithToken sets the HuggingFace API token for authentication +func WithToken(token string) ClientOption { + return func(c *Client) { + if token != "" { + c.token = token + } + } +} + +// WithTransport sets the HTTP transport for the client +func WithTransport(transport http.RoundTripper) ClientOption { + return func(c *Client) { + if transport != nil { + c.httpClient.Transport = transport + } + } +} + +// WithUserAgent sets the User-Agent header for requests +func WithUserAgent(userAgent string) ClientOption { + return func(c *Client) { + if userAgent != "" { + c.userAgent = userAgent + } + } +} + +// WithBaseURL sets a custom base URL (useful for testing) +func WithBaseURL(baseURL string) ClientOption { + return func(c *Client) { + if baseURL != "" { + c.baseURL = strings.TrimSuffix(baseURL, "/") + } + } +} + +// NewClient creates a new HuggingFace Hub API client +func NewClient(opts ...ClientOption) *Client { + c := &Client{ + httpClient: &http.Client{}, + userAgent: defaultUserAgent, + baseURL: defaultBaseURL, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// ListFiles returns all files in a repository at a given revision +func (c *Client) ListFiles(ctx context.Context, repo, revision string) ([]RepoFile, error) { + if revision == "" { + revision = "main" + } + + // HuggingFace API endpoint for listing files + url := fmt.Sprintf("%s/api/models/%s/tree/%s", c.baseURL, repo, revision) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + c.setHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("list files: %w", err) + } + defer resp.Body.Close() + + if err := c.checkResponse(resp, repo); err != nil { + return nil, err + } + + var files []RepoFile + if err := json.NewDecoder(resp.Body).Decode(&files); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return files, nil +} + +// DownloadFile streams a file from the repository +// Returns the reader, content length (-1 if unknown), and any error +func (c *Client) DownloadFile(ctx context.Context, repo, revision, filename string) (io.ReadCloser, int64, error) { + if revision == "" { + revision = "main" + } + + // HuggingFace file download endpoint (handles LFS redirects automatically) + url := fmt.Sprintf("%s/%s/resolve/%s/%s", c.baseURL, repo, revision, filename) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, 0, fmt.Errorf("create request: %w", err) + } + + c.setHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("download file: %w", err) + } + + if err := c.checkResponse(resp, repo); err != nil { + resp.Body.Close() + return nil, 0, err + } + + return resp.Body, resp.ContentLength, nil +} + +// setHeaders sets common headers for HuggingFace API requests +func (c *Client) setHeaders(req *http.Request) { + req.Header.Set("User-Agent", c.userAgent) + if c.token != "" { + req.Header.Set("Authorization", "Bearer "+c.token) + } +} + +// checkResponse checks the HTTP response for errors +func (c *Client) checkResponse(resp *http.Response, repo string) error { + switch resp.StatusCode { + case http.StatusOK: + return nil + case http.StatusUnauthorized, http.StatusForbidden: + return &AuthError{Repo: repo, StatusCode: resp.StatusCode} + case http.StatusNotFound: + return &NotFoundError{Repo: repo} + case http.StatusTooManyRequests: + return &RateLimitError{Repo: repo} + default: + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } +} + +// AuthError indicates authentication failure +type AuthError struct { + Repo string + StatusCode int +} + +func (e *AuthError) Error() string { + return fmt.Sprintf("authentication required for repository %q (status %d)", e.Repo, e.StatusCode) +} + +// NotFoundError indicates the repository or file was not found +type NotFoundError struct { + Repo string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("repository %q not found", e.Repo) +} + +// RateLimitError indicates rate limiting +type RateLimitError struct { + Repo string +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limited while accessing repository %q", e.Repo) +} diff --git a/pkg/distribution/huggingface/client_test.go b/pkg/distribution/huggingface/client_test.go new file mode 100644 index 000000000..e33dff7a0 --- /dev/null +++ b/pkg/distribution/huggingface/client_test.go @@ -0,0 +1,157 @@ +package huggingface + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestClientListFiles(t *testing.T) { + // Mock HuggingFace API response + mockFiles := []RepoFile{ + {Type: "file", Path: "model.safetensors", Size: 1000}, + {Type: "file", Path: "config.json", Size: 100}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/models/test-org/test-model/tree/main" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockFiles) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + files, err := client.ListFiles(context.Background(), "test-org/test-model", "main") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } + + if len(files) != 2 { + t.Errorf("Expected 2 files, got %d", len(files)) + } +} + +func TestClientListFilesDefaultRevision(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the default revision is "main" + if !strings.Contains(r.URL.Path, "/tree/main") { + t.Errorf("Expected /tree/main in path, got %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]RepoFile{}) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + _, err := client.ListFiles(context.Background(), "test/model", "") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } +} + +func TestClientDownloadFile(t *testing.T) { + expectedContent := "test file content" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test-org/test-model/resolve/main/test.txt" { + w.Header().Set("Content-Length", "17") + w.Write([]byte(expectedContent)) + return + } + http.NotFound(w, r) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + reader, size, err := client.DownloadFile(context.Background(), "test-org/test-model", "main", "test.txt") + if err != nil { + t.Fatalf("DownloadFile failed: %v", err) + } + defer reader.Close() + + if size != 17 { + t.Errorf("Expected size 17, got %d", size) + } + + content, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) + } + + if string(content) != expectedContent { + t.Errorf("Expected content %q, got %q", expectedContent, string(content)) + } +} + +func TestClientAuthError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + _, err := client.ListFiles(context.Background(), "private/model", "main") + if err == nil { + t.Fatal("Expected error, got nil") + } + + var authErr *AuthError + if !errors.As(err, &authErr) { + t.Errorf("Expected AuthError, got %T", err) + } +} + +func TestClientNotFoundError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := NewClient(WithBaseURL(server.URL)) + + _, err := client.ListFiles(context.Background(), "nonexistent/model", "main") + if err == nil { + t.Fatal("Expected error, got nil") + } + + var notFoundErr *NotFoundError + if !errors.As(err, ¬FoundErr) { + t.Errorf("Expected NotFoundError, got %T", err) + } +} + +func TestClientWithToken(t *testing.T) { + var receivedToken string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]RepoFile{}) + })) + defer server.Close() + + client := NewClient( + WithBaseURL(server.URL), + WithToken("test-token"), + ) + + _, err := client.ListFiles(context.Background(), "test/model", "main") + if err != nil { + t.Fatalf("ListFiles failed: %v", err) + } + + if receivedToken != "Bearer test-token" { + t.Errorf("Expected 'Bearer test-token', got %q", receivedToken) + } +} diff --git a/pkg/distribution/huggingface/downloader.go b/pkg/distribution/huggingface/downloader.go new file mode 100644 index 000000000..0bb537319 --- /dev/null +++ b/pkg/distribution/huggingface/downloader.go @@ -0,0 +1,221 @@ +package huggingface + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "sync" + + "github.com/docker/model-runner/pkg/distribution/internal/progress" +) + +// Downloader manages file downloads from HuggingFace repositories +type Downloader struct { + client *Client + repo string + revision string + tempDir string +} + +// NewDownloader creates a new downloader for a HuggingFace repository +func NewDownloader(client *Client, repo, revision, tempDir string) *Downloader { + if revision == "" { + revision = "main" + } + return &Downloader{ + client: client, + repo: repo, + revision: revision, + tempDir: tempDir, + } +} + +// DownloadResult contains the result of downloading files +type DownloadResult struct { + // LocalPaths maps original repo paths to local file paths + LocalPaths map[string]string + // TotalBytes is the total number of bytes downloaded + TotalBytes int64 +} + +// syncWriter wraps an io.Writer with a mutex for thread-safe concurrent writes +type syncWriter struct { + mu sync.Mutex + w io.Writer +} + +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + +// fileIDFromPath generates a unique ID for a file based on its path +// Returns a sha256: prefixed hash to match layer ID format +func fileIDFromPath(path string) string { + hash := sha256.Sum256([]byte(path)) + return "sha256:" + hex.EncodeToString(hash[:]) +} + +// DownloadAll downloads all specified files with progress reporting +// Files are downloaded in parallel with per-file progress updates written to progressWriter +func (d *Downloader) DownloadAll(ctx context.Context, files []RepoFile, progressWriter io.Writer) (*DownloadResult, error) { + if len(files) == 0 { + return &DownloadResult{LocalPaths: make(map[string]string)}, nil + } + + totalSize := TotalSize(files) + + // Create result map (thread-safe access) + var mu sync.Mutex + localPaths := make(map[string]string, len(files)) + + // Create thread-safe writer for concurrent progress reporting + var safeWriter io.Writer + if progressWriter != nil { + safeWriter = &syncWriter{w: progressWriter} + } + + // Download files in parallel (limit concurrency to avoid overwhelming) + const maxConcurrent = 4 + sem := make(chan struct{}, maxConcurrent) + var wg sync.WaitGroup + errChan := make(chan error, len(files)) + + for _, file := range files { + wg.Add(1) + go func(f RepoFile) { + defer wg.Done() + + // Acquire semaphore + select { + case sem <- struct{}{}: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + defer func() { <-sem }() + + localPath, err := d.downloadFileWithProgress(ctx, f, uint64(totalSize), safeWriter) + if err != nil { + errChan <- fmt.Errorf("download %s: %w", f.Path, err) + return + } + + mu.Lock() + localPaths[f.Path] = localPath + mu.Unlock() + }(file) + } + + // Wait for all downloads to complete + wg.Wait() + close(errChan) + + // Collect any errors + var errs []error + for err := range errChan { + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return nil, fmt.Errorf("download errors: %v", errs) + } + + // Calculate total downloaded + var totalDownloaded int64 + for _, f := range files { + totalDownloaded += f.ActualSize() + } + + return &DownloadResult{ + LocalPaths: localPaths, + TotalBytes: totalDownloaded, + }, nil +} + +// downloadFileWithProgress downloads a single file with progress reporting +func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile, totalImageSize uint64, progressWriter io.Writer) (string, error) { + // Create local file path (preserve directory structure) + localPath := filepath.Join(d.tempDir, file.Path) + + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(localPath), 0o755); err != nil { + return "", fmt.Errorf("create directory: %w", err) + } + + // Download from HuggingFace + reader, _, err := d.client.DownloadFile(ctx, d.repo, d.revision, file.Path) + if err != nil { + return "", err + } + defer reader.Close() + + // Create local file + f, err := os.Create(localPath) + if err != nil { + return "", fmt.Errorf("create file: %w", err) + } + defer f.Close() + + // Generate unique ID for this file (for progress tracking) + fileID := fileIDFromPath(file.Path) + fileSize := uint64(file.ActualSize()) + + // Copy with progress tracking + pr := &progressReader{ + reader: reader, + progressWriter: progressWriter, + totalImageSize: totalImageSize, + fileSize: fileSize, + fileID: fileID, + } + + if _, err := io.Copy(f, pr); err != nil { + os.Remove(localPath) // Clean up on error + return "", fmt.Errorf("write file: %w", err) + } + + // Write final progress for this file (100% complete) + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID) + } + + return localPath, nil +} + +// progressReader wraps a reader and reports per-file progress +type progressReader struct { + reader io.Reader + progressWriter io.Writer + totalImageSize uint64 + fileSize uint64 + fileID string + bytesRead uint64 + lastReported uint64 +} + +func (pr *progressReader) Read(p []byte) (n int, err error) { + n, err = pr.reader.Read(p) + if n > 0 { + pr.bytesRead += uint64(n) + + // Report progress periodically (every 1MB or when complete) + if pr.progressWriter != nil && (pr.bytesRead-pr.lastReported >= progress.MinBytesForUpdate || pr.bytesRead == pr.fileSize) { + _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID) + pr.lastReported = pr.bytesRead + } + } + return n, err +} + +// DownloadSingleFile downloads a single file and returns its local path +func (d *Downloader) DownloadSingleFile(ctx context.Context, file RepoFile) (string, error) { + return d.downloadFileWithProgress(ctx, file, uint64(file.ActualSize()), nil) +} diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go new file mode 100644 index 000000000..fd072b836 --- /dev/null +++ b/pkg/distribution/huggingface/model.go @@ -0,0 +1,159 @@ +package huggingface + +import ( + "context" + "fmt" + "io" + "log" + "path/filepath" + "sort" + "strings" + + "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/internal/progress" + "github.com/docker/model-runner/pkg/distribution/packaging" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// BuildModel downloads files from a HuggingFace repository and constructs an OCI model artifact +// This is the main entry point for pulling native HuggingFace models +func BuildModel(ctx context.Context, client *Client, repo, revision string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { + // Step 1: List files in the repository + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "") + } + + files, err := client.ListFiles(ctx, repo, revision) + if err != nil { + return nil, fmt.Errorf("list files: %w", err) + } + + // Step 2: Filter to model files (safetensors + configs) + safetensorsFiles, configFiles := FilterModelFiles(files) + + if len(safetensorsFiles) == 0 { + return nil, fmt.Errorf("no safetensors files found in repository %s", repo) + } + + // Combine all files to download + allFiles := append(safetensorsFiles, configFiles...) + + if progressWriter != nil { + totalSize := TotalSize(allFiles) + msg := fmt.Sprintf("Found %d files (%.2f MB total)", + len(allFiles), float64(totalSize)/1024/1024) + _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "") + } + + // Step 3: Download all files + downloader := NewDownloader(client, repo, revision, tempDir) + result, err := downloader.DownloadAll(ctx, allFiles, progressWriter) + if err != nil { + return nil, fmt.Errorf("download files: %w", err) + } + + // Step 4: Build the model artifact + if progressWriter != nil { + _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") + } + + model, err := buildModelFromFiles(result.LocalPaths, safetensorsFiles, configFiles, tempDir) + if err != nil { + return nil, fmt.Errorf("build model: %w", err) + } + + return model, nil +} + +// buildModelFromFiles constructs an OCI model artifact from downloaded files +func buildModelFromFiles(localPaths map[string]string, safetensorsFiles, configFiles []RepoFile, tempDir string) (types.ModelArtifact, error) { + // Collect safetensors paths (sorted for reproducibility) + var safetensorsPaths []string + for _, f := range safetensorsFiles { + localPath, ok := localPaths[f.Path] + if !ok { + return nil, fmt.Errorf("missing local path for %s", f.Path) + } + safetensorsPaths = append(safetensorsPaths, localPath) + } + sort.Strings(safetensorsPaths) + + // Create builder from safetensors files + b, err := builder.FromSafetensors(safetensorsPaths) + if err != nil { + return nil, fmt.Errorf("create builder: %w", err) + } + + // Create config archive if we have config files + if len(configFiles) > 0 { + configArchive, err := createConfigArchive(localPaths, configFiles, tempDir) + if err != nil { + return nil, fmt.Errorf("create config archive: %w", err) + } + // Note: configArchive is cleaned up by the caller's tempDir cleanup + + if configArchive != "" { + b, err = b.WithConfigArchive(configArchive) + if err != nil { + return nil, fmt.Errorf("add config archive: %w", err) + } + } + } + + // Check for chat template and add it + for _, f := range configFiles { + if isChatTemplate(f.Path) { + localPath := localPaths[f.Path] + b, err = b.WithChatTemplateFile(localPath) + if err != nil { + // Non-fatal: log warning but continue to try other potential templates + log.Printf("Warning: failed to add chat template from %s: %v", f.Path, err) + continue + } + break // Only add one chat template + } + } + + return b.Model(), nil +} + +// createConfigArchive creates a tar archive of config files in the specified tempDir +func createConfigArchive(localPaths map[string]string, configFiles []RepoFile, tempDir string) (string, error) { + // Collect config file paths (excluding chat templates which are added separately) + var configPaths []string + for _, f := range configFiles { + if isChatTemplate(f.Path) { + continue // Chat templates are added as separate layers + } + localPath, ok := localPaths[f.Path] + if !ok { + return "", fmt.Errorf("internal error: missing local path for downloaded config file %s", f.Path) + } + configPaths = append(configPaths, localPath) + } + + if len(configPaths) == 0 { + // No config files to archive + return "", nil + } + + // Sort for reproducibility + sort.Strings(configPaths) + + // Create the archive in our tempDir so it gets cleaned up with everything else + archivePath, err := packaging.CreateConfigArchiveInDir(configPaths, tempDir) + if err != nil { + return "", fmt.Errorf("create config archive: %w", err) + } + + return archivePath, nil +} + +// isChatTemplate checks if a file is a chat template +func isChatTemplate(path string) bool { + filename := filepath.Base(path) + lower := strings.ToLower(filename) + return strings.HasSuffix(lower, ".jinja") || + strings.Contains(lower, "chat_template") || + filename == "tokenizer_config.json" // Often contains chat_template +} diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go new file mode 100644 index 000000000..deb41c1b3 --- /dev/null +++ b/pkg/distribution/huggingface/repository.go @@ -0,0 +1,119 @@ +package huggingface + +import ( + "path" + "strings" +) + +// RepoFile represents a file in a HuggingFace repository +type RepoFile struct { + Type string `json:"type"` // "file" or "directory" + Path string `json:"path"` // Relative path in repo + Size int64 `json:"size"` // File size in bytes (0 for directories) + OID string `json:"oid"` // Git blob ID + LFS *LFSInfo `json:"lfs"` // Present if LFS file +} + +// LFSInfo contains LFS-specific file information +type LFSInfo struct { + OID string `json:"oid"` // LFS object ID (sha256) + Size int64 `json:"size"` // Actual file size + PointerSize int64 `json:"pointer_size"` // Size of pointer file +} + +// ActualSize returns the actual file size, accounting for LFS +func (f *RepoFile) ActualSize() int64 { + if f.LFS != nil { + return f.LFS.Size + } + return f.Size +} + +// Filename returns the base filename without directory path +func (f *RepoFile) Filename() string { + return path.Base(f.Path) +} + +// configExtensions defines file extensions treated as config files +// This matches the existing packaging/safetensors.go logic +var configExtensions = []string{".md", ".txt", ".json", ".vocab", ".jinja"} + +// specialConfigFiles are specific filenames treated as config files +var specialConfigFiles = []string{"tokenizer.model"} + +// FileType represents the type of file for model packaging +type FileType int + +const ( + // FileTypeUnknown is an unrecognized file type + FileTypeUnknown FileType = iota + // FileTypeSafetensors is a safetensors model weight file + FileTypeSafetensors + // FileTypeConfig is a configuration file (json, txt, etc.) + FileTypeConfig +) + +// ClassifyFile determines the file type based on filename +func ClassifyFile(filename string) FileType { + lower := strings.ToLower(filename) + + // Check for safetensors files + if strings.HasSuffix(lower, ".safetensors") { + return FileTypeSafetensors + } + + // Check for config file extensions + for _, ext := range configExtensions { + if strings.HasSuffix(lower, ext) { + return FileTypeConfig + } + } + + // Check for special config files + for _, special := range specialConfigFiles { + if strings.EqualFold(filename, special) { + return FileTypeConfig + } + } + + return FileTypeUnknown +} + +// FilterModelFiles filters repository files to only include files needed for model-runner +// Returns safetensors files and config files separately +func FilterModelFiles(files []RepoFile) (safetensors []RepoFile, configs []RepoFile) { + for _, f := range files { + if f.Type != "file" { + continue + } + + switch ClassifyFile(f.Filename()) { + case FileTypeSafetensors: + safetensors = append(safetensors, f) + case FileTypeConfig: + configs = append(configs, f) + case FileTypeUnknown: + // Skip unknown file types + } + } + return safetensors, configs +} + +// TotalSize calculates the total size of files +func TotalSize(files []RepoFile) int64 { + var total int64 + for _, f := range files { + total += f.ActualSize() + } + return total +} + +// IsSafetensorsModel checks if the files contain at least one safetensors file +func IsSafetensorsModel(files []RepoFile) bool { + for _, f := range files { + if f.Type == "file" && ClassifyFile(f.Filename()) == FileTypeSafetensors { + return true + } + } + return false +} diff --git a/pkg/distribution/huggingface/repository_test.go b/pkg/distribution/huggingface/repository_test.go new file mode 100644 index 000000000..63918b710 --- /dev/null +++ b/pkg/distribution/huggingface/repository_test.go @@ -0,0 +1,138 @@ +package huggingface + +import ( + "testing" +) + +func TestClassifyFile(t *testing.T) { + tests := []struct { + name string + filename string + want FileType + }{ + {"safetensors file", "model.safetensors", FileTypeSafetensors}, + {"safetensors uppercase", "model.SAFETENSORS", FileTypeSafetensors}, + {"safetensors mixed case", "Model.SafeTensors", FileTypeSafetensors}, + {"sharded safetensors", "model-00001-of-00003.safetensors", FileTypeSafetensors}, + + {"json config", "config.json", FileTypeConfig}, + {"tokenizer json", "tokenizer.json", FileTypeConfig}, + {"tokenizer config", "tokenizer_config.json", FileTypeConfig}, + {"txt file", "README.txt", FileTypeConfig}, + {"markdown file", "README.md", FileTypeConfig}, + {"vocab file", "vocab.vocab", FileTypeConfig}, + {"jinja template", "chat_template.jinja", FileTypeConfig}, + {"tokenizer model", "tokenizer.model", FileTypeConfig}, + + {"unknown extension", "model.bin", FileTypeUnknown}, + {"python file", "model.py", FileTypeUnknown}, + {"pytorch model", "pytorch_model.bin", FileTypeUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ClassifyFile(tt.filename); got != tt.want { + t.Errorf("ClassifyFile(%q) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} + +func TestFilterModelFiles(t *testing.T) { + files := []RepoFile{ + {Type: "file", Path: "model.safetensors", Size: 1000}, + {Type: "file", Path: "config.json", Size: 100}, + {Type: "file", Path: "tokenizer.json", Size: 200}, + {Type: "file", Path: "README.md", Size: 50}, + {Type: "file", Path: "model.py", Size: 500}, + {Type: "directory", Path: "subdir", Size: 0}, + {Type: "file", Path: "model-00001-of-00002.safetensors", Size: 2000}, + {Type: "file", Path: "model-00002-of-00002.safetensors", Size: 2000}, + } + + safetensors, configs := FilterModelFiles(files) + + if len(safetensors) != 3 { + t.Errorf("Expected 3 safetensors files, got %d", len(safetensors)) + } + if len(configs) != 3 { + t.Errorf("Expected 3 config files, got %d", len(configs)) + } +} + +func TestTotalSize(t *testing.T) { + files := []RepoFile{ + {Type: "file", Path: "a.safetensors", Size: 1000}, + {Type: "file", Path: "b.safetensors", Size: 2000, LFS: &LFSInfo{Size: 5000}}, + } + + total := TotalSize(files) + if total != 6000 { // 1000 + 5000 (LFS size takes precedence) + t.Errorf("TotalSize() = %d, want 6000", total) + } +} + +func TestRepoFileActualSize(t *testing.T) { + tests := []struct { + name string + file RepoFile + want int64 + }{ + { + name: "regular file", + file: RepoFile{Size: 1000}, + want: 1000, + }, + { + name: "LFS file", + file: RepoFile{Size: 100, LFS: &LFSInfo{Size: 5000}}, + want: 5000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.file.ActualSize(); got != tt.want { + t.Errorf("ActualSize() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestIsSafetensorsModel(t *testing.T) { + tests := []struct { + name string + files []RepoFile + want bool + }{ + { + name: "has safetensors", + files: []RepoFile{ + {Type: "file", Path: "model.safetensors"}, + {Type: "file", Path: "config.json"}, + }, + want: true, + }, + { + name: "no safetensors", + files: []RepoFile{ + {Type: "file", Path: "config.json"}, + {Type: "file", Path: "README.md"}, + }, + want: false, + }, + { + name: "empty", + files: []RepoFile{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsSafetensorsModel(tt.files); got != tt.want { + t.Errorf("IsSafetensorsModel() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go index 74a4694c5..0b2159fcd 100644 --- a/pkg/distribution/packaging/safetensors.go +++ b/pkg/distribution/packaging/safetensors.go @@ -71,8 +71,16 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig // It returns the path to the temporary tar file and any error encountered. // The caller is responsible for removing the temporary file when done. func CreateTempConfigArchive(configFiles []string) (string, error) { - // Create temp file - tmpFile, err := os.CreateTemp("", "vllm-config-*.tar") + return CreateConfigArchiveInDir(configFiles, "") +} + +// CreateConfigArchiveInDir creates a tar archive containing the specified config files in the given directory. +// If dir is empty, the system temp directory is used. +// It returns the path to the tar file and any error encountered. +// The caller is responsible for removing the file when done. +func CreateConfigArchiveInDir(configFiles []string, dir string) (string, error) { + // Create temp file in specified directory + tmpFile, err := os.CreateTemp(dir, "vllm-config-*.tar") if err != nil { return "", fmt.Errorf("create temp file: %w", err) }