diff --git a/go.mod b/go.mod index 453f8d0..71764f7 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.20 require ( github.com/antlr4-go/antlr/v4 v4.13.0 - golang.org/x/sync v0.3.0 github.com/coder/websocket v1.8.12 + golang.org/x/sync v0.3.0 ) require golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect diff --git a/go.sum b/go.sum index f100ff3..6ff70a9 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,3 @@ golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88p golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= -nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/libsql/internal/http/driver.go b/libsql/internal/http/driver.go index d8836e7..059fb3a 100644 --- a/libsql/internal/http/driver.go +++ b/libsql/internal/http/driver.go @@ -6,6 +6,6 @@ import ( "github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2" ) -func Connect(url, jwt, host string, schemaDb bool) driver.Conn { - return hranaV2.Connect(url, jwt, host, schemaDb) +func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string) driver.Conn { + return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey) } diff --git a/libsql/internal/http/hranaV2/hranaV2.go b/libsql/internal/http/hranaV2/hranaV2.go index de84102..5bcd052 100644 --- a/libsql/internal/http/hranaV2/hranaV2.go +++ b/libsql/internal/http/hranaV2/hranaV2.go @@ -8,13 +8,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/tursodatabase/libsql-client-go/sqliteparserutils" "io" "net/http" net_url "net/url" "runtime/debug" "strings" + "github.com/tursodatabase/libsql-client-go/sqliteparserutils" + "github.com/tursodatabase/libsql-client-go/libsql/internal/hrana" "github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared" ) @@ -36,8 +37,8 @@ func init() { commitHash = "unknown" } -func Connect(url, jwt, host string, schemaDb bool) driver.Conn { - return &hranaV2Conn{url, jwt, host, schemaDb, "", false, 0} +func Connect(url, jwt, host string, schemaDb bool, encryptionKey string) driver.Conn { + return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, "", false, 0} } type hranaV2Stmt struct { @@ -82,13 +83,14 @@ func (s *hranaV2Stmt) QueryContext(ctx context.Context, args []driver.NamedValue } type hranaV2Conn struct { - url string - jwt string - host string - schemaDb bool - baton string - streamClosed bool - replicationIndex uint64 + url string + jwt string + host string + schemaDb bool + remoteEncryptionKey string + baton string + streamClosed bool + replicationIndex uint64 } func (h *hranaV2Conn) Ping() error { @@ -121,11 +123,11 @@ func (h *hranaV2Conn) PrepareContext(ctx context.Context, query string) (driver. func (h *hranaV2Conn) Close() error { if h.baton != "" { - go func(baton, url, jwt, host string) { + go func(baton, url, jwt, host, encryptionKey string) { msg := hrana.PipelineRequest{Baton: baton} msg.Add(hrana.CloseStream()) - _, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host) - }(h.baton, h.url, h.jwt, h.host) + _, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey) + }(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey) } return nil } @@ -173,7 +175,7 @@ func (h *hranaV2Conn) sendPipelineRequest(ctx context.Context, msg *hrana.Pipeli if h.replicationIndex > 0 { addReplicationIndex(msg, h.replicationIndex) } - result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host) + result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey) if streamClosed { h.streamClosed = true } @@ -230,7 +232,7 @@ func getReplicationIndex(response *hrana.PipelineResponse) uint64 { return replicationIndex } -func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string) (result hrana.PipelineResponse, streamClosed bool, err error) { +func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string) (result hrana.PipelineResponse, streamClosed bool, err error) { reqBody, err := json.Marshal(msg) if err != nil { return hrana.PipelineResponse{}, false, err @@ -247,6 +249,10 @@ func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url st req.Header.Set("Authorization", "Bearer "+jwt) } req.Header.Set("x-libsql-client-version", "libsql-remote-go-"+commitHash) + if remoteEncryptionKey != "" { + req.Header.Set("x-turso-encryption-key", remoteEncryptionKey) + } + req.Host = host resp, err := http.DefaultClient.Do(req) if err != nil { @@ -591,11 +597,11 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri func (h *hranaV2Conn) closeStream() { if h.baton != "" { - go func(baton, url, jwt, host string) { + go func(baton, url, jwt, host, encryptionKey string) { msg := hrana.PipelineRequest{Baton: baton} msg.Add(hrana.CloseStream()) - _, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host) - }(h.baton, h.url, h.jwt, h.host) + _, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey) + }(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey) h.baton = "" } } diff --git a/libsql/sql.go b/libsql/sql.go index 4a18e8a..ec12f6a 100644 --- a/libsql/sql.go +++ b/libsql/sql.go @@ -14,10 +14,11 @@ import ( ) type config struct { - authToken *string - tls *bool - proxy *string - schemaDb *bool + authToken *string + tls *bool + proxy *string + schemaDb *bool + remoteEncryptionKey *string } type Option interface { @@ -76,6 +77,19 @@ func WithSchemaDb(schemaDb bool) Option { }) } +func WithRemoteEncryptionKey(key string) Option { + return option(func(o *config) error { + if o.remoteEncryptionKey != nil { + return fmt.Errorf("remoteEncryptionKey already set") + } + if key == "" { + return fmt.Errorf("remoteEncryptionKey must not be empty") + } + o.remoteEncryptionKey = &key + return nil + }) +} + func (c config) connector(dbPath string) (driver.Connector, error) { u, err := url.Parse(dbPath) if err != nil { @@ -139,6 +153,10 @@ func (c config) connector(dbPath string) (driver.Connector, error) { if c.authToken != nil { authToken = *c.authToken } + encryptionKey := "" + if c.remoteEncryptionKey != nil { + encryptionKey = *c.remoteEncryptionKey + } host := u.Host if c.proxy != nil { @@ -164,7 +182,7 @@ func (c config) connector(dbPath string) (driver.Connector, error) { return wsConnector{url: u.String(), authToken: authToken}, nil } if u.Scheme == "https" || u.Scheme == "http" { - return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb}, nil + return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb, remoteEncryptionKey: encryptionKey}, nil } return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme) @@ -185,14 +203,15 @@ func NewConnector(dbPath string, opts ...Option) (driver.Connector, error) { } type httpConnector struct { - url string - authToken string - host string - schemaDb bool + url string + authToken string + host string + schemaDb bool + remoteEncryptionKey string } func (c httpConnector) Connect(_ctx context.Context) (driver.Conn, error) { - return http.Connect(c.url, c.authToken, c.host, c.schemaDb), nil + return http.Connect(c.url, c.authToken, c.host, c.schemaDb, c.remoteEncryptionKey), nil } func (c httpConnector) Driver() driver.Driver { @@ -341,7 +360,7 @@ func (d Driver) Open(dbUrl string) (driver.Conn, error) { return ws.Connect(u.String(), jwt) } if u.Scheme == "https" || u.Scheme == "http" { - return http.Connect(u.String(), jwt, u.Host, false), nil + return http.Connect(u.String(), jwt, u.Host, false, ""), nil } return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)