-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathhandshake.go
More file actions
205 lines (168 loc) · 6.4 KB
/
handshake.go
File metadata and controls
205 lines (168 loc) · 6.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package wire
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/stackql/psql-wire/internal/buffer"
"github.com/stackql/psql-wire/internal/types"
"go.uber.org/zap"
)
// Handshake performs the connection handshake and returns the connection
// version and a buffered reader to read incoming messages send by the client.
func (srv *Server) Handshake(conn net.Conn) (_ net.Conn, version types.Version, reader buffer.Reader, err error) {
reader = buffer.NewReader(conn, srv.BufferedMsgSize)
version, err = srv.readVersion(reader)
if err != nil {
return conn, version, reader, err
}
if version == types.VersionCancel {
return conn, version, reader, nil
}
// TODO(Jeroen): support GSS encryption
conn, reader, version, err = srv.potentialConnUpgrade(conn, reader, version)
if err != nil {
return conn, version, reader, err
}
return conn, version, reader, nil
}
// readVersion reads the start-up protocol version (uint32) and the
// buffer containing the rest.
func (srv *Server) readVersion(reader buffer.Reader) (_ types.Version, err error) {
var version uint32
_, err = reader.ReadUntypedMsg()
if err != nil {
return 0, err
}
version, err = reader.GetUint32()
if err != nil {
return 0, err
}
return types.Version(version), nil
}
// readyForQuery indicates that the server is ready to receive queries.
// The given server status is included inside the message to indicate the server status.
func readyForQuery(writer buffer.Writer, status types.ServerStatus) error {
writer.Start(types.ServerReady)
writer.AddByte(byte(status))
return writer.End()
}
// readParameters reads the key/value connection parameters send by the client and
// The read parameters will be set inside the given context. A new context containing
// the consumed parameters will be returned.
func (srv *Server) readParameters(ctx context.Context, reader buffer.Reader) (_ context.Context, err error) {
meta := make(Parameters)
srv.logger.Debug("reading client parameters")
for {
key, err := reader.GetString()
if err != nil {
return nil, err
}
// an empty key indicates the end of the connection parameters
if len(key) == 0 {
break
}
value, err := reader.GetString()
if err != nil {
return nil, err
}
srv.logger.Debug("client parameter", zap.String("key", key), zap.String("value", value))
meta[ParameterStatus(key)] = value
}
return setClientParameters(ctx, meta), nil
}
// writeParameters writes the server parameters such as client encoding to the client.
// The written parameters will be attached as a value to the given context. A new
// context containing the written parameters will be returned.
// https://www.postgresql.org/docs/10/libpq-status.html
func (srv *Server) writeParameters(ctx context.Context, writer buffer.Writer, params Parameters) (_ context.Context, err error) {
if params == nil {
params = make(Parameters, 4)
}
srv.logger.Debug("writing server parameters")
params[ParamServerEncoding] = "UTF8"
params[ParamClientEncoding] = "UTF8"
params[ParamStandardConformingStrings] = "on"
params[ParamIsSuperuser] = buffer.EncodeBoolean(IsSuperUser(ctx))
params[ParamSessionAuthorization] = AuthenticatedUsername(ctx)
params[ParamServerVersion] = fmt.Sprintf("%d", 15*10000) // 15.1.2 => 15*10000 + 1*100 + 2*1 => 15102
for key, value := range params {
srv.logger.Debug("server parameter", zap.String("key", string(key)), zap.String("value", value))
writer.Start(types.ServerParameterStatus)
writer.AddString(string(key))
writer.AddNullTerminate()
writer.AddString(value)
writer.AddNullTerminate()
err = writer.End()
if err != nil {
return ctx, err
}
}
return setServerParameters(ctx, params), nil
}
func (srv *Server) isMandatoryTLS(clientAuth tls.ClientAuthType) bool {
if clientAuth == tls.RequireAndVerifyClientCert {
return true
}
return false
}
// potentialConnUpgrade potentially upgrades the given connection using TLS
// if the client requests for it, or the server mandates it. The connection upgrade is ignored if the
// server does not support a secure connection.
func (srv *Server) potentialConnUpgrade(conn net.Conn, reader buffer.Reader, version types.Version) (_ net.Conn, _ buffer.Reader, _ types.Version, err error) {
// server to enforce secure connections as appropriate
isMandatoryTLS := srv.isMandatoryTLS(srv.ClientAuth)
if version != types.VersionSSLRequest {
if isMandatoryTLS {
srv.logger.Warn("client is requesting nil TLS, but the server mandates TLS")
return conn, reader, version, fmt.Errorf("client is requesting nil TLS, but the server mandates TLS")
}
return conn, reader, version, nil
}
srv.logger.Debug("attempting to upgrade the client to a TLS connection")
if len(srv.Certificates) == 0 {
if isMandatoryTLS {
srv.logger.Warn("server mandates TLS, but does not possess the requisite certificates")
return conn, reader, version, fmt.Errorf("server mandates TLS, but does not possess the requisite certificates")
}
srv.logger.Debug("no TLS certificates available continuing with a insecure connection")
return srv.sslUnsupported(conn, reader, version)
}
_, err = conn.Write(sslSupported)
if err != nil {
return conn, reader, version, err
}
tlsConfig := tls.Config{
Certificates: srv.Certificates,
ClientAuth: srv.ClientAuth,
ClientCAs: srv.ClientCAs,
}
// NOTE(Jeroen): initialize the TLS connection and construct a new buffered
// reader for the constructed TLS connection.
conn = tls.Server(conn, &tlsConfig)
reader = buffer.NewReader(conn, srv.BufferedMsgSize)
version, err = srv.readVersion(reader)
if err != nil {
return conn, reader, version, err
}
srv.logger.Debug("connection has been upgraded successfully")
return conn, reader, version, err
}
// sslUnsupported announces to the PostgreSQL client that we are unable to
// upgrade the connection to a secure connection at this time. The client
// version is read again once the insecure connection has been announced.
func (srv *Server) sslUnsupported(conn net.Conn, reader buffer.Reader, version types.Version) (_ net.Conn, _ buffer.Reader, _ types.Version, err error) {
_, err = conn.Write(sslUnsupported)
if err != nil {
return conn, reader, version, err
}
version, err = srv.readVersion(reader)
if err != nil {
return conn, reader, version, err
}
if version == types.VersionCancel {
return conn, reader, version, errors.New("unexpected cancel version after upgrading the client connection")
}
return conn, reader, version, nil
}