Skip to content

Commit fe7673b

Browse files
committed
fix(sqlite): correct StmtLocation/StmtLen for non-ASCII characters in comments
ANTLR4-go stores its input stream as []rune, so all token positions returned by GetStart().GetStart() and GetStop().GetStop() are rune indices, not byte offsets. The SQLite parser was storing these values directly as StmtLocation and StmtLen, which are later consumed by source.Pluck() using byte-based Go string slicing (source[head:tail]). For source files that contain multi-byte UTF-8 characters (non-ASCII) in comments, the rune index diverges from the byte offset, causing the plucked query text to be truncated. Each 2-byte character (e.g. Ü, é) caused one byte to be dropped from the end of the query; each 3-byte character (e.g. ♥) caused two bytes to be dropped; and so on. Fix this by building a rune-index to byte-offset map from the source string before processing the ANTLR parse tree, then converting the ANTLR rune positions to byte offsets before storing them in the AST. The internal loc tracking variable continues to use rune indices (for consistency with the ANTLR token positions), while only the values written into StmtLocation and StmtLen are converted to byte offsets. Add TestParseNonASCIIComment covering 2-, 3-, and 4-byte characters in dash comments, multiple non-ASCII characters, and the multi-statement case where an incorrect loc for one statement would propagate and corrupt the StmtLocation of the following statement.
1 parent ce83d3f commit fe7673b

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

internal/engine/sqlite/parse.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,29 @@ func NewParser() *Parser {
3737
type Parser struct {
3838
}
3939

40+
// runeToByteOffsets returns a slice mapping rune index i to the byte offset of
41+
// the i-th rune in s. A sentinel element equal to len(s) is appended so that
42+
// element numRunes is valid and equals the total byte length of s.
43+
//
44+
// This is needed because ANTLR4 stores the input as []rune and all token
45+
// positions (GetStart/GetStop) are rune indices, while Go string slicing is
46+
// byte-based.
47+
func runeToByteOffsets(s string) []int {
48+
offsets := make([]int, 0, len(s)+1)
49+
for i := range s {
50+
offsets = append(offsets, i)
51+
}
52+
offsets = append(offsets, len(s))
53+
return offsets
54+
}
55+
4056
func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
4157
blob, err := io.ReadAll(r)
4258
if err != nil {
4359
return nil, err
4460
}
45-
input := antlr.NewInputStream(string(blob))
61+
src := string(blob)
62+
input := antlr.NewInputStream(src)
4663
lexer := parser.NewSQLiteLexer(input)
4764
stream := antlr.NewCommonTokenStream(lexer, 0)
4865
pp := parser.NewSQLiteParser(stream)
@@ -57,13 +74,17 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
5774
if !ok {
5875
return nil, fmt.Errorf("expected ParserContext; got %T\n", tree)
5976
}
77+
// ANTLR uses rune-based positions. Build a mapping from rune index to byte
78+
// offset so we can store byte-based StmtLocation/StmtLen, which is what
79+
// source.Pluck and the rest of the pipeline expect.
80+
runeByteMap := runeToByteOffsets(src)
6081
var stmts []ast.Statement
6182
for _, istmt := range pctx.AllSql_stmt_list() {
6283
list, ok := istmt.(*parser.Sql_stmt_listContext)
6384
if !ok {
6485
return nil, fmt.Errorf("expected Sql_stmt_listContext; got %T\n", istmt)
6586
}
66-
loc := 0
87+
loc := 0 // rune offset of the current statement's start
6788

6889
for _, stmt := range list.AllSql_stmt() {
6990
converter := &cc{}
@@ -72,12 +93,14 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
7293
loc = stmt.GetStop().GetStop() + 2
7394
continue
7495
}
75-
len := (stmt.GetStop().GetStop() + 1) - loc
96+
runeEnd := stmt.GetStop().GetStop() + 1
97+
byteStart := runeByteMap[loc]
98+
byteEnd := runeByteMap[runeEnd]
7699
stmts = append(stmts, ast.Statement{
77100
Raw: &ast.RawStmt{
78101
Stmt: out,
79-
StmtLocation: loc,
80-
StmtLen: len,
102+
StmtLocation: byteStart,
103+
StmtLen: byteEnd - byteStart,
81104
},
82105
})
83106
loc = stmt.GetStop().GetStop() + 2
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package sqlite
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/sqlc-dev/sqlc/internal/source"
8+
)
9+
10+
// TestParseNonASCIIComment verifies that non-ASCII characters in SQL comments
11+
// do not corrupt the plucked query text.
12+
//
13+
// ANTLR4 stores the input as []rune so all token positions are rune indices,
14+
// not byte offsets. source.Pluck (and the rest of the pipeline) treats
15+
// StmtLocation/StmtLen as byte offsets. For multi-byte UTF-8 characters the
16+
// two differ, which previously caused the plucked query to be truncated by one
17+
// byte per extra byte in each non-ASCII character.
18+
func TestParseNonASCIIComment(t *testing.T) {
19+
p := NewParser()
20+
21+
tests := []struct {
22+
name string
23+
sql string
24+
}{
25+
{
26+
name: "2-byte char (U+00DC Ü) in dash comment",
27+
sql: "-- name: GetUser :one\n-- Ünïcode comment\nSELECT id FROM users WHERE id = ?",
28+
},
29+
{
30+
name: "3-byte char (U+2665 ♥) in dash comment",
31+
sql: "-- name: GetUser :one\n-- ♥ love\nSELECT id FROM users WHERE id = ?",
32+
},
33+
{
34+
name: "4-byte char (U+1D11E 𝄞) in dash comment",
35+
sql: "-- name: GetUser :one\n-- 𝄞 music\nSELECT id FROM users WHERE id = ?",
36+
},
37+
{
38+
name: "multiple non-ASCII chars in comment",
39+
sql: "-- name: GetUser :one\n-- héllo wörld\nSELECT id FROM users WHERE id = ?",
40+
},
41+
{
42+
name: "non-ASCII only in first of two statements",
43+
sql: "-- name: Q1 :one\n-- Ü\nSELECT 1;\n\n-- name: Q2 :one\nSELECT 2",
44+
},
45+
}
46+
47+
for _, tc := range tests {
48+
t.Run(tc.name, func(t *testing.T) {
49+
stmts, err := p.Parse(strings.NewReader(tc.sql))
50+
if err != nil {
51+
t.Fatalf("Parse error: %v", err)
52+
}
53+
if len(stmts) == 0 {
54+
t.Fatal("expected at least one statement")
55+
}
56+
57+
// For every parsed statement, verify that the plucked text is a
58+
// valid substring of the original SQL (not truncated mid-character).
59+
for i, stmt := range stmts {
60+
raw := stmt.Raw
61+
plucked, err := source.Pluck(tc.sql, raw.StmtLocation, raw.StmtLen)
62+
if err != nil {
63+
t.Fatalf("stmt %d: Pluck error: %v", i, err)
64+
}
65+
if !strings.Contains(tc.sql, plucked) {
66+
t.Errorf("stmt %d: plucked text is not a substring of the input\ngot: %q\ninput: %q", i, plucked, tc.sql)
67+
}
68+
if plucked == "" {
69+
t.Errorf("stmt %d: plucked text is empty", i)
70+
}
71+
}
72+
73+
// For the single-statement cases the plucked text must equal the
74+
// full input, since there is exactly one statement and no trailing
75+
// semicolon to exclude.
76+
if len(stmts) == 1 {
77+
raw := stmts[0].Raw
78+
plucked, _ := source.Pluck(tc.sql, raw.StmtLocation, raw.StmtLen)
79+
if plucked != tc.sql {
80+
t.Errorf("single-statement pluck mismatch\ngot: %q\nwant: %q", plucked, tc.sql)
81+
}
82+
}
83+
84+
// For the two-statement case, verify each statement contains its
85+
// expected SELECT.
86+
if len(stmts) == 2 {
87+
for i, want := range []string{"SELECT 1", "SELECT 2"} {
88+
raw := stmts[i].Raw
89+
plucked, _ := source.Pluck(tc.sql, raw.StmtLocation, raw.StmtLen)
90+
if !strings.Contains(plucked, want) {
91+
t.Errorf("stmt %d: plucked text %q does not contain %q", i, plucked, want)
92+
}
93+
}
94+
}
95+
})
96+
}
97+
}

0 commit comments

Comments
 (0)