@@ -11,6 +11,8 @@ import (
1111 "github.com/stackql/stackql/internal/stackql/acid/tsm_physio"
1212 "github.com/stackql/stackql/internal/stackql/handler"
1313 "github.com/stackql/stackql/internal/stackql/internal_data_transfer/internaldto"
14+ "github.com/stackql/stackql/internal/stackql/paramdecoder"
15+ "github.com/stackql/stackql/internal/stackql/queryshape"
1416 "github.com/stackql/stackql/internal/stackql/responsehandler"
1517 "github.com/stackql/stackql/internal/stackql/util"
1618 "github.com/stackql/stackql/pkg/txncounter"
@@ -19,9 +21,10 @@ import (
1921)
2022
2123var (
22- _ StackQLDriver = & basicStackQLDriver {}
23- _ sqlbackend.SQLBackendFactory = & basicStackQLDriverFactory {}
24- _ StackQLDriverFactory = & basicStackQLDriverFactory {}
24+ _ StackQLDriver = & basicStackQLDriver {}
25+ _ sqlbackend.IExtendedQueryBackend = & basicStackQLDriver {}
26+ _ sqlbackend.SQLBackendFactory = & basicStackQLDriverFactory {}
27+ _ StackQLDriverFactory = & basicStackQLDriverFactory {}
2528)
2629
2730type StackQLDriverFactory interface {
@@ -66,6 +69,10 @@ func (sdf *basicStackQLDriverFactory) newSQLDriver() (StackQLDriver, error) {
6669 debugBuf : buf ,
6770 handlerCtx : clonedCtx ,
6871 txnOrchestrator : txnOrchestrator ,
72+ shapeInferrer : queryshape .NewInferrer (clonedCtx ),
73+ paramDecoder : paramdecoder .NewDecoder (),
74+ stmtCache : make (map [string ]* stmtMeta ),
75+ portalCache : make (map [string ]* portalMeta ),
6976 }
7077 return rv , nil
7178}
@@ -125,10 +132,24 @@ func (dr *basicStackQLDriver) ProcessQuery(query string) {
125132 }
126133}
127134
135+ type stmtMeta struct {
136+ query string
137+ paramOIDs []uint32
138+ columns []sqldata.ISQLColumn
139+ }
140+
141+ type portalMeta struct {
142+ stmtName string
143+ }
144+
128145type basicStackQLDriver struct {
129146 debugBuf * bytes.Buffer
130147 handlerCtx handler.HandlerContext
131148 txnOrchestrator tsm_physio.Orchestrator
149+ shapeInferrer queryshape.Inferrer
150+ paramDecoder paramdecoder.Decoder
151+ stmtCache map [string ]* stmtMeta
152+ portalCache map [string ]* portalMeta
132153}
133154
134155func (dr * basicStackQLDriver ) GetDebugStr () string {
@@ -191,9 +212,88 @@ func NewStackQLDriver(handlerCtx handler.HandlerContext) (StackQLDriver, error)
191212 return & basicStackQLDriver {
192213 handlerCtx : handlerCtx ,
193214 txnOrchestrator : txnOrchestrator ,
215+ shapeInferrer : queryshape .NewInferrer (handlerCtx ),
216+ paramDecoder : paramdecoder .NewDecoder (),
217+ stmtCache : make (map [string ]* stmtMeta ),
218+ portalCache : make (map [string ]* portalMeta ),
194219 }, nil
195220}
196221
222+ func (dr * basicStackQLDriver ) HandleParse (
223+ ctx context.Context , stmtName string , query string , paramOIDs []uint32 ,
224+ ) ([]uint32 , error ) {
225+ // Infer result columns at parse time and cache for Describe/Execute.
226+ columns := dr .shapeInferrer .InferResultColumns (query )
227+ dr .stmtCache [stmtName ] = & stmtMeta {
228+ query : query ,
229+ paramOIDs : paramOIDs ,
230+ columns : columns ,
231+ }
232+ return paramOIDs , nil
233+ }
234+
235+ func (dr * basicStackQLDriver ) HandleBind (
236+ ctx context.Context , portalName string , stmtName string ,
237+ paramFormats []int16 , paramValues [][]byte , resultFormats []int16 ,
238+ ) error {
239+ dr .portalCache [portalName ] = & portalMeta {
240+ stmtName : stmtName ,
241+ }
242+ return nil
243+ }
244+
245+ func (dr * basicStackQLDriver ) HandleDescribeStatement (
246+ ctx context.Context , stmtName string , query string , paramOIDs []uint32 ,
247+ ) ([]uint32 , []sqldata.ISQLColumn , error ) {
248+ if cached , ok := dr .stmtCache [stmtName ]; ok {
249+ return cached .paramOIDs , cached .columns , nil
250+ }
251+ // Fallback: infer on the fly (shouldn't happen if Parse was called first).
252+ columns := dr .shapeInferrer .InferResultColumns (query )
253+ return paramOIDs , columns , nil
254+ }
255+
256+ func (dr * basicStackQLDriver ) HandleDescribePortal (
257+ ctx context.Context , portalName string , stmtName string , query string , paramOIDs []uint32 ,
258+ ) ([]sqldata.ISQLColumn , error ) {
259+ if portal , portalFound := dr .portalCache [portalName ]; portalFound {
260+ if cached , stmtFound := dr .stmtCache [portal .stmtName ]; stmtFound {
261+ return cached .columns , nil
262+ }
263+ }
264+ return dr .shapeInferrer .InferResultColumns (query ), nil
265+ }
266+
267+ func (dr * basicStackQLDriver ) HandleExecute (
268+ ctx context.Context , portalName string , stmtName string , query string ,
269+ paramFormats []int16 , paramValues [][]byte , resultFormats []int16 , maxRows int32 ,
270+ ) (sqldata.ISQLResultStream , error ) {
271+ // Look up cached param OIDs for format-aware decoding.
272+ var paramOIDs []uint32
273+ if portal , portalFound := dr .portalCache [portalName ]; portalFound {
274+ if cached , stmtFound := dr .stmtCache [portal .stmtName ]; stmtFound {
275+ paramOIDs = cached .paramOIDs
276+ }
277+ }
278+ // Decode params (handles both text and binary formats).
279+ decodedStrings , err := dr .paramDecoder .DecodeParams (paramOIDs , paramFormats , paramValues )
280+ if err != nil {
281+ return nil , fmt .Errorf ("parameter decoding error: %w" , err )
282+ }
283+ resolved := queryshape .SubstituteDecodedParams (query , decodedStrings )
284+ return dr .HandleSimpleQuery (ctx , resolved )
285+ }
286+
287+ func (dr * basicStackQLDriver ) HandleCloseStatement (ctx context.Context , stmtName string ) error {
288+ delete (dr .stmtCache , stmtName )
289+ return nil
290+ }
291+
292+ func (dr * basicStackQLDriver ) HandleClosePortal (ctx context.Context , portalName string ) error {
293+ delete (dr .portalCache , portalName )
294+ return nil
295+ }
296+
197297func (dr * basicStackQLDriver ) processQueryOrQueries (
198298 handlerCtx handler.HandlerContext ,
199299) ([]internaldto.ExecutorOutput , bool ) {
0 commit comments