Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ go 1.20

require (
github.com/antlr4-go/antlr/v4 v4.13.0
golang.org/x/sync v0.3.0
github.com/coder/websocket v1.8.12
golang.org/x/sync v0.3.0
)

require golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@ golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88p
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q=
nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
4 changes: 2 additions & 2 deletions libsql/internal/http/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ import (
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/hranaV2"
)

func Connect(url, jwt, host string, schemaDb bool) driver.Conn {
return hranaV2.Connect(url, jwt, host, schemaDb)
func Connect(url, jwt, host string, schemaDb bool, remoteEncryptionKey string) driver.Conn {
return hranaV2.Connect(url, jwt, host, schemaDb, remoteEncryptionKey)
}
42 changes: 24 additions & 18 deletions libsql/internal/http/hranaV2/hranaV2.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/tursodatabase/libsql-client-go/sqliteparserutils"
"io"
"net/http"
net_url "net/url"
"runtime/debug"
"strings"

"github.com/tursodatabase/libsql-client-go/sqliteparserutils"

"github.com/tursodatabase/libsql-client-go/libsql/internal/hrana"
"github.com/tursodatabase/libsql-client-go/libsql/internal/http/shared"
)
Expand All @@ -36,8 +37,8 @@ func init() {
commitHash = "unknown"
}

func Connect(url, jwt, host string, schemaDb bool) driver.Conn {
return &hranaV2Conn{url, jwt, host, schemaDb, "", false, 0}
func Connect(url, jwt, host string, schemaDb bool, encryptionKey string) driver.Conn {
return &hranaV2Conn{url, jwt, host, schemaDb, encryptionKey, "", false, 0}
}

type hranaV2Stmt struct {
Expand Down Expand Up @@ -82,13 +83,14 @@ func (s *hranaV2Stmt) QueryContext(ctx context.Context, args []driver.NamedValue
}

type hranaV2Conn struct {
url string
jwt string
host string
schemaDb bool
baton string
streamClosed bool
replicationIndex uint64
url string
jwt string
host string
schemaDb bool
remoteEncryptionKey string
baton string
streamClosed bool
replicationIndex uint64
}

func (h *hranaV2Conn) Ping() error {
Expand Down Expand Up @@ -121,11 +123,11 @@ func (h *hranaV2Conn) PrepareContext(ctx context.Context, query string) (driver.

func (h *hranaV2Conn) Close() error {
if h.baton != "" {
go func(baton, url, jwt, host string) {
go func(baton, url, jwt, host, encryptionKey string) {
msg := hrana.PipelineRequest{Baton: baton}
msg.Add(hrana.CloseStream())
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host)
}(h.baton, h.url, h.jwt, h.host)
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
}
return nil
}
Expand Down Expand Up @@ -173,7 +175,7 @@ func (h *hranaV2Conn) sendPipelineRequest(ctx context.Context, msg *hrana.Pipeli
if h.replicationIndex > 0 {
addReplicationIndex(msg, h.replicationIndex)
}
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host)
result, streamClosed, err := sendPipelineRequest(ctx, msg, h.url, h.jwt, h.host, h.remoteEncryptionKey)
if streamClosed {
h.streamClosed = true
}
Expand Down Expand Up @@ -230,7 +232,7 @@ func getReplicationIndex(response *hrana.PipelineResponse) uint64 {
return replicationIndex
}

func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string) (result hrana.PipelineResponse, streamClosed bool, err error) {
func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url string, jwt string, host string, remoteEncryptionKey string) (result hrana.PipelineResponse, streamClosed bool, err error) {
reqBody, err := json.Marshal(msg)
if err != nil {
return hrana.PipelineResponse{}, false, err
Expand All @@ -247,6 +249,10 @@ func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url st
req.Header.Set("Authorization", "Bearer "+jwt)
}
req.Header.Set("x-libsql-client-version", "libsql-remote-go-"+commitHash)
if remoteEncryptionKey != "" {
req.Header.Set("x-turso-encryption-key", remoteEncryptionKey)
}

req.Host = host
resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -591,11 +597,11 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri

func (h *hranaV2Conn) closeStream() {
if h.baton != "" {
go func(baton, url, jwt, host string) {
go func(baton, url, jwt, host, encryptionKey string) {
msg := hrana.PipelineRequest{Baton: baton}
msg.Add(hrana.CloseStream())
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host)
}(h.baton, h.url, h.jwt, h.host)
_, _, _ = sendPipelineRequest(context.Background(), &msg, url, jwt, host, encryptionKey)
}(h.baton, h.url, h.jwt, h.host, h.remoteEncryptionKey)
h.baton = ""
}
}
Expand Down
41 changes: 30 additions & 11 deletions libsql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
)

type config struct {
authToken *string
tls *bool
proxy *string
schemaDb *bool
authToken *string
tls *bool
proxy *string
schemaDb *bool
remoteEncryptionKey *string
}

type Option interface {
Expand Down Expand Up @@ -76,6 +77,19 @@ func WithSchemaDb(schemaDb bool) Option {
})
}

func WithRemoteEncryptionKey(key string) Option {
return option(func(o *config) error {
if o.remoteEncryptionKey != nil {
return fmt.Errorf("remoteEncryptionKey already set")
}
if key == "" {
return fmt.Errorf("remoteEncryptionKey must not be empty")
}
o.remoteEncryptionKey = &key
return nil
})
}

func (c config) connector(dbPath string) (driver.Connector, error) {
u, err := url.Parse(dbPath)
if err != nil {
Expand Down Expand Up @@ -139,6 +153,10 @@ func (c config) connector(dbPath string) (driver.Connector, error) {
if c.authToken != nil {
authToken = *c.authToken
}
encryptionKey := ""
if c.remoteEncryptionKey != nil {
encryptionKey = *c.remoteEncryptionKey
}

host := u.Host
if c.proxy != nil {
Expand All @@ -164,7 +182,7 @@ func (c config) connector(dbPath string) (driver.Connector, error) {
return wsConnector{url: u.String(), authToken: authToken}, nil
}
if u.Scheme == "https" || u.Scheme == "http" {
return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb}, nil
return httpConnector{url: u.String(), authToken: authToken, host: host, schemaDb: schemaDb, remoteEncryptionKey: encryptionKey}, nil
}

return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)
Expand All @@ -185,14 +203,15 @@ func NewConnector(dbPath string, opts ...Option) (driver.Connector, error) {
}

type httpConnector struct {
url string
authToken string
host string
schemaDb bool
url string
authToken string
host string
schemaDb bool
remoteEncryptionKey string
}

func (c httpConnector) Connect(_ctx context.Context) (driver.Conn, error) {
return http.Connect(c.url, c.authToken, c.host, c.schemaDb), nil
return http.Connect(c.url, c.authToken, c.host, c.schemaDb, c.remoteEncryptionKey), nil
}

func (c httpConnector) Driver() driver.Driver {
Expand Down Expand Up @@ -341,7 +360,7 @@ func (d Driver) Open(dbUrl string) (driver.Conn, error) {
return ws.Connect(u.String(), jwt)
}
if u.Scheme == "https" || u.Scheme == "http" {
return http.Connect(u.String(), jwt, u.Host, false), nil
return http.Connect(u.String(), jwt, u.Host, false, ""), nil
}

return nil, fmt.Errorf("unsupported URL scheme: %s\nThis driver supports only URLs that start with libsql://, file://, https://, http://, wss:// and ws://", u.Scheme)
Expand Down
Loading