11package api
22
33import (
4- "context"
5- "crypto/ecdsa"
6- "crypto/elliptic"
7- "crypto/rand"
84 "crypto/tls"
9- "crypto/x509"
10- "crypto/x509/pkix"
115 "encoding/base64"
126 "fmt"
137 "io"
14- "math/big"
158 "net"
169 "net/http"
1710 "net/http/httptest"
@@ -21,38 +14,6 @@ import (
2114 "time"
2215)
2316
24- // waitForServerReady polls the server until it's ready to accept connections
25- func waitForServerReady (t * testing.T , addr string , useTLS bool , timeout time.Duration ) {
26- t .Helper ()
27-
28- ctx , cancel := context .WithTimeout (context .Background (), timeout )
29- defer cancel ()
30-
31- for {
32- select {
33- case <- ctx .Done ():
34- t .Fatalf ("server at %s did not become ready within %v" , addr , timeout )
35- default :
36- }
37-
38- var conn net.Conn
39- var err error
40-
41- if useTLS {
42- conn , err = tls .Dial ("tcp" , addr , & tls.Config {InsecureSkipVerify : true })
43- } else {
44- conn , err = net .Dial ("tcp" , addr )
45- }
46-
47- if err == nil {
48- conn .Close ()
49- return // Server is ready
50- }
51-
52- time .Sleep (1 * time .Millisecond )
53- }
54- }
55-
5617// startCONNECTProxy starts an HTTP or HTTPS CONNECT proxy on a random port.
5718// It returns the proxy URL and a channel that receives the protocol observed by
5819// the proxy handler for each CONNECT request.
@@ -61,7 +22,7 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
6122
6223 ch := make (chan string , 10 )
6324
64- handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
25+ srv := httptest . NewUnstartedServer ( http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
6526 select {
6627 case ch <- r .Proto :
6728 default :
@@ -96,32 +57,16 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
9657 go func () { io .Copy (destConn , clientConn ); done <- struct {}{} }()
9758 go func () { io .Copy (clientConn , destConn ); done <- struct {}{} }()
9859 <- done
99- })
100-
101- ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
102- if err != nil {
103- t .Fatalf ("proxy listen: %v" , err )
104- }
105-
106- srv := & http.Server {Handler : handler }
60+ }))
10761
10862 if useTLS {
109- cert := generateTestCert (t , "127.0.0.1" )
110- srv .TLSConfig = & tls.Config {Certificates : []tls.Certificate {cert }}
111- go srv .ServeTLS (ln , "" , "" )
63+ srv .StartTLS ()
11264 } else {
113- go srv .Serve ( ln )
65+ srv .Start ( )
11466 }
115- t .Cleanup (func () { srv .Close () })
116-
117- // Wait for the server to be ready
118- waitForServerReady (t , ln .Addr ().String (), useTLS , 5 * time .Second )
67+ t .Cleanup (srv .Close )
11968
120- scheme := "http"
121- if useTLS {
122- scheme = "https"
123- }
124- pURL , _ := url .Parse (fmt .Sprintf ("%s://%s" , scheme , ln .Addr ().String ()))
69+ pURL , _ := url .Parse (srv .URL )
12570 return pURL , ch
12671}
12772
@@ -130,7 +75,7 @@ func startCONNECTProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-ch
13075func startCONNECTProxyWithAuth (t * testing.T , useTLS bool , wantUser , wantPass string ) (proxyURL * url.URL ) {
13176 t .Helper ()
13277
133- handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
78+ srv := httptest . NewUnstartedServer ( http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
13479 if r .Method != http .MethodConnect {
13580 http .Error (w , "expected CONNECT" , http .StatusMethodNotAllowed )
13681 return
@@ -167,61 +112,20 @@ func startCONNECTProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass str
167112 go func () { io .Copy (destConn , clientConn ); done <- struct {}{} }()
168113 go func () { io .Copy (clientConn , destConn ); done <- struct {}{} }()
169114 <- done
170- })
171-
172- ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
173- if err != nil {
174- t .Fatalf ("proxy listen: %v" , err )
175- }
176-
177- srv := & http.Server {Handler : handler }
115+ }))
178116
179117 if useTLS {
180- cert := generateTestCert (t , "127.0.0.1" )
181- srv .TLSConfig = & tls.Config {Certificates : []tls.Certificate {cert }}
182- go srv .ServeTLS (ln , "" , "" )
118+ srv .StartTLS ()
183119 } else {
184- go srv .Serve ( ln )
120+ srv .Start ( )
185121 }
186- t .Cleanup (func () { srv .Close () })
187-
188- // Wait for the server to be ready
189- waitForServerReady (t , ln .Addr ().String (), useTLS , 5 * time .Second )
122+ t .Cleanup (srv .Close )
190123
191- scheme := "http"
192- if useTLS {
193- scheme = "https"
194- }
195- pURL , _ := url .Parse (fmt .Sprintf ("%s://%s@%s" , scheme , url .UserPassword (wantUser , wantPass ).String (), ln .Addr ().String ()))
124+ pURL , _ := url .Parse (srv .URL )
125+ pURL .User = url .UserPassword (wantUser , wantPass )
196126 return pURL
197127}
198128
199- func generateTestCert (t * testing.T , host string ) tls.Certificate {
200- t .Helper ()
201-
202- key , err := ecdsa .GenerateKey (elliptic .P256 (), rand .Reader )
203- if err != nil {
204- t .Fatalf ("generate key: %v" , err )
205- }
206- template := & x509.Certificate {
207- SerialNumber : big .NewInt (1 ),
208- Subject : pkix.Name {CommonName : host },
209- NotBefore : time .Now (),
210- NotAfter : time .Now ().Add (1 * time .Hour ),
211- KeyUsage : x509 .KeyUsageDigitalSignature | x509 .KeyUsageKeyEncipherment ,
212- ExtKeyUsage : []x509.ExtKeyUsage {x509 .ExtKeyUsageServerAuth },
213- IPAddresses : []net.IP {net .ParseIP (host )},
214- }
215- certDER , err := x509 .CreateCertificate (rand .Reader , template , template , & key .PublicKey , key )
216- if err != nil {
217- t .Fatalf ("create cert: %v" , err )
218- }
219- return tls.Certificate {
220- Certificate : [][]byte {certDER },
221- PrivateKey : key ,
222- }
223- }
224-
225129// newTestTransport creates a base transport suitable for proxy tests.
226130func newTestTransport () * http.Transport {
227131 transport := http .DefaultTransport .(* http.Transport ).Clone ()
@@ -368,39 +272,31 @@ func TestWithProxyTransport_ProxyRejectsConnect(t *testing.T) {
368272 name string
369273 statusCode int
370274 body string
371- wantStatus string
275+ wantErr string
372276 }{
373- {"407 proxy auth required" , http .StatusProxyAuthRequired , "proxy auth required" , "407 Proxy Authentication Required" },
374- {"403 forbidden" , http .StatusForbidden , "access denied by policy" , "403 Forbidden" },
375- {"502 bad gateway" , http .StatusBadGateway , "upstream unreachable" , "502 Bad Gateway" },
277+ {"407 proxy auth required" , http .StatusProxyAuthRequired , "proxy auth required" , "Proxy Authentication Required" },
278+ {"403 forbidden" , http .StatusForbidden , "access denied by policy" , "Forbidden" },
279+ {"502 bad gateway" , http .StatusBadGateway , "upstream unreachable" , "Bad Gateway" },
376280 }
377281
378282 for _ , tt := range tests {
379283 t .Run (tt .name , func (t * testing.T ) {
380284 // Start a proxy that always rejects CONNECT with the given status.
381- ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
382- if err != nil {
383- t .Fatalf ("listen: %v" , err )
384- }
385- srv := & http.Server {Handler : http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
285+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
386286 http .Error (w , tt .body , tt .statusCode )
387- })}
388- go srv .Serve (ln )
389- t .Cleanup (func () { srv .Close () })
287+ }))
288+ t .Cleanup (srv .Close )
390289
391- proxyURL , _ := url .Parse (fmt . Sprintf ( "http://%s" , ln . Addr (). String ()) )
290+ proxyURL , _ := url .Parse (srv . URL )
392291 transport := withProxyTransport (newTestTransport (), proxyURL , "" )
393292 client := & http.Client {Transport : transport , Timeout : 10 * time .Second }
394293
395- _ , err = client .Get ("https://example.com" )
294+ _ , err : = client .Get ("https://example.com" )
396295 if err == nil {
397296 t .Fatal ("expected error, got nil" )
398297 }
399- if ! strings .Contains (err .Error (), tt .wantStatus ) {
400- t .Errorf ("error should contain status %q, got: %v" , tt .wantStatus , err )
401- }
402- if ! strings .Contains (err .Error (), tt .body ) {
403- t .Errorf ("error should contain body %q, got: %v" , tt .body , err )
298+ if ! strings .Contains (err .Error (), tt .wantErr ) {
299+ t .Errorf ("error should contain %q, got: %v" , tt .wantErr , err )
404300 }
405301 })
406302 }
0 commit comments