diff --git a/cmd/src/batch_remote.go b/cmd/src/batch_remote.go index 86ad3650d1..c1d7209675 100644 --- a/cmd/src/batch_remote.go +++ b/cmd/src/batch_remote.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" cliLog "log" - "strings" "time" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -155,13 +154,14 @@ Examples: } ui.ExecutingBatchSpecSuccess() - executionURL := fmt.Sprintf( - "%s/%s/batch-changes/%s/executions/%s", - cfg.endpointURL, - strings.TrimPrefix(namespace.URL, "/"), - batchChangeName, - batchSpecID, - ) + executionURL := cfg.endpointURL.JoinPath( + fmt.Sprintf( + "%s/batch-changes/%s/executions/%s", + namespace.URL, + batchChangeName, + batchSpecID, + ), + ).String() ui.RemoteSuccess(executionURL) return nil diff --git a/cmd/src/main.go b/cmd/src/main.go index 8bd1a5fe77..3f542b97c5 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -7,6 +7,7 @@ import ( "io" "log" "net" + "net/http" "net/url" "os" "path/filepath" @@ -169,7 +170,8 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { return api.NewClient(opts) } -// readConfig reads the config file from the given path. +// readConfig reads the config from the standard config file, the (deprecated) user-specified config file, +// the environment variables, and the (deprecated) command-line flags. func readConfig() (*config, error) { cfgFile := *configPath userSpecified := *configPath != "" @@ -282,12 +284,20 @@ func readConfig() (*config, error) { return nil, errors.Newf("invalid proxy configuration: %w", err) } if !isValidUDS { - return nil, errors.Newf("invalid proxy socket: %s", path) + return nil, errors.Newf("Invalid proxy socket: %s", path) } cfg.proxyPath = path } else { return nil, errors.Newf("invalid proxy endpoint: %s", proxyStr) } + } else { + // no SRC_PROXY; check for the standard proxy env variables HTTP_PROXY, HTTPS_PROXY, and NO_PROXY + if u, err := http.ProxyFromEnvironment(&http.Request{URL: cfg.endpointURL}); err != nil { + // when there's an error, the value for the env variable is not a legit URL + return nil, errors.Newf("invalid HTTP_PROXY or HTTPS_PROXY value: %w", err) + } else { + cfg.proxyURL = u + } } cfg.additionalHeaders = parseAdditionalHeaders() @@ -319,7 +329,7 @@ func isValidUnixSocket(path string) (bool, error) { if os.IsNotExist(err) { return false, nil } - return false, errors.Newf("not a UNIX Domain Socket: %v: %w", path, err) + return false, errors.Newf("not a UNIX domain socket: %v: %w", path, err) } defer conn.Close() diff --git a/internal/api/api.go b/internal/api/api.go index c38af19d25..f3dbc703a3 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -90,33 +90,36 @@ type ClientOpts struct { } func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { - var transport http.RoundTripper - { - tp := http.DefaultTransport.(*http.Transport).Clone() + transport := http.DefaultTransport.(*http.Transport).Clone() - if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { - tp.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - if tp.TLSClientConfig == nil { - tp.TLSClientConfig = &tls.Config{} - } + if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } - if opts.ProxyURL != nil || opts.ProxyPath != "" { - tp = withProxyTransport(tp, opts.ProxyURL, opts.ProxyPath) - } + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } - transport = tp + if opts.ProxyPath != "" || opts.ProxyURL != nil { + // Use our custom dialer for proxied connections. + // A custom dialer is not always needed - the connection libraries will handle HTTP(S)_PROXY-defined proxies + // (Go supports http, https, socks5, and socks5h proxies via HTTP(S)_PROXY), + // but we're also supporting proxies defined via SRC_PROXY, which can include UDS proxies, + // and connecting to TLS-enabled proxies adds an additional wrinkle when using HTTP/2. + transport = withProxyTransport(transport, opts.ProxyURL, opts.ProxyPath) } + // For http:// and socks5:// proxies, the cloned + // transport's default Proxy handles them correctly without intervention. + + var rt http.RoundTripper = transport if opts.AccessToken == "" && opts.OAuthToken != nil { - transport = &oauth.Transport{ + rt = &oauth.Transport{ Base: transport, Token: opts.OAuthToken, } } - - return transport + return rt } // NewClient creates a new API client. diff --git a/internal/api/proxy.go b/internal/api/proxy.go index 9589b9beb5..605537f311 100644 --- a/internal/api/proxy.go +++ b/internal/api/proxy.go @@ -5,18 +5,46 @@ import ( "context" "crypto/tls" "encoding/base64" - "fmt" + "io" "net" "net/http" "net/url" + "sync" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) +type connWithBufferedReader struct { + net.Conn + r *bufio.Reader + mu sync.Mutex +} + +func (c *connWithBufferedReader) Read(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.r.Read(p) +} + +// proxyDialAddr returns proxyURL.Host with a default port appended if one is +// not already present (443 for https, 80 for http). +func proxyDialAddr(proxyURL *url.URL) string { + // net.SplitHostPort returns an error when the input doesn't contain a port + if _, _, err := net.SplitHostPort(proxyURL.Host); err == nil { + return proxyURL.Host + } + if proxyURL.Scheme == "https" { + return net.JoinHostPort(proxyURL.Hostname(), "443") + } + return net.JoinHostPort(proxyURL.Hostname(), "80") +} + // withProxyTransport modifies the given transport to handle proxying of unix, socks5 and http connections. // // Note: baseTransport is considered to be a clone created with transport.Clone() // -// - If a the proxyPath is not empty, a unix socket proxy is created. -// - Otherwise, the proxyURL is used to determine if we should proxy socks5 / http connections +// - If proxyPath is not empty, a unix socket proxy is created. +// - Otherwise, proxyURL is used to determine if we should proxy socks5 / http connections func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyPath string) *http.Transport { handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { // Extract the hostname (without the port) for TLS SNI @@ -24,13 +52,19 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP if err != nil { return nil, err } - tlsConn := tls.Client(conn, &tls.Config{ - ServerName: host, - // Pull InsecureSkipVerify from the target host transport - // so that insecure-skip-verify flag settings are honored for the proxy server - InsecureSkipVerify: baseTransport.TLSClientConfig.InsecureSkipVerify, - }) + cfg := baseTransport.TLSClientConfig.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + // Preserve HTTP/2 negotiation to the origin when ForceAttemptHTTP2 + // is enabled. Without this, the manual TLS handshake would not + // advertise h2 via ALPN, silently forcing HTTP/1.1. + if baseTransport.ForceAttemptHTTP2 && len(cfg.NextProtos) == 0 { + cfg.NextProtos = []string{"h2", "http/1.1"} + } + tlsConn := tls.Client(conn, cfg) if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() return nil, err } return tlsConn, nil @@ -54,67 +88,79 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP baseTransport.Proxy = nil } else if proxyURL != nil { switch proxyURL.Scheme { - case "socks5", "socks5h": - // SOCKS proxies work out of the box - no need to manually dial + case "http", "socks5", "socks5h": + // HTTP and SOCKS proxies work out of the box - no need to manually dial baseTransport.Proxy = http.ProxyURL(proxyURL) - case "http", "https": + case "https": dial := func(ctx context.Context, network, addr string) (net.Conn, error) { - // Dial the proxy - d := net.Dialer{} - conn, err := d.DialContext(ctx, "tcp", proxyURL.Host) + // Dial the proxy. For https:// proxies, we TLS-connect to the + // proxy itself and force ALPN to HTTP/1.1 to prevent Go from + // negotiating HTTP/2 for the CONNECT tunnel. Many proxy servers + // don't support HTTP/2 CONNECT, and Go's default Transport.Proxy + // would negotiate h2 via ALPN when TLS-connecting to an https:// + // proxy, causing "bogus greeting" errors. For http:// proxies, + // CONNECT is always HTTP/1.1 over plain TCP so this isn't needed. + // The target connection (e.g. to sourcegraph.com) still negotiates + // HTTP/2 normally through the established tunnel. + proxyAddr := proxyDialAddr(proxyURL) + + var conn net.Conn + var err error + if proxyURL.Scheme == "https" { + raw, dialErr := (&net.Dialer{}).DialContext(ctx, "tcp", proxyAddr) + if dialErr != nil { + return nil, dialErr + } + cfg := baseTransport.TLSClientConfig.Clone() + cfg.NextProtos = []string{"http/1.1"} + if cfg.ServerName == "" { + cfg.ServerName = proxyURL.Hostname() + } + tlsConn := tls.Client(raw, cfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + raw.Close() + return nil, err + } + conn = tlsConn + } else { + conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", proxyAddr) + } if err != nil { return nil, err } - // this is the whole point of manually dialing the HTTP(S) proxy: - // being able to force HTTP/1. - // When relying on Transport.Proxy, the protocol is always HTTP/2, - // but many proxy servers don't support HTTP/2. - // We don't want to disable HTTP/2 in general because we want to use it when - // connecting to the Sourcegraph API, using HTTP/1 for the proxy connection only. - protocol := "HTTP/1.1" - - // CONNECT is the HTTP method used to set up a tunneling connection with a proxy - method := "CONNECT" - - // Manually writing out the HTTP commands because it's not complicated, - // and http.Request has some janky behavior: - // - ignores the Proto field and hard-codes the protocol to HTTP/1.1 - // - ignores the Host Header (Header.Set("Host", host)) and uses URL.Host instead. - // - When the Host field is set, overrides the URL field - connectReq := fmt.Sprintf("%s %s %s\r\n", method, addr, protocol) - - // A Host header is required per RFC 2616, section 14.23 - connectReq += fmt.Sprintf("Host: %s\r\n", addr) - - // use authentication if proxy credentials are present + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } if proxyURL.User != nil { password, _ := proxyURL.User.Password() auth := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.Username() + ":" + password)) - connectReq += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", auth) + connectReq.Header.Set("Proxy-Authorization", "Basic "+auth) } - - // finish up with an extra carriage return + newline, as per RFC 7230, section 3 - connectReq += "\r\n" - - // Send the CONNECT request to the proxy to establish the tunnel - if _, err := conn.Write([]byte(connectReq)); err != nil { + if err := connectReq.Write(conn); err != nil { conn.Close() return nil, err } - // Read and check the response from the proxy - resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, nil) if err != nil { conn.Close() return nil, err } if resp.StatusCode != http.StatusOK { + // For non-200, it's safe/appropriate to close the body (it’s a real response body here). + // Try to read a bit (4k bytes) to include in the error message. + b, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + resp.Body.Close() conn.Close() - return nil, fmt.Errorf("failed to connect to proxy %v: %v", proxyURL, resp.Status) + return nil, errors.Newf("failed to connect to proxy %s: %s: %q", proxyURL.Redacted(), resp.Status, b) } - resp.Body.Close() - return conn, nil + // 200 CONNECT: do NOT resp.Body.Close(); it would interfere with the tunnel. + return &connWithBufferedReader{Conn: conn, r: br}, nil } dialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { // Dial the underlying connection through the proxy @@ -126,7 +172,7 @@ func withProxyTransport(baseTransport *http.Transport, proxyURL *url.URL, proxyP } baseTransport.DialContext = dial baseTransport.DialTLSContext = dialTLS - // clear out any system proxy settings + // clear out the system proxy because we're defining our own dialers baseTransport.Proxy = nil } } diff --git a/internal/api/proxy_test.go b/internal/api/proxy_test.go new file mode 100644 index 0000000000..29b8f5345b --- /dev/null +++ b/internal/api/proxy_test.go @@ -0,0 +1,448 @@ +package api + +import ( + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// startProxy starts an HTTP or HTTPS CONNECT proxy on a random port. +// It returns the proxy URL and a channel that receives the protocol observed by +// the proxy handler for each CONNECT request. +func startProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-chan string) { + t.Helper() + + ch := make(chan string, 10) + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case ch <- r.Proto: + default: + } + + if r.Method != http.MethodConnect { + http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed) + return + } + + destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer destConn.Close() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + clientConn, _, err := hijacker.Hijack() + if err != nil { + return + } + defer clientConn.Close() + + done := make(chan struct{}, 2) + go func() { io.Copy(destConn, clientConn); done <- struct{}{} }() + go func() { io.Copy(clientConn, destConn); done <- struct{}{} }() + <-done + // Close both sides so the remaining goroutine unblocks. + clientConn.Close() + destConn.Close() + <-done + })) + + if useTLS { + srv.StartTLS() + } else { + srv.Start() + } + t.Cleanup(srv.Close) + + pURL, _ := url.Parse(srv.URL) + return pURL, ch +} + +// startProxyWithAuth is like startProxy but requires +// Proxy-Authorization with the given username and password. +func startProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass string) (proxyURL *url.URL) { + t.Helper() + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed) + return + } + + authHeader := r.Header.Get("Proxy-Authorization") + wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(wantUser+":"+wantPass)) + if authHeader != wantAuth { + http.Error(w, "proxy auth required", http.StatusProxyAuthRequired) + return + } + + destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer destConn.Close() + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijacking not supported", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + clientConn, _, err := hijacker.Hijack() + if err != nil { + return + } + defer clientConn.Close() + + done := make(chan struct{}, 2) + go func() { io.Copy(destConn, clientConn); done <- struct{}{} }() + go func() { io.Copy(clientConn, destConn); done <- struct{}{} }() + <-done + clientConn.Close() + destConn.Close() + <-done + })) + + if useTLS { + srv.StartTLS() + } else { + srv.Start() + } + t.Cleanup(srv.Close) + + pURL, _ := url.Parse(srv.URL) + pURL.User = url.UserPassword(wantUser, wantPass) + return pURL +} + +// newTestTransport creates a base transport suitable for proxy tests. +func newTestTransport() *http.Transport { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + return transport +} + +// startTargetServer starts an HTTPS server (with HTTP/2 enabled) that +// responds with "ok" to GET /. +func startTargetServer(t *testing.T) *httptest.Server { + t.Helper() + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "ok") + })) + srv.EnableHTTP2 = true + srv.StartTLS() + t.Cleanup(srv.Close) + return srv +} + +func TestWithProxyTransport_HTTPProxy(t *testing.T) { + target := startTargetServer(t) + proxyURL, obsCh := startProxy(t, false) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through http proxy: %v", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + select { + case proto := <-obsCh: + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } + case <-time.After(2 * time.Second): + t.Fatal("proxy handler was never invoked") + } +} + +func TestWithProxyTransport_HTTPSProxy(t *testing.T) { + target := startTargetServer(t) + proxyURL, obsCh := startProxy(t, true) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Errorf("expected body 'ok', got %q", got) + } + + select { + case proto := <-obsCh: + if proto != "HTTP/1.1" { + t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto) + } + case <-time.After(2 * time.Second): + t.Fatal("proxy handler was never invoked") + } +} + +func TestWithProxyTransport_ProxyAuth(t *testing.T) { + target := startTargetServer(t) + + t.Run("http proxy with auth", func(t *testing.T) { + proxyURL := startProxyWithAuth(t, false, "user", "pass") + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through authenticated http proxy: %v", err) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("https proxy with auth", func(t *testing.T) { + // Under the race detector on resource-constrained CI hosts + // the TLS handshake to the proxy can sporadically fail with + // "first record does not look like a TLS handshake" / EOF. + // Retry with a fresh proxy + transport to tolerate this. + var resp *http.Response + var lastErr error + for attempt := range 3 { + proxyURL := startProxyWithAuth(t, true, "user", "s3cret") + transport := withProxyTransport(newTestTransport(), proxyURL, "") + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + resp, lastErr = client.Get(target.URL) + transport.CloseIdleConnections() + if lastErr == nil { + break + } + t.Logf("attempt %d: %v", attempt+1, lastErr) + } + if lastErr != nil { + t.Fatalf("GET through authenticated https proxy (after retries): %v", lastErr) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) +} + +func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) { + // Verify that when tunneling through an HTTPS proxy, the connection to + // the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1). + target := startTargetServer(t) + proxyURL, _ := startProxy(t, true) + + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + resp, err := client.Get(target.URL) + if err != nil { + t.Fatalf("GET through https proxy: %v", err) + } + defer resp.Body.Close() + if _, err := io.ReadAll(resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + + if resp.ProtoMajor != 2 { + t.Errorf("expected HTTP/2 to origin, got %s", resp.Proto) + } +} + +func TestWithProxyTransport_HandshakeFailureClosesConn(t *testing.T) { + // Verify that when the TLS handshake to the origin fails, the underlying + // tunnel connection is closed (regression test for tlsConn.Close on error). + // + // A plain TCP listener acts as the target. The proxy CONNECT succeeds + // (TCP-level), but the subsequent TLS handshake fails because the target + // is not a TLS server. If handshakeTLS properly closes tlsConn on failure, + // the tunnel tears down and the target sees the connection close. + connClosed := make(chan struct{}) + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // Send non-TLS bytes so the client handshake fails immediately + // rather than waiting for a timeout. + conn.Write([]byte("not-tls\n")) + // Drain until the remote side closes the tunnel. + io.Copy(io.Discard, conn) + close(connClosed) + }() + + proxyURL, _ := startProxy(t, true) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 5 * time.Second} + + _, err = client.Get("https://" + ln.Addr().String()) + if err == nil { + t.Fatal("expected TLS handshake error, got nil") + } + + select { + case <-connClosed: + // Connection was properly cleaned up. + case <-time.After(5 * time.Second): + t.Fatal("connection was not closed after TLS handshake failure") + } +} + +func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr string + }{ + {"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "Proxy Authentication Required"}, + {"403 forbidden", http.StatusForbidden, "access denied by policy", "Forbidden"}, + {"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "Bad Gateway"}, + } + + // Use a local target so we never depend on external DNS. + target := startTargetServer(t) + + for _, tt := range tests { + t.Run("http proxy/"+tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, tt.body, tt.statusCode) + })) + t.Cleanup(srv.Close) + + proxyURL, _ := url.Parse(srv.URL) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + _, err := client.Get(target.URL) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error should contain %q, got: %v", tt.wantErr, err) + } + }) + + t.Run("https proxy/"+tt.name, func(t *testing.T) { + // The HTTPS proxy path uses a custom dialer with its own error + // formatting that includes the status, body, and redacted proxy URL. + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, tt.body, tt.statusCode) + })) + srv.StartTLS() + t.Cleanup(srv.Close) + + proxyURL, _ := url.Parse(srv.URL) + transport := withProxyTransport(newTestTransport(), proxyURL, "") + t.Cleanup(transport.CloseIdleConnections) + client := &http.Client{Transport: transport, Timeout: 10 * time.Second} + + _, err := client.Get(target.URL) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), fmt.Sprintf("%d", tt.statusCode)) { + t.Errorf("error should contain status code %d, got: %v", tt.statusCode, err) + } + if !strings.Contains(err.Error(), tt.body) { + t.Errorf("error should contain body %q, got: %v", tt.body, err) + } + }) + } +} + +func TestProxyDialAddr(t *testing.T) { + tests := []struct { + name string + url string + want string + }{ + {"https with port", "https://proxy.example.com:8443", "proxy.example.com:8443"}, + {"https without port", "https://proxy.example.com", "proxy.example.com:443"}, + {"http with port", "http://proxy.example.com:8080", "proxy.example.com:8080"}, + {"http without port", "http://proxy.example.com", "proxy.example.com:80"}, + {"ipv4 with port", "http://192.168.1.100:3128", "192.168.1.100:3128"}, + {"ipv4 without port https", "https://10.0.0.1", "10.0.0.1:443"}, + {"ipv4 without port http", "http://172.16.0.5", "172.16.0.5:80"}, + {"ipv6 with port", "http://[::1]:8080", "[::1]:8080"}, + {"ipv6 without port https", "https://[2001:db8::1]", "[2001:db8::1]:443"}, + {"ipv6 without port http", "http://[fe80::1]", "[fe80::1]:80"}, + {"localhost with port", "http://localhost:9090", "localhost:9090"}, + {"localhost without port https", "https://localhost", "localhost:443"}, + {"localhost without port http", "http://localhost", "localhost:80"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.ParseRequestURI(tt.url) + if err != nil { + t.Fatalf("parse URL: %v", err) + } + got := proxyDialAddr(u) + if got != tt.want { + t.Errorf("proxyHostPort(%s) = %q, want %q", tt.url, got, tt.want) + } + }) + } +}