Skip to content

Commit 25accc8

Browse files
committed
Client code, streamlined handler examples
1 parent fe1983d commit 25accc8

File tree

3 files changed

+323
-74
lines changed

3 files changed

+323
-74
lines changed

client.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package httpsign
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
)
8+
9+
// Client represents an HTTP client that optionally signs requests and optionally verifies responses.
10+
// The signer may be nil to avoid signing, and so forth.
11+
// The fetchVerifier callback allows to generate a verifier based on the particular response.
12+
// Either verifier or fetchVerifier may be specified, but not both.
13+
// The client embeds an http.Client, which may be http.DefaultClient or any other.
14+
type Client struct {
15+
sigName string
16+
signer *Signer
17+
verifier *Verifier
18+
fetchVerifier func(res *http.Response, req *http.Request) (sigName string, verifier *Verifier)
19+
http.Client
20+
}
21+
22+
// NewClient constructs a new client, with the flexibility of including a custom http.Client.
23+
func NewClient(sigName string, signer *Signer, verifier *Verifier, fetchVerifier func(res *http.Response, req *http.Request) (sigName string, verifier *Verifier), client http.Client) *Client {
24+
return &Client{sigName: sigName, signer: signer, verifier: verifier, fetchVerifier: fetchVerifier, Client: client}
25+
}
26+
27+
// NewDefaultClient constructs a new client, based on the http.DefaultClient.
28+
func NewDefaultClient(sigName string, signer *Signer, verifier *Verifier, fetchVerifier func(res *http.Response, req *http.Request) (sigName string, verifier *Verifier)) *Client {
29+
return NewClient(sigName, signer, verifier, fetchVerifier, *http.DefaultClient)
30+
}
31+
32+
func validateClient(c *Client) error {
33+
if c == nil {
34+
return fmt.Errorf("nil client")
35+
}
36+
if c.verifier != nil && c.fetchVerifier != nil {
37+
return fmt.Errorf("at most one of \"verifier\" and \"fetchVerifier\" must be set")
38+
}
39+
return nil
40+
}
41+
42+
// Do sends an http.Request, with optional signing and/or verification. Errors may be produced by any of
43+
// these operations.
44+
func (c *Client) Do(req *http.Request) (*http.Response, error) {
45+
if err := validateClient(c); err != nil {
46+
return nil, err
47+
}
48+
if c.signer != nil {
49+
sigInput, sig, err := SignRequest(c.sigName, *c.signer, req)
50+
if err != nil {
51+
return nil, fmt.Errorf("failed to sign request: %v", err)
52+
}
53+
req.Header.Add("Signature", sig)
54+
req.Header.Add("Signature-Input", sigInput)
55+
}
56+
57+
// Send the request, receive response
58+
res, err := c.Client.Do(req)
59+
if err != nil {
60+
return res, err
61+
}
62+
63+
if c.verifier != nil {
64+
_, err := VerifyResponse(c.sigName, *c.verifier, res)
65+
if err != nil {
66+
return nil, err
67+
}
68+
} else if c.fetchVerifier != nil {
69+
sigName, verifier := c.fetchVerifier(res, req)
70+
if err != nil {
71+
return nil, err
72+
}
73+
_, err := VerifyResponse(sigName, *verifier, res)
74+
if err != nil {
75+
return nil, err
76+
}
77+
}
78+
return res, nil
79+
}
80+
81+
// Get sends an HTTP GET, a wrapper for Do.
82+
func (c *Client) Get(url string) (res *http.Response, err error) {
83+
req, err := http.NewRequest("GET", url, nil)
84+
if err != nil {
85+
return nil, err
86+
}
87+
return c.Do(req)
88+
}
89+
90+
// Head sends an HTTP HEAD, a wrapper for Do.
91+
func (c *Client) Head(url string) (res *http.Response, err error) {
92+
req, err := http.NewRequest("HEAD", url, nil)
93+
if err != nil {
94+
return nil, err
95+
}
96+
return c.Do(req)
97+
}
98+
99+
// Post sends an HTTP POST, a wrapper for Do.
100+
func (c *Client) Post(url, contentType string, body io.Reader) (res *http.Response, err error) {
101+
req, err := http.NewRequest("POST", url, body)
102+
if err != nil {
103+
return nil, err
104+
}
105+
req.Header.Set("Content-Type", contentType)
106+
return c.Do(req)
107+
}

client_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package httpsign
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"reflect"
9+
"testing"
10+
)
11+
12+
func TestClient_Get(t *testing.T) {
13+
type fields struct {
14+
sigName string
15+
signer *Signer
16+
verifier *Verifier
17+
fetchVerifier func(res *http.Response, req *http.Request) (sigName string, verifier *Verifier)
18+
Client http.Client
19+
}
20+
type args struct {
21+
url string
22+
}
23+
tests := []struct {
24+
name string
25+
fields fields
26+
args args
27+
wantRes string
28+
wantErr bool
29+
}{
30+
{
31+
name: "from Google",
32+
fields: fields{
33+
sigName: "sig1",
34+
signer: func() *Signer {
35+
signer, _ := NewHMACSHA256Signer("key1", bytes.Repeat([]byte{1}, 64), NewSignConfig(), HeaderList([]string{"@method"}))
36+
return signer
37+
}(),
38+
verifier: nil,
39+
fetchVerifier: nil,
40+
Client: *http.DefaultClient,
41+
},
42+
args: args{
43+
url: "",
44+
},
45+
wantRes: "200 OK",
46+
wantErr: false,
47+
},
48+
{
49+
name: "not found",
50+
fields: fields{
51+
sigName: "sig1",
52+
signer: func() *Signer {
53+
signer, _ := NewHMACSHA256Signer("key1", bytes.Repeat([]byte{1}, 64), NewSignConfig(), HeaderList([]string{"@method"}))
54+
return signer
55+
}(),
56+
verifier: nil,
57+
fetchVerifier: nil,
58+
Client: *http.DefaultClient,
59+
},
60+
args: args{
61+
url: "/thisaintaurl",
62+
},
63+
wantRes: "404 Not Found",
64+
wantErr: false,
65+
},
66+
{
67+
name: "bad signature name",
68+
fields: fields{
69+
sigName: "",
70+
signer: func() *Signer {
71+
signer, _ := NewHMACSHA256Signer("key1", bytes.Repeat([]byte{1}, 64), NewSignConfig(), HeaderList([]string{"@method"}))
72+
return signer
73+
}(),
74+
verifier: nil,
75+
fetchVerifier: nil,
76+
Client: *http.DefaultClient,
77+
},
78+
args: args{
79+
url: "",
80+
},
81+
wantRes: "",
82+
wantErr: true,
83+
},
84+
}
85+
86+
simpleHandler := func(w http.ResponseWriter, r *http.Request) {
87+
if r.RequestURI == "/" {
88+
w.WriteHeader(200)
89+
} else {
90+
w.WriteHeader(404)
91+
}
92+
fmt.Fprintln(w, "Hey client, good to see ya")
93+
}
94+
ts := httptest.NewServer(http.HandlerFunc(simpleHandler))
95+
defer ts.Close()
96+
97+
for _, tt := range tests {
98+
t.Run(tt.name, func(t *testing.T) {
99+
c := &Client{
100+
sigName: tt.fields.sigName,
101+
signer: tt.fields.signer,
102+
verifier: tt.fields.verifier,
103+
fetchVerifier: tt.fields.fetchVerifier,
104+
Client: tt.fields.Client,
105+
}
106+
res, err := c.Get(ts.URL + tt.args.url)
107+
var gotRes string
108+
if res != nil {
109+
gotRes = res.Status
110+
}
111+
if (err != nil) != tt.wantErr {
112+
t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr)
113+
return
114+
}
115+
if !reflect.DeepEqual(gotRes, tt.wantRes) {
116+
t.Errorf("Get() gotRes = %v, want %v", gotRes, tt.wantRes)
117+
}
118+
})
119+
}
120+
}
121+
122+
func TestClient_Head(t *testing.T) {
123+
type fields struct {
124+
sigName string
125+
signer *Signer
126+
verifier *Verifier
127+
fetchVerifier func(res *http.Response, req *http.Request) (sigName string, verifier *Verifier)
128+
Client http.Client
129+
}
130+
type args struct {
131+
url string
132+
}
133+
tests := []struct {
134+
name string
135+
fields fields
136+
args args
137+
wantRes string
138+
wantErr bool
139+
}{
140+
{
141+
name: "TLS",
142+
fields: fields{
143+
sigName: "sig1",
144+
signer: func() *Signer {
145+
signer, _ := NewHMACSHA256Signer("key1", bytes.Repeat([]byte{1}, 64), NewSignConfig(), HeaderList([]string{"@method"}))
146+
return signer
147+
}(),
148+
verifier: nil,
149+
fetchVerifier: nil,
150+
Client: *http.DefaultClient,
151+
},
152+
args: args{
153+
url: "https://www.google.com/",
154+
},
155+
wantRes: "200 OK",
156+
wantErr: false,
157+
},
158+
}
159+
for _, tt := range tests {
160+
t.Run(tt.name, func(t *testing.T) {
161+
c := &Client{
162+
sigName: tt.fields.sigName,
163+
signer: tt.fields.signer,
164+
verifier: tt.fields.verifier,
165+
fetchVerifier: tt.fields.fetchVerifier,
166+
Client: tt.fields.Client,
167+
}
168+
res, err := c.Head(tt.args.url)
169+
var gotRes string
170+
if res != nil {
171+
gotRes = res.Status
172+
}
173+
if (err != nil) != tt.wantErr {
174+
t.Errorf("Head() error = %v, wantErr %v", err, tt.wantErr)
175+
return
176+
}
177+
if !reflect.DeepEqual(gotRes, tt.wantRes) {
178+
t.Errorf("Head() gotRes = %v, want %v", gotRes, tt.wantRes)
179+
}
180+
})
181+
}
182+
}

0 commit comments

Comments
 (0)