From a76c9f067ca46ece91b3b58ced47f1fd2ff2ab80 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Tue, 31 Mar 2026 12:30:16 +0100 Subject: [PATCH] Add audio, moderations, and tokenize/detokenize endpoint support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Register previously missing OpenAI-compatible routes so they are no longer rejected with 404: /v1/audio/transcriptions, /v1/audio/translations (multipart/form-data) /v1/audio/speech (JSON) /v1/moderations /tokenize and /detokenize (vLLM extension) Add BackendModeAudio and handleAudioInference which extracts the model field from multipart form data. vLLM passes audio requests through natively; llama.cpp returns a descriptive error directing users to the chat completions input_audio content-part instead. Address review feedback: extract shared scheduleInference helper to restore parity (auto-install progress, preload-only, recorder, origin tracking) between handleOpenAIInference and handleAudioInference; fix multipart temp-file leak (defer MultipartForm.RemoveAll); tighten /v1/audio/ path matching to explicit HasSuffix checks; make Content-Type check case-insensitive. Replace the Go routing layer with a Rust reverse proxy (axum + tokio) compiled as a CGo-linked staticlib (router/). The Rust router owns: - All route registration and path aliasing (/v1/ -> /engines/, etc.) - CORS middleware matching the Go CorsMiddleware semantics - Path normalisation (NormalizePathLayer) - Static routes: GET / and GET /version The deleted Go files (pkg/routing/router.go, pkg/routing/routing.go, pkg/middleware/alias.go) are fully replaced. pkg/routing/service.go is trimmed; main.go registers Ollama, Anthropic, and Responses handlers directly on the backend mux. The in-process CGo callback path (pkg/router/handler.go) replaces the network proxy hop: Rust calls Go's http.Handler directly via a streaming protocol — dmr_write_chunk()/dmr_close_stream() push response chunks into a tokio::sync::mpsc channel as Go writes them, so streaming endpoints like POST /models/create deliver progress in real time without buffering the full response. Rust shared utilities are extracted into a dmr-common workspace crate (init_tracing, unix_now_secs). model-cli Rust code is deduplicated: shared send_and_check free function, request_timeout helper, apply_azure_version helper, build_app/run_gateway_async extracted to handlers.rs. Build system: - Cargo workspace root (Cargo.toml) unifies all Rust crates - make build-router-lib compiles the Rust staticlib before go build - Dockerfile installs Rust and builds libdmr_router.a in the builder stage - CI test job installs Rust toolchain and builds the library so go test -race works with CGo enabled - Platform-split CGo LDFLAGS: Darwin keeps -framework flags, Linux uses plain -lpthread/-ldl/-lm - pkg/router/router_stub.go provides a no-op implementation for CGO_ENABLED=0 builds (lint, cross-compilation) Signed-off-by: Eric Curtin --- .github/workflows/ci.yml | 16 + .gitignore | 4 + Cargo.toml | 10 + Dockerfile | 16 +- Makefile | 10 +- cmd/cli/Makefile | 2 +- dmr-common/Cargo.toml | 8 + dmr-common/src/lib.rs | 30 + main.go | 186 ++-- model-cli/Cargo.toml | 5 +- model-cli/src/handlers.rs | 113 +++ model-cli/src/lib.rs | 123 +-- model-cli/src/main.rs | 117 +-- model-cli/src/providers/anthropic.rs | 118 +-- model-cli/src/providers/mod.rs | 33 + model-cli/src/providers/openai.rs | 114 +-- model-cli/src/router.rs | 5 +- pkg/inference/backend.go | 7 + .../backends/diffusers/diffusers_config.go | 3 +- .../backends/llamacpp/llamacpp_config.go | 5 + pkg/inference/backends/mlx/mlx_config.go | 2 + .../backends/sglang/sglang_config.go | 2 + pkg/inference/backends/vllm/vllm_config.go | 3 + pkg/inference/backends/vllm/vllm_metal.go | 3 + pkg/inference/scheduling/api.go | 20 + pkg/inference/scheduling/http_handler.go | 166 ++- pkg/middleware/alias.go | 20 - pkg/router/handler.go | 183 ++++ pkg/router/handler_bridge.c | 27 + pkg/router/router.go | 133 +++ pkg/router/router_stub.go | 41 + pkg/routing/router.go | 84 -- pkg/routing/routing.go | 24 - pkg/routing/service.go | 31 +- router/.gitignore | 1 + router/Cargo.lock | 951 ++++++++++++++++++ router/Cargo.toml | 24 + router/dmr_router.h | 101 ++ router/src/cors.rs | 132 +++ router/src/lib.rs | 287 ++++++ router/src/proxy.rs | 456 +++++++++ router/src/routes.rs | 109 ++ 42 files changed, 3025 insertions(+), 700 deletions(-) create mode 100644 Cargo.toml create mode 100644 dmr-common/Cargo.toml create mode 100644 dmr-common/src/lib.rs delete mode 100644 pkg/middleware/alias.go create mode 100644 pkg/router/handler.go create mode 100644 pkg/router/handler_bridge.c create mode 100644 pkg/router/router.go create mode 100644 pkg/router/router_stub.go delete mode 100644 pkg/routing/router.go delete mode 100644 pkg/routing/routing.go create mode 100644 router/.gitignore create mode 100644 router/Cargo.lock create mode 100644 router/Cargo.toml create mode 100644 router/dmr_router.h create mode 100644 router/src/cors.rs create mode 100644 router/src/lib.rs create mode 100644 router/src/proxy.rs create mode 100644 router/src/routes.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5f59c76a..5477b6f89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,6 +53,22 @@ jobs: go mod tidy git diff --exit-code go.mod go.sum + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache Rust build artifacts + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - name: Build Rust router library + run: make build-router-lib + - name: Run tests with race detection run: go test -race ./... diff --git a/.gitignore b/.gitignore index 0a7f7eaeb..a13deabe8 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ llamacpp/install vllm-metal-macos-arm64-*.tar.gz .DS_Store + +# Cargo workspace build output +/target/ +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..31e89068b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[workspace] +members = [ + "dmr-common", + "model-cli", + "router", +] +resolver = "2" + +[profile.release] +lto = true diff --git a/Dockerfile b/Dockerfile index c4f35c57e..78cffc603 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,15 +15,17 @@ FROM docker.io/library/golang:${GO_VERSION}-bookworm AS builder ARG VERSION -# Install git for go mod download if needed -RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/* +# Install git and the Rust toolchain (needed to build libdmr_router.a via CGo). +RUN apt-get update && apt-get install -y --no-install-recommends git curl && rm -rf /var/lib/apt/lists/* +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable +ENV PATH="/root/.cargo/bin:${PATH}" WORKDIR /app # Copy go mod/sum first for better caching COPY --link go.mod go.sum ./ -# Download dependencies (with cache mounts) +# Download Go dependencies (with cache mounts) RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ go mod download @@ -31,9 +33,13 @@ RUN --mount=type=cache,target=/go/pkg/mod \ # Copy the rest of the source code COPY --link . . -# Build the Go binary (static build) -RUN --mount=type=cache,target=/go/pkg/mod \ +# Build the Rust dmr-router static library, then the Go binary. +# Both steps share a single RUN so the Rust output is available to the linker. +RUN --mount=type=cache,target=/root/.cargo/registry \ + --mount=type=cache,target=/root/.cargo/git \ --mount=type=cache,target=/root/.cache/go-build \ + --mount=type=cache,target=/go/pkg/mod \ + cargo build --release --manifest-path router/Cargo.toml && \ CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w -X main.Version=${VERSION}" -o model-runner . # Build the Go binary for SGLang (without vLLM) diff --git a/Makefile b/Makefile index 6bd8cf0f2..ad04eb7d1 100644 --- a/Makefile +++ b/Makefile @@ -23,19 +23,24 @@ DOCKER_BUILD_ARGS := \ -t $(DOCKER_IMAGE) # Phony targets grouped by category -.PHONY: build build-cli build-dmr build-llamacpp install-cli run clean test integration-tests e2e +.PHONY: build build-cli build-dmr build-llamacpp build-router-lib install-cli run clean test integration-tests e2e .PHONY: validate validate-all lint help .PHONY: docker-build docker-build-multiplatform docker-run docker-run-impl .PHONY: docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang .PHONY: test-docker-ce-installation .PHONY: vllm-metal-build vllm-metal-install vllm-metal-dev vllm-metal-clean .PHONY: diffusers-build diffusers-install diffusers-dev diffusers-clean + # Default target: build server, CLI plugin, and dmr convenience wrapper .DEFAULT_GOAL := build build: build-server build-cli build-dmr -build-server: +# Build the Rust dmr-router static library. +build-router-lib: + cargo build --release --manifest-path router/Cargo.toml + +build-server: build-router-lib CGO_ENABLED=1 go build -ldflags="-s -w -X main.Version=$(shell git describe --tags --always --dirty --match 'v*')" -o $(APP_NAME) . build-cli: @@ -68,6 +73,7 @@ clean: rm -f $(APP_NAME) rm -f dmr rm -f model-runner.sock + cargo clean --manifest-path router/Cargo.toml # Run tests test: diff --git a/cmd/cli/Makefile b/cmd/cli/Makefile index 9737c2d82..d3cfb1c3b 100644 --- a/cmd/cli/Makefile +++ b/cmd/cli/Makefile @@ -20,7 +20,7 @@ build-gateway: @echo "Building gateway static library (Rust)..." @mkdir -p $(GATEWAY_LIB_DIR) cargo build --release --manifest-path $(GATEWAY_RUST_DIR)/Cargo.toml - @cp $(GATEWAY_RUST_DIR)/target/release/libmodel_cli_gateway.a $(GATEWAY_LIB_DIR)/libgateway.a + @cp ../../target/release/libmodel_cli_gateway.a $(GATEWAY_LIB_DIR)/libgateway.a @echo "Gateway library staged at $(GATEWAY_LIB_DIR)/libgateway.a" build: build-gateway diff --git a/dmr-common/Cargo.toml b/dmr-common/Cargo.toml new file mode 100644 index 000000000..4351e5723 --- /dev/null +++ b/dmr-common/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "dmr-common" +version = "0.1.0" +edition = "2021" +description = "Shared utilities for Docker Model Runner Rust crates" + +[dependencies] +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/dmr-common/src/lib.rs b/dmr-common/src/lib.rs new file mode 100644 index 000000000..7ea81e263 --- /dev/null +++ b/dmr-common/src/lib.rs @@ -0,0 +1,30 @@ +//! Shared utilities for Docker Model Runner Rust crates. + +/// Return the current time as seconds since the Unix epoch. +/// +/// Used when constructing OpenAI-format response objects that require a +/// `created` timestamp. Returns 0 on the (extremely unlikely) event that +/// the system clock predates 1970. +pub fn unix_now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +/// Initialise the global tracing subscriber. +/// +/// Reads `RUST_LOG` from the environment; falls back to `fallback` if unset +/// or invalid. Silently ignores subsequent calls (e.g. when called from both +/// a library entry point and a binary entry point in the same process). +/// +/// # Arguments +/// * `fallback` – default filter string, e.g. `"info"` or `"myapp=debug"`. +pub fn init_tracing(fallback: &str) { + let _ = tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(fallback)), + ) + .try_init(); +} diff --git a/main.go b/main.go index f54dc88f7..5c496736f 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,8 @@ package main import ( "context" "crypto/tls" - "encoding/json" "fmt" "log/slog" - "net" "net/http" "os" "os/signal" @@ -15,6 +13,7 @@ import ( "syscall" "time" + "github.com/docker/model-runner/pkg/anthropic" "github.com/docker/model-runner/pkg/envconfig" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" @@ -24,6 +23,9 @@ import ( "github.com/docker/model-runner/pkg/logging" dmrlogs "github.com/docker/model-runner/pkg/logs" "github.com/docker/model-runner/pkg/metrics" + "github.com/docker/model-runner/pkg/ollama" + "github.com/docker/model-runner/pkg/responses" + "github.com/docker/model-runner/pkg/router" "github.com/docker/model-runner/pkg/routing" modeltls "github.com/docker/model-runner/pkg/tls" ) @@ -42,6 +44,7 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() + // sockName is the public-facing socket the Rust router listens on. sockName := envconfig.SocketPath() modelPath, err := envconfig.ModelsPath() if err != nil { @@ -145,101 +148,99 @@ func main() { "", false, ), - AllowedOrigins: envconfig.AllowedOrigins(), - IncludeResponsesAPI: true, - ExtraRoutes: func(r *routing.NormalizedServeMux, s *routing.Service) { - // Root handler – only catches exact "/" requests - r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - if req.URL.Path != "/" { - http.NotFound(w, req) - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("Docker Model Runner is running")) - }) - - // Version endpoint - r.HandleFunc("/version", func(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]string{"version": Version}); err != nil { - log.Warn("failed to write version response", "error", err) - } - }) - - // Logs endpoint (Docker Desktop mode only). - if logDir := envconfig.LogDir(); logDir != "" { - r.HandleFunc( - "GET /logs", - dmrlogs.NewHTTPHandler(logDir), - ) - log.Info("Logs endpoint enabled at /logs", "dir", logDir) - } - - // Metrics endpoint - if !envconfig.DisableMetrics() { - metricsHandler := metrics.NewAggregatedMetricsHandler( - log.With("component", "metrics"), - s.SchedulerHTTP, - ) - r.Handle("/metrics", metricsHandler) - log.Info("Metrics endpoint enabled at /metrics") - } else { - log.Info("Metrics endpoint disabled") - } - }, + AllowedOrigins: envconfig.AllowedOrigins(), }) if err != nil { log.Error("failed to initialize service", "error", err) exitFunc(1) } - server := &http.Server{ - Handler: svc.Router, - ReadHeaderTimeout: 10 * time.Second, + // Build the backend HTTP mux. Routing (path aliasing, CORS, path + // normalisation, /version, /) is handled by the Rust dmr-router sidecar + // that sits in front of this server. We only need to register the + // inference endpoints and the observability endpoints that the Rust router + // proxies through. + mux := http.NewServeMux() + mux.Handle(inference.InferencePrefix+"/", svc.SchedulerHTTP) + mux.Handle(inference.ModelsPrefix+"/", svc.ModelHandler) + mux.Handle(inference.ModelsPrefix, svc.ModelHandler) + + // Ollama API compatibility layer (/api/). + ollamaHandler := ollama.NewHTTPHandler(log, svc.Scheduler, svc.SchedulerHTTP, envconfig.AllowedOrigins(), svc.ModelManager) + mux.Handle(ollama.APIPrefix+"/", ollamaHandler) + + // Anthropic Messages API compatibility layer (/anthropic/). + anthropicHandler := anthropic.NewHandler(log, svc.SchedulerHTTP, envconfig.AllowedOrigins(), svc.ModelManager) + mux.Handle(anthropic.APIPrefix+"/", anthropicHandler) + + // OpenAI Responses API compatibility layer (/responses, /v1/responses, /engines/responses). + responsesHandler := responses.NewHTTPHandler(log, svc.SchedulerHTTP, envconfig.AllowedOrigins()) + mux.Handle(responses.APIPrefix+"/", responsesHandler) + mux.Handle(responses.APIPrefix, responsesHandler) + mux.Handle("/v1"+responses.APIPrefix+"/", responsesHandler) + mux.Handle("/v1"+responses.APIPrefix, responsesHandler) + mux.Handle(inference.InferencePrefix+responses.APIPrefix+"/", responsesHandler) + mux.Handle(inference.InferencePrefix+responses.APIPrefix, responsesHandler) + + // Logs endpoint (Docker Desktop mode only). + if logDir := envconfig.LogDir(); logDir != "" { + mux.HandleFunc("GET /logs", dmrlogs.NewHTTPHandler(logDir)) + log.Info("Logs endpoint enabled at /logs", "dir", logDir) } - serverErrors := make(chan error, 1) - // TLS server (optional) + // Metrics endpoint. + if !envconfig.DisableMetrics() { + metricsHandler := metrics.NewAggregatedMetricsHandler( + log.With("component", "metrics"), + svc.SchedulerHTTP, + ) + mux.Handle("/metrics", metricsHandler) + log.Info("Metrics endpoint enabled at /metrics") + } else { + log.Info("Metrics endpoint disabled") + } + + // ── Register the Go mux as the in-process Rust router backend ──────────── + // The Rust router calls Go's http.Handler directly via CGo for every + // inference request — no second socket needed. The streaming writer in + // handler.go pushes chunks to Rust via dmr_write_chunk() as they are + // written, so streaming endpoints like POST /models/create work correctly. + handlerFn, handlerCtx := router.RegisterHandler(mux) + + // routerCfg is populated below depending on TCP vs Unix socket mode. + routerCfg := router.Config{ + HandlerFn: handlerFn, + HandlerCtx: handlerCtx, + AllowedOrigins: envconfig.AllowedOrigins(), + Version: Version, + } + + serverErrors := make(chan error, 1) // never fires in in-process mode + + // TLS server (optional) — serves the mux directly on a TCP port. var tlsServer *http.Server tlsServerErrors := make(chan error, 1) - // Check if we should use TCP port instead of Unix socket tcpPort := envconfig.TCPPort() if tcpPort != "" { - // Use TCP port - addr := ":" + tcpPort - log.Info("Listening on TCP port", "port", tcpPort) - server.Addr = addr - go func() { - serverErrors <- server.ListenAndServe() - }() - } else { - // Use Unix socket - if err := os.Remove(sockName); err != nil { - if !os.IsNotExist(err) { - log.Error("Failed to remove existing socket", "error", err) - exitFunc(1) - } - } - ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockName, Net: "unix"}) + backendPort, err := parsePort(tcpPort) if err != nil { - log.Error("Failed to listen on socket", "error", err) + log.Error("Invalid TCP_PORT", "error", err) exitFunc(1) } - go func() { - serverErrors <- server.Serve(ln) - }() + routerCfg.ListenPort = uint16(backendPort) + log.Info("Rust router listening on TCP port", "port", backendPort) + } else { + routerCfg.ListenSock = sockName + log.Info("Rust router listening on Unix socket", "path", sockName) } - // Start TLS server if enabled + // ── TLS server (optional) ───────────────────────────────────────────────── if envconfig.TLSEnabled() { tlsPort := envconfig.TLSPort() - - // Get certificate paths certPath := envconfig.TLSCert() keyPath := envconfig.TLSKey() - // Auto-generate certificates if not provided and auto-cert is not disabled if certPath == "" || keyPath == "" { if envconfig.TLSAutoCert(true) { log.Info("Auto-generating TLS certificates...") @@ -257,7 +258,6 @@ func main() { } } - // Load TLS configuration tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath) if err != nil { log.Error("Failed to load TLS configuration", "error", err) @@ -266,14 +266,13 @@ func main() { tlsServer = &http.Server{ Addr: ":" + tlsPort, - Handler: svc.Router, + Handler: mux, TLSConfig: tlsConfig, ReadHeaderTimeout: 10 * time.Second, } log.Info("Listening on TLS port", "port", tlsPort) go func() { - // Use ListenAndServeTLS with empty strings since TLSConfig already has the certs ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig) if err != nil { tlsServerErrors <- err @@ -283,6 +282,13 @@ func main() { }() } + // ── Rust router ─────────────────────────────────────────────────────────── + // router.Start launches the axum server in a background goroutine and + // returns a StopFunc for graceful shutdown plus an error channel. + stopRouter, routerErrors := router.Start(routerCfg) + log.Info("Rust router started", "listen", routerCfg.ListenSock) + + // ── Scheduler ───────────────────────────────────────────────────────────── schedulerErrors := make(chan error, 1) go func() { schedulerErrors <- svc.Scheduler.Run(ctx) @@ -292,25 +298,30 @@ func main() { if envconfig.TLSEnabled() { tlsServerErrorsChan = tlsServerErrors } else { - // Use a nil channel which will block forever when TLS is disabled tlsServerErrorsChan = nil } select { - case err := <-serverErrors: + case err := <-routerErrors: if err != nil { - log.Error("Server error", "error", err) + log.Error("Rust router error", "error", err) } case err := <-tlsServerErrorsChan: if err != nil { log.Error("TLS server error", "error", err) } + case err := <-serverErrors: + if err != nil { + log.Error("Backend server error", "error", err) + } case <-ctx.Done(): log.Info("Shutdown signal received") - log.Info("Shutting down the server") - if err := server.Close(); err != nil { - log.Error("Server shutdown error", "error", err) + + log.Info("Stopping Rust router") + if stopRouter != nil { + stopRouter() } + if tlsServer != nil { log.Info("Shutting down the TLS server") if err := tlsServer.Close(); err != nil { @@ -325,6 +336,15 @@ func main() { log.Info("Docker Model Runner stopped") } +// parsePort parses a decimal port string and returns an int. +func parsePort(s string) (int, error) { + var p int + if _, err := fmt.Sscanf(s, "%d", &p); err != nil { + return 0, fmt.Errorf("invalid port %q: %w", s, err) + } + return p, nil +} + // createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables. // Returns nil config (use defaults) when LLAMA_ARGS is unset, or an error if // the args contain disallowed flags. diff --git a/model-cli/Cargo.toml b/model-cli/Cargo.toml index da2eb736a..63d45a6f5 100644 --- a/model-cli/Cargo.toml +++ b/model-cli/Cargo.toml @@ -16,6 +16,7 @@ path = "src/lib.rs" crate-type = ["staticlib"] [dependencies] +dmr-common = { path = "../dmr-common" } axum = { version = "0.8", features = ["macros"] } clap = { version = "4", features = ["derive"] } futures = "0.3" @@ -28,12 +29,8 @@ subtle = "2" tokio-stream = "0.1" tower-http = { version = "0.6", features = ["cors", "trace"] } tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } bytes = "1" http = "1" async-stream = "0.3" pin-project-lite = "0.2" async-trait = "0.1" - -[profile.release] -lto = true diff --git a/model-cli/src/handlers.rs b/model-cli/src/handlers.rs index 95b49674a..df8a196ba 100644 --- a/model-cli/src/handlers.rs +++ b/model-cli/src/handlers.rs @@ -1,10 +1,16 @@ +use std::net::SocketAddr; +use std::path::PathBuf; use std::sync::Arc; use axum::body::Body; use axum::extract::State; use axum::http::header; use axum::response::{IntoResponse, Json, Response}; +use axum::routing::{get, post}; +use axum::Router as AxumRouter; use tokio_stream::StreamExt; +use tower_http::cors::CorsLayer; +use tower_http::trace::TraceLayer; use crate::error::AppError; use crate::router::Router; @@ -16,6 +22,113 @@ pub struct AppState { pub master_key: Option, } +/// Build the axum application with all routes, CORS, and tracing layers. +/// +/// This single definition is shared by both the standalone binary (`main.rs`) +/// and the CGo static-library entry point (`lib.rs`). +pub fn build_app(state: Arc) -> AxumRouter { + let auth_layer = axum::middleware::from_fn_with_state( + state.master_key.clone(), + crate::auth::auth_middleware, + ); + + let protected_routes = AxumRouter::new() + .route("/v1/chat/completions", post(chat_completion_handler)) + .route("/chat/completions", post(chat_completion_handler)) + .route("/v1/embeddings", post(embeddings_handler)) + .route("/embeddings", post(embeddings_handler)) + .route("/v1/models", get(list_models_handler)) + .route("/models", get(list_models_handler)) + .layer(auth_layer); + + let public_routes = AxumRouter::new() + .route("/health", get(health_handler)) + .route("/", get(health_handler)); + + public_routes + .merge(protected_routes) + .layer( + CorsLayer::new() + .allow_origin(tower_http::cors::Any) + .allow_methods([ + axum::http::Method::GET, + axum::http::Method::POST, + axum::http::Method::OPTIONS, + ]) + .allow_headers([ + axum::http::header::CONTENT_TYPE, + axum::http::header::AUTHORIZATION, + axum::http::HeaderName::from_static("x-api-key"), + ]), + ) + .layer(TraceLayer::new_for_http()) + .with_state(state) +} + +/// Core async gateway logic shared between the binary and the CGo library. +/// +/// Loads config, builds the router and app, and serves until the process exits. +pub async fn run_gateway_async(config: PathBuf, host: String, port: u16, verbose: bool) { + let log_filter = if verbose { + "model_cli=debug,tower_http=debug" + } else { + "model_cli=info,tower_http=info" + }; + + dmr_common::init_tracing(log_filter); + + tracing::info!("Loading config from: {}", config.display()); + let cfg = match crate::config::Config::load(&config) { + Ok(c) => c, + Err(e) => { + tracing::error!("Failed to load config: {}", e); + std::process::exit(1); + } + }; + + let model_count = cfg.model_list.len(); + let model_names: Vec<&str> = cfg.model_list.iter().map(|m| m.model_name.as_str()).collect(); + tracing::info!("Loaded {} model deployment(s): {:?}", model_count, model_names); + + let master_key = cfg.general_settings.master_key.clone(); + if master_key.is_some() { + tracing::info!("Authentication enabled (master_key is set)"); + } else { + tracing::warn!("No master_key configured — API is open to all requests"); + } + + let llm_router = match crate::router::Router::from_config(&cfg) { + Ok(r) => r, + Err(e) => { + tracing::error!("Failed to build router: {}", e); + std::process::exit(1); + } + }; + + let state = Arc::new(AppState { + router: llm_router, + master_key: master_key.clone(), + }); + + let app = build_app(state); + + let addr: SocketAddr = format!("{}:{}", host, port) + .parse() + .expect("Invalid host:port"); + + tracing::info!( + "model-cli gateway v{} listening on {}", + env!("CARGO_PKG_VERSION"), + addr + ); + tracing::info!(" Chat completions: http://{}/v1/chat/completions", addr); + tracing::info!(" Models: http://{}/v1/models", addr); + tracing::info!(" Health: http://{}/health", addr); + + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve(listener, app).await.unwrap(); +} + // ── Health check ── pub async fn health_handler(State(state): State>) -> Json { diff --git a/model-cli/src/lib.rs b/model-cli/src/lib.rs index f2723c4c4..06d292df6 100644 --- a/model-cli/src/lib.rs +++ b/model-cli/src/lib.rs @@ -10,14 +10,6 @@ use std::ffi::CStr; use std::os::raw::{c_char, c_int}; use std::path::PathBuf; -use axum::routing::{get, post}; -use axum::Router as AxumRouter; -use tower_http::cors::CorsLayer; -use tower_http::trace::TraceLayer; -use tracing_subscriber::EnvFilter; - -use handlers::AppState; - /// C-callable entry point invoked by the Go CLI's `gateway` subcommand. /// /// # Safety @@ -103,119 +95,6 @@ pub unsafe extern "C" fn run_gateway(argc: c_int, argv: *const *const c_char) -> } }; - rt.block_on(async_run_gateway(config, host, port, verbose)); + rt.block_on(handlers::run_gateway_async(config, host, port, verbose)); 0 } - -async fn async_run_gateway(config: PathBuf, host: String, port: u16, verbose: bool) { - use std::net::SocketAddr; - use std::sync::Arc; - - let log_filter = if verbose { - "model_cli=debug,tower_http=debug" - } else { - "model_cli=info,tower_http=info" - }; - - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_filter)), - ) - .init(); - - tracing::info!("Loading config from: {}", config.display()); - let cfg = match config::Config::load(&config) { - Ok(c) => c, - Err(e) => { - tracing::error!("Failed to load config: {}", e); - std::process::exit(1); - } - }; - - let model_count = cfg.model_list.len(); - let model_names: Vec<&str> = cfg.model_list.iter().map(|m| m.model_name.as_str()).collect(); - tracing::info!("Loaded {} model deployment(s): {:?}", model_count, model_names); - - let master_key = cfg.general_settings.master_key.clone(); - if master_key.is_some() { - tracing::info!("Authentication enabled (master_key is set)"); - } else { - tracing::warn!("No master_key configured — API is open to all requests"); - } - - let llm_router = match router::Router::from_config(&cfg) { - Ok(r) => r, - Err(e) => { - tracing::error!("Failed to build router: {}", e); - std::process::exit(1); - } - }; - - let state = Arc::new(AppState { - router: llm_router, - master_key: master_key.clone(), - }); - - let app = build_app(state); - - let addr: SocketAddr = format!("{}:{}", host, port) - .parse() - .expect("Invalid host:port"); - - tracing::info!( - "model-cli gateway v{} listening on {}", - env!("CARGO_PKG_VERSION"), - addr - ); - tracing::info!(" Chat completions: http://{}/v1/chat/completions", addr); - tracing::info!(" Models: http://{}/v1/models", addr); - tracing::info!(" Health: http://{}/health", addr); - - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); -} - -fn build_app(state: std::sync::Arc) -> AxumRouter { - let auth_layer = axum::middleware::from_fn_with_state( - state.master_key.clone(), - auth::auth_middleware, - ); - - let protected_routes = AxumRouter::new() - .route( - "/v1/chat/completions", - post(handlers::chat_completion_handler), - ) - .route( - "/chat/completions", - post(handlers::chat_completion_handler), - ) - .route("/v1/embeddings", post(handlers::embeddings_handler)) - .route("/embeddings", post(handlers::embeddings_handler)) - .route("/v1/models", get(handlers::list_models_handler)) - .route("/models", get(handlers::list_models_handler)) - .layer(auth_layer); - - let public_routes = AxumRouter::new() - .route("/health", get(handlers::health_handler)) - .route("/", get(handlers::health_handler)); - - public_routes - .merge(protected_routes) - .layer( - CorsLayer::new() - .allow_origin(tower_http::cors::Any) - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - axum::http::Method::OPTIONS, - ]) - .allow_headers([ - axum::http::header::CONTENT_TYPE, - axum::http::header::AUTHORIZATION, - axum::http::HeaderName::from_static("x-api-key"), - ]), - ) - .layer(TraceLayer::new_for_http()) - .with_state(state) -} diff --git a/model-cli/src/main.rs b/model-cli/src/main.rs index 170e17c1c..e1ba0b341 100644 --- a/model-cli/src/main.rs +++ b/model-cli/src/main.rs @@ -6,18 +6,9 @@ mod providers; mod router; mod types; -use std::net::SocketAddr; use std::path::PathBuf; -use std::sync::Arc; -use axum::routing::{get, post}; -use axum::Router as AxumRouter; use clap::{Parser, Subcommand}; -use tower_http::cors::CorsLayer; -use tower_http::trace::TraceLayer; -use tracing_subscriber::EnvFilter; - -use handlers::AppState; /// model-cli: CLI tool for Docker Model Runner and compatible LLM providers. #[derive(Parser, Debug)] @@ -65,111 +56,5 @@ async fn main() { } async fn run_gateway(args: GatewayArgs) { - let log_filter = if args.verbose { - "model_cli=debug,tower_http=debug" - } else { - "model_cli=info,tower_http=info" - }; - - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_filter)), - ) - .init(); - - tracing::info!("Loading config from: {}", args.config.display()); - let cfg = match config::Config::load(&args.config) { - Ok(c) => c, - Err(e) => { - tracing::error!("Failed to load config: {}", e); - std::process::exit(1); - } - }; - - let model_count = cfg.model_list.len(); - let model_names: Vec<&str> = cfg.model_list.iter().map(|m| m.model_name.as_str()).collect(); - tracing::info!("Loaded {} model deployment(s): {:?}", model_count, model_names); - - let master_key = cfg.general_settings.master_key.clone(); - if master_key.is_some() { - tracing::info!("Authentication enabled (master_key is set)"); - } else { - tracing::warn!("No master_key configured — API is open to all requests"); - } - - let llm_router = match router::Router::from_config(&cfg) { - Ok(r) => r, - Err(e) => { - tracing::error!("Failed to build router: {}", e); - std::process::exit(1); - } - }; - - let state = Arc::new(AppState { - router: llm_router, - master_key: master_key.clone(), - }); - - let app = build_app(state); - - let addr: SocketAddr = format!("{}:{}", args.host, args.port) - .parse() - .expect("Invalid host:port"); - - tracing::info!( - "model-cli gateway v{} listening on {}", - env!("CARGO_PKG_VERSION"), - addr - ); - tracing::info!(" Chat completions: http://{}/v1/chat/completions", addr); - tracing::info!(" Models: http://{}/v1/models", addr); - tracing::info!(" Health: http://{}/health", addr); - - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); -} - -fn build_app(state: Arc) -> AxumRouter { - let auth_layer = axum::middleware::from_fn_with_state( - state.master_key.clone(), - auth::auth_middleware, - ); - - let protected_routes = AxumRouter::new() - .route( - "/v1/chat/completions", - post(handlers::chat_completion_handler), - ) - .route( - "/chat/completions", - post(handlers::chat_completion_handler), - ) - .route("/v1/embeddings", post(handlers::embeddings_handler)) - .route("/embeddings", post(handlers::embeddings_handler)) - .route("/v1/models", get(handlers::list_models_handler)) - .route("/models", get(handlers::list_models_handler)) - .layer(auth_layer); - - let public_routes = AxumRouter::new() - .route("/health", get(handlers::health_handler)) - .route("/", get(handlers::health_handler)); - - public_routes - .merge(protected_routes) - .layer( - CorsLayer::new() - .allow_origin(tower_http::cors::Any) - .allow_methods([ - axum::http::Method::GET, - axum::http::Method::POST, - axum::http::Method::OPTIONS, - ]) - .allow_headers([ - axum::http::header::CONTENT_TYPE, - axum::http::header::AUTHORIZATION, - axum::http::HeaderName::from_static("x-api-key"), - ]), - ) - .layer(TraceLayer::new_for_http()) - .with_state(state) + handlers::run_gateway_async(args.config, args.host, args.port, args.verbose).await; } diff --git a/model-cli/src/providers/anthropic.rs b/model-cli/src/providers/anthropic.rs index e15e7c72c..6b1e6ceed 100644 --- a/model-cli/src/providers/anthropic.rs +++ b/model-cli/src/providers/anthropic.rs @@ -1,4 +1,3 @@ -use axum::http::StatusCode; use bytes::Bytes; use futures::StreamExt; use reqwest::Client; @@ -11,7 +10,7 @@ use crate::types::{ ChatMessage, ChunkChoice, EmbeddingRequest, EmbeddingResponse, Usage, }; -use super::{parse_upstream_error, resolve_api_key, ByteStream, Provider}; +use super::{request_timeout, resolve_api_key, send_and_check, ByteStream, Provider}; const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1"; const ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -114,6 +113,31 @@ struct AnthropicStreamMessage { model: String, } +/// Build an `AnthropicRequest` from an OpenAI `ChatCompletionRequest`. +/// +/// `stream` controls whether the upstream request asks for SSE streaming. +fn build_anthropic_request( + request: &ChatCompletionRequest, + actual_model: &str, + stream: bool, +) -> AnthropicRequest { + let (system, messages) = convert_messages(&request.messages); + let stop_sequences = request.stop.as_ref().map(|s| match s { + crate::types::StopSequence::Single(s) => vec![s.clone()], + crate::types::StopSequence::Multiple(v) => v.clone(), + }); + AnthropicRequest { + model: actual_model.to_string(), + messages, + max_tokens: request.max_tokens.unwrap_or(4096), + system, + temperature: request.temperature, + top_p: request.top_p, + stop_sequences, + stream: Some(stream), + } +} + /// Convert OpenAI messages to Anthropic format. /// /// Returns (system_prompt, user/assistant messages). @@ -183,51 +207,21 @@ impl Provider for AnthropicProvider { let url = format!("{}/messages", base_url); let (_, actual_model) = crate::config::parse_provider_model(¶ms.model); - let (system, messages) = convert_messages(&request.messages); - - let stop_sequences = request.stop.as_ref().map(|s| match s { - crate::types::StopSequence::Single(s) => vec![s.clone()], - crate::types::StopSequence::Multiple(v) => v.clone(), - }); - - let anthropic_req = AnthropicRequest { - model: actual_model.to_string(), - messages, - max_tokens: request.max_tokens.unwrap_or(4096), - system, - temperature: request.temperature, - top_p: request.top_p, - stop_sequences, - stream: Some(false), - }; + let anthropic_req = build_anthropic_request(request, actual_model, false); let api_key = resolve_api_key("anthropic", params) .ok_or_else(|| AppError::Unauthorized("Missing Anthropic API key".to_string()))?; - let timeout_secs = params.timeout.unwrap_or(600.0); - let response = self + let req = self .client .post(&url) .header("x-api-key", &api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json") .json(&anthropic_req) - .timeout(std::time::Duration::from_secs_f64(timeout_secs)) - .send() - .await - .map_err(|e| AppError::ProviderError { - status: StatusCode::BAD_GATEWAY, - message: format!("Failed to reach Anthropic: {}", e), - })?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - return Err(parse_upstream_error( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - &body, - )); - } + .timeout(request_timeout(params)); + + let response = send_and_check(req, "Anthropic").await?; let anthropic_resp: AnthropicResponse = response.json().await.map_err(|e| { AppError::Internal(format!("Failed to parse Anthropic response: {}", e)) @@ -243,10 +237,7 @@ impl Provider for AnthropicProvider { .collect::>() .join(""); - let created = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); + let created = dmr_common::unix_now_secs(); Ok(ChatCompletionResponse { id: format!("chatcmpl-{}", anthropic_resp.id), @@ -282,60 +273,27 @@ impl Provider for AnthropicProvider { let url = format!("{}/messages", base_url); let (_, actual_model) = crate::config::parse_provider_model(¶ms.model); - let (system, messages) = convert_messages(&request.messages); - - let stop_sequences = request.stop.as_ref().map(|s| match s { - crate::types::StopSequence::Single(s) => vec![s.clone()], - crate::types::StopSequence::Multiple(v) => v.clone(), - }); - - let anthropic_req = AnthropicRequest { - model: actual_model.to_string(), - messages, - max_tokens: request.max_tokens.unwrap_or(4096), - system, - temperature: request.temperature, - top_p: request.top_p, - stop_sequences, - stream: Some(true), - }; + let anthropic_req = build_anthropic_request(request, actual_model, true); let api_key = resolve_api_key("anthropic", params) .ok_or_else(|| AppError::Unauthorized("Missing Anthropic API key".to_string()))?; - let timeout_secs = params.timeout.unwrap_or(600.0); - let response = self + let req = self .client .post(&url) .header("x-api-key", &api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json") .json(&anthropic_req) - .timeout(std::time::Duration::from_secs_f64(timeout_secs)) - .send() - .await - .map_err(|e| AppError::ProviderError { - status: StatusCode::BAD_GATEWAY, - message: format!("Failed to reach Anthropic: {}", e), - })?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - return Err(parse_upstream_error( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - &body, - )); - } + .timeout(request_timeout(params)); + + let response = send_and_check(req, "Anthropic").await?; // Translate Anthropic SSE events into OpenAI-compatible SSE events. let model = actual_model.to_string(); let byte_stream = response.bytes_stream(); - let created = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); + let created = dmr_common::unix_now_secs(); let mut buffer = String::new(); let stream = async_stream::stream! { diff --git a/model-cli/src/providers/mod.rs b/model-cli/src/providers/mod.rs index fee014ac7..6912129a3 100644 --- a/model-cli/src/providers/mod.rs +++ b/model-cli/src/providers/mod.rs @@ -10,6 +10,34 @@ use crate::config::ModelParams; use crate::error::AppError; use crate::types::{ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse}; +/// Send `req`, check the HTTP status, and surface upstream errors as `AppError`. +/// +/// `provider_label` appears in the network-failure message (e.g. `"provider"` +/// or `"Anthropic"`). Returns the raw `reqwest::Response` on success. +/// +/// Both `OpenAIProvider` and `AnthropicProvider` previously contained +/// identical copies of this logic as inherent methods. +pub async fn send_and_check( + req: reqwest::RequestBuilder, + provider_label: &str, +) -> Result { + let response = req.send().await.map_err(|e| AppError::ProviderError { + status: StatusCode::BAD_GATEWAY, + message: format!("Failed to reach {provider_label}: {e}"), + })?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(parse_upstream_error( + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), + &body, + )); + } + + Ok(response) +} + /// A boxed stream of bytes (SSE chunks) or errors. pub type ByteStream = Pin> + Send>>; @@ -150,6 +178,11 @@ pub fn resolve_api_key(provider_name: &str, params: &ModelParams) -> Option std::time::Duration { + std::time::Duration::from_secs_f64(params.timeout.unwrap_or(600.0)) +} + /// Parse upstream provider error responses into AppError. pub fn parse_upstream_error(status: StatusCode, body: &str) -> AppError { if let Ok(json) = serde_json::from_str::(body) { diff --git a/model-cli/src/providers/openai.rs b/model-cli/src/providers/openai.rs index 5e0e248a8..f3a5205bf 100644 --- a/model-cli/src/providers/openai.rs +++ b/model-cli/src/providers/openai.rs @@ -1,4 +1,3 @@ -use axum::http::StatusCode; use futures::StreamExt; use reqwest::Client; @@ -8,7 +7,7 @@ use crate::types::{ ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, }; -use super::{build_api_url, parse_upstream_error, resolve_api_key, ByteStream, Provider}; +use super::{build_api_url, request_timeout, resolve_api_key, send_and_check, ByteStream, Provider}; /// OpenAI-compatible provider. /// @@ -25,11 +24,6 @@ impl OpenAIProvider { client: Client::new(), } } - - fn provider_name_for_model<'a>(&self, params: &'a ModelParams) -> &'a str { - let (provider, _) = crate::config::parse_provider_model(¶ms.model); - provider - } } #[async_trait::async_trait] @@ -39,49 +33,24 @@ impl Provider for OpenAIProvider { request: &ChatCompletionRequest, params: &ModelParams, ) -> Result { - let provider_name = self.provider_name_for_model(params); + let (provider_name, actual_model) = crate::config::parse_provider_model(¶ms.model); let url = build_api_url(provider_name, params, "/chat/completions"); - // Override model and stream fields before forwarding to the upstream provider. - let (_, actual_model) = crate::config::parse_provider_model(¶ms.model); let mut outgoing = request.clone(); outgoing.model = actual_model.to_string(); outgoing.stream = Some(false); let mut req = self.client.post(&url).json(&outgoing); - if let Some(api_key) = resolve_api_key(provider_name, params) { req = req.bearer_auth(&api_key); } + req = apply_azure_version(req, provider_name, params); + req = req.timeout(request_timeout(params)); - // Azure-specific query parameter - if provider_name == "azure" || provider_name == "azure_ai" { - if let Some(ref version) = params.api_version { - req = req.query(&[("api-version", version.as_str())]); - } - } - - let timeout_secs = params.timeout.unwrap_or(600.0); - req = req.timeout(std::time::Duration::from_secs_f64(timeout_secs)); - - let response = req.send().await.map_err(|e| AppError::ProviderError { - status: StatusCode::BAD_GATEWAY, - message: format!("Failed to reach provider: {}", e), - })?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - return Err(parse_upstream_error( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - &body, - )); - } - + let response = send_and_check(req, "provider").await?; let resp: ChatCompletionResponse = response.json().await.map_err(|e| { AppError::Internal(format!("Failed to parse provider response: {}", e)) })?; - Ok(resp) } @@ -90,49 +59,24 @@ impl Provider for OpenAIProvider { request: &ChatCompletionRequest, params: &ModelParams, ) -> Result { - let provider_name = self.provider_name_for_model(params); + let (provider_name, actual_model) = crate::config::parse_provider_model(¶ms.model); let url = build_api_url(provider_name, params, "/chat/completions"); - let (_, actual_model) = crate::config::parse_provider_model(¶ms.model); let mut outgoing = request.clone(); outgoing.model = actual_model.to_string(); outgoing.stream = Some(true); let mut req = self.client.post(&url).json(&outgoing); - if let Some(api_key) = resolve_api_key(provider_name, params) { req = req.bearer_auth(&api_key); } + req = apply_azure_version(req, provider_name, params); + req = req.timeout(request_timeout(params)); - if provider_name == "azure" || provider_name == "azure_ai" { - if let Some(ref version) = params.api_version { - req = req.query(&[("api-version", version.as_str())]); - } - } - - let timeout_secs = params.timeout.unwrap_or(600.0); - req = req.timeout(std::time::Duration::from_secs_f64(timeout_secs)); - - let response = req.send().await.map_err(|e| AppError::ProviderError { - status: StatusCode::BAD_GATEWAY, - message: format!("Failed to reach provider: {}", e), - })?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - return Err(parse_upstream_error( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - &body, - )); - } - - // Stream the SSE response directly — the provider already speaks OpenAI SSE format. - let byte_stream = response.bytes_stream(); - let stream = byte_stream.map(|chunk| { + let response = send_and_check(req, "provider").await?; + let stream = response.bytes_stream().map(|chunk| { chunk.map_err(|e| AppError::Internal(format!("Stream error: {}", e))) }); - Ok(Box::pin(stream)) } @@ -141,40 +85,36 @@ impl Provider for OpenAIProvider { request: &EmbeddingRequest, params: &ModelParams, ) -> Result { - let provider_name = self.provider_name_for_model(params); + let (provider_name, actual_model) = crate::config::parse_provider_model(¶ms.model); let url = build_api_url(provider_name, params, "/embeddings"); - let (_, actual_model) = crate::config::parse_provider_model(¶ms.model); let mut outgoing = request.clone(); outgoing.model = actual_model.to_string(); let mut req = self.client.post(&url).json(&outgoing); - if let Some(api_key) = resolve_api_key(provider_name, params) { req = req.bearer_auth(&api_key); } + req = req.timeout(request_timeout(params)); - let timeout_secs = params.timeout.unwrap_or(600.0); - req = req.timeout(std::time::Duration::from_secs_f64(timeout_secs)); - - let response = req.send().await.map_err(|e| AppError::ProviderError { - status: StatusCode::BAD_GATEWAY, - message: format!("Failed to reach provider: {}", e), - })?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - return Err(parse_upstream_error( - StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY), - &body, - )); - } - + let response = send_and_check(req, "provider").await?; let resp: EmbeddingResponse = response.json().await.map_err(|e| { AppError::Internal(format!("Failed to parse provider response: {}", e)) })?; - Ok(resp) } } + +/// Append the Azure `api-version` query parameter when the provider is Azure. +fn apply_azure_version( + mut req: reqwest::RequestBuilder, + provider_name: &str, + params: &ModelParams, +) -> reqwest::RequestBuilder { + if provider_name == "azure" || provider_name == "azure_ai" { + if let Some(ref version) = params.api_version { + req = req.query(&[("api-version", version.as_str())]); + } + } + req +} diff --git a/model-cli/src/router.rs b/model-cli/src/router.rs index 08ed333c4..881a16766 100644 --- a/model-cli/src/router.rs +++ b/model-cli/src/router.rs @@ -84,10 +84,7 @@ impl Router { /// List models in OpenAI format. pub fn list_models(&self) -> ModelListResponse { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); + let now = dmr_common::unix_now_secs(); let mut models: Vec = self .deployments diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 95f300743..77fb8ba1e 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -23,6 +23,9 @@ const ( // BackendModeImageGeneration indicates that the backend should run in // image generation mode. BackendModeImageGeneration + // BackendModeAudio indicates that the backend should run in audio + // processing mode (transcription, translation, speech synthesis). + BackendModeAudio ) // Backend status constants for standardized status reporting. @@ -131,6 +134,8 @@ func (m BackendMode) String() string { return "reranking" case BackendModeImageGeneration: return "image-generation" + case BackendModeAudio: + return "audio" default: return "unknown" } @@ -168,6 +173,8 @@ func ParseBackendMode(mode string) (BackendMode, bool) { return BackendModeReranking, true case "image-generation": return BackendModeImageGeneration, true + case "audio": + return BackendModeAudio, true default: return BackendModeCompletion, false } diff --git a/pkg/inference/backends/diffusers/diffusers_config.go b/pkg/inference/backends/diffusers/diffusers_config.go index 010445e65..1061c0a7f 100644 --- a/pkg/inference/backends/diffusers/diffusers_config.go +++ b/pkg/inference/backends/diffusers/diffusers_config.go @@ -40,7 +40,8 @@ func (c *Config) GetArgs(model string, socket string, mode inference.BackendMode switch mode { case inference.BackendModeImageGeneration: // Default mode for diffusers - image generation - case inference.BackendModeCompletion, inference.BackendModeEmbedding, inference.BackendModeReranking: + case inference.BackendModeCompletion, inference.BackendModeEmbedding, inference.BackendModeReranking, + inference.BackendModeAudio: return nil, fmt.Errorf("unsupported backend mode %q for diffusers", mode) } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index 87816410c..0791285e8 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -67,6 +67,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference args = append(args, "--embeddings", "--reranking") case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) + case inference.BackendModeAudio: + // llama.cpp does not expose a dedicated audio transcription/translation + // endpoint. Use the chat completions endpoint with an input_audio content + // part instead (see OpenAI multimodal message format). + return nil, fmt.Errorf("audio endpoint not supported by llama.cpp; use /v1/chat/completions with an input_audio content part") } if budget := GetReasoningBudget(config); budget != nil { diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index 29f98638f..0a9aac5a5 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -51,6 +51,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return nil, fmt.Errorf("reranking mode not supported by MLX backend") case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) + case inference.BackendModeAudio: + return nil, fmt.Errorf("unsupported backend mode %q", mode) } // Add max-tokens if specified in model config or backend config diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 814a516f2..46fbb6de0 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -53,6 +53,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // SGLang does not have a specific flag for reranking case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) + case inference.BackendModeAudio: + return nil, fmt.Errorf("unsupported backend mode %q", mode) } // Add context-length if specified in model config or backend config diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index 3ad9e230d..c671a58bf 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -51,6 +51,9 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // vLLM does not have a specific flag for reranking case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) + case inference.BackendModeAudio: + // vLLM natively supports audio endpoints (/v1/audio/transcriptions, + // /v1/audio/translations, /v1/audio/speech) — no extra flags needed. } // Add max-model-len if specified in model config or backend config diff --git a/pkg/inference/backends/vllm/vllm_metal.go b/pkg/inference/backends/vllm/vllm_metal.go index db1abda95..3aad577e6 100644 --- a/pkg/inference/backends/vllm/vllm_metal.go +++ b/pkg/inference/backends/vllm/vllm_metal.go @@ -259,6 +259,9 @@ func (v *vllmMetal) buildArgs(bundle interface{ SafetensorsPath() string }, sock return nil, fmt.Errorf("reranking mode not supported by vllm-metal backend") case inference.BackendModeImageGeneration: return nil, fmt.Errorf("image generation mode not supported by vllm-metal backend") + case inference.BackendModeAudio: + // vllm-metal natively supports audio endpoints (/v1/audio/transcriptions, + // /v1/audio/translations, /v1/audio/speech) — no extra flags needed. } // Register model aliases so the model-runner can address the model by its diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index 2dd11df46..73b8130c8 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -14,6 +14,11 @@ const ( // DoS attacks. maximumOpenAIInferenceRequestSize = 10 * 1024 * 1024 + // maximumAudioInferenceRequestSize is the maximum size for audio API + // requests (transcriptions, translations). Audio files can be large, so + // we allow up to 25MB which matches the OpenAI API limit. + maximumAudioInferenceRequestSize = 25 * 1024 * 1024 + // modelCLIUserAgentPrefix is the user-agent prefix set by the model CLI. modelCLIUserAgentPrefix = "docker-model-cli/" ) @@ -27,6 +32,10 @@ func trimRequestPathToOpenAIRoot(path string) string { return path[index:] } else if index = strings.Index(path, "/score"); index != -1 { return path[index:] + } else if index = strings.Index(path, "/tokenize"); index != -1 { + return path[index:] + } else if index = strings.Index(path, "/detokenize"); index != -1 { + return path[index:] } return path } @@ -47,6 +56,17 @@ func backendModeForRequest(path string) (inference.BackendMode, bool) { } else if strings.HasSuffix(path, "/v1/images/generations") { // OpenAI Images API - image generation mode return inference.BackendModeImageGeneration, true + } else if strings.HasSuffix(path, "/v1/audio/transcriptions") || + strings.HasSuffix(path, "/v1/audio/translations") || + strings.HasSuffix(path, "/v1/audio/speech") { + // OpenAI Audio API - audio processing mode + return inference.BackendModeAudio, true + } else if strings.HasSuffix(path, "/v1/moderations") { + // OpenAI Moderations API - treated as completion mode + return inference.BackendModeCompletion, true + } else if strings.HasSuffix(path, "/tokenize") || strings.HasSuffix(path, "/detokenize") { + // vLLM tokenize/detokenize endpoints - treated as completion mode + return inference.BackendModeCompletion, true } return inference.BackendMode(0), false } diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 769bdc7ec..233766cf5 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -93,6 +93,14 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { // Image generation routes "POST " + inference.InferencePrefix + "/{backend}/v1/images/generations", "POST " + inference.InferencePrefix + "/v1/images/generations", + // Moderations routes + "POST " + inference.InferencePrefix + "/{backend}/v1/moderations", + "POST " + inference.InferencePrefix + "/v1/moderations", + // Tokenize/detokenize routes (vLLM extension) + "POST " + inference.InferencePrefix + "/{backend}/tokenize", + "POST " + inference.InferencePrefix + "/tokenize", + "POST " + inference.InferencePrefix + "/{backend}/detokenize", + "POST " + inference.InferencePrefix + "/detokenize", } // Anthropic Messages API routes @@ -103,10 +111,25 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { "POST " + inference.InferencePrefix + "/v1/messages/count_tokens", } + // Audio routes use multipart/form-data so they require a separate handler + // that extracts the model field from the form rather than the JSON body. + audioRoutes := []string{ + "POST " + inference.InferencePrefix + "/{backend}/v1/audio/transcriptions", + "POST " + inference.InferencePrefix + "/v1/audio/transcriptions", + "POST " + inference.InferencePrefix + "/{backend}/v1/audio/translations", + "POST " + inference.InferencePrefix + "/v1/audio/translations", + // Speech synthesis uses JSON but is still audio mode + "POST " + inference.InferencePrefix + "/{backend}/v1/audio/speech", + "POST " + inference.InferencePrefix + "/v1/audio/speech", + } + m := make(map[string]http.HandlerFunc) for _, route := range append(openAIRoutes, anthropicRoutes...) { m[route] = h.handleOpenAIInference } + for _, route := range audioRoutes { + m[route] = h.handleAudioInference + } // Register /v1/models routes - these delegate to the model manager m["GET "+inference.InferencePrefix+"/{backend}/v1/models"] = h.handleModels @@ -136,18 +159,6 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { // - POST /{backend}/rerank // - POST /{backend}/score func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Request) { - // Determine the requested backend and ensure that it's valid. - var backend inference.Backend - if b := r.PathValue("backend"); b == "" { - backend = h.scheduler.defaultBackend - } else { - backend = h.scheduler.backends[b] - } - if backend == nil { - http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound) - return - } - // Read the entire request body. We put some basic size constraints in place // to avoid DoS attacks. We do this early to avoid client write timeouts. body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) @@ -155,13 +166,6 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque return } - // Determine the backend operation mode. - backendMode, ok := backendModeForRequest(r.URL.Path) - if !ok { - http.Error(w, "unknown request path", http.StatusInternalServerError) - return - } - // Set origin header for Anthropic Messages API requests if not already set. // This enables proper response format detection in the recorder. if strings.HasSuffix(r.URL.Path, "/v1/messages") && r.Header.Get(inference.RequestOriginHeader) == "" { @@ -179,9 +183,100 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque return } + h.scheduleInference(w, r, request.Model, body, true) +} + +// handleAudioInference handles audio API requests such as: +// - POST /{backend}/v1/audio/transcriptions (multipart/form-data) +// - POST /{backend}/v1/audio/translations (multipart/form-data) +// - POST /{backend}/v1/audio/speech (JSON) +// +// Audio transcription and translation use multipart/form-data instead of JSON, +// so the model name is extracted from the form field rather than the JSON body. +func (h *HTTPHandler) handleAudioInference(w http.ResponseWriter, r *http.Request) { + // Extract model name and buffer the full body. + // Audio endpoints may use multipart/form-data (transcriptions, translations) + // or JSON (speech synthesis). + var modelName string + var upstreamBody []byte + + contentType := strings.ToLower(r.Header.Get("Content-Type")) + if strings.HasPrefix(contentType, "multipart/form-data") { + // Read the entire body for buffering before proxying. + body, ok := readRequestBody(w, r, maximumAudioInferenceRequestSize) + if !ok { + return + } + upstreamBody = body + + // Parse a clone of the request to extract the model field. + // We clean up any temp files immediately after reading the field. + r2 := r.Clone(r.Context()) + r2.Body = io.NopCloser(bytes.NewReader(body)) + r2.ContentLength = int64(len(body)) + if err := r2.ParseMultipartForm(maximumAudioInferenceRequestSize); err != nil { + http.Error(w, "failed to parse multipart form", http.StatusBadRequest) + return + } + if r2.MultipartForm != nil { + defer r2.MultipartForm.RemoveAll() //nolint:errcheck + } + modelName = r2.FormValue("model") + } else { + // JSON body (e.g., /v1/audio/speech). + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { + return + } + upstreamBody = body + + var request OpenAIInferenceRequest + if err := json.Unmarshal(body, &request); err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + modelName = request.Model + } + + if modelName == "" { + http.Error(w, "model is required", http.StatusBadRequest) + return + } + + h.scheduleInference(w, r, modelName, upstreamBody, false) +} + +// scheduleInference is the shared scheduling core used by handleOpenAIInference +// and handleAudioInference. It resolves the backend, looks up the model, tracks +// usage, streams auto-install progress, waits for installation, loads a runner, +// and serves the upstream request. +// +// recordRequest controls whether the request/response pair is captured in the +// OpenAI recorder (appropriate for JSON inference requests; not for multipart +// audio bodies). +func (h *HTTPHandler) scheduleInference(w http.ResponseWriter, r *http.Request, modelName string, body []byte, recordRequest bool) { + // Determine the requested backend and ensure that it's valid. + var backend inference.Backend + if b := r.PathValue("backend"); b == "" { + backend = h.scheduler.defaultBackend + } else { + backend = h.scheduler.backends[b] + } + if backend == nil { + http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound) + return + } + + // Determine the backend operation mode. + backendMode, ok := backendModeForRequest(r.URL.Path) + if !ok { + http.Error(w, "unknown request path", http.StatusInternalServerError) + return + } + // Check if the shared model manager has the requested model available. if !backend.UsesExternalModelManagement() { - model, err := h.scheduler.modelManager.GetLocal(request.Model) + model, err := h.scheduler.modelManager.GetLocal(modelName) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -190,24 +285,20 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque } return } - // Determine the action for tracking + + // Determine the action for tracking. + // Only trust whitelisted origin values to prevent header spoofing. action := "inference/" + backendMode.String() - // Check if there's a request origin header to provide more specific tracking - // Only trust whitelisted values to prevent header spoofing if origin := r.Header.Get(inference.RequestOriginHeader); origin != "" { switch origin { case inference.OriginOllamaCompletion: action = origin - // If an unknown origin is provided, ignore it and use the default action - // This prevents untrusted clients from spoofing tracking data } } - - // Non-blocking call to track the model usage. h.scheduler.tracker.TrackModel(model, r.UserAgent(), action) // Automatically select backend for given model. - backend = h.scheduler.selectBackendForModel(model, backend, request.Model) + backend = h.scheduler.selectBackendForModel(model, backend, modelName) } // If a deferred backend needs on-demand installation and the request @@ -258,10 +349,10 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque } } - modelID := h.scheduler.modelManager.ResolveID(request.Model) + modelID := h.scheduler.modelManager.ResolveID(modelName) // Request a runner to execute the request and defer its release. - runner, err := h.scheduler.loader.load(r.Context(), backend.Name(), modelID, request.Model, backendMode) + runner, err := h.scheduler.loader.load(r.Context(), backend.Name(), modelID, modelName, backendMode) if err != nil { http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError) return @@ -274,13 +365,14 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque return } - // Record the request in the OpenAI recorder. - recordID := h.scheduler.openAIRecorder.RecordRequest(request.Model, r, body) - w = h.scheduler.openAIRecorder.NewResponseRecorder(w) - defer func() { - // Record the response in the OpenAI recorder. - h.scheduler.openAIRecorder.RecordResponse(recordID, request.Model, w) - }() + // Record the request/response in the OpenAI recorder when appropriate. + if recordRequest { + recordID := h.scheduler.openAIRecorder.RecordRequest(modelName, r, body) + w = h.scheduler.openAIRecorder.NewResponseRecorder(w) + defer func() { + h.scheduler.openAIRecorder.RecordResponse(recordID, modelName, w) + }() + } // Create a request with the body replaced for forwarding upstream. // Set ContentLength explicitly so the backend always receives a Content-Length diff --git a/pkg/middleware/alias.go b/pkg/middleware/alias.go deleted file mode 100644 index 1d255e47d..000000000 --- a/pkg/middleware/alias.go +++ /dev/null @@ -1,20 +0,0 @@ -package middleware - -import ( - "net/http" - - "github.com/docker/model-runner/pkg/inference" -) - -// AliasHandler provides path aliasing by prepending the inference prefix to incoming request paths. -type AliasHandler struct { - Handler http.Handler -} - -func (h *AliasHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Clone the request with modified path, prepending the inference prefix. - r2 := r.Clone(r.Context()) - r2.URL.Path = inference.InferencePrefix + r.URL.Path - - h.Handler.ServeHTTP(w, r2) -} diff --git a/pkg/router/handler.go b/pkg/router/handler.go new file mode 100644 index 000000000..3ac0a5d28 --- /dev/null +++ b/pkg/router/handler.go @@ -0,0 +1,183 @@ +//go:build cgo + +package router + +// handler.go — CGo bridge that lets the Rust router call Go's http.Handler +// directly, with full streaming support. +// +// The streaming protocol (mirrors dmr_router.h): +// - Rust sets resp.stream_ctx to a raw *ChunkSender before calling Go. +// - Go calls dmr_write_chunk(stream_ctx, data, len) for each Write/Flush. +// - Go calls dmr_close_stream(stream_ctx) after ServeHTTP returns. +// - Rust reads from the mpsc receiver and forwards chunks to the axum client. +// +// Memory contract for headers: +// - resp.header_block is allocated by Go via C.malloc; Rust frees it. + +/* +#include "dmr_router.h" +#include + +// dmr_go_handler_fn and dmr_c_alloc are defined in handler_bridge.c which +// includes _cgo_export.h (generated by CGo) so it can safely reference +// goHandleRequest without a conflicting forward declaration. +DmrHandlerFn dmr_go_handler_fn(void); +uint8_t *dmr_c_alloc(size_t n); +*/ +import "C" + +import ( + "bytes" + "fmt" + "net/http" + "strings" + "sync/atomic" + "unsafe" +) + +// globalHandler is set once by RegisterHandler and read on every request. +var globalHandler atomic.Value // stores http.Handler + +// RegisterHandler stores h as the in-process backend and returns the +// (handlerFn, handlerCtx) pair to embed in DmrRouterConfig. +// Must be called before router.Start. +func RegisterHandler(h http.Handler) (handlerFn unsafe.Pointer, handlerCtx unsafe.Pointer) { + globalHandler.Store(h) + return unsafe.Pointer(C.dmr_go_handler_fn()), nil +} + +// streamingResponseWriter implements http.ResponseWriter and http.Flusher. +// Each Write call pushes a chunk to Rust via dmr_write_chunk; Flush is a no-op +// since dmr_write_chunk sends immediately. +type streamingResponseWriter struct { + header http.Header + statusCode int + wroteHeader bool + resp *C.DmrResponse +} + +func newStreamingResponseWriter(resp *C.DmrResponse) *streamingResponseWriter { + return &streamingResponseWriter{ + header: make(http.Header), + statusCode: http.StatusOK, + resp: resp, + } +} + +func (w *streamingResponseWriter) Header() http.Header { + return w.header +} + +func (w *streamingResponseWriter) WriteHeader(code int) { + if w.wroteHeader { + return + } + w.wroteHeader = true + w.statusCode = code + + // Serialise headers into flat "Name: Value\0…\0" block and store in resp. + var hdrBuf strings.Builder + for name, vals := range w.header { + for _, val := range vals { + fmt.Fprintf(&hdrBuf, "%s: %s", name, val) + hdrBuf.WriteByte(0) + } + } + hdrBuf.WriteByte(0) // end-of-list sentinel + hdrBytes := []byte(hdrBuf.String()) + + w.resp.status = C.uint16_t(code) + if len(hdrBytes) > 0 { + p := C.dmr_c_alloc(C.size_t(len(hdrBytes))) + C.memcpy(unsafe.Pointer(p), unsafe.Pointer(&hdrBytes[0]), C.size_t(len(hdrBytes))) + w.resp.header_block.ptr = p + w.resp.header_block.len = C.size_t(len(hdrBytes)) + } +} + +func (w *streamingResponseWriter) Write(data []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + if len(data) == 0 { + return 0, nil + } + rc := C.dmr_write_chunk(w.resp.stream_ctx, (*C.uint8_t)(unsafe.Pointer(&data[0])), C.size_t(len(data))) + if rc != 0 { + return 0, fmt.Errorf("dmr_write_chunk: client disconnected") + } + return len(data), nil +} + +// Flush implements http.Flusher. Chunks are sent immediately on Write, so +// this is a no-op — but it must exist so handlers that type-assert to +// http.Flusher (like the model pull progress handler) don't fail. +func (w *streamingResponseWriter) Flush() {} + +// goHandleRequest is invoked by Rust (via the C shim) for every proxied +// inference request. It reconstructs an http.Request from the DmrRequest +// and dispatches it to the registered http.Handler using the streaming writer. +// +//export goHandleRequest +func goHandleRequest(_ unsafe.Pointer, req *C.DmrRequest, resp *C.DmrResponse) { + h, ok := globalHandler.Load().(http.Handler) + if !ok || h == nil { + resp.status = 503 + C.dmr_close_stream(resp.stream_ctx) + return + } + + // ── Reconstruct http.Request ───────────────────────────────────────── + + method := C.GoString(req.method) + path := C.GoString(req.path) + + var bodyReader *bytes.Reader + if req.body.ptr != nil && req.body.len > 0 { + bodyBytes := C.GoBytes(unsafe.Pointer(req.body.ptr), C.int(req.body.len)) + bodyReader = bytes.NewReader(bodyBytes) + } else { + bodyReader = bytes.NewReader(nil) + } + + httpReq, err := http.NewRequest(method, path, bodyReader) + if err != nil { + resp.status = 400 + C.dmr_close_stream(resp.stream_ctx) + return + } + + // Parse flat "Name: Value\0…\0" header block. + if req.header_block.ptr != nil && req.header_block.len > 0 { + raw := C.GoBytes(unsafe.Pointer(req.header_block.ptr), C.int(req.header_block.len)) + pos := 0 + for pos < len(raw) { + end := bytes.IndexByte(raw[pos:], 0) + if end < 0 { + break + } + end += pos + if end == pos { + break + } + entry := string(raw[pos:end]) + if sep := strings.Index(entry, ": "); sep >= 0 { + httpReq.Header.Add(entry[:sep], entry[sep+2:]) + } + pos = end + 1 + } + } + + // ── Invoke Go handler with streaming writer ─────────────────────────── + + w := newStreamingResponseWriter(resp) + h.ServeHTTP(w, httpReq) + + // Ensure WriteHeader was called even if the handler wrote nothing. + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + + // Signal end-of-body to Rust (drops the mpsc Sender, closing the channel). + C.dmr_close_stream(resp.stream_ctx) +} diff --git a/pkg/router/handler_bridge.c b/pkg/router/handler_bridge.c new file mode 100644 index 000000000..8384fdf48 --- /dev/null +++ b/pkg/router/handler_bridge.c @@ -0,0 +1,27 @@ +/* + * handler_bridge.c — C helper for the Go→Rust handler callback. + * + * This file is compiled by CGo as part of the router package. It includes + * _cgo_export.h (generated by CGo, contains the correct prototype for + * goHandleRequest) so we can safely take its address. + * + * We cannot do this in the Go file's inline C preamble because the preamble + * is compiled before _cgo_export.h exists; only a separate .c file can + * include it. + */ + +#include "_cgo_export.h" +#include "dmr_router.h" +#include +#include + +/* Return goHandleRequest as a DmrHandlerFn. Called from Go via CGo. */ +DmrHandlerFn dmr_go_handler_fn(void) { + return (DmrHandlerFn)goHandleRequest; +} + +/* Allocate n bytes via the C heap. Rust frees with its global allocator + * (which also calls free(3)). Called from Go via CGo. */ +uint8_t *dmr_c_alloc(size_t n) { + return (uint8_t *)malloc(n); +} diff --git a/pkg/router/router.go b/pkg/router/router.go new file mode 100644 index 000000000..3e0318909 --- /dev/null +++ b/pkg/router/router.go @@ -0,0 +1,133 @@ +//go:build cgo + +// Package router provides the CGo bridge to the Rust dmr-router static +// library. It exposes Start, which launches the axum HTTP router in a +// background goroutine and returns a StopFunc for graceful shutdown. +// +// The Rust library (router/libdmr_router.a) is compiled from router/src/lib.rs +// and linked at Go build time via the CGo directives below. +package router + +/* +#cgo CFLAGS: -I${SRCDIR}/../../router +#cgo darwin LDFLAGS: -L${SRCDIR}/../../target/release -ldmr_router -framework Security -framework CoreFoundation -framework SystemConfiguration -lpthread -lresolv -ldl -lm +#cgo linux LDFLAGS: -L${SRCDIR}/../../target/release -ldmr_router -lpthread -lresolv -ldl -lm +#include "dmr_router.h" +#include +*/ +import "C" + +import ( + "fmt" + "strings" + "unsafe" +) + +// Config holds the parameters passed to the Rust router. +type Config struct { + // ListenSock is the Unix socket path the router listens on. + // Leave empty to use ListenPort instead. + ListenSock string + // ListenPort is the TCP port to listen on when ListenSock is empty. + ListenPort uint16 + + // HandlerFn and HandlerCtx are the in-process Go handler callback + // obtained from RegisterHandler. When HandlerFn is non-nil requests + // are dispatched directly to Go's http.Handler with no socket hop. + // BackendSock and BackendPort are ignored in this mode. + HandlerFn unsafe.Pointer + HandlerCtx unsafe.Pointer + + // BackendSock is the Unix socket path of the Go inference backend. + // Used only when HandlerFn is nil. + BackendSock string + // BackendPort is the TCP port of the Go backend. + // Used only when HandlerFn is nil and BackendSock is empty. + BackendPort uint16 + + // AllowedOrigins is a slice of allowed CORS origins. + AllowedOrigins []string + // Version is the version string returned by GET /version. + Version string +} + +// StopFunc stops the running router gracefully when called. +// It is safe to call from any goroutine and is idempotent. +type StopFunc func() + +// Start launches the Rust HTTP router in a background goroutine. +// +// It returns a StopFunc and an error channel. Call StopFunc to request a +// graceful shutdown. The error channel receives exactly one value when the +// router exits: nil on clean shutdown, non-nil on error. +func Start(cfg Config) (StopFunc, <-chan error) { + errCh := make(chan error, 1) + + // Pre-allocate the stop handle BEFORE spawning the goroutine. This + // eliminates the race where stopRouter() was called before Rust had + // written handle_out (which only happens after block_on returns, i.e. + // after the router has already shut down — making stop a no-op and + // leaving the process hung waiting for the router to exit). + // + // Rust wires the oneshot sender into this handle at the start of + // dmr_router_serve, before the event loop blocks. + handle := C.dmr_router_new_handle() + + // Build C strings. They are freed inside the goroutine after + // dmr_router_serve returns (i.e. after the router has shut down). + var cListenSock, cBackendSock, cOrigins, cVersion *C.char + if cfg.ListenSock != "" { + cListenSock = C.CString(cfg.ListenSock) + } + if cfg.BackendSock != "" { + cBackendSock = C.CString(cfg.BackendSock) + } + if len(cfg.AllowedOrigins) > 0 { + cOrigins = C.CString(strings.Join(cfg.AllowedOrigins, ",")) + } + cVersion = C.CString(cfg.Version) + + ccfg := C.DmrRouterConfig{ + listen_sock: cListenSock, + listen_port: C.uint16_t(cfg.ListenPort), + handler_fn: (*[0]byte)(cfg.HandlerFn), + handler_ctx: cfg.HandlerCtx, + backend_sock: cBackendSock, + backend_port: C.uint16_t(cfg.BackendPort), + allowed_origins: cOrigins, + version: cVersion, + } + + go func() { + // Pass the pre-allocated handle; Rust wires stop_tx into it before + // blocking, so dmr_router_stop can be called at any point. + rc := C.dmr_router_serve(&ccfg, &handle) + + // Free C strings now that the blocking call has returned. + if cListenSock != nil { + C.free(unsafe.Pointer(cListenSock)) + } + if cBackendSock != nil { + C.free(unsafe.Pointer(cBackendSock)) + } + if cOrigins != nil { + C.free(unsafe.Pointer(cOrigins)) + } + C.free(unsafe.Pointer(cVersion)) + + if rc != 0 { + errCh <- fmt.Errorf("dmr-router exited with code %d", int(rc)) + } else { + errCh <- nil + } + }() + + stop := func() { + // handle was pre-allocated before the goroutine started, and Rust + // wired the oneshot sender into it before blocking — so this is + // always safe to call, even immediately after Start returns. + C.dmr_router_stop(handle) + } + + return stop, errCh +} diff --git a/pkg/router/router_stub.go b/pkg/router/router_stub.go new file mode 100644 index 000000000..7071b6f03 --- /dev/null +++ b/pkg/router/router_stub.go @@ -0,0 +1,41 @@ +//go:build !cgo + +// Package router provides a stub implementation of the Rust dmr-router bridge +// for environments where CGo is disabled (e.g. lint passes, cross-compilation +// targets that do not support the static library). +// +// The real implementation lives in router.go and handler.go and requires CGo. +package router + +import ( + "fmt" + "net/http" + "unsafe" +) + +// Config holds the parameters passed to the Rust router. +type Config struct { + ListenSock string + ListenPort uint16 + HandlerFn unsafe.Pointer + HandlerCtx unsafe.Pointer + BackendSock string + BackendPort uint16 + AllowedOrigins []string + Version string +} + +// StopFunc stops the running router gracefully when called. +type StopFunc func() + +// Start is a stub that returns an error immediately when CGo is disabled. +func Start(_ Config) (StopFunc, <-chan error) { + errCh := make(chan error, 1) + errCh <- fmt.Errorf("dmr-router: CGo is required but was disabled at build time") + return func() {}, errCh +} + +// RegisterHandler is a stub that returns nil pointers when CGo is disabled. +func RegisterHandler(_ http.Handler) (handlerFn unsafe.Pointer, handlerCtx unsafe.Pointer) { + return nil, nil +} diff --git a/pkg/routing/router.go b/pkg/routing/router.go deleted file mode 100644 index f5249237f..000000000 --- a/pkg/routing/router.go +++ /dev/null @@ -1,84 +0,0 @@ -package routing - -import ( - "net/http" - - "github.com/docker/model-runner/pkg/anthropic" - "github.com/docker/model-runner/pkg/inference" - "github.com/docker/model-runner/pkg/inference/models" - "github.com/docker/model-runner/pkg/inference/scheduling" - "github.com/docker/model-runner/pkg/logging" - "github.com/docker/model-runner/pkg/middleware" - "github.com/docker/model-runner/pkg/ollama" - "github.com/docker/model-runner/pkg/responses" -) - -// RouterConfig holds the dependencies needed to build the standard -// model-runner HTTP route structure. -type RouterConfig struct { - Log logging.Logger - Scheduler *scheduling.Scheduler - SchedulerHTTP *scheduling.HTTPHandler - ModelHandler *models.HTTPHandler - ModelManager *models.Manager - - // AllowedOrigins is forwarded to the Ollama and Anthropic handlers - // for CORS support. It may be nil. - AllowedOrigins []string - - // ModelHandlerMiddleware optionally wraps the model handler before - // registration (e.g. pinata uses this for access restrictions). - // If nil the model handler is registered directly. - ModelHandlerMiddleware func(http.Handler) http.Handler - - // IncludeResponsesAPI enables the OpenAI Responses API compatibility - // layer, registering it under /responses, /v1/responses, and - // /engines/responses prefixes. Requires SchedulerHTTP to be set. - IncludeResponsesAPI bool -} - -// NewRouter builds a NormalizedServeMux with the standard model-runner -// route structure: models endpoints, scheduler/inference endpoints, -// path aliases (/v1/, /rerank, /score), Ollama compatibility, and -// Anthropic compatibility. -func NewRouter(cfg RouterConfig) *NormalizedServeMux { - router := NewNormalizedServeMux() - - // Models endpoints – optionally wrapped by middleware. - var modelEndpoint http.Handler = cfg.ModelHandler - if cfg.ModelHandlerMiddleware != nil { - modelEndpoint = cfg.ModelHandlerMiddleware(cfg.ModelHandler) - } - router.Handle(inference.ModelsPrefix, modelEndpoint) - router.Handle(inference.ModelsPrefix+"/", modelEndpoint) - - // Scheduler / inference endpoints. - router.Handle(inference.InferencePrefix+"/", cfg.SchedulerHTTP) - - // Path aliases: /v1 → /engines/v1, /rerank → /engines/rerank, /score → /engines/score. - aliasHandler := &middleware.AliasHandler{Handler: cfg.SchedulerHTTP} - router.Handle("/v1/", aliasHandler) - router.Handle("/rerank", aliasHandler) - router.Handle("/score", aliasHandler) - - // Ollama API compatibility layer. - ollamaHandler := ollama.NewHTTPHandler(cfg.Log, cfg.Scheduler, cfg.SchedulerHTTP, cfg.AllowedOrigins, cfg.ModelManager) - router.Handle(ollama.APIPrefix+"/", ollamaHandler) - - // Anthropic Messages API compatibility layer. - anthropicHandler := anthropic.NewHandler(cfg.Log, cfg.SchedulerHTTP, cfg.AllowedOrigins, cfg.ModelManager) - router.Handle(anthropic.APIPrefix+"/", anthropicHandler) - - // OpenAI Responses API compatibility layer. - if cfg.IncludeResponsesAPI { - responsesHandler := responses.NewHTTPHandler(cfg.Log, cfg.SchedulerHTTP, cfg.AllowedOrigins) - router.Handle(responses.APIPrefix+"/", responsesHandler) - router.Handle(responses.APIPrefix, responsesHandler) - router.Handle("/v1"+responses.APIPrefix+"/", responsesHandler) - router.Handle("/v1"+responses.APIPrefix, responsesHandler) - router.Handle(inference.InferencePrefix+responses.APIPrefix+"/", responsesHandler) - router.Handle(inference.InferencePrefix+responses.APIPrefix, responsesHandler) - } - - return router -} diff --git a/pkg/routing/routing.go b/pkg/routing/routing.go deleted file mode 100644 index e715ff6ab..000000000 --- a/pkg/routing/routing.go +++ /dev/null @@ -1,24 +0,0 @@ -package routing - -import ( - "net/http" - "path" - "strings" -) - -type NormalizedServeMux struct { - *http.ServeMux -} - -func NewNormalizedServeMux() *NormalizedServeMux { - return &NormalizedServeMux{http.NewServeMux()} -} - -func (nm *NormalizedServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "//") { - normalizedPath := path.Clean(r.URL.Path) - r.URL.Path = normalizedPath - } - - nm.ServeMux.ServeHTTP(w, r) -} diff --git a/pkg/routing/service.go b/pkg/routing/service.go index 207e2e995..8fad5883e 100644 --- a/pkg/routing/service.go +++ b/pkg/routing/service.go @@ -24,7 +24,7 @@ type BackendDef struct { } // ServiceConfig holds the parameters needed to build the full inference -// service stack: model manager, model handler, scheduler, and router. +// service stack: model manager, model handler, and scheduler. type ServiceConfig struct { Log logging.Logger ClientConfig models.ClientConfig @@ -51,19 +51,6 @@ type ServiceConfig struct { // AllowedOrigins is forwarded to model, scheduler, Ollama, and // Anthropic handlers for CORS support. It may be nil. AllowedOrigins []string - - // ModelHandlerMiddleware optionally wraps the model handler before - // route registration (e.g. for access restrictions). - ModelHandlerMiddleware func(http.Handler) http.Handler - - // IncludeResponsesAPI enables the OpenAI Responses API compatibility - // layer in the router. - IncludeResponsesAPI bool - - // ExtraRoutes is called after the standard routes are registered. - // The Service fields (except Router) are fully populated when this - // is called, so the callback can reference them. - ExtraRoutes func(*NormalizedServeMux, *Service) } // Service is the assembled inference service stack. @@ -72,7 +59,6 @@ type Service struct { ModelHandler *models.HTTPHandler Scheduler *scheduling.Scheduler SchedulerHTTP *scheduling.HTTPHandler - Router *NormalizedServeMux Backends map[string]inference.Backend } @@ -109,21 +95,6 @@ func NewService(cfg ServiceConfig) (*Service, error) { Backends: backends, } - svc.Router = NewRouter(RouterConfig{ - Log: cfg.Log, - Scheduler: scheduler, - SchedulerHTTP: schedulerHTTP, - ModelHandler: modelHandler, - ModelManager: modelManager, - AllowedOrigins: cfg.AllowedOrigins, - ModelHandlerMiddleware: cfg.ModelHandlerMiddleware, - IncludeResponsesAPI: cfg.IncludeResponsesAPI, - }) - - if cfg.ExtraRoutes != nil { - cfg.ExtraRoutes(svc.Router, svc) - } - return svc, nil } diff --git a/router/.gitignore b/router/.gitignore new file mode 100644 index 000000000..b83d22266 --- /dev/null +++ b/router/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/router/Cargo.lock b/router/Cargo.lock new file mode 100644 index 000000000..a5927f5f7 --- /dev/null +++ b/router/Cargo.lock @@ -0,0 +1,951 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "axum-macros", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "dmr-router" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum", + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "tokio", + "tower", + "tower-http", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "byteorder", + "num-traits", +] + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "system-configuration", + "tokio", + "tower-layer", + "tower-service", + "tracing", + "windows-registry", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "hdrhistogram", + "indexmap", + "pin-project-lite", + "slab", + "sync_wrapper", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "http", + "http-body", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/router/Cargo.toml b/router/Cargo.toml new file mode 100644 index 000000000..28c8dd80b --- /dev/null +++ b/router/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "dmr-router" +version = "0.1.0" +edition = "2021" +description = "HTTP routing layer for Docker Model Runner (CGo-linked staticlib)" + +[lib] +name = "dmr_router" +crate-type = ["staticlib"] +path = "src/lib.rs" + +[dependencies] +dmr-common = { path = "../dmr-common" } +axum = { version = "0.8", features = ["http2", "macros"] } +tokio = { version = "1", features = ["full"] } +hyper = { version = "1", features = ["full"] } +hyper-util = { version = "0.1", features = ["full"] } +http-body-util = "0.1" +tower = { version = "0.5", features = ["full"] } +tower-http = { version = "0.6", features = ["normalize-path", "cors", "trace"] } +anyhow = "1" +bytes = "1" +tokio-stream = "0.1" +tracing = "0.1" diff --git a/router/dmr_router.h b/router/dmr_router.h new file mode 100644 index 000000000..3772ba663 --- /dev/null +++ b/router/dmr_router.h @@ -0,0 +1,101 @@ +/* + * dmr_router.h — C header for the dmr-router static library. + */ + +#ifndef DMR_ROUTER_H +#define DMR_ROUTER_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ── Byte slice ───────────────────────────────────────────────────────────── */ + +typedef struct DmrBytes { + uint8_t *ptr; + size_t len; +} DmrBytes; + +void dmr_free_bytes(DmrBytes b); + +/* ── Streaming write callback ─────────────────────────────────────────────── * + * + * Rust creates an mpsc channel and passes the raw sender pointer as + * stream_ctx. Go calls dmr_write_chunk() for every chunk written to the + * ResponseWriter (and on every Flush()). Rust reads from the channel + * receiver and forwards chunks to the axum response stream. + * + * After ServeHTTP returns Go calls dmr_close_stream() exactly once to + * signal end-of-body. + */ + +/* + * Send one body chunk to the Rust stream. + * data/len: chunk bytes. len==0 is a flush hint. + * Returns 0 on success, non-zero if the client disconnected. + */ +int32_t dmr_write_chunk(void *stream_ctx, const uint8_t *data, size_t len); + +/* + * Signal end-of-body and release the stream_ctx sender. + * Must be called exactly once, after the last dmr_write_chunk. + */ +void dmr_close_stream(void *stream_ctx); + +/* ── Request / response structs ───────────────────────────────────────────── */ + +typedef struct DmrRequest { + const char *method; /* NUL-terminated; valid for duration of call */ + const char *path; /* NUL-terminated; valid for duration of call */ + DmrBytes header_block; /* "Name: Value\0…\0"; Rust allocates, Go frees */ + DmrBytes body; /* request body; Rust allocates, Go frees */ +} DmrRequest; + +typedef struct DmrResponse { + uint16_t status; + DmrBytes header_block; /* "Name: Value\0…\0"; Go allocates via C malloc */ + /* + * stream_ctx is set by Rust before calling the handler. Go must call + * dmr_write_chunk(stream_ctx, ...) for each body chunk and + * dmr_close_stream(stream_ctx) when done. Do not set body.ptr/len. + */ + void *stream_ctx; +} DmrResponse; + +/* ── Handler callback ─────────────────────────────────────────────────────── */ + +typedef void (*DmrHandlerFn)(void *ctx, + const DmrRequest *req, + DmrResponse *resp); + +/* ── Configuration ────────────────────────────────────────────────────────── */ + +typedef struct DmrRouterConfig { + const char *listen_sock; + uint16_t listen_port; + DmrHandlerFn handler_fn; + void *handler_ctx; + const char *backend_sock; + uint16_t backend_port; + const char *allowed_origins; + const char *version; +} DmrRouterConfig; + +/* ── Opaque handle ────────────────────────────────────────────────────────── */ + +typedef struct DmrRouterHandle DmrRouterHandle; + +DmrRouterHandle *dmr_router_new_handle(void); +void dmr_router_free_handle(DmrRouterHandle *handle); +int dmr_router_serve(const DmrRouterConfig *cfg, + DmrRouterHandle **handle_out); +void dmr_router_stop(DmrRouterHandle *handle); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* DMR_ROUTER_H */ diff --git a/router/src/cors.rs b/router/src/cors.rs new file mode 100644 index 000000000..1f4247f8e --- /dev/null +++ b/router/src/cors.rs @@ -0,0 +1,132 @@ +//! CORS middleware matching the Go CorsMiddleware behaviour. +//! +//! Rules (from pkg/middleware/cors.go): +//! - If allowedOrigins contains "*", all origins are allowed. +//! - Otherwise, only origins in the allowed set receive CORS headers. +//! - Requests from origins not in the allowed set receive 403. +//! - Requests with no Origin header pass through unchanged. +//! - OPTIONS preflight: if origin is valid, respond 204 with CORS headers; +//! otherwise fall through to the next handler. + +use std::collections::HashSet; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use axum::body::Body; +use axum::http::{header, Method, Request, Response, StatusCode}; +use tower::{Layer, Service}; + +/// Builds a CORS layer from an origin allow-list. +#[derive(Clone)] +pub struct CorsLayer { + allow_all: bool, + allowed: Arc>, +} + +impl CorsLayer { + pub fn new(origins: Vec) -> Self { + let allow_all = origins.iter().any(|o| o == "*"); + let allowed = Arc::new(origins.into_iter().collect()); + Self { allow_all, allowed } + } +} + +impl Layer for CorsLayer { + type Service = CorsMiddleware; + fn layer(&self, inner: S) -> Self::Service { + CorsMiddleware { + inner, + allow_all: self.allow_all, + allowed: self.allowed.clone(), + } + } +} + +#[derive(Clone)] +pub struct CorsMiddleware { + inner: S, + allow_all: bool, + allowed: Arc>, +} + +impl Service> for CorsMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let allow_all = self.allow_all; + let allowed = self.allowed.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + let origin = req + .headers() + .get(header::ORIGIN) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let origin_allowed = match &origin { + None => true, // no Origin header → pass through + Some(o) => allow_all || allowed.contains(o.as_str()), + }; + + if let Some(ref o) = origin { + if !origin_allowed { + let mut resp = Response::new(Body::from("Origin not allowed")); + *resp.status_mut() = StatusCode::FORBIDDEN; + return Ok(resp); + } + + // Handle OPTIONS preflight. + if req.method() == Method::OPTIONS { + let mut resp = Response::new(Body::empty()); + *resp.status_mut() = StatusCode::NO_CONTENT; + let h = resp.headers_mut(); + h.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + o.parse().unwrap(), + ); + h.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + "true".parse().unwrap(), + ); + h.insert( + header::ACCESS_CONTROL_ALLOW_METHODS, + "GET, POST, DELETE".parse().unwrap(), + ); + h.insert( + header::ACCESS_CONTROL_ALLOW_HEADERS, + "*".parse().unwrap(), + ); + return Ok(resp); + } + } + + // Pass through to inner handler. + let mut resp = inner.call(req).await?; + + // Attach Access-Control-Allow-Origin for non-preflight allowed origins. + if let Some(o) = origin { + if origin_allowed { + resp.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + o.parse().unwrap(), + ); + } + } + + Ok(resp) + }) + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs new file mode 100644 index 000000000..bd22cc7be --- /dev/null +++ b/router/src/lib.rs @@ -0,0 +1,287 @@ +//! dmr-router: HTTP routing layer for Docker Model Runner, compiled as a +//! CGo-linked static library. +//! +//! Exposes two C functions to Go: +//! +//! dmr_router_serve(cfg, handle_out) -> i32 +//! Blocks until the router shuts down. Returns 0 on clean shutdown, +//! non-zero on error. Must be called from a dedicated goroutine. +//! +//! dmr_router_stop(handle) -> void +//! Signals the router to shut down gracefully and frees the handle. +//! Safe to call from any goroutine. +//! +//! dmr_free_bytes(b: DmrBytes) -> void +//! Frees a byte buffer allocated on the Rust side of the FFI boundary. +//! Go calls this to release DmrRequest.header_block and DmrRequest.body +//! after it has finished reading them. + +mod cors; +mod proxy; +mod routes; + +use std::ffi::CStr; +use std::net::SocketAddr; +use std::os::raw::{c_char, c_int, c_void}; +use std::path::PathBuf; + +use tokio::sync::oneshot; +use tracing::info; + +use crate::proxy::{BackendClient, DmrBytes, DmrHandlerFn}; +use crate::routes::build_router; + +// ── Default CORS origins ───────────────────────────────────────────────────── + +/// Origins always added to the CORS allow-list, matching +/// Go's envconfig.AllowedOrigins() baseline. +const DEFAULT_ORIGINS: &[&str] = &[ + "http://localhost", + "http://127.0.0.1", + "http://0.0.0.0", +]; + +// ── C-facing configuration struct ─────────────────────────────────────────── + +/// Configuration passed from Go to `dmr_router_serve`. +/// All string fields are NUL-terminated C strings owned by the caller; they +/// must remain valid for the duration of the `dmr_router_serve` call. +#[repr(C)] +pub struct DmrRouterConfig { + /// NUL-terminated path of the Unix socket the router listens on. + /// Pass NULL to use a TCP port instead. + pub listen_sock: *const c_char, + /// TCP port the router listens on. Ignored when listen_sock is non-NULL. + pub listen_port: u16, + + /// In-process Go handler (no network hop). + /// When non-NULL, backend_sock and backend_port are ignored. + pub handler_fn: Option, + pub handler_ctx: *mut c_void, + + /// Unix socket path of the Go backend (used when handler_fn is NULL). + pub backend_sock: *const c_char, + /// TCP port of the Go backend (used when handler_fn is NULL). + pub backend_port: u16, + + /// NUL-terminated comma-separated allowed CORS origins. May be NULL. + pub allowed_origins: *const c_char, + /// NUL-terminated version string served at GET /version. May be NULL. + pub version: *const c_char, +} + +/// Opaque stop handle returned to Go so it can call `dmr_router_stop`. +pub struct DmrRouterHandle { + tx: Option>, +} + +// ── Parsed (safe-Rust) configuration ──────────────────────────────────────── + +/// A network address: either a Unix domain socket path or a TCP port. +enum Addr { + Unix(PathBuf), + Tcp(u16), +} + +struct Config { + listen: Addr, + backend: BackendClient, + allowed_origins: Vec, + version: String, +} + +// ── C string helpers ───────────────────────────────────────────────────────── + +unsafe fn cstr_to_string(ptr: *const c_char) -> Option { + if ptr.is_null() { + None + } else { + Some(unsafe { CStr::from_ptr(ptr) }.to_string_lossy().into_owned()) + } +} + +unsafe fn parse_addr(sock_ptr: *const c_char, port: u16) -> Addr { + match unsafe { cstr_to_string(sock_ptr) } { + Some(path) => Addr::Unix(PathBuf::from(path)), + None => Addr::Tcp(port), + } +} + +/// Parse a `DmrRouterConfig` into a safe `Config`. +/// +/// # Safety +/// All pointer fields in `cfg` must be valid NUL-terminated C strings or NULL. +unsafe fn parse_config(cfg: &DmrRouterConfig) -> Config { + let listen = unsafe { parse_addr(cfg.listen_sock, cfg.listen_port) }; + + let backend = if let Some(f) = cfg.handler_fn { + // In-process mode: call Go directly, no socket. + unsafe { BackendClient::new_go(f, cfg.handler_ctx) } + } else { + match unsafe { cstr_to_string(cfg.backend_sock) } { + Some(path) => BackendClient::new_unix(PathBuf::from(path)), + None => BackendClient::new_tcp(cfg.backend_port), + } + }; + + let mut allowed_origins: Vec = + DEFAULT_ORIGINS.iter().map(|s| s.to_string()).collect(); + if let Some(raw) = unsafe { cstr_to_string(cfg.allowed_origins) } { + for o in raw.split(',') { + let o = o.trim().to_string(); + if !o.is_empty() { + allowed_origins.push(o); + } + } + } + + let version = unsafe { cstr_to_string(cfg.version) } + .unwrap_or_else(|| "unknown".to_string()); + + Config { listen, backend, allowed_origins, version } +} + +// ── Public C API ───────────────────────────────────────────────────────────── + +/// Allocate a new `DmrRouterHandle` and return it to Go. +/// +/// Go calls this **before** spawning the `dmr_router_serve` goroutine so that +/// it holds a valid stop handle from the very start — before `block_on` ever +/// runs. The same handle pointer is then passed to `dmr_router_serve` via +/// `handle_out` so Rust can write the `oneshot::Sender` into it. +/// +/// The returned pointer must be freed exactly once, either by `dmr_router_stop` +/// or by `dmr_router_free_handle` if the router never started. +#[unsafe(no_mangle)] +pub extern "C" fn dmr_router_new_handle() -> *mut DmrRouterHandle { + Box::into_raw(Box::new(DmrRouterHandle { tx: None })) +} + +/// Free a `DmrRouterHandle` that was never passed to `dmr_router_serve`. +/// Use `dmr_router_stop` for handles that were passed to `dmr_router_serve`. +/// +/// # Safety +/// `handle` must be a pointer obtained from `dmr_router_new_handle` that has +/// not yet been passed to `dmr_router_serve` and has not been freed already. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_router_free_handle(handle: *mut DmrRouterHandle) { + if !handle.is_null() { + drop(unsafe { Box::from_raw(handle) }); + } +} + +/// Free a `DmrBytes` buffer that was allocated by the Rust side of the FFI +/// boundary. Go calls this after reading `DmrRequest.header_block` and +/// `DmrRequest.body`. +/// +/// # Safety +/// `b.ptr` must be a pointer previously allocated by the Rust global allocator +/// (i.e. from a `Vec` via `Vec::into_raw_parts`), or NULL. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_free_bytes(b: DmrBytes) { + if !b.ptr.is_null() && b.len > 0 { + drop(unsafe { Vec::from_raw_parts(b.ptr, b.len, b.len) }); + } +} + +/// Start the router and block until `dmr_router_stop` is called or a fatal +/// error occurs. Returns 0 on clean shutdown, 1 on error. +/// +/// # Safety +/// `cfg` must point to a valid `DmrRouterConfig`. All pointer fields inside +/// `cfg` must be valid NUL-terminated C strings for the duration of this call. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_router_serve( + cfg: *const DmrRouterConfig, + handle_out: *mut *mut DmrRouterHandle, +) -> c_int { + if cfg.is_null() { + return 1; + } + let config = unsafe { parse_config(&*cfg) }; + + dmr_common::init_tracing("info"); + + let rt = match tokio::runtime::Runtime::new() { + Ok(r) => r, + Err(e) => { + eprintln!("dmr-router: failed to create tokio runtime: {e}"); + return 1; + } + }; + + let (stop_tx, stop_rx) = oneshot::channel::<()>(); + + // Wire the stop sender into the pre-allocated handle. + // Go calls dmr_router_new_handle() before starting this function and + // passes the result as *handle_out; we write the tx field into the + // existing allocation rather than replacing the pointer. + if !handle_out.is_null() { + let h: *mut DmrRouterHandle = unsafe { *handle_out }; + if !h.is_null() { + unsafe { (*h).tx = Some(stop_tx) }; + } + // If *handle_out is null (legacy / direct call), allocate a new handle. + else { + let handle = Box::new(DmrRouterHandle { tx: Some(stop_tx) }); + unsafe { *handle_out = Box::into_raw(handle) }; + } + } + + let result = rt.block_on(serve(config, stop_rx)); + match result { + Ok(()) => 0, + Err(e) => { + eprintln!("dmr-router: {e}"); + 1 + } + } +} + +/// Signal the router to shut down gracefully and free the handle. +/// +/// # Safety +/// `handle` must be a pointer previously obtained from `dmr_router_serve` +/// via `handle_out`, and must not have been freed already. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_router_stop(handle: *mut DmrRouterHandle) { + if handle.is_null() { + return; + } + let mut h = unsafe { Box::from_raw(handle) }; + if let Some(tx) = h.tx.take() { + let _ = tx.send(()); + } +} + +// ── Async server core ──────────────────────────────────────────────────────── + +async fn serve(cfg: Config, stop_rx: oneshot::Receiver<()>) -> anyhow::Result<()> { + let app = build_router(cfg.backend, cfg.allowed_origins, cfg.version); + + let shutdown = async move { + let _ = stop_rx.await; + info!("dmr-router: shutdown signal received"); + }; + + match cfg.listen { + Addr::Tcp(port) => { + let addr: SocketAddr = format!("0.0.0.0:{port}").parse()?; + info!(%addr, "dmr-router listening on TCP"); + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app) + .with_graceful_shutdown(shutdown) + .await?; + } + Addr::Unix(ref path) => { + let _ = std::fs::remove_file(path); + let listener = tokio::net::UnixListener::bind(path)?; + info!(path = %path.display(), "dmr-router listening on Unix socket"); + axum::serve(listener, app) + .with_graceful_shutdown(shutdown) + .await?; + } + } + + Ok(()) +} diff --git a/router/src/proxy.rs b/router/src/proxy.rs new file mode 100644 index 000000000..23ed24bf1 --- /dev/null +++ b/router/src/proxy.rs @@ -0,0 +1,456 @@ +//! Backend dispatch: network reverse-proxy or direct in-process call into +//! Go's http.Handler via a streaming C callback. +//! +//! The GoHandler path: +//! 1. Rust creates a tokio::sync::mpsc channel. +//! 2. The raw sender is cast to *mut c_void and written into DmrResponse.stream_ctx. +//! 3. Go's ServeHTTP runs on a spawn_blocking thread; every Write/Flush call +//! invokes dmr_write_chunk() which sends the chunk into the channel. +//! 4. When ServeHTTP returns, Go calls dmr_close_stream(), which drops the sender. +//! 5. Rust polls the mpsc receiver as a streaming Body, forwarding chunks to the +//! axum client as they arrive — full streaming with no buffering. + +use std::os::raw::{c_void, c_int}; +use std::path::PathBuf; +use std::sync::Arc; + +use axum::body::Body; +use axum::extract::Request; +use axum::http::{HeaderName, HeaderValue, StatusCode, Uri}; +use axum::response::{IntoResponse, Response}; +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper::body::Incoming; +use hyper_util::client::legacy::connect::HttpConnector; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; +use tokio::net::UnixStream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt as TokioStreamExt; +use tracing::warn; + +// ── FFI types (must match dmr_router.h exactly) ────────────────────────────── + +#[repr(C)] +pub struct DmrBytes { + pub ptr: *mut u8, + pub len: usize, +} + +#[repr(C)] +pub struct DmrRequest<'a> { + pub method: *const std::os::raw::c_char, + pub path: *const std::os::raw::c_char, + pub header_block: DmrBytes, + pub body: DmrBytes, + _lifetime: std::marker::PhantomData<&'a ()>, +} + +/// Response struct passed to Go. +/// Go sets `status` and `header_block`, then calls `dmr_write_chunk` / +/// `dmr_close_stream` using `stream_ctx` for the body. +#[repr(C)] +pub struct DmrResponse { + pub status: u16, + pub header_block: DmrBytes, + pub stream_ctx: *mut c_void, +} + +pub type DmrHandlerFn = + unsafe extern "C" fn(ctx: *mut c_void, req: *const DmrRequest<'_>, resp: *mut DmrResponse); + +// ── Streaming C exports ─────────────────────────────────────────────────────── + +/// The item type carried through the mpsc channel. +/// `None` = end of stream (dmr_close_stream was called). +type ChunkSender = mpsc::Sender>; + +/// Send one body chunk from Go into the Rust mpsc channel. +/// +/// # Safety +/// `stream_ctx` must be a raw pointer obtained from `Box::into_raw::`. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_write_chunk( + stream_ctx: *mut c_void, + data: *const u8, + len: usize, +) -> c_int { + if stream_ctx.is_null() { + return -1; + } + let tx: &ChunkSender = unsafe { &*(stream_ctx as *const ChunkSender) }; + if len == 0 { + // Flush hint — no data, just a signal; nothing to send. + return 0; + } + let chunk = Bytes::copy_from_slice(unsafe { std::slice::from_raw_parts(data, len) }); + // try_send so we never block (the channel has capacity 64). + match tx.try_send(Some(chunk)) { + Ok(()) => 0, + Err(_) => -1, // client disconnected or buffer full + } +} + +/// Signal end-of-body and free the sender. +/// +/// # Safety +/// `stream_ctx` must be a pointer previously obtained from +/// `Box::into_raw::` and must not have been freed already. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn dmr_close_stream(stream_ctx: *mut c_void) { + if stream_ctx.is_null() { + return; + } + // Dropping the Box drops the Sender, closing the channel. + let _ = unsafe { Box::from_raw(stream_ctx as *mut ChunkSender) }; +} + +/// Free a DmrBytes buffer allocated by the Rust side. +unsafe fn dmr_free_bytes(b: DmrBytes) { + if !b.ptr.is_null() && b.len > 0 { + drop(unsafe { Vec::from_raw_parts(b.ptr, b.len, b.len) }); + } +} + +// ── GoHandlerInner ──────────────────────────────────────────────────────────── + +struct GoHandlerInner { + handler_fn: DmrHandlerFn, + handler_ctx: *mut c_void, +} + +// SAFETY: handler_ctx is a Go cgo.Handle (uintptr_t cast to pointer). +unsafe impl Send for GoHandlerInner {} +unsafe impl Sync for GoHandlerInner {} + +// ── BackendClient ───────────────────────────────────────────────────────────── + +#[derive(Clone)] +enum BackendAddr { + GoHandler(Arc), + Unix(PathBuf), + Tcp(u16), +} + +#[derive(Clone)] +pub struct BackendClient { + addr: BackendAddr, + tcp_client: Option>>, +} + +impl BackendClient { + /// # Safety + /// `handler_fn` must remain valid for the lifetime of this `BackendClient`. + pub unsafe fn new_go(handler_fn: DmrHandlerFn, handler_ctx: *mut c_void) -> Self { + Self { + addr: BackendAddr::GoHandler(Arc::new(GoHandlerInner { handler_fn, handler_ctx })), + tcp_client: None, + } + } + + pub fn new_unix(path: PathBuf) -> Self { + Self { addr: BackendAddr::Unix(path), tcp_client: None } + } + + pub fn new_tcp(port: u16) -> Self { + let connector = HttpConnector::new(); + let client = Client::builder(TokioExecutor::new()).build(connector); + Self { addr: BackendAddr::Tcp(port), tcp_client: Some(Arc::new(client)) } + } + + pub async fn proxy(&self, req: Request, target_path: &str) -> Response { + match &self.addr { + BackendAddr::GoHandler(inner) => call_go_handler(inner.clone(), req, target_path).await, + BackendAddr::Tcp(port) => self.proxy_tcp(req, target_path, *port).await, + BackendAddr::Unix(sock) => self.proxy_unix(req, target_path, sock.clone()).await, + } + } + + async fn proxy_tcp(&self, req: Request, target_path: &str, port: u16) -> Response { + let client = self.tcp_client.as_ref().unwrap(); + let uri = match build_uri(&format!("http://127.0.0.1:{port}"), target_path, req.uri()) { + Ok(u) => u, + Err(e) => { + warn!("failed to build upstream URI: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + let (parts, body) = req.into_parts(); + let mut upstream = hyper::Request::builder() + .method(parts.method).uri(uri).version(hyper::Version::HTTP_11); + for (k, v) in &parts.headers { upstream = upstream.header(k, v); } + let upstream_req = match upstream.body(body) { + Ok(r) => r, + Err(e) => { + warn!("failed to build upstream request: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + match client.request(upstream_req).await { + Ok(resp) => strip_cors(resp.map(Body::new)), + Err(e) => { warn!("upstream error: {e}"); (StatusCode::BAD_GATEWAY, "bad gateway").into_response() } + } + } + + async fn proxy_unix(&self, req: Request, target_path: &str, sock: PathBuf) -> Response { + let stream = match UnixStream::connect(&sock).await { + Ok(s) => s, + Err(e) => { + warn!("failed to connect to backend socket {}: {e}", sock.display()); + return (StatusCode::BAD_GATEWAY, "backend unavailable").into_response(); + } + }; + let (mut sender, conn) = + match hyper::client::conn::http1::handshake(hyper_util::rt::TokioIo::new(stream)).await { + Ok(p) => p, + Err(e) => { + warn!("HTTP handshake failed: {e}"); + return (StatusCode::BAD_GATEWAY, "bad gateway").into_response(); + } + }; + tokio::spawn(async move { if let Err(e) = conn.await { warn!("backend connection error: {e}"); } }); + + let uri = match build_uri("http://localhost", target_path, req.uri()) { + Ok(u) => u, + Err(e) => { + warn!("failed to build upstream URI: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + let (parts, body) = req.into_parts(); + let body_bytes: Bytes = match body.collect().await { + Ok(c) => c.to_bytes(), + Err(e) => { + warn!("failed to read request body: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + let content_length = body_bytes.len(); + let mut upstream = hyper::Request::builder() + .method(parts.method).uri(uri).version(hyper::Version::HTTP_11); + for (k, v) in &parts.headers { + if k != axum::http::header::CONTENT_LENGTH { upstream = upstream.header(k, v); } + } + upstream = upstream + .header(axum::http::header::CONTENT_LENGTH, content_length) + .header(axum::http::header::HOST, "localhost"); + let upstream_req = match upstream.body(Full::new(body_bytes)) { + Ok(r) => r, + Err(e) => { + warn!("failed to build upstream request: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + match sender.send_request(upstream_req).await { + Ok(resp) => strip_cors(resp.map(|b: Incoming| Body::new(b))), + Err(e) => { warn!("upstream send error: {e}"); (StatusCode::BAD_GATEWAY, "bad gateway").into_response() } + } + } +} + +// ── In-process Go handler dispatch (streaming) ─────────────────────────────── + +/// Dispatch a request directly to Go's http.Handler via the C callback. +/// +/// Architecture: +/// - An mpsc channel (capacity 64) is created; the Sender is heap-allocated +/// and its raw pointer is written into DmrResponse.stream_ctx. +/// - Go's ServeHTTP runs on a spawn_blocking thread; every Write/Flush call +/// invokes dmr_write_chunk(), sending chunks into the channel. +/// - When ServeHTTP returns, Go calls dmr_close_stream(), dropping the Sender +/// and closing the channel. +/// - The Receiver is wrapped in a ReceiverStream and returned as the axum +/// response Body, so chunks flow to the client as they are produced. +async fn call_go_handler( + inner: Arc, + req: Request, + target_path: &str, +) -> Response { + // ── 1. Collect request body ─────────────────────────────────────────── + let method_str = req.method().to_string(); + let path_str = target_path.to_owned(); + let headers = req.headers().clone(); + + let (_, body) = req.into_parts(); + let body_bytes: Bytes = match body.collect().await { + Ok(c) => c.to_bytes(), + Err(e) => { + warn!("failed to read request body: {e}"); + return (StatusCode::INTERNAL_SERVER_ERROR, "bad gateway").into_response(); + } + }; + + // ── 2. Serialise request headers ────────────────────────────────────── + let mut hdr_block: Vec = Vec::new(); + for (name, value) in &headers { + hdr_block.extend_from_slice(name.as_str().as_bytes()); + hdr_block.extend_from_slice(b": "); + hdr_block.extend_from_slice(value.as_bytes()); + hdr_block.push(0); + } + hdr_block.push(0); + + // ── 3. Create streaming channel ─────────────────────────────────────── + // Capacity 64: enough to buffer several chunks so Go is rarely blocked. + let (tx, rx) = mpsc::channel::>(64); + + // Heap-allocate the sender; the raw pointer goes to Go via stream_ctx. + // Go frees it by calling dmr_close_stream(). + let tx_ptr = Box::into_raw(Box::new(tx)) as *mut c_void; + + // ── 4. Launch Go handler on a blocking thread ───────────────────────── + // We need the header/status before we can build the axum response, so + // we use a oneshot channel to get that metadata back from the thread. + let (meta_tx, meta_rx) = tokio::sync::oneshot::channel::<(u16, Vec<(String, String)>)>(); + + let body_len = body_bytes.len(); + let body_ptr = { let mut v = body_bytes.to_vec(); let p = v.as_mut_ptr(); std::mem::forget(v); p }; + let hdr_len = hdr_block.len(); + let hdr_ptr = { let mut hb = hdr_block; let p = hb.as_mut_ptr(); std::mem::forget(hb); p }; + let method_c = std::ffi::CString::new(method_str).unwrap_or_default(); + let path_c = std::ffi::CString::new(path_str).unwrap_or_default(); + + // Wrap raw pointers for Send across the thread boundary. + // Both handler_ctx (Go cgo.Handle) and tx_ptr (Box) are + // safe to send to another thread. + struct SendPtr(*mut c_void); + unsafe impl Send for SendPtr {} + + let ctx_send = SendPtr(inner.handler_ctx); + let tx_ptr_send = SendPtr(tx_ptr); + // hdr_ptr and body_ptr also need wrapping since they're *mut u8. + struct SendU8Ptr(*mut u8); + unsafe impl Send for SendU8Ptr {} + let hdr_ptr_send = SendU8Ptr(hdr_ptr); + let body_ptr_send = SendU8Ptr(body_ptr); + + let handler_fn = inner.handler_fn; + + // All raw pointers used inside the closure are safe to send: + // - handler_ctx: Go cgo.Handle (uintptr_t) + // - tx_ptr: Box allocated on this thread + // - hdr_ptr / body_ptr: Vec allocations from this thread + // - method_c / path_c: CString allocated on this thread + // Wrap the closure in a SendClosure to assert this to the compiler. + struct SendClosure(F); + unsafe impl Send for SendClosure {} + impl SendClosure { + fn call(self) { (self.0)() } + } + + let closure = SendClosure(move || { + let handler_ctx = ctx_send.0; + let stream_ctx = tx_ptr_send.0; + let hdr_ptr = hdr_ptr_send.0; + let body_ptr = body_ptr_send.0; + + let req_ffi = DmrRequest { + method: method_c.as_ptr(), + path: path_c.as_ptr(), + header_block: DmrBytes { ptr: hdr_ptr, len: hdr_len }, + body: DmrBytes { ptr: body_ptr, len: body_len }, + _lifetime: std::marker::PhantomData, + }; + let mut resp_ffi = DmrResponse { + status: 500, + header_block: DmrBytes { ptr: std::ptr::null_mut(), len: 0 }, + stream_ctx, + }; + + // Call Go. Go calls dmr_write_chunk() for each chunk and + // dmr_close_stream() when ServeHTTP returns. + unsafe { handler_fn(handler_ctx, &req_ffi, &mut resp_ffi) }; + + // Free request buffers. + unsafe { + dmr_free_bytes(DmrBytes { ptr: hdr_ptr, len: hdr_len }); + dmr_free_bytes(DmrBytes { ptr: body_ptr, len: body_len }); + } + + // Parse response headers and send metadata to the async side. + let status = resp_ffi.status; + let mut parsed_headers: Vec<(String, String)> = Vec::new(); + if !resp_ffi.header_block.ptr.is_null() && resp_ffi.header_block.len > 0 { + let raw = unsafe { + std::slice::from_raw_parts(resp_ffi.header_block.ptr, resp_ffi.header_block.len) + }; + let mut pos = 0; + while pos < raw.len() { + let end = raw[pos..].iter().position(|&b| b == 0) + .map(|i| pos + i).unwrap_or(raw.len()); + if end == pos { break; } + if let Ok(entry) = std::str::from_utf8(&raw[pos..end]) { + if let Some(sep) = entry.find(": ") { + parsed_headers.push((entry[..sep].to_owned(), entry[sep + 2..].to_owned())); + } + } + pos = end + 1; + } + } + unsafe { dmr_free_bytes(resp_ffi.header_block); } + + // Ignore error: if receiver dropped, the client already disconnected. + let _ = meta_tx.send((status, parsed_headers)); + }); + tokio::task::spawn_blocking(move || closure.call()); + + // ── 5. Wait for status + headers, then stream body ──────────────────── + // Use a generous timeout so a misbehaving handler never blocks the + // Tokio executor indefinitely. + let (status_code, parsed_headers) = match tokio::time::timeout( + std::time::Duration::from_secs(300), + meta_rx, + ).await { + Ok(Ok(m)) => m, + Ok(Err(_)) => { + warn!("go handler thread dropped meta sender"); + return (StatusCode::INTERNAL_SERVER_ERROR, "handler error").into_response(); + } + Err(_) => { + warn!("go handler timed out waiting for response headers"); + return (StatusCode::GATEWAY_TIMEOUT, "handler timeout").into_response(); + } + }; + + let status = StatusCode::from_u16(status_code) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + // Build a streaming Body from the mpsc receiver. + // None items close the stream; Bytes items are forwarded as chunks. + // Flatten Option items: None closes the stream, Some(b) yields b. + let stream = ReceiverStream::new(rx) + .take_while(|item: &Option| item.is_some()) + .map(|item: Option| Ok::(item.unwrap())); + let streaming_body = Body::from_stream(stream); + + let mut builder = axum::http::response::Builder::new().status(status); + for (name, value) in &parsed_headers { + if let (Ok(n), Ok(v)) = ( + HeaderName::from_bytes(name.as_bytes()), + HeaderValue::from_str(value), + ) { + builder = builder.header(n, v); + } + } + + builder.body(streaming_body).unwrap_or_else(|_| { + (StatusCode::INTERNAL_SERVER_ERROR, "response build error").into_response() + }) +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn build_uri(base: &str, target_path: &str, original: &Uri) -> anyhow::Result { + let pq = match original.query() { + Some(q) => format!("{target_path}?{q}"), + None => target_path.to_string(), + }; + Ok(format!("{base}{pq}").parse::()?) +} + +fn strip_cors(resp: Response) -> Response { + let (mut parts, body) = resp.into_parts(); + parts.headers.remove(axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN); + Response::from_parts(parts, body) +} diff --git a/router/src/routes.rs b/router/src/routes.rs new file mode 100644 index 000000000..a70aa6005 --- /dev/null +++ b/router/src/routes.rs @@ -0,0 +1,109 @@ +//! Route table mirroring pkg/routing/router.go. +//! +//! Design: +//! - All inference routes proxy to the Go backend unchanged. +//! - Alias routes (/v1/, /rerank, /score, /tokenize, /detokenize) prepend +//! "/engines" to the path before proxying (mirrors AliasHandler). +//! - Path normalisation (double-slash collapsing) is applied globally via +//! tower-http NormalizePath middleware. +//! - CORS is applied globally via our custom CorsLayer. + +use axum::extract::{Request, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::routing::{any, get}; +use axum::Router; +use tower::ServiceBuilder; +use tower_http::normalize_path::NormalizePathLayer; + +use crate::cors::CorsLayer; +use crate::proxy::BackendClient; + +/// Shared application state. +#[derive(Clone)] +struct AppState { + backend: BackendClient, + version: String, +} + +/// Build the full axum Router. +pub fn build_router( + backend: BackendClient, + allowed_origins: Vec, + version: String, +) -> Router { + let state = AppState { backend, version }; + + let router = Router::new() + // ── Static / informational routes ────────────────────────────────── + .route("/", get(handle_health)) + .route("/version", get(handle_version)) + // ── Model management ─────────────────────────────────────────────── + .route("/models", any(proxy_direct)) + .route("/models/{*path}", any(proxy_direct)) + // ── Inference engine (direct) ────────────────────────────────────── + .route("/engines/{*path}", any(proxy_direct)) + // ── Path aliases → prepend /engines ─────────────────────────────── + // These mirror the deleted AliasHandler in pkg/middleware/alias.go. + .route("/v1/{*path}", any(proxy_alias)) + .route("/rerank", any(proxy_alias)) + .route("/score", any(proxy_alias)) + .route("/tokenize", any(proxy_alias)) + .route("/detokenize", any(proxy_alias)) + // ── Ollama compatibility layer (/api/) ───────────────────────────── + .route("/api/{*path}", any(proxy_direct)) + // ── Anthropic compatibility layer (/anthropic/) ──────────────────── + .route("/anthropic/{*path}", any(proxy_direct)) + // ── OpenAI Responses API ─────────────────────────────────────────── + .route("/responses", any(proxy_direct)) + .route("/responses/{*path}", any(proxy_direct)) + .route("/v1/responses", any(proxy_direct)) + .route("/v1/responses/{*path}", any(proxy_direct)) + .route("/engines/responses", any(proxy_direct)) + .route("/engines/responses/{*path}", any(proxy_direct)) + // ── Observability ────────────────────────────────────────────────── + .route("/logs", any(proxy_direct)) + .route("/metrics", any(proxy_direct)) + .with_state(state); + + // Apply path normalisation and CORS as outer layers. + let cors = CorsLayer::new(allowed_origins); + let svc = ServiceBuilder::new() + .layer(NormalizePathLayer::trim_trailing_slash()) + .layer(cors); + + router.layer(svc) +} + +// ── Handlers ──────────────────────────────────────────────────────────────── + +/// GET / → "Docker Model Runner is running" +async fn handle_health() -> impl IntoResponse { + (StatusCode::OK, "Docker Model Runner is running") +} + +/// GET /version → {"version":""} +async fn handle_version(State(state): State) -> impl IntoResponse { + let body = format!(r#"{{"version":"{}"}}"#, state.version); + (StatusCode::OK, [("content-type", "application/json")], body) +} + +/// Proxy request to the backend with the path unchanged. +async fn proxy_direct(State(state): State, req: Request) -> Response { + let path = req.uri().path_and_query().map_or_else( + || req.uri().path().to_owned(), + |pq| pq.as_str().to_owned(), + ); + state.backend.proxy(req, &path).await +} + +/// Alias handler: prepend "/engines" to the path, then proxy. +/// Mirrors the deleted pkg/middleware/alias.go AliasHandler. +async fn proxy_alias(State(state): State, req: Request) -> Response { + let original = req.uri().path_and_query().map_or_else( + || req.uri().path().to_owned(), + |pq| pq.as_str().to_owned(), + ); + let aliased = format!("/engines{original}"); + state.backend.proxy(req, &aliased).await +}