From 48b346bc7a1b118ba5bb94e3c906ba50d8e11eac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Wed, 23 Apr 2025 18:21:42 +0200 Subject: [PATCH 01/17] feat: extract from docker/model-cli --- dockermodelrunner/desktop/api.go | 68 +++ dockermodelrunner/desktop/desktop.go | 487 ++++++++++++++++++ dockermodelrunner/go.mod | 5 + dockermodelrunner/go.sum | 2 + dockermodelrunner/inference/api.go | 12 + .../modeldistribution/types/types.go | 65 +++ dockermodelrunner/models/api.go | 62 +++ 7 files changed, 701 insertions(+) create mode 100644 dockermodelrunner/desktop/api.go create mode 100644 dockermodelrunner/desktop/desktop.go create mode 100644 dockermodelrunner/go.mod create mode 100644 dockermodelrunner/go.sum create mode 100644 dockermodelrunner/inference/api.go create mode 100644 dockermodelrunner/modeldistribution/types/types.go create mode 100644 dockermodelrunner/models/api.go diff --git a/dockermodelrunner/desktop/api.go b/dockermodelrunner/desktop/api.go new file mode 100644 index 000000000..27fa9cb45 --- /dev/null +++ b/dockermodelrunner/desktop/api.go @@ -0,0 +1,68 @@ +package desktop + +// ProgressMessage represents a message sent during model pull operations +type ProgressMessage struct { + Type string `json:"type"` // "progress", "success", or "error" + Message string `json:"message"` // Human-readable message +} + +type OpenAIChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []OpenAIChatMessage `json:"messages"` + Stream bool `json:"stream"` +} + +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Delta struct { + Content string `json:"content"` + Role string `json:"role,omitempty"` + } `json:"delta"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +type OpenAIModelList struct { + Object string `json:"object"` + Data []*OpenAIModel `json:"data"` +} + +// TODO: To be replaced by the Model struct from pianta's common/pkg/inference/models/api.go. +// (https://github.com/docker/pinata/pull/33331) +type Format string + +type Config struct { + Format Format `json:"format,omitempty"` + Quantization string `json:"quantization,omitempty"` + Parameters string `json:"parameters,omitempty"` + Architecture string `json:"architecture,omitempty"` + Size string `json:"size,omitempty"` +} + +type Model struct { + // ID is the globally unique model identifier. + ID string `json:"id"` + // Tags are the list of tags associated with the model. + Tags []string `json:"tags"` + // Created is the Unix epoch timestamp corresponding to the model creation. + Created int64 `json:"created"` + // Config describes the model. + Config Config `json:"config"` +} diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go new file mode 100644 index 000000000..dd5e4e55c --- /dev/null +++ b/dockermodelrunner/desktop/desktop.go @@ -0,0 +1,487 @@ +package desktop + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "html" + "io" + "net/http" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/docker/docker-sdk-go/dockermodelrunner/inference" + "github.com/docker/docker-sdk-go/dockermodelrunner/models" +) + +var ( + ErrNotFound = errors.New("model not found") + ErrServiceUnavailable = errors.New("service unavailable") +) + +type Client struct { + dockerClient DockerHttpClient +} + +//go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHttpClient +type DockerHttpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +func New(dockerClient DockerHttpClient) *Client { + return &Client{dockerClient} +} + +type Status struct { + Running bool `json:"running"` + Status []byte `json:"status"` + Error error `json:"error"` +} + +func (c *Client) Status() Status { + // TODO: Query "/". + resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) + if err != nil { + err = c.handleQueryError(err, inference.ModelsPrefix) + if errors.Is(err, ErrServiceUnavailable) { + return Status{ + Running: false, + } + } + return Status{ + Running: false, + Error: err, + } + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + var status []byte + statusResp, err := c.doRequest(http.MethodGet, inference.InferencePrefix+"/status", nil) + if err != nil { + status = []byte(fmt.Sprintf("error querying status: %v", err)) + } else { + defer statusResp.Body.Close() + statusBody, err := io.ReadAll(statusResp.Body) + if err != nil { + status = []byte(fmt.Sprintf("error reading status body: %v", err)) + } else { + status = statusBody + } + } + return Status{ + Running: true, + Status: status, + } + } + return Status{ + Running: false, + Error: fmt.Errorf("unexpected status code: %d", resp.StatusCode), + } +} + +func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { + jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) + if err != nil { + return "", false, fmt.Errorf("error marshaling request: %w", err) + } + + createPath := inference.ModelsPrefix + "/create" + resp, err := c.doRequest( + http.MethodPost, + createPath, + bytes.NewReader(jsonData), + ) + if err != nil { + return "", false, c.handleQueryError(err, createPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) + } + + progressShown := false + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + progressLine := scanner.Text() + if progressLine == "" { + continue + } + + // Parse the progress message + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) + } + + // Handle different message types + switch progressMsg.Type { + case "progress": + progress(progressMsg.Message) + progressShown = true + case "error": + return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) + case "success": + return progressMsg.Message, progressShown, nil + default: + return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) + } + } + + // If we get here, something went wrong + return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) +} + +func (c *Client) Push(model string, progress func(string)) (string, bool, error) { + pushPath := inference.ModelsPrefix + "/" + model + "/push" + resp, err := c.doRequest( + http.MethodPost, + pushPath, + nil, // Assuming no body is needed for the push request + ) + if err != nil { + return "", false, c.handleQueryError(err, pushPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body)) + } + + progressShown := false + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + progressLine := scanner.Text() + if progressLine == "" { + continue + } + + // Parse the progress message + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) + } + + // Handle different message types + switch progressMsg.Type { + case "progress": + progress(progressMsg.Message) + progressShown = true + case "error": + return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message) + case "success": + return progressMsg.Message, progressShown, nil + default: + return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) + } + } + + // If we get here, something went wrong + return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model) +} + +func (c *Client) List() ([]Model, error) { + modelsRoute := inference.ModelsPrefix + body, err := c.listRaw(modelsRoute, "") + if err != nil { + return []Model{}, err + } + + var modelsJson []Model + if err := json.Unmarshal(body, &modelsJson); err != nil { + return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return modelsJson, nil +} + +func (c *Client) ListOpenAI() (OpenAIModelList, error) { + modelsRoute := inference.InferencePrefix + "/v1/models" + rawResponse, err := c.listRaw(modelsRoute, "") + if err != nil { + return OpenAIModelList{}, err + } + var modelsJson OpenAIModelList + if err := json.Unmarshal(rawResponse, &modelsJson); err != nil { + return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err) + } + return modelsJson, nil +} + +func (c *Client) Inspect(model string) (Model, error) { + if model != "" { + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + modelId, err := c.fullModelID(model) + if err != nil { + return Model{}, fmt.Errorf("invalid model name: %s", model) + } + model = modelId + } + } + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model) + if err != nil { + return Model{}, err + } + var modelInspect Model + if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { + return modelInspect, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return modelInspect, nil +} + +func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { + modelsRoute := inference.InferencePrefix + "/v1/models" + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + var err error + if model, err = c.fullModelID(model); err != nil { + return OpenAIModel{}, fmt.Errorf("invalid model name: %s", model) + } + } + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) + if err != nil { + return OpenAIModel{}, err + } + var modelInspect OpenAIModel + if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { + return modelInspect, fmt.Errorf("failed to unmarshal response body: %w", err) + } + return modelInspect, nil +} + +func (c *Client) listRaw(route string, model string) ([]byte, error) { + resp, err := c.doRequest(http.MethodGet, route, nil) + if err != nil { + return nil, c.handleQueryError(err, route) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if model != "" && resp.StatusCode == http.StatusNotFound { + return nil, errors.Wrap(ErrNotFound, model) + } + return nil, fmt.Errorf("failed to list models: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return body, nil + +} + +func (c *Client) fullModelID(id string) (string, error) { + bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") + if err != nil { + return "", err + } + + var modelsJson []Model + if err := json.Unmarshal(bodyResponse, &modelsJson); err != nil { + return "", fmt.Errorf("failed to unmarshal response body: %w", err) + } + + for _, m := range modelsJson { + if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { + return m.ID, nil + } + } + + return "", fmt.Errorf("model with ID %s not found", id) +} + +func (c *Client) Chat(model, prompt string) error { + if !strings.Contains(strings.Trim(model, "/"), "/") { + // Do an extra API call to check if the model parameter isn't a model ID. + if expanded, err := c.fullModelID(model); err == nil { + model = expanded + } + } + + reqBody := OpenAIChatRequest{ + Model: model, + Messages: []OpenAIChatMessage{ + { + Role: "user", + Content: prompt, + }, + }, + Stream: true, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("error marshaling request: %w", err) + } + + chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions" + resp, err := c.doRequest( + http.MethodPost, + chatCompletionsPath, + bytes.NewReader(jsonData), + ) + if err != nil { + return c.handleQueryError(err, chatCompletionsPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + if data == "[DONE]" { + break + } + + var streamResp OpenAIChatResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + return fmt.Errorf("error parsing stream response: %w", err) + } + + if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" { + chunk := streamResp.Choices[0].Delta.Content + fmt.Print(chunk) + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading response stream: %w", err) + } + + return nil +} + +func (c *Client) Remove(models []string, force bool) (string, error) { + modelRemoved := "" + for _, model := range models { + // Check if not a model ID passed as parameter. + if !strings.Contains(model, "/") { + if expanded, err := c.fullModelID(model); err == nil { + model = expanded + } + } + + // Construct the URL with query parameters + removePath := fmt.Sprintf("%s/%s?force=%s", + inference.ModelsPrefix, + model, + strconv.FormatBool(force), + ) + + resp, err := c.doRequest(http.MethodDelete, removePath, nil) + if err != nil { + return modelRemoved, c.handleQueryError(err, removePath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotFound { + return modelRemoved, fmt.Errorf("no such model: %s", model) + } + var bodyStr string + body, err := io.ReadAll(resp.Body) + if err != nil { + bodyStr = fmt.Sprintf("(failed to read response body: %v)", err) + } else { + bodyStr = string(body) + } + return modelRemoved, fmt.Errorf("removing %s failed with status %s: %s", model, resp.Status, bodyStr) + } + modelRemoved += fmt.Sprintf("Model %s removed successfully\n", model) + } + return modelRemoved, nil +} + +func URL(path string) string { + return fmt.Sprintf("http://localhost" + inference.ExperimentalEndpointsPrefix + path) +} + +// doRequest is a helper function that performs HTTP requests and handles 503 responses +func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest(method, URL(path), body) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.dockerClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusServiceUnavailable { + resp.Body.Close() + return nil, ErrServiceUnavailable + } + + return resp, nil +} + +func (c *Client) handleQueryError(err error, path string) error { + if errors.Is(err, ErrServiceUnavailable) { + return ErrServiceUnavailable + } + return fmt.Errorf("error querying %s: %w", path, err) +} + +func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { + // Check if the source is a model ID, and expand it if necessary + if !strings.Contains(strings.Trim(source, "/"), "/") { + // Do an extra API call to check if the model parameter might be a model ID + if expanded, err := c.fullModelID(source); err == nil { + source = expanded + } + } + + // Construct the URL with query parameters + tagPath := fmt.Sprintf("%s/%s/tag?repo=%s&tag=%s", + inference.ModelsPrefix, + source, + targetRepo, + targetTag, + ) + + resp, err := c.doRequest(http.MethodPost, tagPath, nil) + if err != nil { + return "", c.handleQueryError(err, tagPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("tagging failed with status %s: %s", resp.Status, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + return string(body), nil +} diff --git a/dockermodelrunner/go.mod b/dockermodelrunner/go.mod new file mode 100644 index 000000000..b0ad30508 --- /dev/null +++ b/dockermodelrunner/go.mod @@ -0,0 +1,5 @@ +module github.com/docker/docker-sdk-go/dockermodelrunner + +go 1.23.6 + +require github.com/pkg/errors v0.9.1 diff --git a/dockermodelrunner/go.sum b/dockermodelrunner/go.sum new file mode 100644 index 000000000..7c401c3f5 --- /dev/null +++ b/dockermodelrunner/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/dockermodelrunner/inference/api.go b/dockermodelrunner/inference/api.go new file mode 100644 index 000000000..9d6428caa --- /dev/null +++ b/dockermodelrunner/inference/api.go @@ -0,0 +1,12 @@ +package inference + +// ExperimentalEndpointsPrefix is used to prefix all routes on the Docker +// socket while they are still in their experimental stage. This prefix doesn't +// apply to endpoints on model-runner.docker.internal. +const ExperimentalEndpointsPrefix = "/exp/vDD4.40" + +// InferencePrefix is the prefix for inference related related routes. +var InferencePrefix = "/engines" + +// ModelsPrefix is the prefix for all model manager related routes. +var ModelsPrefix = "/models" diff --git a/dockermodelrunner/modeldistribution/types/types.go b/dockermodelrunner/modeldistribution/types/types.go new file mode 100644 index 000000000..51ea183c6 --- /dev/null +++ b/dockermodelrunner/modeldistribution/types/types.go @@ -0,0 +1,65 @@ +package types + +// Store interface for model storage operations +type Store interface { + // Push a model to the store with given tags + Push(modelPath string, tags []string) error + + // Pull a model by tag + Pull(tag string, destPath string) error + + // List all models in the store + List() ([]Model, error) + + // GetByTag Get model info by tag + GetByTag(tag string) (*Model, error) + + // Delete a model by tag + Delete(tag string) error + + // AddTags Add tags to an existing model + AddTags(tag string, newTags []string) error + + // RemoveTags Remove tags from a model + RemoveTags(tags []string) error + + // Version Get store version + Version() string + + // Upgrade store to latest version + Upgrade() error +} + +// Model represents a model with its metadata and tags +type Model struct { + // ID is the globally unique model identifier. + ID string `json:"id"` + // Tags are the list of tags associated with the model. + Tags []string `json:"tags"` + // Files are the GGUF files associated with the model. + Files []string `json:"files"` + // Created is the Unix epoch timestamp corresponding to the model creation. + Created int64 `json:"created"` +} + +// ModelIndex represents the index of all models in the store +type ModelIndex struct { + Models []Model `json:"models"` +} + +// StoreLayout represents the layout information of the store +type StoreLayout struct { + Version string `json:"version"` +} + +// ManifestReference represents a reference to a manifest in the store +type ManifestReference struct { + Digest string `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` +} + +// StoreOptions represents options for creating a store +type StoreOptions struct { + RootPath string +} diff --git a/dockermodelrunner/models/api.go b/dockermodelrunner/models/api.go new file mode 100644 index 000000000..1d63d4f3b --- /dev/null +++ b/dockermodelrunner/models/api.go @@ -0,0 +1,62 @@ +package models + +import "github.com/docker/docker-sdk-go/dockermodelrunner/modeldistribution/types" + +// ModelCreateRequest represents a model create request. It is designed to +// follow Docker Engine API conventions, most closely following the request +// associated with POST /images/create. At the moment is only designed to +// facilitate pulls, though in the future it may facilitate model building and +// refinement (such as fine tuning, quantization, or distillation). +type ModelCreateRequest struct { + // From is the name of the model to pull. + From string `json:"from"` +} + +// ToOpenAI converts a types.Model to its OpenAI API representation. +func ToOpenAI(m *types.Model) *OpenAIModel { + return &OpenAIModel{ + ID: m.Tags[0], + Object: "model", + Created: m.Created, + OwnedBy: "docker", + } +} + +// ModelList represents a list of models. +type ModelList []*types.Model + +// ToOpenAI converts the model list to its OpenAI API representation. This function never +// returns a nil slice (though it may return an empty slice). +func (l ModelList) toOpenAI() *OpenAIModelList { + // Convert the constituent models. + models := make([]*OpenAIModel, len(l)) + for m, model := range l { + models[m] = ToOpenAI(model) + } + + // Create the OpenAI model list. + return &OpenAIModelList{ + Object: "list", + Data: models, + } +} + +// OpenAIModel represents a locally stored model using OpenAI conventions. +type OpenAIModel struct { + // ID is the model tag. + ID string `json:"id"` + // Object is the object type. For OpenAIModel, it is always "model". + Object string `json:"object"` + // Created is the Unix epoch timestamp corresponding to the model creation. + Created int64 `json:"created"` + // OwnedBy is the model owner. At the moment, it is always "docker". + OwnedBy string `json:"owned_by"` +} + +// OpenAIModelList represents a list of models using OpenAI conventions. +type OpenAIModelList struct { + // Object is the object type. For OpenAIModelList, it is always "list". + Object string `json:"object"` + // Data is the list of models. + Data []*OpenAIModel `json:"data"` +} From 9e8f645d9d2eae6b67b08075e8de3b30c87aae3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Wed, 23 Apr 2025 18:22:31 +0200 Subject: [PATCH 02/17] chore: make method public --- dockermodelrunner/models/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockermodelrunner/models/api.go b/dockermodelrunner/models/api.go index 1d63d4f3b..4da681fd1 100644 --- a/dockermodelrunner/models/api.go +++ b/dockermodelrunner/models/api.go @@ -27,7 +27,7 @@ type ModelList []*types.Model // ToOpenAI converts the model list to its OpenAI API representation. This function never // returns a nil slice (though it may return an empty slice). -func (l ModelList) toOpenAI() *OpenAIModelList { +func (l ModelList) ToOpenAI() *OpenAIModelList { // Convert the constituent models. models := make([]*OpenAIModel, len(l)) for m, model := range l { From 4da0a972432791e2771ae017fcea3156f8d48fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Wed, 23 Apr 2025 18:27:05 +0200 Subject: [PATCH 03/17] docs: add Go docs for methods, types and vars --- dockermodelrunner/desktop/api.go | 89 ++++++++++++++++++++++------ dockermodelrunner/desktop/desktop.go | 23 ++++++- 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/dockermodelrunner/desktop/api.go b/dockermodelrunner/desktop/api.go index 27fa9cb45..ad116e150 100644 --- a/dockermodelrunner/desktop/api.go +++ b/dockermodelrunner/desktop/api.go @@ -6,63 +6,118 @@ type ProgressMessage struct { Message string `json:"message"` // Human-readable message } +// OpenAIChatMessage represents a message sent during OpenAI chat operations type OpenAIChatMessage struct { - Role string `json:"role"` + // Role is the role of the message sender. + Role string `json:"role"` + + // Content is the content of the message. Content string `json:"content"` } +// OpenAIChatRequest represents a request to the OpenAI chat API. type OpenAIChatRequest struct { - Model string `json:"model"` + // Model is the model to use for the chat. + Model string `json:"model"` + + // Messages is the list of messages to send to the chat. Messages []OpenAIChatMessage `json:"messages"` - Stream bool `json:"stream"` + + // Stream is whether to stream the response. + Stream bool `json:"stream"` } +// OpenAIChatResponse represents a response from the OpenAI chat API. type OpenAIChatResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` + // ID is the ID of the chat. + ID string `json:"id"` + + // Object is the object type. + Object string `json:"object"` + + // Created is the creation time of the chat. + Created int64 `json:"created"` + + // Model is the model used for the chat. + Model string `json:"model"` + + // Choices is the list of choices from the chat. Choices []struct { + // Delta is the delta of the choice. Delta struct { + // Content is the content of the choice. Content string `json:"content"` - Role string `json:"role,omitempty"` + + // Role is the role of the choice. + Role string `json:"role,omitempty"` } `json:"delta"` - Index int `json:"index"` + + // Index is the index of the choice. + Index int `json:"index"` + + // FinishReason is the reason the chat finished. FinishReason string `json:"finish_reason"` } `json:"choices"` } +// OpenAIModel represents a model in the OpenAI API. type OpenAIModel struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` + // ID is the ID of the model. + ID string `json:"id"` + + // Object is the object type. + Object string `json:"object"` + + // Created is the creation time of the model. + Created int64 `json:"created"` + + // OwnedBy is the owner of the model. OwnedBy string `json:"owned_by"` } +// OpenAIModelList represents a list of models in the OpenAI API. type OpenAIModelList struct { - Object string `json:"object"` - Data []*OpenAIModel `json:"data"` + // Object is the object type. + Object string `json:"object"` + + // Data is the list of models. + Data []*OpenAIModel `json:"data"` } +// Format represents the format of a model. // TODO: To be replaced by the Model struct from pianta's common/pkg/inference/models/api.go. // (https://github.com/docker/pinata/pull/33331) type Format string +// Config represents the configuration of a model. type Config struct { - Format Format `json:"format,omitempty"` + // Format is the format of the model. + Format Format `json:"format,omitempty"` + + // Quantization is the quantization of the model. Quantization string `json:"quantization,omitempty"` - Parameters string `json:"parameters,omitempty"` + + // Parameters is the parameters of the model. + Parameters string `json:"parameters,omitempty"` + + // Architecture is the architecture of the model. Architecture string `json:"architecture,omitempty"` - Size string `json:"size,omitempty"` + + // Size is the size of the model. + Size string `json:"size,omitempty"` } +// Model represents a model in the Docker Model Runner. type Model struct { // ID is the globally unique model identifier. ID string `json:"id"` + // Tags are the list of tags associated with the model. Tags []string `json:"tags"` + // Created is the Unix epoch timestamp corresponding to the model creation. Created int64 `json:"created"` + // Config describes the model. Config Config `json:"config"` } diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index dd5e4e55c..54a7ae653 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -18,23 +18,31 @@ import ( ) var ( - ErrNotFound = errors.New("model not found") + // ErrNotFound is returned when a model is not found. + ErrNotFound = errors.New("model not found") + + // ErrServiceUnavailable is returned when the service is unavailable. ErrServiceUnavailable = errors.New("service unavailable") ) +// Client is a client for the Docker Model Runner API. type Client struct { dockerClient DockerHttpClient } +// DockerHttpClient is an interface that can be used to mock the Docker client. +// //go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHttpClient type DockerHttpClient interface { Do(req *http.Request) (*http.Response, error) } +// New creates a new Client. func New(dockerClient DockerHttpClient) *Client { return &Client{dockerClient} } +// Status represents the status of the Docker Model Runner. type Status struct { Running bool `json:"running"` Status []byte `json:"status"` @@ -82,6 +90,7 @@ func (c *Client) Status() Status { } } +// Pull pulls a model from the Docker Model Runner. func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { @@ -137,6 +146,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) } +// Push pushes a model to the Docker Model Runner. func (c *Client) Push(model string, progress func(string)) (string, bool, error) { pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( @@ -187,6 +197,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model) } +// List lists all models in the Docker Model Runner. func (c *Client) List() ([]Model, error) { modelsRoute := inference.ModelsPrefix body, err := c.listRaw(modelsRoute, "") @@ -202,6 +213,7 @@ func (c *Client) List() ([]Model, error) { return modelsJson, nil } +// ListOpenAI lists all models in the Docker Model Runner using the OpenAI API. func (c *Client) ListOpenAI() (OpenAIModelList, error) { modelsRoute := inference.InferencePrefix + "/v1/models" rawResponse, err := c.listRaw(modelsRoute, "") @@ -215,6 +227,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { return modelsJson, nil } +// Inspect inspects a model in the Docker Model Runner. func (c *Client) Inspect(model string) (Model, error) { if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { @@ -238,6 +251,7 @@ func (c *Client) Inspect(model string) (Model, error) { return modelInspect, nil } +// InspectOpenAI inspects a model in the Docker Model Runner using the OpenAI API. func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { @@ -258,6 +272,7 @@ func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { return modelInspect, nil } +// listRaw lists all models in the Docker Model Runner. func (c *Client) listRaw(route string, model string) ([]byte, error) { resp, err := c.doRequest(http.MethodGet, route, nil) if err != nil { @@ -280,6 +295,7 @@ func (c *Client) listRaw(route string, model string) ([]byte, error) { } +// fullModelID returns the full model ID for a given model ID. func (c *Client) fullModelID(id string) (string, error) { bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") if err != nil { @@ -300,6 +316,7 @@ func (c *Client) fullModelID(id string) (string, error) { return "", fmt.Errorf("model with ID %s not found", id) } +// Chat chats with a model in the Docker Model Runner. func (c *Client) Chat(model, prompt string) error { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -375,6 +392,7 @@ func (c *Client) Chat(model, prompt string) error { return nil } +// Remove removes a model from the Docker Model Runner. func (c *Client) Remove(models []string, force bool) (string, error) { modelRemoved := "" for _, model := range models { @@ -416,6 +434,7 @@ func (c *Client) Remove(models []string, force bool) (string, error) { return modelRemoved, nil } +// URL returns the URL for the Docker Model Runner. func URL(path string) string { return fmt.Sprintf("http://localhost" + inference.ExperimentalEndpointsPrefix + path) } @@ -443,6 +462,7 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, return resp, nil } +// handleQueryError is a helper function that handles query errors. func (c *Client) handleQueryError(err error, path string) error { if errors.Is(err, ErrServiceUnavailable) { return ErrServiceUnavailable @@ -450,6 +470,7 @@ func (c *Client) handleQueryError(err error, path string) error { return fmt.Errorf("error querying %s: %w", path, err) } +// Tag tags a model in the Docker Model Runner. func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { From a09ef8d306e2c61004149c9e401ff241010a9688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 07:47:17 +0200 Subject: [PATCH 04/17] chore: add to go work --- go.work | 1 + 1 file changed, 1 insertion(+) diff --git a/go.work b/go.work index f51212b77..8e9c2e9bb 100644 --- a/go.work +++ b/go.work @@ -3,4 +3,5 @@ go 1.23.6 use ( ./dockerconfig ./dockercontext + ./dockermodelrunner ) From da23ff9597dc3b7d49969ef251946c4e8f3216d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 08:00:48 +0200 Subject: [PATCH 05/17] fix: lint --- dockermodelrunner/desktop/desktop.go | 35 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 54a7ae653..d8370395c 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -27,18 +27,18 @@ var ( // Client is a client for the Docker Model Runner API. type Client struct { - dockerClient DockerHttpClient + dockerClient DockerHTTPClient } -// DockerHttpClient is an interface that can be used to mock the Docker client. +// DockerHTTPClient is an interface that can be used to mock the Docker client. // -//go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHttpClient -type DockerHttpClient interface { +//go:generate mockgen -source=desktop.go -destination=../mocks/mock_desktop.go -package=mocks DockerHTTPClient +type DockerHTTPClient interface { Do(req *http.Request) (*http.Response, error) } // New creates a new Client. -func New(dockerClient DockerHttpClient) *Client { +func New(dockerClient DockerHTTPClient) *Client { return &Client{dockerClient} } @@ -205,12 +205,12 @@ func (c *Client) List() ([]Model, error) { return []Model{}, err } - var modelsJson []Model - if err := json.Unmarshal(body, &modelsJson); err != nil { - return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err) + var modelsJSON []Model + if err := json.Unmarshal(body, &modelsJSON); err != nil { + return modelsJSON, fmt.Errorf("failed to unmarshal response body: %w", err) } - return modelsJson, nil + return modelsJSON, nil } // ListOpenAI lists all models in the Docker Model Runner using the OpenAI API. @@ -220,11 +220,11 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { if err != nil { return OpenAIModelList{}, err } - var modelsJson OpenAIModelList - if err := json.Unmarshal(rawResponse, &modelsJson); err != nil { - return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err) + var modelsJSON OpenAIModelList + if err := json.Unmarshal(rawResponse, &modelsJSON); err != nil { + return modelsJSON, fmt.Errorf("failed to unmarshal response body: %w", err) } - return modelsJson, nil + return modelsJSON, nil } // Inspect inspects a model in the Docker Model Runner. @@ -292,7 +292,6 @@ func (c *Client) listRaw(route string, model string) ([]byte, error) { return nil, fmt.Errorf("failed to read response body: %w", err) } return body, nil - } // fullModelID returns the full model ID for a given model ID. @@ -302,12 +301,12 @@ func (c *Client) fullModelID(id string) (string, error) { return "", err } - var modelsJson []Model - if err := json.Unmarshal(bodyResponse, &modelsJson); err != nil { + var modelsJSON []Model + if err := json.Unmarshal(bodyResponse, &modelsJSON); err != nil { return "", fmt.Errorf("failed to unmarshal response body: %w", err) } - for _, m := range modelsJson { + for _, m := range modelsJSON { if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { return m.ID, nil } @@ -436,7 +435,7 @@ func (c *Client) Remove(models []string, force bool) (string, error) { // URL returns the URL for the Docker Model Runner. func URL(path string) string { - return fmt.Sprintf("http://localhost" + inference.ExperimentalEndpointsPrefix + path) + return "http://localhost" + inference.ExperimentalEndpointsPrefix + path } // doRequest is a helper function that performs HTTP requests and handles 503 responses From 75c2876e07b28902e033a96144913dfdb028ed2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 08:04:23 +0200 Subject: [PATCH 06/17] fix: user errors from stdlib --- dockermodelrunner/desktop/desktop.go | 5 ++--- dockermodelrunner/go.mod | 2 -- dockermodelrunner/go.sum | 2 -- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index d8370395c..1ad8368d4 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "html" "io" @@ -11,8 +12,6 @@ import ( "strconv" "strings" - "github.com/pkg/errors" - "github.com/docker/docker-sdk-go/dockermodelrunner/inference" "github.com/docker/docker-sdk-go/dockermodelrunner/models" ) @@ -282,7 +281,7 @@ func (c *Client) listRaw(route string, model string) ([]byte, error) { if resp.StatusCode != http.StatusOK { if model != "" && resp.StatusCode == http.StatusNotFound { - return nil, errors.Wrap(ErrNotFound, model) + return nil, fmt.Errorf("%w: %s", ErrNotFound, model) } return nil, fmt.Errorf("failed to list models: %s", resp.Status) } diff --git a/dockermodelrunner/go.mod b/dockermodelrunner/go.mod index b0ad30508..6868f0b9e 100644 --- a/dockermodelrunner/go.mod +++ b/dockermodelrunner/go.mod @@ -1,5 +1,3 @@ module github.com/docker/docker-sdk-go/dockermodelrunner go 1.23.6 - -require github.com/pkg/errors v0.9.1 diff --git a/dockermodelrunner/go.sum b/dockermodelrunner/go.sum index 7c401c3f5..e69de29bb 100644 --- a/dockermodelrunner/go.sum +++ b/dockermodelrunner/go.sum @@ -1,2 +0,0 @@ -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= From 855f35458b67b749b08de751c5fcdfd543ead2b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 10:55:09 +0200 Subject: [PATCH 07/17] chore: idiomatic serialisation of time.Time --- .../modeldistribution/types/types.go | 39 ++++++++++- .../modeldistribution/types/types_test.go | 64 +++++++++++++++++++ dockermodelrunner/models/api.go | 41 +++++++++++- dockermodelrunner/models/api_test.go | 58 +++++++++++++++++ 4 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 dockermodelrunner/modeldistribution/types/types_test.go create mode 100644 dockermodelrunner/models/api_test.go diff --git a/dockermodelrunner/modeldistribution/types/types.go b/dockermodelrunner/modeldistribution/types/types.go index 51ea183c6..fd915bc27 100644 --- a/dockermodelrunner/modeldistribution/types/types.go +++ b/dockermodelrunner/modeldistribution/types/types.go @@ -1,5 +1,11 @@ package types +import ( + "encoding/json" + "fmt" + "time" +) + // Store interface for model storage operations type Store interface { // Push a model to the store with given tags @@ -39,7 +45,38 @@ type Model struct { // Files are the GGUF files associated with the model. Files []string `json:"files"` // Created is the Unix epoch timestamp corresponding to the model creation. - Created int64 `json:"created"` + Created time.Time `json:"created"` +} + +// modelAlias is an alias for Model to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want Model to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type modelAlias Model + +// modelResponseJSON is a struct used for JSON marshaling/unmarshaling of Model. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type modelResponseJSON struct { + modelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *Model) UnmarshalJSON(b []byte) error { + var resp modelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = Model(resp.modelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr Model) MarshalJSON() ([]byte, error) { + return json.Marshal(modelResponseJSON{ + modelAlias: modelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) } // ModelIndex represents the index of all models in the store diff --git a/dockermodelrunner/modeldistribution/types/types_test.go b/dockermodelrunner/modeldistribution/types/types_test.go new file mode 100644 index 000000000..1be7fefcb --- /dev/null +++ b/dockermodelrunner/modeldistribution/types/types_test.go @@ -0,0 +1,64 @@ +package types + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "files": ["file1", "file2"], + "created": 1682179200 + }` + + var response Model + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Files: []string{ + "file1", + "file2", + }, + Created: time.Unix(1682179200, 0), + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "files": ["file1", "file2"], + "created": "not-a-number" + }` + + var response Model + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Files: []string{ + "file1", + "file2", + }, + Created: time.Unix(1682179200, 0), + } + + expectedJSON := `{"id":"model123","tags":["tag1","tag2"],"files":["file1","file2"],"created":1682179200}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} diff --git a/dockermodelrunner/models/api.go b/dockermodelrunner/models/api.go index 4da681fd1..715861b2a 100644 --- a/dockermodelrunner/models/api.go +++ b/dockermodelrunner/models/api.go @@ -1,6 +1,12 @@ package models -import "github.com/docker/docker-sdk-go/dockermodelrunner/modeldistribution/types" +import ( + "encoding/json" + "fmt" + "time" + + "github.com/docker/docker-sdk-go/dockermodelrunner/modeldistribution/types" +) // ModelCreateRequest represents a model create request. It is designed to // follow Docker Engine API conventions, most closely following the request @@ -48,7 +54,7 @@ type OpenAIModel struct { // Object is the object type. For OpenAIModel, it is always "model". Object string `json:"object"` // Created is the Unix epoch timestamp corresponding to the model creation. - Created int64 `json:"created"` + Created time.Time `json:"created"` // OwnedBy is the model owner. At the moment, it is always "docker". OwnedBy string `json:"owned_by"` } @@ -60,3 +66,34 @@ type OpenAIModelList struct { // Data is the list of models. Data []*OpenAIModel `json:"data"` } + +// openAIModelAlias is an alias for OpenAIModel to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want OpenAIModel to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type openAIModelAlias OpenAIModel + +// openAIModelResponseJSON is a struct used for JSON marshaling/unmarshaling of OpenAIModel. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type openAIModelResponseJSON struct { + openAIModelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *OpenAIModel) UnmarshalJSON(b []byte) error { + var resp openAIModelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = OpenAIModel(resp.openAIModelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr OpenAIModel) MarshalJSON() ([]byte, error) { + return json.Marshal(openAIModelResponseJSON{ + openAIModelAlias: openAIModelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) +} diff --git a/dockermodelrunner/models/api_test.go b/dockermodelrunner/models/api_test.go new file mode 100644 index 000000000..43d88c881 --- /dev/null +++ b/dockermodelrunner/models/api_test.go @@ -0,0 +1,58 @@ +package models + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "object": "model", + "created": 1682179200, + "owned_by": "docker" + }` + + var response OpenAIModel + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, OpenAIModel{ + ID: "model123", + Object: "model", + Created: time.Unix(1682179200, 0), + OwnedBy: "docker", + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "object": "model", + "created": "not-a-number", + "owned_by": "docker" + }` + + var response OpenAIModel + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := OpenAIModel{ + ID: "model123", + Object: "model", + Created: time.Unix(1682179200, 0), + OwnedBy: "docker", + } + + expectedJSON := `{"id":"model123","object":"model","created":1682179200,"owned_by":"docker"}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} From c709201076e6f6955e52c315f0bcc28f52f14af5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:06:35 +0200 Subject: [PATCH 08/17] chore: proper error messages --- dockermodelrunner/desktop/desktop.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 1ad8368d4..4505c1b00 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -206,7 +206,7 @@ func (c *Client) List() ([]Model, error) { var modelsJSON []Model if err := json.Unmarshal(body, &modelsJSON); err != nil { - return modelsJSON, fmt.Errorf("failed to unmarshal response body: %w", err) + return modelsJSON, fmt.Errorf("unmarshal response body: %w", err) } return modelsJSON, nil @@ -221,7 +221,7 @@ func (c *Client) ListOpenAI() (OpenAIModelList, error) { } var modelsJSON OpenAIModelList if err := json.Unmarshal(rawResponse, &modelsJSON); err != nil { - return modelsJSON, fmt.Errorf("failed to unmarshal response body: %w", err) + return modelsJSON, fmt.Errorf("unmarshal response body: %w", err) } return modelsJSON, nil } @@ -244,7 +244,7 @@ func (c *Client) Inspect(model string) (Model, error) { } var modelInspect Model if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { - return modelInspect, fmt.Errorf("failed to unmarshal response body: %w", err) + return modelInspect, fmt.Errorf("unmarshal response body: %w", err) } return modelInspect, nil @@ -266,7 +266,7 @@ func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { } var modelInspect OpenAIModel if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { - return modelInspect, fmt.Errorf("failed to unmarshal response body: %w", err) + return modelInspect, fmt.Errorf("unmarshal response body: %w", err) } return modelInspect, nil } @@ -283,12 +283,12 @@ func (c *Client) listRaw(route string, model string) ([]byte, error) { if model != "" && resp.StatusCode == http.StatusNotFound { return nil, fmt.Errorf("%w: %s", ErrNotFound, model) } - return nil, fmt.Errorf("failed to list models: %s", resp.Status) + return nil, fmt.Errorf("list models: %s", resp.Status) } body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("read response body: %w", err) } return body, nil } @@ -302,7 +302,7 @@ func (c *Client) fullModelID(id string) (string, error) { var modelsJSON []Model if err := json.Unmarshal(bodyResponse, &modelsJSON); err != nil { - return "", fmt.Errorf("failed to unmarshal response body: %w", err) + return "", fmt.Errorf("unmarshal response body: %w", err) } for _, m := range modelsJSON { @@ -499,7 +499,7 @@ func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) + return "", fmt.Errorf("read response body: %w", err) } return string(body), nil From f384060b62da543047b14f35ed28dd01ee92002e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:18:15 +0200 Subject: [PATCH 09/17] chore: adjust error messages --- dockermodelrunner/desktop/desktop.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 4505c1b00..61e29a62c 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -93,7 +93,7 @@ func (c *Client) Status() Status { func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { - return "", false, fmt.Errorf("error marshaling request: %w", err) + return "", false, fmt.Errorf("marshal request: %w", err) } createPath := inference.ModelsPrefix + "/create" @@ -109,7 +109,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("pull %s failed with status %s: %s", model, resp.Status, string(body)) } progressShown := false @@ -124,7 +124,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) // Parse the progress message var progressMsg ProgressMessage if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) + return "", progressShown, fmt.Errorf("unmarshal progress message: %w", err) } // Handle different message types @@ -133,7 +133,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) progress(progressMsg.Message) progressShown = true case "error": - return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) + return "", progressShown, fmt.Errorf("pull %s: %s", model, progressMsg.Message) case "success": return progressMsg.Message, progressShown, nil default: @@ -160,7 +160,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("push %s failed with status %s: %s", model, resp.Status, string(body)) } progressShown := false @@ -175,7 +175,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) // Parse the progress message var progressMsg ProgressMessage if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) + return "", progressShown, fmt.Errorf("unmarshal progress message: %w", err) } // Handle different message types @@ -184,7 +184,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) progress(progressMsg.Message) progressShown = true case "error": - return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message) + return "", progressShown, fmt.Errorf("push %s: %s", model, progressMsg.Message) case "success": return progressMsg.Message, progressShown, nil default: @@ -336,7 +336,7 @@ func (c *Client) Chat(model, prompt string) error { jsonData, err := json.Marshal(reqBody) if err != nil { - return fmt.Errorf("error marshaling request: %w", err) + return fmt.Errorf("marshal request: %w", err) } chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions" @@ -374,7 +374,7 @@ func (c *Client) Chat(model, prompt string) error { var streamResp OpenAIChatResponse if err := json.Unmarshal([]byte(data), &streamResp); err != nil { - return fmt.Errorf("error parsing stream response: %w", err) + return fmt.Errorf("unmarshal stream response: %w", err) } if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" { @@ -384,7 +384,7 @@ func (c *Client) Chat(model, prompt string) error { } if err := scanner.Err(); err != nil { - return fmt.Errorf("error reading response stream: %w", err) + return fmt.Errorf("read response stream: %w", err) } return nil @@ -441,7 +441,7 @@ func URL(path string) string { func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { req, err := http.NewRequest(method, URL(path), body) if err != nil { - return nil, fmt.Errorf("error creating request: %w", err) + return nil, fmt.Errorf("new %s request: %w", method, err) } if body != nil { req.Header.Set("Content-Type", "application/json") @@ -465,7 +465,7 @@ func (c *Client) handleQueryError(err error, path string) error { if errors.Is(err, ErrServiceUnavailable) { return ErrServiceUnavailable } - return fmt.Errorf("error querying %s: %w", path, err) + return fmt.Errorf("query %s: %w", path, err) } // Tag tags a model in the Docker Model Runner. @@ -494,7 +494,7 @@ func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { if resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("tagging failed with status %s: %s", resp.Status, string(body)) + return "", fmt.Errorf("tag with status %s: %s", resp.Status, string(body)) } body, err := io.ReadAll(resp.Body) From afd4769348e5644218a360ea7d1acb2e3ef0fad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:22:26 +0200 Subject: [PATCH 10/17] chore: consistent error message when status is not OK --- dockermodelrunner/desktop/desktop.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 61e29a62c..877a1bcc7 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -109,7 +109,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", false, fmt.Errorf("pull %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("pull %s status=%d body=%s", model, resp.StatusCode, body) } progressShown := false @@ -160,7 +160,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", false, fmt.Errorf("push %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("push %s status=%d body=%s", model, resp.StatusCode, body) } progressShown := false @@ -352,7 +352,7 @@ func (c *Client) Chat(model, prompt string) error { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body) + return fmt.Errorf("chat with %s status=%d body=%s", model, resp.StatusCode, body) } scanner := bufio.NewScanner(resp.Body) @@ -494,7 +494,7 @@ func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { if resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("tag with status %s: %s", resp.Status, string(body)) + return "", fmt.Errorf("tag %s:%s status=%d body=%s", targetRepo, targetTag, resp.StatusCode, body) } body, err := io.ReadAll(resp.Body) From f42ccd33ab7f46319e963d7f4dee8a19aab6c0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:25:13 +0200 Subject: [PATCH 11/17] chore: inverse logic for readability --- dockermodelrunner/desktop/desktop.go | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 877a1bcc7..3001f3bec 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -64,28 +64,29 @@ func (c *Client) Status() Status { } } defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - var status []byte - statusResp, err := c.doRequest(http.MethodGet, inference.InferencePrefix+"/status", nil) + if resp.StatusCode != http.StatusOK { + return Status{ + Running: false, + Error: fmt.Errorf("unexpected status code: %d", resp.StatusCode), + } + } + + var status []byte + statusResp, err := c.doRequest(http.MethodGet, inference.InferencePrefix+"/status", nil) + if err != nil { + status = []byte(fmt.Sprintf("error querying status: %v", err)) + } else { + defer statusResp.Body.Close() + statusBody, err := io.ReadAll(statusResp.Body) if err != nil { - status = []byte(fmt.Sprintf("error querying status: %v", err)) + status = []byte(fmt.Sprintf("error reading status body: %v", err)) } else { - defer statusResp.Body.Close() - statusBody, err := io.ReadAll(statusResp.Body) - if err != nil { - status = []byte(fmt.Sprintf("error reading status body: %v", err)) - } else { - status = statusBody - } - } - return Status{ - Running: true, - Status: status, + status = statusBody } } return Status{ - Running: false, - Error: fmt.Errorf("unexpected status code: %d", resp.StatusCode), + Running: true, + Status: status, } } From 525757a971c0dc41df7243f31f1eb11b32bdac37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:27:53 +0200 Subject: [PATCH 12/17] chore: return idiomatic nils --- dockermodelrunner/desktop/desktop.go | 39 ++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 3001f3bec..8be13da4a 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -202,74 +202,81 @@ func (c *Client) List() ([]Model, error) { modelsRoute := inference.ModelsPrefix body, err := c.listRaw(modelsRoute, "") if err != nil { - return []Model{}, err + return nil, err } var modelsJSON []Model if err := json.Unmarshal(body, &modelsJSON); err != nil { - return modelsJSON, fmt.Errorf("unmarshal response body: %w", err) + return nil, fmt.Errorf("unmarshal response body: %w", err) } return modelsJSON, nil } // ListOpenAI lists all models in the Docker Model Runner using the OpenAI API. -func (c *Client) ListOpenAI() (OpenAIModelList, error) { +func (c *Client) ListOpenAI() (*OpenAIModelList, error) { modelsRoute := inference.InferencePrefix + "/v1/models" rawResponse, err := c.listRaw(modelsRoute, "") if err != nil { - return OpenAIModelList{}, err + return nil, err } + var modelsJSON OpenAIModelList if err := json.Unmarshal(rawResponse, &modelsJSON); err != nil { - return modelsJSON, fmt.Errorf("unmarshal response body: %w", err) + return nil, fmt.Errorf("unmarshal response body: %w", err) } - return modelsJSON, nil + + return &modelsJSON, nil } // Inspect inspects a model in the Docker Model Runner. -func (c *Client) Inspect(model string) (Model, error) { +func (c *Client) Inspect(model string) (*Model, error) { if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. modelId, err := c.fullModelID(model) if err != nil { - return Model{}, fmt.Errorf("invalid model name: %s", model) + return nil, fmt.Errorf("invalid model name: %s", model) } model = modelId } } + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model) if err != nil { - return Model{}, err + return nil, err } + var modelInspect Model if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { - return modelInspect, fmt.Errorf("unmarshal response body: %w", err) + return nil, fmt.Errorf("unmarshal response body: %w", err) } - return modelInspect, nil + return &modelInspect, nil } // InspectOpenAI inspects a model in the Docker Model Runner using the OpenAI API. -func (c *Client) InspectOpenAI(model string) (OpenAIModel, error) { +func (c *Client) InspectOpenAI(model string) (*OpenAIModel, error) { modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. var err error if model, err = c.fullModelID(model); err != nil { - return OpenAIModel{}, fmt.Errorf("invalid model name: %s", model) + return nil, fmt.Errorf("invalid model name: %s", model) } } + rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) if err != nil { - return OpenAIModel{}, err + return nil, err } + var modelInspect OpenAIModel if err := json.Unmarshal(rawResponse, &modelInspect); err != nil { - return modelInspect, fmt.Errorf("unmarshal response body: %w", err) + return nil, fmt.Errorf("unmarshal response body: %w", err) } - return modelInspect, nil + + return &modelInspect, nil } // listRaw lists all models in the Docker Model Runner. From 137d8d2f1819e28117e593a0ecb92e6164fd7569 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:28:37 +0200 Subject: [PATCH 13/17] docs: add method comment --- dockermodelrunner/desktop/desktop.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 8be13da4a..4b87ce4a5 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -48,6 +48,7 @@ type Status struct { Error error `json:"error"` } +// Status returns the status of the Docker Model Runner. func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) From 061835ed34c8dad5d3c797bb1d7279fed6317407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:35:10 +0200 Subject: [PATCH 14/17] chore: extract scanning progress to a function --- dockermodelrunner/desktop/desktop.go | 42 ++++++---------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/dockermodelrunner/desktop/desktop.go b/dockermodelrunner/desktop/desktop.go index 4b87ce4a5..e11067a8a 100644 --- a/dockermodelrunner/desktop/desktop.go +++ b/dockermodelrunner/desktop/desktop.go @@ -114,6 +114,11 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) return "", false, fmt.Errorf("pull %s status=%d body=%s", model, resp.StatusCode, body) } + return scanProgress(resp, "pull", model, progress) +} + +// scanProgress scans the progress of a model for a given action. +func scanProgress(resp *http.Response, action string, model string, progress func(string)) (string, bool, error) { progressShown := false scanner := bufio.NewScanner(resp.Body) @@ -135,7 +140,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) progress(progressMsg.Message) progressShown = true case "error": - return "", progressShown, fmt.Errorf("pull %s: %s", model, progressMsg.Message) + return "", progressShown, fmt.Errorf("%s %s: %s", action, model, progressMsg.Message) case "success": return progressMsg.Message, progressShown, nil default: @@ -143,8 +148,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) } } - // If we get here, something went wrong - return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) + return "", progressShown, fmt.Errorf("%s model %s: unexpected end of stream", action, model) } // Push pushes a model to the Docker Model Runner. @@ -165,37 +169,7 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) return "", false, fmt.Errorf("push %s status=%d body=%s", model, resp.StatusCode, body) } - progressShown := false - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - progressLine := scanner.Text() - if progressLine == "" { - continue - } - - // Parse the progress message - var progressMsg ProgressMessage - if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", progressShown, fmt.Errorf("unmarshal progress message: %w", err) - } - - // Handle different message types - switch progressMsg.Type { - case "progress": - progress(progressMsg.Message) - progressShown = true - case "error": - return "", progressShown, fmt.Errorf("push %s: %s", model, progressMsg.Message) - case "success": - return progressMsg.Message, progressShown, nil - default: - return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) - } - } - - // If we get here, something went wrong - return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model) + return scanProgress(resp, "push", model, progress) } // List lists all models in the Docker Model Runner. From b36ec438177f72d6de0f373f6cb664d2040802f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:35:31 +0200 Subject: [PATCH 15/17] chore: mod tidy --- dockermodelrunner/go.mod | 11 +++++++++++ dockermodelrunner/go.sum | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/dockermodelrunner/go.mod b/dockermodelrunner/go.mod index 6868f0b9e..9d9ee2c82 100644 --- a/dockermodelrunner/go.mod +++ b/dockermodelrunner/go.mod @@ -1,3 +1,14 @@ module github.com/docker/docker-sdk-go/dockermodelrunner go 1.23.6 + +require github.com/stretchr/testify v1.10.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/dockermodelrunner/go.sum b/dockermodelrunner/go.sum index e69de29bb..a2aa25d63 100644 --- a/dockermodelrunner/go.sum +++ b/dockermodelrunner/go.sum @@ -0,0 +1,8 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 55d46965941f9cf314d992a9a9957f8ce7f29916 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:39:07 +0200 Subject: [PATCH 16/17] chore: more idiomatic usage of time.Time for models --- dockermodelrunner/desktop/api.go | 39 ++++++++++++- dockermodelrunner/desktop/api_test.go | 82 +++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 dockermodelrunner/desktop/api_test.go diff --git a/dockermodelrunner/desktop/api.go b/dockermodelrunner/desktop/api.go index ad116e150..801edaf9f 100644 --- a/dockermodelrunner/desktop/api.go +++ b/dockermodelrunner/desktop/api.go @@ -1,5 +1,11 @@ package desktop +import ( + "encoding/json" + "fmt" + "time" +) + // ProgressMessage represents a message sent during model pull operations type ProgressMessage struct { Type string `json:"type"` // "progress", "success", or "error" @@ -116,8 +122,39 @@ type Model struct { Tags []string `json:"tags"` // Created is the Unix epoch timestamp corresponding to the model creation. - Created int64 `json:"created"` + Created time.Time `json:"created"` // Config describes the model. Config Config `json:"config"` } + +// modelAlias is an alias for Model to avoid recursion in JSON marshaling/unmarshaling. +// This is necessary because we want Model to contain a time.Time field which is not directly +// compatible with JSON serialization/deserialization. +type modelAlias Model + +// modelResponseJSON is a struct used for JSON marshaling/unmarshaling of Model. +// It includes a Unix timestamp for the Created field to ensure compatibility with JSON. +type modelResponseJSON struct { + modelAlias + CreatedAt int64 `json:"created"` +} + +// UnmarshalJSON implements json.Unmarshaler. +func (mr *Model) UnmarshalJSON(b []byte) error { + var resp modelResponseJSON + if err := json.Unmarshal(b, &resp); err != nil { + return fmt.Errorf("unmarshal model response: %w", err) + } + *mr = Model(resp.modelAlias) + mr.Created = time.Unix(resp.CreatedAt, 0) + return nil +} + +// MarshalJSON implements json.Marshaler. +func (mr Model) MarshalJSON() ([]byte, error) { + return json.Marshal(modelResponseJSON{ + modelAlias: modelAlias(mr), + CreatedAt: mr.Created.Unix(), + }) +} diff --git a/dockermodelrunner/desktop/api_test.go b/dockermodelrunner/desktop/api_test.go new file mode 100644 index 000000000..43406ec14 --- /dev/null +++ b/dockermodelrunner/desktop/api_test.go @@ -0,0 +1,82 @@ +package desktop + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalJSON(t *testing.T) { + jsonData := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "config": { + "format": "format1", + "quantization": "quantization1", + "parameters": "parameters1", + "architecture": "architecture1", + "size": "size1" + }, + "created": 1682179200 + }` + + var response Model + err := json.Unmarshal([]byte(jsonData), &response) + require.NoError(t, err) + require.Equal(t, Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Config: Config{ + Format: "format1", + Quantization: "quantization1", + Parameters: "parameters1", + Architecture: "architecture1", + Size: "size1", + }, + Created: time.Unix(1682179200, 0), + }, response) +} + +func TestUnmarshalJSONError(t *testing.T) { + // Invalid JSON with malformed created timestamp + invalidJSON := `{ + "id": "model123", + "tags": ["tag1", "tag2"], + "config": { + "format": "format1", + "quantization": "quantization1", + "parameters": "parameters1", + "architecture": "architecture1", + "size": "size1" + }, + "created": "not-a-number" + }` + + var response Model + err := json.Unmarshal([]byte(invalidJSON), &response) + require.Error(t, err) + require.Contains(t, err.Error(), "unmarshal model response") +} + +func TestMarshalJSON(t *testing.T) { + response := Model{ + ID: "model123", + Tags: []string{"tag1", "tag2"}, + Config: Config{ + Format: "format1", + Quantization: "quantization1", + Parameters: "parameters1", + Architecture: "architecture1", + Size: "size1", + }, + Created: time.Unix(1682179200, 0), + } + + expectedJSON := `{"id":"model123","tags":["tag1","tag2"],"config":{"format":"format1","quantization":"quantization1","parameters":"parameters1","architecture":"architecture1","size":"size1"},"created":1682179200}` + + jsonData, err := json.Marshal(response) + require.NoError(t, err, "Failed to marshal JSON") + require.JSONEq(t, expectedJSON, string(jsonData), "Unexpected JSON output") +} From 9d5463645314d9658edd94894a778adc4d628028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20de=20la=20Pe=C3=B1a?= Date: Thu, 24 Apr 2025 11:39:51 +0200 Subject: [PATCH 17/17] chore: mod tidy --- dockermodelrunner/go.sum | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dockermodelrunner/go.sum b/dockermodelrunner/go.sum index a2aa25d63..e48aae2f5 100644 --- a/dockermodelrunner/go.sum +++ b/dockermodelrunner/go.sum @@ -1,8 +1,23 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=