Skip to content

Commit 1c6bbb2

Browse files
committed
add QueryBinary, an alloc-free way to read all rows into a buffer
WIP; goal is alloc-free reads of a query into a Go-provided buffer. Go code can then parse the simple binary format and alloc if needed (doing its own cache lookups, including alloc-free m[string([]byte)] lookups, and returning existing Views if data is unmodified) Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
1 parent 2d70ae2 commit 1c6bbb2

File tree

8 files changed

+437
-3
lines changed

8 files changed

+437
-3
lines changed

binary.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Copyright (c) 2023 Tailscale Inc & AUTHORS All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package sqlite
6+
7+
import (
8+
"context"
9+
"encoding/binary"
10+
"errors"
11+
"fmt"
12+
"math"
13+
"reflect"
14+
"sync"
15+
16+
"github.com/tailscale/sqlite/sqliteh"
17+
"golang.org/x/sys/cpu"
18+
)
19+
20+
type driverConnRawCall struct {
21+
f func(driverConn any) error
22+
23+
// results
24+
dc *conn
25+
ok bool
26+
}
27+
28+
var driverConnRawCallPool = &sync.Pool{
29+
New: func() any {
30+
c := new(driverConnRawCall)
31+
c.f = func(driverConn any) error {
32+
c.dc, c.ok = driverConn.(*conn)
33+
return nil
34+
}
35+
return c
36+
},
37+
}
38+
39+
func getDriverConn(sc SQLConn) (dc *conn, ok bool) {
40+
c := driverConnRawCallPool.Get().(*driverConnRawCall)
41+
defer driverConnRawCallPool.Put(c)
42+
err := sc.Raw(c.f)
43+
if err != nil {
44+
return nil, false
45+
}
46+
return c.dc, c.ok
47+
}
48+
49+
func QueryBinary(ctx context.Context, sqlconn SQLConn, optScratch []byte, query string, args ...any) (BinaryResults, error) {
50+
c, ok := getDriverConn(sqlconn)
51+
if !ok {
52+
return nil, errors.New("sqlconn is not of expected type")
53+
}
54+
st, err := c.prepare(ctx, query, IsPersist(ctx))
55+
if err != nil {
56+
return nil, err
57+
}
58+
buf := optScratch
59+
if len(buf) == 0 {
60+
buf = make([]byte, 128)
61+
}
62+
for {
63+
st.stmt.ResetAndClear()
64+
65+
// Bind args.
66+
for colIdx, a := range args {
67+
rv := reflect.ValueOf(a)
68+
switch rv.Kind() {
69+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
70+
if err := st.stmt.BindInt64(colIdx+1, rv.Int()); err != nil {
71+
return nil, fmt.Errorf("binding col idx %d to %T (%v): %w", colIdx, a, rv.Int(), err)
72+
}
73+
default:
74+
// TODO(bradfitz): more types, at least strings for stable IDs.
75+
return nil, fmt.Errorf("unsupported arg type %T", a)
76+
}
77+
}
78+
79+
n, err := st.stmt.StepAllBinary(buf)
80+
if err == nil {
81+
return BinaryResults(buf[:n]), nil
82+
}
83+
if e, ok := err.(sqliteh.BufferSizeTooSmallError); ok {
84+
buf = make([]byte, e.EncodedSize)
85+
continue
86+
}
87+
return nil, err
88+
}
89+
}
90+
91+
// BinaryResults is the result of QueryBinary.
92+
//
93+
// You should not depend on its specific format and parse it via its methods
94+
// instead.
95+
type BinaryResults []byte
96+
97+
type BinaryToken struct {
98+
StartRow bool
99+
EndRow bool
100+
EndRows bool
101+
IsInt bool // if so, use Int() method
102+
IsFloat bool // if so, use Float() method
103+
IsNull bool
104+
IsBytes bool
105+
Error bool
106+
107+
x uint64
108+
Bytes []byte
109+
}
110+
111+
func (t *BinaryToken) String() string {
112+
switch {
113+
case t.StartRow:
114+
return "start-row"
115+
case t.EndRow:
116+
return "end-row"
117+
case t.EndRows:
118+
return "end-rows"
119+
case t.IsNull:
120+
return "null"
121+
case t.IsInt:
122+
return fmt.Sprintf("int: %v", t.Int())
123+
case t.IsFloat:
124+
return fmt.Sprintf("float: %g", t.Float())
125+
case t.IsBytes:
126+
return fmt.Sprintf("bytes: %q", t.Bytes)
127+
case t.Error:
128+
return "error"
129+
default:
130+
return "unknown"
131+
}
132+
}
133+
134+
func (t *BinaryToken) Int() int64 { return int64(t.x) }
135+
func (t *BinaryToken) Float() float64 { return math.Float64frombits(t.x) }
136+
137+
func (r *BinaryResults) Next() BinaryToken {
138+
if len(*r) == 0 {
139+
return BinaryToken{Error: true}
140+
}
141+
first := (*r)[0]
142+
*r = (*r)[1:]
143+
switch first {
144+
default:
145+
return BinaryToken{Error: true}
146+
case '(':
147+
return BinaryToken{StartRow: true}
148+
case ')':
149+
return BinaryToken{EndRow: true}
150+
case 'E':
151+
return BinaryToken{EndRows: true}
152+
case 'n':
153+
return BinaryToken{IsNull: true}
154+
case 'i', 'f':
155+
if len(*r) < 8 {
156+
return BinaryToken{Error: true}
157+
}
158+
t := BinaryToken{IsInt: first == 'i', IsFloat: first == 'f'}
159+
if cpu.IsBigEndian {
160+
t.x = binary.BigEndian.Uint64((*r)[:8])
161+
} else {
162+
t.x = binary.LittleEndian.Uint64((*r)[:8])
163+
}
164+
*r = (*r)[8:]
165+
return t
166+
case 'b':
167+
if len(*r) < 8 {
168+
return BinaryToken{Error: true}
169+
}
170+
t := BinaryToken{IsBytes: true}
171+
var n int64
172+
if cpu.IsBigEndian {
173+
n = int64(binary.BigEndian.Uint64((*r)[:8]))
174+
} else {
175+
n = int64(binary.LittleEndian.Uint64((*r)[:8]))
176+
}
177+
*r = (*r)[8:]
178+
if int64(len(*r)) < n {
179+
return BinaryToken{Error: true}
180+
}
181+
t.Bytes = (*r)[:n]
182+
*r = (*r)[n:]
183+
return t
184+
}
185+
}

binary_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) 2023 Tailscale Inc & AUTHORS All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package sqlite
6+
7+
import (
8+
"context"
9+
"math"
10+
"reflect"
11+
"testing"
12+
13+
"github.com/google/go-cmp/cmp"
14+
)
15+
16+
func TestQueryBinary(t *testing.T) {
17+
ctx := WithPersist(context.Background())
18+
db := openTestDB(t)
19+
exec(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, f REAL, txt TEXT, blb BLOB)")
20+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", math.MinInt64, 1.0, "text-a", "blob-a")
21+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", -1, -1.0, "text-b", "blob-b")
22+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 0, 0, "text-c", "blob-c")
23+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 20, 2, "text-d", "blob-d")
24+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", math.MaxInt64, nil, "text-e", "blob-e")
25+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 42, 0.25, "text-f", nil)
26+
exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 43, 1.75, "text-g", nil)
27+
28+
conn, err := db.Conn(ctx)
29+
if err != nil {
30+
t.Fatal(err)
31+
}
32+
33+
buf, err := QueryBinary(ctx, conn, make([]byte, 100), "SELECT * FROM t ORDER BY id")
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
t.Logf("Got %d bytes: %q", len(buf), buf)
38+
39+
var got []string
40+
iter := buf
41+
for len(iter) > 0 {
42+
t := iter.Next()
43+
got = append(got, t.String())
44+
if t.Error {
45+
break
46+
}
47+
}
48+
want := []string{
49+
"start-row", "int: -9223372036854775808", "float: 1", "bytes: \"text-a\"", "bytes: \"blob-a\"", "end-row",
50+
"start-row", "int: -1", "float: -1", "bytes: \"text-b\"", "bytes: \"blob-b\"", "end-row",
51+
"start-row", "int: 0", "float: 0", "bytes: \"text-c\"", "bytes: \"blob-c\"", "end-row",
52+
"start-row", "int: 20", "float: 2", "bytes: \"text-d\"", "bytes: \"blob-d\"", "end-row",
53+
"start-row", "int: 42", "float: 0.25", "bytes: \"text-f\"", "null", "end-row",
54+
"start-row", "int: 43", "float: 1.75", "bytes: \"text-g\"", "null", "end-row",
55+
"start-row", "int: 9223372036854775807", "null", "bytes: \"text-e\"", "bytes: \"blob-e\"", "end-row",
56+
"end-rows",
57+
}
58+
if !reflect.DeepEqual(got, want) {
59+
t.Errorf("wrong results\n got: %q\nwant: %q\n\ndiff:\n%s", got, want, cmp.Diff(want, got))
60+
}
61+
62+
allocs := int(testing.AllocsPerRun(10000, func() {
63+
_, err := QueryBinary(ctx, conn, buf, "SELECT * FROM t")
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
}))
68+
const maxAllocs = 5 // as of Go 1.20
69+
if allocs > maxAllocs {
70+
t.Errorf("allocs = %v; want max %v", allocs, maxAllocs)
71+
}
72+
}
73+
74+
func BenchmarkQueryBinaryParallel(b *testing.B) {
75+
ctx := WithPersist(context.Background())
76+
db := openTestDB(b)
77+
exec(b, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, f REAL, txt TEXT, blb BLOB)")
78+
exec(b, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 42, 0.25, "text-f", "some big big big big blob so big like so many bytes even")
79+
80+
b.ResetTimer()
81+
b.ReportAllocs()
82+
b.RunParallel(func(pb *testing.PB) {
83+
conn, err := db.Conn(ctx)
84+
if err != nil {
85+
b.Error(err)
86+
return
87+
}
88+
89+
var buf = make([]byte, 250)
90+
91+
for pb.Next() {
92+
res, err := QueryBinary(ctx, conn, buf, "SELECT id, f, txt, blb FROM t WHERE id=?", 42)
93+
if err != nil {
94+
b.Error(err)
95+
return
96+
}
97+
t := res.Next()
98+
if !t.StartRow {
99+
b.Errorf("didn't get start row; got %v", t)
100+
return
101+
}
102+
t = res.Next()
103+
if t.Int() != 42 {
104+
b.Errorf("got %v; want 42", t)
105+
return
106+
}
107+
}
108+
})
109+
110+
}

cgosqlite/cgosqlite.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ package cgosqlite
5454
// #include "cgosqlite.h"
5555
import "C"
5656
import (
57+
"errors"
5758
"sync"
5859
"time"
5960
"unsafe"
@@ -120,6 +121,7 @@ type Stmt struct {
120121
// used as scratch space when calling into cgo
121122
rowid, changes C.sqlite3_int64
122123
duration C.int64_t
124+
encodedSize C.int
123125
}
124126

125127
// Open implements sqliteh.OpenFunc.
@@ -420,6 +422,23 @@ func (stmt *Stmt) ColumnDeclType(col int) string {
420422
return res
421423
}
422424

425+
func (stmt *Stmt) StepAllBinary(dstBuf []byte) (n int, err error) {
426+
if len(dstBuf) == 0 {
427+
return 0, errors.New("zero-length buffer to StepAllBinary")
428+
}
429+
ret := C.ts_sqlite_step_all(stmt.stmt.int(), (*C.char)(unsafe.Pointer(&dstBuf[0])), C.int(len(dstBuf)), &stmt.encodedSize)
430+
431+
if int(stmt.encodedSize) > len(dstBuf) {
432+
return 0, sqliteh.BufferSizeTooSmallError{
433+
EncodedSize: int(stmt.encodedSize),
434+
}
435+
}
436+
if err := errCode(ret); err != nil {
437+
return 0, err
438+
}
439+
return int(stmt.encodedSize), nil
440+
}
441+
423442
var emptyCStr = C.CString("")
424443

425444
func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) }

0 commit comments

Comments
 (0)