From 29cd974ccae02641cabb16d4d7ad070ab33b50f0 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 5 Jun 2026 17:51:03 -0400 Subject: [PATCH] tskagent: fix incorrect comment parsing from key files The code to pull comments out of key data made some heuristic (but incorrect) assumptions about the organization of the key. In particular, for keys where no comment is defined, it could incorrectly grab an earlier field (e.g., one of the key parameters) which would confuse a client expecting printable text. Correctly process the fields we care about, so that when an empty comment is set we don't grab the wrong thing by mistake. Also, update the tests to cover RSA as well as ED25519 keys, and generate test keys randomly. I ran the new tests 50,000 times under "stress" and was not able to produce any new failures. Thanks to @andrew-d for reporting this and for providing a failing test to debug from. Co-authored-by: Andrew Dunham --- internal_test.go | 87 ++++++++++++++++++++++++------- tskagent.go | 132 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 174 insertions(+), 45 deletions(-) diff --git a/internal_test.go b/internal_test.go index 979b0f9..db3c75e 100644 --- a/internal_test.go +++ b/internal_test.go @@ -4,34 +4,85 @@ package tskagent import ( + "crypto" + "crypto/ed25519" + crand "crypto/rand" + "crypto/rsa" + "encoding/pem" "testing" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" - - _ "embed" ) var _ agent.Agent = &Server{} -// The test data key is a throwaway generated for testing, and is not used -// anywhere else. To generate a new test key, run: -// -// ssh-keygen -C "Dummy key for testing" -t ed25519 -f testdata/test.key - -//go:embed testdata/test.key -var testPrivKey []byte - func TestKeyParse(t *testing.T) { - key, err := parseStoredKey("foo", 1, testPrivKey) - if err != nil { - t.Fatalf("Parsing stored key: %v", err) + tests := []struct { + name string + input []byte + comment string + keyType string + }{ + { + name: "ED2559/Comment", + input: mustGenerateKey(t, genED25519, "elliptic justice"), + comment: "elliptic justice", + keyType: "ssh-ed25519", + }, + { + name: "ED2559/NoComment", + input: mustGenerateKey(t, genED25519, ""), + comment: "", + keyType: "ssh-ed25519", + }, + { + name: "RSA/Comment", + input: mustGenerateKey(t, genRSA, "what year is it"), + comment: "what year is it", + keyType: "ssh-rsa", + }, + { + name: "RSA/NoComment", + input: mustGenerateKey(t, genRSA, ""), + comment: "", + keyType: "ssh-rsa", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + key, err := parseStoredKey(tc.name, 1, tc.input) + if err != nil { + t.Fatalf("parsing stored key: %v", err) + } + if key.Comment != tc.comment { + t.Errorf("Comment: got %q, want %q", key.Comment, tc.comment) + } + if got := key.Signer.PublicKey().Type(); got != tc.keyType { + t.Errorf("Key type: got %q, want %q", got, tc.keyType) + } + }) } +} - const wantComment = "Dummy key for testing" - if key.Comment != wantComment { - t.Errorf("Comment: got %q, want %q", key.Comment, wantComment) +func mustGenerateKey(t *testing.T, gen func() (crypto.PrivateKey, error), comment string) []byte { + t.Helper() + key, err := gen() + if err != nil { + t.Fatalf("Generating key: %v", err) } - if got, want := key.Signer.PublicKey().Type(), "ssh-ed25519"; got != want { - t.Errorf("Key type: got %q, want %q", got, want) + enc, err := ssh.MarshalPrivateKey(key, comment) + if err != nil { + t.Fatalf("Marshaling key: %v", err) } + return pem.EncodeToMemory(enc) +} + +func genED25519() (crypto.PrivateKey, error) { + _, key, err := ed25519.GenerateKey(crand.Reader) + return key, err +} + +func genRSA() (crypto.PrivateKey, error) { + return rsa.GenerateKey(crand.Reader, 1024) } diff --git a/tskagent.go b/tskagent.go index 2d3e1dc..f44064c 100644 --- a/tskagent.go +++ b/tskagent.go @@ -303,43 +303,121 @@ func parseStoredKey(name string, version api.SecretVersion, data []byte) (*sshKe func parseComment(key []byte) string { blk, _ := pem.Decode(key) - // The OpenSSH key format begins with a header followed by a public and a - // private key. Cut off the headers and skip the public key to find the - // private key, where the comment resides. The header is separated from the - // keys by a hard-coded uint32 key count of 1 (big-endian). - _, keys, ok := bytes.Cut(blk.Bytes, []byte("\x00\x00\x00\x01")) - if !ok { + // See: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key + s := newScanner(blk.Bytes) + + // Check magic format header. + if err := s.scanLiteral("openssh-key-v1\x00"); err != nil { + return "" // not a key file, or some antique version + } + cipher, err := s.scanString() + if err != nil || string(cipher) != "none" { + return "" // encrypted contents, we can't read them + } + // Skip kdfname, kdfoptions, which we don't care about. + if err := s.skipStrings(2); err != nil { + return "" + } + // The next field is the number of keys. This could in theory be any value, + // but OpenSSH hardcodes it to 1. + if nk, err := s.scanUint32(); err != nil || nk != 1 { + return "" + } + // Skip the public keys, as the comment (if any) is with the private key. + if err := s.skipStrings(1); err != nil { return "" } - // Skip the public key... - pubLen := int(binary.BigEndian.Uint32(keys)) - if 4+pubLen > len(keys) { + // The rest of the packet should be a bundle of private keys. + // Because we know cipher is "none", it is plaintext, but there may + // be some padding at the end. + pkeys, err := s.scanString() + if err != nil { return "" } - keys = keys[4+pubLen:] - // Extract the private key... - privLen := int(binary.BigEndian.Uint32(keys)) - if 4+privLen > len(keys) { + pk := newScanner(pkeys) + // Skip the two 32-bit validity nonces. + if err := pk.skipBytes(8); err != nil { return "" } + // The rest of the bundle depends on what type of key this is, but + // the last string field will be the comment (if any). + var last string + for !pk.atEOF() { + s, err := pk.scanString() + if err != nil { + break + } + last = string(s) + } + return last +} + +// A scanner is a minimal scanner for a slice of bytes representing an OpenSSH +// key file. The methods of this type alias (but do not modify) the input. +type scanner struct { + buf []byte +} + +func newScanner(data []byte) *scanner { + return &scanner{buf: data} +} - // Remove padding at end (pad bytes are 0x01-0x07) - for n := len(keys) - 1; keys[n] < 0x08; n-- { - keys = keys[:n] +// atEOF reports whether s has any further contents. +func (s *scanner) atEOF() bool { return len(s.buf) == 0 } + +// skipBytes advances past the first n bytes of the input. +func (s *scanner) skipBytes(n int) error { + if len(s.buf) < n { + return fmt.Errorf("got %d bytes, want %d", len(s.buf), n) } - keys = keys[4:] // remove length prefix (checked above) - keys = keys[8:] // remove checksum (not used) - - // The comment is the last length-prefixed field of the private key. - // Skip past all the others. - for len(keys) >= 4 { - n := int(binary.BigEndian.Uint32(keys)) - if 4+n == len(keys) { - return string(keys[4:]) + s.buf = s.buf[n:] + return nil +} + +// skipStrings advances past the next n length-prefixed strings. +func (s *scanner) skipStrings(n int) error { + for n > 0 { + if _, err := s.scanString(); err != nil { + return err } - keys = keys[4+n:] + n-- + } + return nil +} + +// scanLiteral advances past the specified prefix of the input. +func (s *scanner) scanLiteral(want string) error { + rest, ok := bytes.CutPrefix(s.buf, []byte(want)) + if !ok { + return fmt.Errorf("missing %q", want) } - return "" + s.buf = rest + return nil +} + +// scanString consumes and returns a length-prefixed string. +func (s *scanner) scanString() ([]byte, error) { + n32, err := s.scanUint32() + if err != nil { + return nil, err + } + n := int(n32) + if n > len(s.buf) { + return nil, fmt.Errorf("got %d bytes, want %d", len(s.buf), n) + } + out := s.buf[:n] + s.buf = s.buf[n:] + return out, nil +} + +// scanUint32 consumes and returns a big-endian 32-bit integer. +func (s *scanner) scanUint32() (uint32, error) { + if len(s.buf) < 4 { + return 0, fmt.Errorf("got %d bytes, want 4", len(s.buf)) + } + out := binary.BigEndian.Uint32(s.buf) + s.buf = s.buf[4:] + return out, nil }