Skip to content

Commit fe1983d

Browse files
committed
More tests, improve verification
1 parent 7a30373 commit fe1983d

File tree

4 files changed

+247
-66
lines changed

4 files changed

+247
-66
lines changed

config.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ func (c *SignConfig) setFakeCreated(ts int64) *SignConfig {
4646
return c
4747
}
4848

49-
// setExpires adds an "expires" parameter containing an expiration deadline, as Unix time.
49+
// SetExpires adds an "expires" parameter containing an expiration deadline, as Unix time.
5050
// Default: 0 (do not add the parameter)
51-
func (c *SignConfig) setExpires(expires int64) *SignConfig {
51+
func (c *SignConfig) SetExpires(expires int64) *SignConfig {
5252
c.expires = expires
5353
return c
5454
}
5555

56-
// setNonce adds a "nonce" string parameter whose content should be unique per signed message.
56+
// SetNonce adds a "nonce" string parameter whose content should be unique per signed message.
5757
// Default: empty string (do not add the parameter)
58-
func (c *SignConfig) setNonce(nonce string) *SignConfig {
58+
func (c *SignConfig) SetNonce(nonce string) *SignConfig {
5959
c.nonce = nonce
6060
return c
6161
}
@@ -73,16 +73,8 @@ type VerifyConfig struct {
7373
verifyCreated bool
7474
notNewerThan time.Duration
7575
notOlderThan time.Duration
76-
verifyAlg bool
7776
allowedAlgs []string
78-
}
79-
80-
// SetAllowedAlgs defines what are the allowed values of the "alg" parameter.
81-
// This is useful if the actual algorithm used in verification is taken from the message - not a recommended practice.
82-
// Default: all supported asymmetric algorithms.
83-
func (v *VerifyConfig) SetAllowedAlgs(allowedAlgs []string) *VerifyConfig {
84-
v.allowedAlgs = allowedAlgs
85-
return v
77+
rejectExpired bool
8678
}
8779

8880
// SetNotNewerThan sets the window for messages that appear to be newer than the current time,
@@ -106,21 +98,29 @@ func (v *VerifyConfig) SetVerifyCreated(verifyCreated bool) *VerifyConfig {
10698
return v
10799
}
108100

109-
// SetVerifyAlg indicates that the "alg" parameter exist. Use SetAllowedAlgs to specify allowed values.
110-
// Default: false.
111-
func (v *VerifyConfig) SetVerifyAlg(verifyAlg bool) *VerifyConfig {
112-
v.verifyAlg = verifyAlg
101+
// SetRejectExpired indicates that expired messages (according to the "expires" parameter) must fail verification.
102+
// Default: true.
103+
func (v *VerifyConfig) SetRejectExpired(rejectExpired bool) *VerifyConfig {
104+
v.rejectExpired = rejectExpired
105+
return v
106+
}
107+
108+
// SetAllowedAlgs defines the allowed values of the "alg" parameter.
109+
// This is useful if the actual algorithm used in verification is taken from the message - not a recommended practice.
110+
// Default: an empty list, signifying all values are accepted.
111+
func (v *VerifyConfig) SetAllowedAlgs(allowedAlgs []string) *VerifyConfig {
112+
v.allowedAlgs = allowedAlgs
113113
return v
114114
}
115115

116116
// NewVerifyConfig generates a default configuration.
117117
func NewVerifyConfig() *VerifyConfig {
118118
return &VerifyConfig{
119119
verifyCreated: true,
120-
notNewerThan: 1_000 * time.Millisecond,
121-
notOlderThan: 10_000 * time.Millisecond,
122-
verifyAlg: false,
123-
allowedAlgs: []string{"rsa-v1_5-sha256", "rsa-pss-sha512", "ecdsa-p256-sha256"},
120+
notNewerThan: 2 * time.Second,
121+
notOlderThan: 10 * time.Second,
122+
rejectExpired: true,
123+
allowedAlgs: []string{},
124124
}
125125
}
126126

handler_test.go

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ func Test_WrapHandler(t *testing.T) {
5454
}
5555

5656
verifier, err := NewHMACSHA256Verifier("key", bytes.Repeat([]byte{0}, 64), NewVerifyConfig(), *NewFields())
57+
if err != nil {
58+
log.Fatal(err)
59+
}
5760
_, err = VerifyResponse("sig1", *verifier, res)
5861
if err != nil {
5962
log.Fatal(err)
@@ -115,7 +118,7 @@ func ExampleWrapHandler_clientSigns() {
115118
}
116119

117120
func ExampleWrapHandler_serverSigns() {
118-
// Callback to let the server locate its verification key and configuration
121+
// Callback to let the server locate its signing key and configuration
119122
fetchSigner := func(res http.Response, r *http.Request) (string, *Signer) {
120123
sigName := "sig1"
121124
signer, _ := NewHMACSHA256Signer("key", bytes.Repeat([]byte{0}, 64), nil,
@@ -151,3 +154,115 @@ func ExampleWrapHandler_serverSigns() {
151154
fmt.Println("verified: ", verified)
152155
// output: verified: true
153156
}
157+
158+
// test various failures
159+
func TestWrapHandlerServerSigns(t *testing.T) {
160+
serverSignsTestCase := func(t *testing.T, nilSigner, dontSignResponse, earlyExpires, noSigner, badKey, badAlgs bool, wantBody, wantStatus string) {
161+
// Callback to let the server locate its signing key and configuration
162+
var signConfig *SignConfig
163+
if !earlyExpires {
164+
signConfig = nil
165+
} else {
166+
signConfig = NewSignConfig().SetExpires(2000)
167+
}
168+
fetchSigner := func(res http.Response, r *http.Request) (string, *Signer) {
169+
sigName := "sig1"
170+
signer, _ := NewHMACSHA256Signer("key", bytes.Repeat([]byte{0}, 64), signConfig,
171+
HeaderList([]string{"@status", "bar", "date"}))
172+
return sigName, signer
173+
}
174+
badFetchSigner := func(res http.Response, r *http.Request) (string, *Signer) {
175+
return "just a name", nil
176+
}
177+
178+
simpleHandler := func(w http.ResponseWriter, r *http.Request) { // this handler gets wrapped
179+
w.WriteHeader(200)
180+
w.Header().Set("bar", "baz")
181+
fmt.Fprintln(w, "Hello, client")
182+
}
183+
184+
// Configure the wrapper and set it up
185+
var config *HandlerConfig
186+
if !nilSigner {
187+
if !noSigner {
188+
config = NewHandlerConfig().SetVerifyRequest(false).SetFetchSigner(fetchSigner)
189+
} else {
190+
config = NewHandlerConfig().SetVerifyRequest(false).SetFetchSigner(badFetchSigner)
191+
}
192+
193+
} else {
194+
config = NewHandlerConfig().SetVerifyRequest(false).SetFetchSigner(nil)
195+
196+
}
197+
if dontSignResponse {
198+
config = config.SetSignResponse(false)
199+
}
200+
ts := httptest.NewServer(WrapHandler(http.HandlerFunc(simpleHandler), config))
201+
defer ts.Close()
202+
203+
// HTTP client code
204+
res, err := http.Get(ts.URL)
205+
if err != nil {
206+
log.Fatal(err)
207+
}
208+
body, err := io.ReadAll(res.Body)
209+
if err != nil {
210+
log.Fatal(err)
211+
}
212+
res.Body.Close()
213+
214+
if string(body) != wantBody {
215+
t.Errorf("Status: got %s want %s", string(body), wantBody)
216+
}
217+
if res.Status != wantStatus {
218+
t.Errorf("Status: got %s want %s", res.Status, wantStatus)
219+
}
220+
221+
var key []byte
222+
if !badKey {
223+
key = bytes.Repeat([]byte{0}, 64)
224+
} else {
225+
key = bytes.Repeat([]byte{3}, 64)
226+
}
227+
verifyConfig := NewVerifyConfig()
228+
if badAlgs {
229+
verifyConfig = verifyConfig.SetAllowedAlgs([]string{"zuzu"})
230+
}
231+
verifier, _ := NewHMACSHA256Verifier("key", key, verifyConfig, *NewFields())
232+
verified, _ := VerifyResponse("sig1", *verifier, res)
233+
234+
if verified {
235+
t.Errorf("surprise! Verification successful")
236+
}
237+
}
238+
nilSigner := func(t *testing.T) {
239+
serverSignsTestCase(t, true, false, false, false, false, false, "Failed to sign response: could not fetch a signer\n",
240+
"500 Internal Server Error")
241+
}
242+
dontSignResponse := func(t *testing.T) {
243+
serverSignsTestCase(t, false, true, false, false, false, false, "Hello, client\n",
244+
"200 OK")
245+
}
246+
earlyExpires := func(t *testing.T) {
247+
serverSignsTestCase(t, false, false, true, false, false, false, "Hello, client\n",
248+
"200 OK")
249+
}
250+
noSigner := func(t *testing.T) {
251+
serverSignsTestCase(t, false, false, false, true, false, false, "Failed to sign response: could not fetch a signer, check key ID\n",
252+
"500 Internal Server Error")
253+
}
254+
badKey := func(t *testing.T) {
255+
serverSignsTestCase(t, false, false, false, false, true, false, "Hello, client\n",
256+
"200 OK")
257+
}
258+
badAlgs := func(t *testing.T) {
259+
serverSignsTestCase(t, false, false, false, false, false, true, "Hello, client\n",
260+
"200 OK")
261+
}
262+
t.Run("nil signer", nilSigner)
263+
t.Run("don't sign response", dontSignResponse)
264+
t.Run("early expires field", earlyExpires)
265+
t.Run("bad fetch signer", noSigner)
266+
t.Run("wrong verification key", badKey)
267+
t.Run("failed algorithm check", badAlgs)
268+
}

signatures.go

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,55 +157,62 @@ func VerifyRequest(signatureName string, verifier Verifier, req *http.Request) (
157157
return verifyMessage(*verifier.c, signatureName, verifier, *parsedMessage, verifier.f)
158158
}
159159

160-
// RequestKeyID parses a signed request and returns the key ID used in the given signature.
161-
func RequestKeyID(signatureName string, req *http.Request) (string, error) {
160+
// RequestDetails parses a signed request and returns the key ID and optionally the algorithm used in the given signature.
161+
func RequestDetails(signatureName string, req *http.Request) (keyID, alg string, err error) {
162162
if req == nil {
163-
return "", fmt.Errorf("nil request")
163+
return "", "", fmt.Errorf("nil request")
164164
}
165165
if signatureName == "" {
166-
return "", fmt.Errorf("empty signature name")
166+
return "", "", fmt.Errorf("empty signature name")
167167
}
168168
parsedMessage, err := parseRequest(req)
169169
if err != nil {
170-
return "", err
170+
return "", "", err
171171
}
172172
return messageKeyID(signatureName, *parsedMessage)
173173
}
174174

175-
// ResponseKeyID parses a signed response and returns the key ID used in the given signature.
176-
func ResponseKeyID(signatureName string, res *http.Response) (string, error) {
175+
// ResponseDetails parses a signed response and returns the key ID and optionally the algorithm used in the given signature.
176+
func ResponseDetails(signatureName string, res *http.Response) (keyID, alg string, err error) {
177177
if res == nil {
178-
return "", fmt.Errorf("nil response")
178+
return "", "", fmt.Errorf("nil response")
179179
}
180180
if signatureName == "" {
181-
return "", fmt.Errorf("empty signature name")
181+
return "", "", fmt.Errorf("empty signature name")
182182
}
183183
parsedMessage, err := parseResponse(res)
184184
if err != nil {
185-
return "", err
185+
return "", "", err
186186
}
187187
return messageKeyID(signatureName, *parsedMessage)
188188
}
189189

190-
func messageKeyID(signatureName string, parsedMessage parsedMessage) (string, error) {
190+
func messageKeyID(signatureName string, parsedMessage parsedMessage) (keyID, alg string, err error) {
191191
si, found := parsedMessage.components[*fromHeaderName("signature-input")]
192192
if !found {
193-
return "", fmt.Errorf("missing \"signature-input\" header")
193+
return "", "", fmt.Errorf("missing \"signature-input\" header")
194194
}
195195
signatureInput := si[0]
196196
psi, err := parseSignatureInput(signatureInput, signatureName)
197197
if err != nil {
198-
return "", err
198+
return
199199
}
200200
keyIDParam, ok := psi.params["keyid"]
201201
if !ok {
202-
return "", fmt.Errorf("missing \"keyid\" parameter")
202+
return "", "", fmt.Errorf("missing \"keyid\" parameter")
203203
}
204-
keyID, ok := keyIDParam.(string)
204+
keyID, ok = keyIDParam.(string)
205205
if !ok {
206-
return "", fmt.Errorf("malformed \"keyid\" parameter")
206+
return "", "", fmt.Errorf("malformed \"keyid\" parameter")
207207
}
208-
return keyID, nil
208+
algParam, ok := psi.params["alg"] // "alg" is optional
209+
if ok {
210+
alg, ok = algParam.(string)
211+
if !ok {
212+
return "", "", fmt.Errorf("malformed \"alg\" parameter")
213+
}
214+
}
215+
return keyID, alg, nil
209216
}
210217

211218
//
@@ -284,7 +291,7 @@ func applyVerificationPolicy(psi *psiSignature, config VerifyConfig) error {
284291
return fmt.Errorf("message is too old, check for replay")
285292
}
286293
}
287-
if config.verifyAlg && len(config.allowedAlgs) > 0 {
294+
if len(config.allowedAlgs) > 0 {
288295
algParam, ok := psi.params["alg"]
289296
if !ok {
290297
return fmt.Errorf("missing \"alg\" parameter")
@@ -303,6 +310,20 @@ func applyVerificationPolicy(psi *psiSignature, config VerifyConfig) error {
303310
return fmt.Errorf("\"alg\" parameter not allowed by policy")
304311
}
305312
}
313+
if config.rejectExpired {
314+
now := time.Now()
315+
expiresParam, ok := psi.params["expires"]
316+
if ok {
317+
expires, ok := expiresParam.(int64)
318+
if !ok {
319+
return fmt.Errorf("malformed \"expires\" parameter")
320+
}
321+
expiresTime := time.Unix(expires, 0)
322+
if now.After(expiresTime) {
323+
return fmt.Errorf("expired signature")
324+
}
325+
}
326+
}
306327
return nil
307328
}
308329

0 commit comments

Comments
 (0)