diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 558fdb9980..604556692c 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -138,6 +138,7 @@ var ( Version: version, Host: viper.GetString("host"), Port: viper.GetInt("port"), + ListenHost: viper.GetString("listen-host"), BaseURL: viper.GetString("base-url"), ResourcePath: viper.GetString("base-path"), ExportTranslations: viper.GetBool("export-translations"), @@ -184,6 +185,7 @@ func init() { // HTTP-specific flags httpCmd.Flags().Int("port", 8082, "HTTP server port") + httpCmd.Flags().String("listen-host", "", "Host the HTTP server binds to (e.g. 127.0.0.1). Empty binds to all interfaces.") httpCmd.Flags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)") httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") httpCmd.Flags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses") @@ -204,6 +206,7 @@ func init() { _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) + _ = viper.BindPFlag("listen-host", httpCmd.Flags().Lookup("listen-host")) _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) _ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path")) _ = viper.BindPFlag("scope-challenge", httpCmd.Flags().Lookup("scope-challenge")) diff --git a/pkg/http/server.go b/pkg/http/server.go index 3c9d7679e4..36d3e111bc 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -5,9 +5,11 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "os" "os/signal" + "strconv" "syscall" "time" @@ -32,9 +34,13 @@ type ServerConfig struct { // GitHub Host to target for API requests (e.g. github.com or github.enterprise.com) Host string - // Port to listen on (default: 8082) + // Port to listen on (default: 8082). Port int + // ListenHost is the host the HTTP server binds to (e.g. "127.0.0.1"). + // When empty, the server binds to all interfaces. Combined with Port. + ListenHost string + // BaseURL is the publicly accessible URL of this server for OAuth resource metadata. // If not set, the server will derive the URL from incoming request headers. BaseURL string @@ -192,7 +198,7 @@ func RunHTTPServer(cfg ServerConfig) error { }) logger.Info("OAuth protected resource endpoints registered", "baseURL", cfg.BaseURL) - addr := fmt.Sprintf(":%d", cfg.Port) + addr := resolveListenAddress(cfg.ListenHost, cfg.Port) httpSvr := http.Server{ Addr: addr, Handler: r, @@ -223,6 +229,16 @@ func RunHTTPServer(cfg ServerConfig) error { return nil } +// resolveListenAddress returns the address string passed to http.Server. +// When host is empty the server binds to all interfaces on the given port; +// otherwise host and port are joined into a single address. +func resolveListenAddress(host string, port int) string { + if host == "" { + return fmt.Sprintf(":%d", port) + } + return net.JoinHostPort(host, strconv.Itoa(port)) +} + func initGlobalToolScopeMap(t translations.TranslationHelperFunc) error { // Build inventory with all tools to extract scope information inv, err := inventory.NewBuilder(). diff --git a/pkg/http/server_test.go b/pkg/http/server_test.go index 1804134651..b509876d9e 100644 --- a/pkg/http/server_test.go +++ b/pkg/http/server_test.go @@ -125,6 +125,47 @@ func TestCreateHTTPFeatureChecker(t *testing.T) { } } +func TestResolveListenAddress(t *testing.T) { + tests := []struct { + name string + host string + port int + want string + }{ + { + name: "empty host falls back to :port", + host: "", + port: 8082, + want: ":8082", + }, + { + name: "ipv4 host is joined with port", + host: "127.0.0.1", + port: 9090, + want: "127.0.0.1:9090", + }, + { + name: "ipv6 host is bracketed and joined with port", + host: "::1", + port: 9090, + want: "[::1]:9090", + }, + { + name: "hostname is joined with port", + host: "localhost", + port: 8082, + want: "localhost:8082", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resolveListenAddress(tt.host, tt.port) + assert.Equal(t, tt.want, got) + }) + } +} + func TestHeaderAllowedFeatureFlagsMatchesAllowed(t *testing.T) { // Ensure HeaderAllowedFeatureFlags delegates to AllowedFeatureFlags allowed := github.HeaderAllowedFeatureFlags()