Skip to content

Commit fde7228

Browse files
pg-wire-extended-query-support (#640)
Summary: - Support for the `postgres` extended query pattern. - Safe integer casting for runtime vs RDBMS type conversions. - Added robot test `PG Extended Query Column Descriptions Available`. - Added robot test `PG Extended Query Prepared Statement Returns Rows`. - Added robot test `PG Extended Query Prepared Statement NULL Param Returns Zero Rows`.
1 parent b7da38c commit fde7228

File tree

18 files changed

+1027
-14
lines changed

18 files changed

+1027
-14
lines changed

.golangci.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ linters:
196196
- lll
197197
- revive
198198
path: mcp_client\/cmd\/.*\.go
199+
- linters:
200+
- revive
201+
- unparam
202+
path: internal\/stackql\/driver\/.*\.go
203+
- linters:
204+
- stylecheck
205+
path: internal\/stackql\/queryshape\/.*\.go
199206
- linters:
200207
- revive
201208
path: internal\/stackql\/acid\/tsm_physio\/.*\.go

internal/stackql/driver/driver.go

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"github.com/stackql/stackql/internal/stackql/acid/tsm_physio"
1212
"github.com/stackql/stackql/internal/stackql/handler"
1313
"github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto"
14+
"github.com/stackql/stackql/internal/stackql/paramdecoder"
15+
"github.com/stackql/stackql/internal/stackql/queryshape"
1416
"github.com/stackql/stackql/internal/stackql/responsehandler"
1517
"github.com/stackql/stackql/internal/stackql/util"
1618
"github.com/stackql/stackql/pkg/txncounter"
@@ -19,9 +21,10 @@ import (
1921
)
2022

2123
var (
22-
_ StackQLDriver = &basicStackQLDriver{}
23-
_ sqlbackend.SQLBackendFactory = &basicStackQLDriverFactory{}
24-
_ StackQLDriverFactory = &basicStackQLDriverFactory{}
24+
_ StackQLDriver = &basicStackQLDriver{}
25+
_ sqlbackend.IExtendedQueryBackend = &basicStackQLDriver{}
26+
_ sqlbackend.SQLBackendFactory = &basicStackQLDriverFactory{}
27+
_ StackQLDriverFactory = &basicStackQLDriverFactory{}
2528
)
2629

2730
type StackQLDriverFactory interface {
@@ -66,6 +69,10 @@ func (sdf *basicStackQLDriverFactory) newSQLDriver() (StackQLDriver, error) {
6669
debugBuf: buf,
6770
handlerCtx: clonedCtx,
6871
txnOrchestrator: txnOrchestrator,
72+
shapeInferrer: queryshape.NewInferrer(clonedCtx),
73+
paramDecoder: paramdecoder.NewDecoder(),
74+
stmtCache: make(map[string]*stmtMeta),
75+
portalCache: make(map[string]*portalMeta),
6976
}
7077
return rv, nil
7178
}
@@ -125,10 +132,24 @@ func (dr *basicStackQLDriver) ProcessQuery(query string) {
125132
}
126133
}
127134

135+
type stmtMeta struct {
136+
query string
137+
paramOIDs []uint32
138+
columns []sqldata.ISQLColumn
139+
}
140+
141+
type portalMeta struct {
142+
stmtName string
143+
}
144+
128145
type basicStackQLDriver struct {
129146
debugBuf *bytes.Buffer
130147
handlerCtx handler.HandlerContext
131148
txnOrchestrator tsm_physio.Orchestrator
149+
shapeInferrer queryshape.Inferrer
150+
paramDecoder paramdecoder.Decoder
151+
stmtCache map[string]*stmtMeta
152+
portalCache map[string]*portalMeta
132153
}
133154

134155
func (dr *basicStackQLDriver) GetDebugStr() string {
@@ -191,9 +212,88 @@ func NewStackQLDriver(handlerCtx handler.HandlerContext) (StackQLDriver, error)
191212
return &basicStackQLDriver{
192213
handlerCtx: handlerCtx,
193214
txnOrchestrator: txnOrchestrator,
215+
shapeInferrer: queryshape.NewInferrer(handlerCtx),
216+
paramDecoder: paramdecoder.NewDecoder(),
217+
stmtCache: make(map[string]*stmtMeta),
218+
portalCache: make(map[string]*portalMeta),
194219
}, nil
195220
}
196221

222+
func (dr *basicStackQLDriver) HandleParse(
223+
ctx context.Context, stmtName string, query string, paramOIDs []uint32,
224+
) ([]uint32, error) {
225+
// Infer result columns at parse time and cache for Describe/Execute.
226+
columns := dr.shapeInferrer.InferResultColumns(query)
227+
dr.stmtCache[stmtName] = &stmtMeta{
228+
query: query,
229+
paramOIDs: paramOIDs,
230+
columns: columns,
231+
}
232+
return paramOIDs, nil
233+
}
234+
235+
func (dr *basicStackQLDriver) HandleBind(
236+
ctx context.Context, portalName string, stmtName string,
237+
paramFormats []int16, paramValues [][]byte, resultFormats []int16,
238+
) error {
239+
dr.portalCache[portalName] = &portalMeta{
240+
stmtName: stmtName,
241+
}
242+
return nil
243+
}
244+
245+
func (dr *basicStackQLDriver) HandleDescribeStatement(
246+
ctx context.Context, stmtName string, query string, paramOIDs []uint32,
247+
) ([]uint32, []sqldata.ISQLColumn, error) {
248+
if cached, ok := dr.stmtCache[stmtName]; ok {
249+
return cached.paramOIDs, cached.columns, nil
250+
}
251+
// Fallback: infer on the fly (shouldn't happen if Parse was called first).
252+
columns := dr.shapeInferrer.InferResultColumns(query)
253+
return paramOIDs, columns, nil
254+
}
255+
256+
func (dr *basicStackQLDriver) HandleDescribePortal(
257+
ctx context.Context, portalName string, stmtName string, query string, paramOIDs []uint32,
258+
) ([]sqldata.ISQLColumn, error) {
259+
if portal, portalFound := dr.portalCache[portalName]; portalFound {
260+
if cached, stmtFound := dr.stmtCache[portal.stmtName]; stmtFound {
261+
return cached.columns, nil
262+
}
263+
}
264+
return dr.shapeInferrer.InferResultColumns(query), nil
265+
}
266+
267+
func (dr *basicStackQLDriver) HandleExecute(
268+
ctx context.Context, portalName string, stmtName string, query string,
269+
paramFormats []int16, paramValues [][]byte, resultFormats []int16, maxRows int32,
270+
) (sqldata.ISQLResultStream, error) {
271+
// Look up cached param OIDs for format-aware decoding.
272+
var paramOIDs []uint32
273+
if portal, portalFound := dr.portalCache[portalName]; portalFound {
274+
if cached, stmtFound := dr.stmtCache[portal.stmtName]; stmtFound {
275+
paramOIDs = cached.paramOIDs
276+
}
277+
}
278+
// Decode params (handles both text and binary formats).
279+
decodedStrings, err := dr.paramDecoder.DecodeParams(paramOIDs, paramFormats, paramValues)
280+
if err != nil {
281+
return nil, fmt.Errorf("parameter decoding error: %w", err)
282+
}
283+
resolved := queryshape.SubstituteDecodedParams(query, decodedStrings)
284+
return dr.HandleSimpleQuery(ctx, resolved)
285+
}
286+
287+
func (dr *basicStackQLDriver) HandleCloseStatement(ctx context.Context, stmtName string) error {
288+
delete(dr.stmtCache, stmtName)
289+
return nil
290+
}
291+
292+
func (dr *basicStackQLDriver) HandleClosePortal(ctx context.Context, portalName string) error {
293+
delete(dr.portalCache, portalName)
294+
return nil
295+
}
296+
197297
func (dr *basicStackQLDriver) processQueryOrQueries(
198298
handlerCtx handler.HandlerContext,
199299
) ([]internaldto.ExecutorOutput, bool) {
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Package paramdecoder decodes parameter values from their wire format
2+
// (text or binary) into string representations suitable for SQL substitution.
3+
package paramdecoder
4+
5+
import (
6+
"encoding/binary"
7+
"fmt"
8+
"math"
9+
"strconv"
10+
"time"
11+
12+
"github.com/lib/pq/oid"
13+
)
14+
15+
// Decoder decodes raw parameter bytes according to their format codes
16+
// and OIDs, returning string representations for each.
17+
type Decoder interface {
18+
DecodeParams(paramOIDs []uint32, paramFormats []int16, paramValues [][]byte) ([]string, error)
19+
}
20+
21+
// NewDecoder creates a new parameter decoder.
22+
func NewDecoder() Decoder {
23+
return &standardDecoder{}
24+
}
25+
26+
type standardDecoder struct{}
27+
28+
func (d *standardDecoder) DecodeParams(
29+
paramOIDs []uint32, paramFormats []int16, paramValues [][]byte,
30+
) ([]string, error) {
31+
result := make([]string, len(paramValues))
32+
for i, val := range paramValues {
33+
if val == nil {
34+
result[i] = "NULL"
35+
continue
36+
}
37+
format := resolveFormat(paramFormats, i)
38+
paramOID := oid.Oid(0)
39+
if i < len(paramOIDs) {
40+
paramOID = oid.Oid(paramOIDs[i])
41+
}
42+
decoded, err := decodeParam(paramOID, format, val)
43+
if err != nil {
44+
return nil, fmt.Errorf("parameter $%d: %w", i+1, err)
45+
}
46+
result[i] = decoded
47+
}
48+
return result, nil
49+
}
50+
51+
// resolveFormat returns the format code for parameter at index i.
52+
// Per postgres protocol: empty = all text, length 1 = applies to all,
53+
// otherwise per-parameter.
54+
func resolveFormat(formats []int16, i int) int16 {
55+
if len(formats) == 0 {
56+
return 0 // text
57+
}
58+
if len(formats) == 1 {
59+
return formats[0]
60+
}
61+
if i < len(formats) {
62+
return formats[i]
63+
}
64+
return 0 // text
65+
}
66+
67+
// decodeParam decodes a single parameter value.
68+
// Format 0 = text (bytes are UTF-8), format 1 = binary (OID-specific encoding).
69+
func decodeParam(paramOID oid.Oid, format int16, val []byte) (string, error) {
70+
if format == 0 {
71+
// Text format: raw bytes are the UTF-8 string representation.
72+
return string(val), nil
73+
}
74+
// Binary format: decode based on OID.
75+
return decodeBinary(paramOID, val)
76+
}
77+
78+
// Binary wire sizes for fixed-width postgres types.
79+
const (
80+
boolSize = 1
81+
int2Size = 2
82+
int4Size = 4
83+
int8Size = 8
84+
float4Size = 4
85+
float8Size = 8
86+
timestampSize = 8
87+
)
88+
89+
// decodeBinary decodes a binary-encoded parameter value to its string representation.
90+
//
91+
//nolint:cyclop,exhaustive // switch over OIDs is inherently branchy; only common types handled
92+
func decodeBinary(paramOID oid.Oid, val []byte) (string, error) {
93+
switch paramOID {
94+
case oid.T_bool:
95+
if len(val) != boolSize {
96+
return "", fmt.Errorf("bool: expected %d byte, got %d", boolSize, len(val))
97+
}
98+
if val[0] != 0 {
99+
return "true", nil
100+
}
101+
return "false", nil
102+
case oid.T_int2:
103+
if len(val) != int2Size {
104+
return "", fmt.Errorf("int2: expected %d bytes, got %d", int2Size, len(val))
105+
}
106+
v := int16(binary.BigEndian.Uint16(val)) //nolint:gosec // deliberate narrowing
107+
return strconv.FormatInt(int64(v), 10), nil
108+
case oid.T_int4:
109+
if len(val) != int4Size {
110+
return "", fmt.Errorf("int4: expected %d bytes, got %d", int4Size, len(val))
111+
}
112+
v := int32(binary.BigEndian.Uint32(val)) //nolint:gosec // deliberate narrowing
113+
return strconv.FormatInt(int64(v), 10), nil
114+
case oid.T_int8:
115+
if len(val) != int8Size {
116+
return "", fmt.Errorf("int8: expected %d bytes, got %d", int8Size, len(val))
117+
}
118+
v := int64(binary.BigEndian.Uint64(val)) //nolint:gosec // deliberate conversion
119+
return strconv.FormatInt(v, 10), nil
120+
case oid.T_float4:
121+
if len(val) != float4Size {
122+
return "", fmt.Errorf("float4: expected %d bytes, got %d", float4Size, len(val))
123+
}
124+
bits := binary.BigEndian.Uint32(val)
125+
return strconv.FormatFloat(float64(math.Float32frombits(bits)), 'f', -1, 32), nil
126+
case oid.T_float8:
127+
if len(val) != float8Size {
128+
return "", fmt.Errorf("float8: expected %d bytes, got %d", float8Size, len(val))
129+
}
130+
bits := binary.BigEndian.Uint64(val)
131+
return strconv.FormatFloat(math.Float64frombits(bits), 'f', -1, 64), nil
132+
case oid.T_timestamp, oid.T_timestamptz:
133+
if len(val) != timestampSize {
134+
return "", fmt.Errorf("timestamp: expected %d bytes, got %d", timestampSize, len(val))
135+
}
136+
// Postgres binary timestamp: microseconds since 2000-01-01 00:00:00 UTC.
137+
microseconds := int64(binary.BigEndian.Uint64(val)) //nolint:gosec // deliberate conversion
138+
pgEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
139+
ts := pgEpoch.Add(time.Duration(microseconds) * time.Microsecond)
140+
return ts.Format("2006-01-02 15:04:05.999999"), nil
141+
case oid.T_text, oid.T_varchar, oid.T_name:
142+
return string(val), nil
143+
default:
144+
// Unknown OID: treat as text (safe fallback).
145+
return string(val), nil
146+
}
147+
}

0 commit comments

Comments
 (0)