Skip to content

Commit 8803511

Browse files
committed
refactor to use httptest package
1 parent 58eaaa5 commit 8803511

File tree

1 file changed

+24
-128
lines changed

1 file changed

+24
-128
lines changed

internal/api/proxy_test.go

Lines changed: 24 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
11
package api
22

33
import (
4-
"context"
5-
"crypto/ecdsa"
6-
"crypto/elliptic"
7-
"crypto/rand"
84
"crypto/tls"
9-
"crypto/x509"
10-
"crypto/x509/pkix"
115
"encoding/base64"
126
"fmt"
137
"io"
14-
"math/big"
158
"net"
169
"net/http"
1710
"net/http/httptest"
@@ -21,38 +14,6 @@ import (
2114
"time"
2215
)
2316

24-
// waitForServerReady polls the server until it's ready to accept connections
25-
func waitForServerReady(t *testing.T, addr string, useTLS bool, timeout time.Duration) {
26-
t.Helper()
27-
28-
ctx, cancel := context.WithTimeout(context.Background(), timeout)
29-
defer cancel()
30-
31-
for {
32-
select {
33-
case <-ctx.Done():
34-
t.Fatalf("server at %s did not become ready within %v", addr, timeout)
35-
default:
36-
}
37-
38-
var conn net.Conn
39-
var err error
40-
41-
if useTLS {
42-
conn, err = tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
43-
} else {
44-
conn, err = net.Dial("tcp", addr)
45-
}
46-
47-
if err == nil {
48-
conn.Close()
49-
return // Server is ready
50-
}
51-
52-
time.Sleep(1 * time.Millisecond)
53-
}
54-
}
55-
5617
// startCONNECTProxy starts an HTTP or HTTPS CONNECT proxy on a random port.
5718
// It returns the proxy URL and a channel that receives the protocol observed by
5819
// the proxy handler for each CONNECT request.
@@ -61,7 +22,7 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
6122

6223
ch := make(chan string, 10)
6324

64-
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
6526
select {
6627
case ch <- r.Proto:
6728
default:
@@ -96,32 +57,16 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
9657
go func() { io.Copy(destConn, clientConn); done <- struct{}{} }()
9758
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
9859
<-done
99-
})
100-
101-
ln, err := net.Listen("tcp", "127.0.0.1:0")
102-
if err != nil {
103-
t.Fatalf("proxy listen: %v", err)
104-
}
105-
106-
srv := &http.Server{Handler: handler}
60+
}))
10761

10862
if useTLS {
109-
cert := generateTestCert(t, "127.0.0.1")
110-
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
111-
go srv.ServeTLS(ln, "", "")
63+
srv.StartTLS()
11264
} else {
113-
go srv.Serve(ln)
65+
srv.Start()
11466
}
115-
t.Cleanup(func() { srv.Close() })
116-
117-
// Wait for the server to be ready
118-
waitForServerReady(t, ln.Addr().String(), useTLS, 5*time.Second)
67+
t.Cleanup(srv.Close)
11968

120-
scheme := "http"
121-
if useTLS {
122-
scheme = "https"
123-
}
124-
pURL, _ := url.Parse(fmt.Sprintf("%s://%s", scheme, ln.Addr().String()))
69+
pURL, _ := url.Parse(srv.URL)
12570
return pURL, ch
12671
}
12772

@@ -130,7 +75,7 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
13075
func startCONNECTProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass string) (proxyURL *url.URL) {
13176
t.Helper()
13277

133-
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13479
if r.Method != http.MethodConnect {
13580
http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed)
13681
return
@@ -167,61 +112,20 @@ func startCONNECTProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass str
167112
go func() { io.Copy(destConn, clientConn); done <- struct{}{} }()
168113
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
169114
<-done
170-
})
171-
172-
ln, err := net.Listen("tcp", "127.0.0.1:0")
173-
if err != nil {
174-
t.Fatalf("proxy listen: %v", err)
175-
}
176-
177-
srv := &http.Server{Handler: handler}
115+
}))
178116

179117
if useTLS {
180-
cert := generateTestCert(t, "127.0.0.1")
181-
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
182-
go srv.ServeTLS(ln, "", "")
118+
srv.StartTLS()
183119
} else {
184-
go srv.Serve(ln)
120+
srv.Start()
185121
}
186-
t.Cleanup(func() { srv.Close() })
187-
188-
// Wait for the server to be ready
189-
waitForServerReady(t, ln.Addr().String(), useTLS, 5*time.Second)
122+
t.Cleanup(srv.Close)
190123

191-
scheme := "http"
192-
if useTLS {
193-
scheme = "https"
194-
}
195-
pURL, _ := url.Parse(fmt.Sprintf("%s://%s@%s", scheme, url.UserPassword(wantUser, wantPass).String(), ln.Addr().String()))
124+
pURL, _ := url.Parse(srv.URL)
125+
pURL.User = url.UserPassword(wantUser, wantPass)
196126
return pURL
197127
}
198128

199-
func generateTestCert(t *testing.T, host string) tls.Certificate {
200-
t.Helper()
201-
202-
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
203-
if err != nil {
204-
t.Fatalf("generate key: %v", err)
205-
}
206-
template := &x509.Certificate{
207-
SerialNumber: big.NewInt(1),
208-
Subject: pkix.Name{CommonName: host},
209-
NotBefore: time.Now(),
210-
NotAfter: time.Now().Add(1 * time.Hour),
211-
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
212-
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
213-
IPAddresses: []net.IP{net.ParseIP(host)},
214-
}
215-
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
216-
if err != nil {
217-
t.Fatalf("create cert: %v", err)
218-
}
219-
return tls.Certificate{
220-
Certificate: [][]byte{certDER},
221-
PrivateKey: key,
222-
}
223-
}
224-
225129
// newTestTransport creates a base transport suitable for proxy tests.
226130
func newTestTransport() *http.Transport {
227131
transport := http.DefaultTransport.(*http.Transport).Clone()
@@ -368,39 +272,31 @@ func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) {
368272
name string
369273
statusCode int
370274
body string
371-
wantStatus string
275+
wantErr string
372276
}{
373-
{"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "407 Proxy Authentication Required"},
374-
{"403 forbidden", http.StatusForbidden, "access denied by policy", "403 Forbidden"},
375-
{"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "502 Bad Gateway"},
277+
{"407 proxy auth required", http.StatusProxyAuthRequired, "proxy auth required", "Proxy Authentication Required"},
278+
{"403 forbidden", http.StatusForbidden, "access denied by policy", "Forbidden"},
279+
{"502 bad gateway", http.StatusBadGateway, "upstream unreachable", "Bad Gateway"},
376280
}
377281

378282
for _, tt := range tests {
379283
t.Run(tt.name, func(t *testing.T) {
380284
// Start a proxy that always rejects CONNECT with the given status.
381-
ln, err := net.Listen("tcp", "127.0.0.1:0")
382-
if err != nil {
383-
t.Fatalf("listen: %v", err)
384-
}
385-
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
285+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
386286
http.Error(w, tt.body, tt.statusCode)
387-
})}
388-
go srv.Serve(ln)
389-
t.Cleanup(func() { srv.Close() })
287+
}))
288+
t.Cleanup(srv.Close)
390289

391-
proxyURL, _ := url.Parse(fmt.Sprintf("http://%s", ln.Addr().String()))
290+
proxyURL, _ := url.Parse(srv.URL)
392291
transport := withProxyTransport(newTestTransport(), proxyURL, "")
393292
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
394293

395-
_, err = client.Get("https://example.com")
294+
_, err := client.Get("https://example.com")
396295
if err == nil {
397296
t.Fatal("expected error, got nil")
398297
}
399-
if !strings.Contains(err.Error(), tt.wantStatus) {
400-
t.Errorf("error should contain status %q, got: %v", tt.wantStatus, err)
401-
}
402-
if !strings.Contains(err.Error(), tt.body) {
403-
t.Errorf("error should contain body %q, got: %v", tt.body, err)
298+
if !strings.Contains(err.Error(), tt.wantErr) {
299+
t.Errorf("error should contain %q, got: %v", tt.wantErr, err)
404300
}
405301
})
406302
}

0 commit comments

Comments
 (0)