Skip to content

Commit 83744cb

Browse files
committed
More sqlite-createtable-parser.
1 parent a36d72c commit 83744cb

13 files changed

Lines changed: 300 additions & 103 deletions

File tree

internal/util/page.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package util
2+
3+
func ValidPageSize(s int) bool {
4+
return s&(s-1) == 0 && 512 <= s && s <= 65536
5+
}
Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,11 @@
1-
package sql3util
1+
package util
22

33
import (
44
"strconv"
55
"strings"
66
"time"
77
)
88

9-
// NamedArg splits an named arg into a key and value,
10-
// around an equals sign.
11-
// Spaces are trimmed around both key and value.
12-
func NamedArg(arg string) (key, val string) {
13-
key, val, _ = strings.Cut(arg, "=")
14-
key = strings.TrimSpace(key)
15-
val = strings.TrimSpace(val)
16-
return
17-
}
18-
19-
// Unquote unquotes a string.
20-
//
21-
// https://sqlite.org/lang_keywords.html
22-
func Unquote(val string) string {
23-
if len(val) < 2 {
24-
return val
25-
}
26-
fst := val[0]
27-
lst := val[len(val)-1]
28-
rst := val[1 : len(val)-1]
29-
if fst == '[' && lst == ']' {
30-
return rst
31-
}
32-
if fst != lst {
33-
return val
34-
}
35-
var old, new string
36-
switch fst {
37-
default:
38-
return val
39-
case '`':
40-
old, new = "``", "`"
41-
case '"':
42-
old, new = `""`, `"`
43-
case '\'':
44-
old, new = `''`, `'`
45-
}
46-
return strings.ReplaceAll(rst, old, new)
47-
}
48-
49-
// ParseBool parses a boolean.
50-
//
51-
// https://sqlite.org/pragma.html#syntax
529
func ParseBool(s string) (b, ok bool) {
5310
if len(s) == 0 {
5411
return false, false
@@ -68,7 +25,6 @@ func ParseBool(s string) (b, ok bool) {
6825
return false, false
6926
}
7027

71-
// ParseFloat parses a decimal floating point number.
7228
func ParseFloat(s string) (f float64, ok bool) {
7329
if strings.TrimLeft(s, "+-.0123456789Ee") != "" {
7430
return
@@ -77,10 +33,6 @@ func ParseFloat(s string) (f float64, ok bool) {
7733
return f, err == nil
7834
}
7935

80-
// ParseTimeShift parses a time shift modifier,
81-
// also the output of timediff.
82-
//
83-
// https://sqlite.org/lang_datefunc.html
8436
func ParseTimeShift(s string) (years, months, days int, duration time.Duration, ok bool) {
8537
// Sign part: ±
8638
neg := strings.HasPrefix(s, "-")
Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,14 @@
1-
package sql3util_test
1+
package util_test
22

33
import (
44
"testing"
55
"time"
66

77
"github.com/ncruces/go-sqlite3"
88
_ "github.com/ncruces/go-sqlite3/embed"
9-
"github.com/ncruces/go-sqlite3/util/sql3util"
9+
"github.com/ncruces/go-sqlite3/internal/util"
1010
)
1111

12-
func TestUnquote(t *testing.T) {
13-
tests := []struct {
14-
val string
15-
want string
16-
}{
17-
{"a", "a"},
18-
{"abc", "abc"},
19-
{"abba", "abba"},
20-
{"`ab``c`", "ab`c"},
21-
{"'ab''c'", "ab'c"},
22-
{"'ab``c'", "ab``c"},
23-
{"[ab``c]", "ab``c"},
24-
{`"ab""c"`, `ab"c`},
25-
}
26-
for _, tt := range tests {
27-
t.Run(tt.val, func(t *testing.T) {
28-
if got := sql3util.Unquote(tt.val); got != tt.want {
29-
t.Errorf("Unquote(%s) = %s, want %s", tt.val, got, tt.want)
30-
}
31-
})
32-
}
33-
}
34-
3512
func TestParseBool(t *testing.T) {
3613
tests := []struct {
3714
str string
@@ -49,7 +26,7 @@ func TestParseBool(t *testing.T) {
4926
}
5027
for _, tt := range tests {
5128
t.Run(tt.str, func(t *testing.T) {
52-
gotVal, gotOK := sql3util.ParseBool(tt.str)
29+
gotVal, gotOK := util.ParseBool(tt.str)
5330
if gotVal != tt.val || gotOK != tt.ok {
5431
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
5532
}
@@ -88,7 +65,7 @@ func TestParseTimeShift(t *testing.T) {
8865
}
8966
for _, tt := range tests {
9067
t.Run(tt.str, func(t *testing.T) {
91-
years, months, days, duration, gotOK := sql3util.ParseTimeShift(tt.str)
68+
years, months, days, duration, gotOK := util.ParseTimeShift(tt.str)
9269
gotVal := epoch.AddDate(years, months, days).Add(duration)
9370
if !gotVal.Equal(tt.val) || gotOK != tt.ok {
9471
t.Errorf("ParseTimeShift(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
@@ -136,7 +113,7 @@ func FuzzParseTimeShift(f *testing.F) {
136113
epoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
137114

138115
f.Fuzz(func(t *testing.T, str string) {
139-
years, months, days, duration, ok := sql3util.ParseTimeShift(str)
116+
years, months, days, duration, ok := util.ParseTimeShift(str)
140117

141118
// Account for a full 400 year cycle.
142119
if years < -200 || years > +200 {

time.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"time"
88

99
"github.com/ncruces/go-sqlite3/internal/util"
10-
"github.com/ncruces/go-sqlite3/util/sql3util"
1110
"github.com/ncruces/julianday"
1211
)
1312

@@ -160,7 +159,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
160159
if s, ok := v.(string); ok {
161160
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
162161
v = i
163-
} else if f, ok := sql3util.ParseFloat(s); ok {
162+
} else if f, ok := util.ParseFloat(s); ok {
164163
v = f
165164
} else {
166165
return time.Time{}, util.TimeErr
@@ -237,7 +236,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
237236
v = i
238237
break
239238
}
240-
f, ok := sql3util.ParseFloat(s)
239+
f, ok := util.ParseFloat(s)
241240
if ok {
242241
v = f
243242
break

util/sql3util/const.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ const (
4949
DEFTYPE_NOTDEFERRABLE_INITIALLY_IMMEDIATE
5050
)
5151

52+
type ConstraintType uint32
53+
54+
const (
55+
TABLECONSTRAINT_PRIMARYKEY ConstraintType = iota
56+
TABLECONSTRAINT_UNIQUE
57+
TABLECONSTRAINT_CHECK
58+
TABLECONSTRAINT_FOREIGNKEY
59+
)
60+
5261
type StatementType uint32
5362

5463
const (

util/sql3util/parse.go

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/tetratelabs/wazero"
99
"github.com/tetratelabs/wazero/api"
1010

11+
"github.com/ncruces/go-sqlite3"
1112
"github.com/ncruces/go-sqlite3/internal/util"
1213
)
1314

@@ -29,6 +30,10 @@ var (
2930
// [CREATE]: https://sqlite.org/lang_createtable.html
3031
// [ALTER TABLE]: https://sqlite.org/lang_altertable.html
3132
func ParseTable(sql string) (_ *Table, err error) {
33+
if len(sql) > 8192 {
34+
return nil, sqlite3.TOOBIG
35+
}
36+
3237
once.Do(func() {
3338
ctx := context.Background()
3439
cfg := wazero.NewRuntimeConfigInterpreter()
@@ -81,6 +86,7 @@ type Table struct {
8186
IsWithoutRowID bool
8287
IsStrict bool
8388
Columns []Column
89+
Constraints []TableConstraint
8490
Type StatementType
8591
CurrentName string
8692
NewName string
@@ -96,9 +102,16 @@ func (t *Table) load(mod api.Module, ptr uint32, sql string) {
96102
t.IsWithoutRowID = loadBool(mod, ptr+26)
97103
t.IsStrict = loadBool(mod, ptr+27)
98104

99-
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) {
105+
t.Columns = loadSlice(mod, ptr+28, func(ptr uint32, ret *Column) uint32 {
106+
p, _ := mod.Memory().ReadUint32Le(ptr)
107+
ret.load(mod, p, sql)
108+
return 4
109+
})
110+
111+
t.Constraints = loadSlice(mod, ptr+36, func(ptr uint32, ret *TableConstraint) uint32 {
100112
p, _ := mod.Memory().ReadUint32Le(ptr)
101113
ret.load(mod, p, sql)
114+
return 4
102115
})
103116

104117
t.Type = loadEnum[StatementType](mod, ptr+44)
@@ -159,6 +172,47 @@ func (c *Column) load(mod api.Module, ptr uint32, sql string) {
159172
c.GeneratedType = loadEnum[GenType](mod, ptr+96)
160173
}
161174

175+
// TableConstraint holds metadata about a table key constraint.
176+
type TableConstraint struct {
177+
Type ConstraintType
178+
Name string
179+
// Type is TABLECONSTRAINT_PRIMARYKEY or TABLECONSTRAINT_UNIQUE
180+
IndexedColumns []IdxColumn
181+
ConflictClause ConflictClause
182+
IsAutoIncrement bool
183+
// Type is TABLECONSTRAINT_CHECK
184+
Expr string
185+
// Type is TABLECONSTRAINT_FOREIGNKEY
186+
ForeignKeyNames []string
187+
ForeignKeyClause *ForeignKey
188+
}
189+
190+
func (c *TableConstraint) load(mod api.Module, ptr uint32, sql string) {
191+
c.Type = loadEnum[ConstraintType](mod, ptr+0)
192+
c.Name = loadString(mod, ptr+4, sql)
193+
switch c.Type {
194+
case TABLECONSTRAINT_PRIMARYKEY, TABLECONSTRAINT_UNIQUE:
195+
c.IndexedColumns = loadSlice(mod, ptr+12, func(ptr uint32, ret *IdxColumn) uint32 {
196+
ret.load(mod, ptr, sql)
197+
return 20
198+
})
199+
c.ConflictClause = loadEnum[ConflictClause](mod, ptr+20)
200+
c.IsAutoIncrement = loadBool(mod, ptr+24)
201+
case TABLECONSTRAINT_CHECK:
202+
c.Expr = loadString(mod, ptr+12, sql)
203+
case TABLECONSTRAINT_FOREIGNKEY:
204+
c.ForeignKeyNames = loadSlice(mod, ptr+12, func(ptr uint32, ret *string) uint32 {
205+
*ret = loadString(mod, ptr, sql)
206+
return 8
207+
})
208+
if ptr, _ := mod.Memory().ReadUint32Le(ptr + 20); ptr != 0 {
209+
c.ForeignKeyClause = &ForeignKey{}
210+
c.ForeignKeyClause.load(mod, ptr, sql)
211+
}
212+
}
213+
}
214+
215+
// ForeignKey holds metadata about a foreign key constraint.
162216
type ForeignKey struct {
163217
Table string
164218
Columns []string
@@ -171,8 +225,9 @@ type ForeignKey struct {
171225
func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) {
172226
f.Table = loadString(mod, ptr+0, sql)
173227

174-
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) {
228+
f.Columns = loadSlice(mod, ptr+8, func(ptr uint32, ret *string) uint32 {
175229
*ret = loadString(mod, ptr, sql)
230+
return 8
176231
})
177232

178233
f.OnDelete = loadEnum[FKAction](mod, ptr+16)
@@ -181,6 +236,19 @@ func (f *ForeignKey) load(mod api.Module, ptr uint32, sql string) {
181236
f.Deferrable = loadEnum[FKDefType](mod, ptr+32)
182237
}
183238

239+
// IdxColumn holds metadata about an indexed column.
240+
type IdxColumn struct {
241+
Name string
242+
CollateName string
243+
Order OrderClause
244+
}
245+
246+
func (c *IdxColumn) load(mod api.Module, ptr uint32, sql string) {
247+
c.Name = loadString(mod, ptr+0, sql)
248+
c.CollateName = loadString(mod, ptr+8, sql)
249+
c.Order = loadEnum[OrderClause](mod, ptr+16)
250+
}
251+
184252
func loadString(mod api.Module, ptr uint32, sql string) string {
185253
off, _ := mod.Memory().ReadUint32Le(ptr + 0)
186254
if off == 0 {
@@ -190,16 +258,15 @@ func loadString(mod api.Module, ptr uint32, sql string) string {
190258
return sql[off-sqlp : off+len-sqlp]
191259
}
192260

193-
func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T)) []T {
261+
func loadSlice[T any](mod api.Module, ptr uint32, fn func(uint32, *T) uint32) []T {
194262
ref, _ := mod.Memory().ReadUint32Le(ptr + 4)
195263
if ref == 0 {
196264
return nil
197265
}
198266
len, _ := mod.Memory().ReadUint32Le(ptr + 0)
199267
ret := make([]T, len)
200268
for i := range ret {
201-
fn(ref, &ret[i])
202-
ref += 4
269+
ref += fn(ref, &ret[i])
203270
}
204271
return ret
205272
}

0 commit comments

Comments
 (0)