Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_model v0.6.2
github.com/prometheus/common v0.67.5
github.com/sirupsen/logrus v1.9.4
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
github.com/stretchr/testify v1.11.1
Expand Down
7 changes: 5 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY=
github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
Expand All @@ -270,6 +270,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
Expand Down Expand Up @@ -353,6 +354,7 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down Expand Up @@ -393,6 +395,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
Expand Down
110 changes: 66 additions & 44 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"crypto/tls"
"log/slog"
"net"
"net/http"
"os"
Expand All @@ -24,28 +25,37 @@ import (
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/platform"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/docker/model-runner/pkg/middleware"
"github.com/docker/model-runner/pkg/ollama"
"github.com/docker/model-runner/pkg/responses"
"github.com/docker/model-runner/pkg/routing"
modeltls "github.com/docker/model-runner/pkg/tls"
"github.com/sirupsen/logrus"
)

const (
// DefaultTLSPort is the default TLS port for Moby
DefaultTLSPort = "12444"
)

var log = logrus.New()
// initLogger creates the application logger based on LOG_LEVEL env var.
func initLogger() *slog.Logger {
level := logging.ParseLevel(os.Getenv("LOG_LEVEL"))
return logging.NewLogger(level)
}

var log = initLogger()

// Log is the logger used by the application, exported for testing purposes.
var Log = log

// testLog is a test-override logger used by createLlamaCppConfigFromEnv.
var testLog = log

// exitFunc is the function called for fatal errors. Overridable in tests.
var exitFunc = os.Exit

func main() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
Expand All @@ -57,7 +67,8 @@ func main() {

userHomeDir, err := os.UserHomeDir()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
log.Error("Failed to get user home directory", "error", err)
os.Exit(1)
}

modelPath := os.Getenv("MODELS_PATH")
Expand Down Expand Up @@ -101,27 +112,27 @@ func main() {

clientConfig := models.ClientConfig{
StoreRootPath: modelPath,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Logger: log.With("component", "model-manager"),
Transport: baseTransport,
}
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
modelManager := models.NewManager(log.With("component", "model-manager"), clientConfig)
modelHandler := models.NewHTTPHandler(
log,
modelManager,
nil,
)
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
log.Info("LLAMA_SERVER_PATH", "path", llamaServerPath)
if vllmServerPath != "" {
log.Infof("VLLM_SERVER_PATH: %s", vllmServerPath)
log.Info("VLLM_SERVER_PATH", "path", vllmServerPath)
}
if sglangServerPath != "" {
log.Infof("SGLANG_SERVER_PATH: %s", sglangServerPath)
log.Info("SGLANG_SERVER_PATH", "path", sglangServerPath)
}
if mlxServerPath != "" {
log.Infof("MLX_SERVER_PATH: %s", mlxServerPath)
log.Info("MLX_SERVER_PATH", "path", mlxServerPath)
}
if vllmMetalServerPath != "" {
log.Infof("VLLM_METAL_SERVER_PATH: %s", vllmMetalServerPath)
log.Info("VLLM_METAL_SERVER_PATH", "path", vllmMetalServerPath)
}

// Create llama.cpp configuration from environment variables
Expand All @@ -130,7 +141,7 @@ func main() {
llamaCppBackend, err := llamacpp.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": llamacpp.Name}),
log.With("component", llamacpp.Name),
llamaServerPath,
func() string {
wd, _ := os.Getwd()
Expand All @@ -141,58 +152,63 @@ func main() {
llamaCppConfig,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
log.Error("Unable to initialize backend", "backend", llamacpp.Name, "error", err)
os.Exit(1)
}

vllmBackend, err := initVLLMBackend(log, modelManager, vllmServerPath)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", vllm.Name, err)
log.Error("Unable to initialize backend", "backend", vllm.Name, "error", err)
os.Exit(1)
}

mlxBackend, err := mlx.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": mlx.Name}),
log.With("component", mlx.Name),
nil,
mlxServerPath,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", mlx.Name, err)
log.Error("Unable to initialize backend", "backend", mlx.Name, "error", err)
os.Exit(1)
}

sglangBackend, err := sglang.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": sglang.Name}),
log.With("component", sglang.Name),
nil,
sglangServerPath,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err)
log.Error("Unable to initialize backend", "backend", sglang.Name, "error", err)
os.Exit(1)
}

diffusersBackend, err := diffusers.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": diffusers.Name}),
log.With("component", diffusers.Name),
nil,
diffusersServerPath,
)

if err != nil {
log.Fatalf("unable to initialize diffusers backend: %v", err)
log.Error("Unable to initialize backend", "backend", diffusers.Name, "error", err)
os.Exit(1)
}

var vllmMetalBackend inference.Backend
if platform.SupportsVLLMMetal() {
vllmMetalBackend, err = vllmmetal.New(
log,
modelManager,
log.WithFields(logrus.Fields{"component": vllmmetal.Name}),
log.With("component", vllmmetal.Name),
vllmMetalServerPath,
)
if err != nil {
log.Warnf("Failed to initialize vllm-metal backend: %v", err)
log.Warn("Failed to initialize vllm-metal backend", "error", err)
}
}

Expand Down Expand Up @@ -222,7 +238,7 @@ func main() {
http.DefaultClient,
metrics.NewTracker(
http.DefaultClient,
log.WithField("component", "metrics"),
log.With("component", "metrics"),
"",
false,
),
Expand Down Expand Up @@ -278,7 +294,7 @@ func main() {
// Add metrics endpoint if enabled
if os.Getenv("DISABLE_METRICS") != "1" {
metricsHandler := metrics.NewAggregatedMetricsHandler(
log.WithField("component", "metrics"),
log.With("component", "metrics"),
schedulerHTTP,
)
router.Handle("/metrics", metricsHandler)
Expand All @@ -302,7 +318,7 @@ func main() {
if tcpPort != "" {
// Use TCP port
addr := ":" + tcpPort
log.Infof("Listening on TCP port %s", tcpPort)
log.Info("Listening on TCP port", "port", tcpPort)
server.Addr = addr
go func() {
serverErrors <- server.ListenAndServe()
Expand All @@ -311,12 +327,14 @@ func main() {
// Use Unix socket
if err := os.Remove(sockName); err != nil {
if !os.IsNotExist(err) {
log.Fatalf("Failed to remove existing socket: %v", err)
log.Error("Failed to remove existing socket", "error", err)
os.Exit(1)
}
}
ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"})
if err != nil {
log.Fatalf("Failed to listen on socket: %v", err)
log.Error("Failed to listen on socket", "error", err)
os.Exit(1)
}
go func() {
serverErrors <- server.Serve(ln)
Expand All @@ -341,19 +359,22 @@ func main() {
var err error
certPath, keyPath, err = modeltls.EnsureCertificates("", "")
if err != nil {
log.Fatalf("Failed to ensure TLS certificates: %v", err)
log.Error("Failed to ensure TLS certificates", "error", err)
os.Exit(1)
}
log.Infof("Using TLS certificate: %s", certPath)
log.Infof("Using TLS key: %s", keyPath)
log.Info("Using TLS certificate", "path", certPath)
log.Info("Using TLS key", "path", keyPath)
} else {
log.Fatal("TLS enabled but no certificate provided and auto-cert is disabled")
log.Error("TLS enabled but no certificate provided and auto-cert is disabled")
os.Exit(1)
}
}

// Load TLS configuration
tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath)
if err != nil {
log.Fatalf("Failed to load TLS configuration: %v", err)
log.Error("Failed to load TLS configuration", "error", err)
os.Exit(1)
}

tlsServer = &http.Server{
Expand All @@ -363,7 +384,7 @@ func main() {
ReadHeaderTimeout: 10 * time.Second,
}

log.Infof("Listening on TLS port %s", tlsPort)
log.Info("Listening on TLS port", "port", tlsPort)
go func() {
// Use ListenAndServeTLS with empty strings since TLSConfig already has the certs
ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig)
Expand Down Expand Up @@ -391,30 +412,30 @@ func main() {
select {
case err := <-serverErrors:
if err != nil {
log.Errorf("Server error: %v", err)
log.Error("Server error", "error", err)
}
case err := <-tlsServerErrorsChan:
if err != nil {
log.Errorf("TLS server error: %v", err)
log.Error("TLS server error", "error", err)
}
case <-ctx.Done():
log.Infoln("Shutdown signal received")
log.Infoln("Shutting down the server")
log.Info("Shutdown signal received")
log.Info("Shutting down the server")
if err := server.Close(); err != nil {
log.Errorf("Server shutdown error: %v", err)
log.Error("Server shutdown error", "error", err)
}
if tlsServer != nil {
log.Infoln("Shutting down the TLS server")
log.Info("Shutting down the TLS server")
if err := tlsServer.Close(); err != nil {
log.Errorf("TLS server shutdown error: %v", err)
log.Error("TLS server shutdown error", "error", err)
}
}
log.Infoln("Waiting for the scheduler to stop")
log.Info("Waiting for the scheduler to stop")
if err := <-schedulerErrors; err != nil {
log.Errorf("Scheduler error: %v", err)
log.Error("Scheduler error", "error", err)
}
}
log.Infoln("Docker Model Runner stopped")
log.Info("Docker Model Runner stopped")
}

// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
Expand All @@ -435,12 +456,13 @@ func createLlamaCppConfigFromEnv() config.BackendConfig {
for _, arg := range args {
for _, disallowed := range disallowedArgs {
if arg == disallowed {
testLog.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
testLog.Error("LLAMA_ARGS cannot override argument", "arg", disallowed)
exitFunc(1)
}
}
}

testLog.Infof("Using custom arguments: %v", args)
testLog.Info("Using custom arguments", "args", args)
return &llamacpp.Config{
Args: args,
}
Expand Down
12 changes: 4 additions & 8 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/sirupsen/logrus"
)

func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
Expand Down Expand Up @@ -61,17 +60,14 @@ func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
t.Setenv("LLAMA_ARGS", tt.llamaArgs)
}

// Create a test logger that captures fatal errors
originalLog := testLog
defer func() { testLog = originalLog }()
// Override exitFunc to capture exit calls instead of actually exiting
originalExitFunc := exitFunc
defer func() { exitFunc = originalExitFunc }()

// Create a new logger that will exit with a special exit code
newTestLog := logrus.New()
var exitCode int
newTestLog.ExitFunc = func(code int) {
exitFunc = func(code int) {
exitCode = code
}
testLog = newTestLog

config := createLlamaCppConfigFromEnv()

Expand Down
Loading