diff --git a/db.go b/db.go index 658963759..66dd570dc 100644 --- a/db.go +++ b/db.go @@ -733,9 +733,28 @@ func (db *DB) getMemTables() ([]*memTable, func()) { } } -// get returns the value in memtable or disk for given key. -// Note that value will include meta byte. -// +func (db *DB) checkKeyInMemtables(tables []*memTable, key []byte, maxVs *y.ValueStruct, version uint64) bool { + y.NumGetsAdd(db.opt.MetricsEnabled, 1) + for i := 0; i < len(tables); i++ { + vs := tables[i].sl.Get(key) + y.NumMemtableGetsAdd(db.opt.MetricsEnabled, 1) + if vs.Meta == 0 && vs.Value == nil { + continue + } + // Found the required version of the key, mark as done, no need to process + // it further + if vs.Version == version { + y.NumGetsWithResultsAdd(db.opt.MetricsEnabled, 1) + *maxVs = vs + return true + } + if maxVs.Version < vs.Version { + *maxVs = vs + } + } + return false +} + // IMPORTANT: We should never write an entry with an older timestamp for the same key, We need to // maintain this invariant to search for the latest value of a key, or else we need to search in all // tables and find the max version among them. To maintain this invariant, we also need to ensure @@ -747,31 +766,52 @@ func (db *DB) getMemTables() ([]*memTable, func()) { // do that. For every get("fooX") call where X is the version, we will search // for "fooX" in all the levels of the LSM tree. This is expensive but it // removes the overhead of handling move keys completely. -func (db *DB) get(key []byte) (y.ValueStruct, error) { +// +// getBatch would return the values of list of keys in order +// Note that value will include meta byte. +func (db *DB) getBatch(keys [][]byte, keysRead []bool, version uint64) ([]y.ValueStruct, error) { if db.IsClosed() { - return y.ValueStruct{}, ErrDBClosed + return nil, ErrDBClosed } tables, decr := db.getMemTables() // Lock should be released. defer decr() - var maxVs y.ValueStruct - version := y.ParseTs(key) + maxVs := make([]y.ValueStruct, len(keys)) - y.NumGetsAdd(db.opt.MetricsEnabled, 1) - for i := 0; i < len(tables); i++ { - vs := tables[i].sl.Get(key) - y.NumMemtableGetsAdd(db.opt.MetricsEnabled, 1) - if vs.Meta == 0 && vs.Value == nil { + // For memtable, we need to check every memtable each time + for j, key := range keys { + if keysRead[j] { continue } - // Found the required version of the key, return immediately. - if vs.Version == version { - y.NumGetsWithResultsAdd(db.opt.MetricsEnabled, 1) - return vs, nil + if db.checkKeyInMemtables(tables, key, &maxVs[j], version) { + keysRead[j] = true } - if maxVs.Version < vs.Version { - maxVs = vs + } + return db.lc.getBatch(keys, maxVs, 0, keysRead, version) +} + +// get returns the value in memtable or disk for given key. +// Note that value will include meta byte. +func (db *DB) get(key []byte) (y.ValueStruct, error) { + if db.opt.useGetBatch { + done := make([]bool, 1) + vals, err := db.getBatch([][]byte{key}, done) + if len(vals) != 0 { + return vals[0], err } + return y.ValueStruct{}, err + } + + if db.IsClosed() { + return y.ValueStruct{}, ErrDBClosed + } + tables, decr := db.getMemTables() // Lock should be released. + defer decr() + + var maxVs y.ValueStruct + version := y.ParseTs(key) + if db.checkKeyInMemtables(tables, key, &maxVs, version) { + return maxVs, nil } return db.lc.get(key, maxVs, 0) } diff --git a/level_handler.go b/level_handler.go index 391803f5c..f6e00cb95 100644 --- a/level_handler.go +++ b/level_handler.go @@ -262,6 +262,95 @@ func (s *levelHandler) getTableForKey(key []byte) ([]*table.Table, func() error) return []*table.Table{tbl}, tbl.DecrRef } +// checkInsideIteator checks if the key is present in the iterator or not. It updates maxVs if the value is +// found. +func (s *levelHandler) checkInsideIterator(key []byte, it *table.Iterator, maxVs *y.ValueStruct) { + y.NumLSMGetsAdd(s.db.opt.MetricsEnabled, s.strLevel, 1) + it.Seek(key) + if !it.Valid() { + return + } + if !y.SameKey(key, it.Key()) { + return + } + if version := y.ParseTs(it.Key()); maxVs.Version < version { + *maxVs = it.ValueCopy() + maxVs.Version = version + } +} + +func (s *levelHandler) getBatch(keys [][]byte, keysRead []bool) ([]y.ValueStruct, error) { + // Find the table for which the key is in, and then seek it. There's a good chance that they next key to be + // searched, is in the same table as well. Hence, we store the iterators found. If we don't find the results + // in the given table, we would need to search again. Worst case, this function could be a little worse than + // getting the n keys, in n different get calls. + createIteratorsForEachTable := func(key []byte) (y.ValueStruct, func() error, []*table.Iterator) { + tables, decr := s.getTableForKey(key) + keyNoTs := y.ParseKey(key) + itrs := make([]*table.Iterator, 0) + + hash := y.Hash(keyNoTs) + var maxVs y.ValueStruct + for _, th := range tables { + if th.DoesNotHave(hash) { + y.NumLSMBloomHitsAdd(s.db.opt.MetricsEnabled, s.strLevel, 1) + continue + } + + it := th.NewIterator(0) + itrs = append(itrs, it) + s.checkInsideIterator(key, it, &maxVs) + } + + return maxVs, decr, itrs + } + + // Use old results from createIteratorsForEachTable and find in those tables. + findInIterators := func(key []byte, itrs []*table.Iterator) y.ValueStruct { + var maxVs y.ValueStruct + for _, it := range itrs { + s.checkInsideIterator(key, it, &maxVs) + } + return maxVs + } + + results := make([]y.ValueStruct, len(keys)) + + decr := func() error { return nil } + var itrs []*table.Iterator + + close_iters := func() { + for _, itr := range itrs { + itr.Close() + } + } + + defer close_iters() + + for i := 0; i < len(keys); i++ { + if keysRead[i] { + continue + } + // If there are no iterators present, create new iterators + if len(itrs) == 0 { + results[i], decr, itrs = createIteratorsForEachTable(keys[i]) + } else { + results[i] = findInIterators(keys[i], itrs) + // If we can't find in the current tables, then data is there in other tables. We would + // then need to close iterators, call decr() and then recreate new iterators. + if len(results[i].Value) == 0 { + close_iters() + if err := decr(); err != nil { + return nil, err + } + results[i], decr, itrs = createIteratorsForEachTable(keys[i]) + } + } + } + + return results, decr() +} + // get returns value for a given key or the key after that. If not found, return nil. func (s *levelHandler) get(key []byte) (y.ValueStruct, error) { tables, decr := s.getTableForKey(key) @@ -278,17 +367,7 @@ func (s *levelHandler) get(key []byte) (y.ValueStruct, error) { it := th.NewIterator(0) defer it.Close() - y.NumLSMGetsAdd(s.db.opt.MetricsEnabled, s.strLevel, 1) - it.Seek(key) - if !it.Valid() { - continue - } - if y.SameKey(key, it.Key()) { - if version := y.ParseTs(it.Key()); maxVs.Version < version { - maxVs = it.ValueCopy() - maxVs.Version = version - } - } + s.checkInsideIterator(key, it, &maxVs) } return maxVs, decr() } diff --git a/levels.go b/levels.go index 48a954316..7abef9888 100644 --- a/levels.go +++ b/levels.go @@ -1594,6 +1594,54 @@ func (s *levelsController) close() error { return y.Wrap(err, "levelsController.Close") } +func (s *levelsController) getBatch(keys [][]byte, + maxVs []y.ValueStruct, startLevel int, keysRead []bool, version uint64) ([]y.ValueStruct, error) { + if s.kv.IsClosed() { + return []y.ValueStruct{}, ErrDBClosed + } + // It's important that we iterate the levels from 0 on upward. The reason is, if we iterated + // in opposite order, or in parallel (naively calling all the h.RLock() in some order) we could + // read level L's tables post-compaction and level L+1's tables pre-compaction. (If we do + // parallelize this, we will need to call the h.RLock() function by increasing order of level + // number.) + for _, h := range s.levels { + // Ignore all levels below startLevel. This is useful for GC when L0 is kept in memory. + if h.level < startLevel { + continue + } + vs, err := h.getBatch(keys, keysRead, version) // Calls h.RLock() and h.RUnlock(). + if err != nil { + return []y.ValueStruct{}, y.Wrapf(err, "get keys: %q", keys) + } + + for i, v := range vs { + // keysRead is only update by this function or one in db. levelhandler will + // not update done. No need to do anything is done is set. + if keysRead[i] { + continue + } + if v.Value == nil && v.Meta == 0 { + continue + } + y.NumBytesReadsLSMAdd(s.kv.opt.MetricsEnabled, int64(len(v.Value))) + if v.Version == version { + maxVs[i] = v + keysRead[i] = true + } + if maxVs[i].Version < v.Version { + maxVs[i] = v + } + } + } + + for i := 0; i < len(maxVs); i++ { + if len(maxVs[i].Value) > 0 { + y.NumGetsWithResultsAdd(s.kv.opt.MetricsEnabled, 1) + } + } + return maxVs, nil +} + // get searches for a given key in all the levels of the LSM tree. It returns // key version <= the expected version (version in key). If not found, // it returns an empty y.ValueStruct. diff --git a/options.go b/options.go index 218b94772..7a2a95f7d 100644 --- a/options.go +++ b/options.go @@ -118,6 +118,9 @@ type Options struct { maxBatchSize int64 // max batch size in bytes maxValueThreshold float64 + + // This would let you use get batch instead of get, an experimental api instead + useGetBatch bool } // DefaultOptions sets a list of recommended options for good performance. @@ -176,6 +179,7 @@ func DefaultOptions(path string) Options { EncryptionKeyRotationDuration: 10 * 24 * time.Hour, // Default 10 days. DetectConflicts: true, NamespaceOffset: -1, + useGetBatch: true, } } diff --git a/txn.go b/txn.go index 50d17a5bc..0f0d9615b 100644 --- a/txn.go +++ b/txn.go @@ -429,6 +429,84 @@ func (txn *Txn) Delete(key []byte) error { return txn.modify(e) } +func (txn *Txn) GetBatch(keys [][]byte) (items []*Item, rerr error) { + if txn.discarded { + return nil, ErrDiscardedTxn + } + + for _, key := range keys { + if len(key) == 0 { + return nil, ErrEmptyKey + } + if err := txn.db.isBanned(key); err != nil { + return nil, err + } + } + + items = make([]*Item, len(keys)) + done := make([]bool, len(keys)) + + if txn.update { + doneAll := 0 + for i, key := range keys { + item := items[i] + if e, has := txn.pendingWrites[string(key)]; has && bytes.Equal(key, e.Key) { + if isDeletedOrExpired(e.meta, e.ExpiresAt) { + items[i] = nil + continue + } + // Fulfill from cache. + item.meta = e.meta + item.val = e.Value + item.userMeta = e.UserMeta + item.key = key + item.status = prefetched + item.version = txn.readTs + item.expiresAt = e.ExpiresAt + // We probably don't need to set db on item here. + done[i] = true + doneAll += 1 + } + // Only track reads if this is update txn. No need to track read if txn serviced it + // internally. + txn.addReadKey(key) + } + if doneAll == len(keys) { + return items, nil + } + } + + seeks := make([][]byte, len(keys)) + for i, key := range keys { + seeks[i] = y.KeyWithTs(key, txn.readTs) + } + vss, err := txn.db.getBatch(seeks, done, txn.readTs) + if err != nil { + return nil, y.Wrapf(err, "DB::Get keys: %q", keys) + } + + for i, vs := range vss { + if vs.Value == nil && vs.Meta == 0 { + items[i] = nil + } + if isDeletedOrExpired(vs.Meta, vs.ExpiresAt) { + items[i] = nil + } + + items[i] = new(Item) + items[i].key = keys[i] + items[i].version = vs.Version + items[i].meta = vs.Meta + items[i].userMeta = vs.UserMeta + items[i].vptr = y.SafeCopy(items[i].vptr, vs.Value) + items[i].txn = txn + items[i].expiresAt = vs.ExpiresAt + } + + return items, nil + +} + // Get looks for key and returns corresponding Item. // If key is not found, ErrKeyNotFound is returned. func (txn *Txn) Get(key []byte) (item *Item, rerr error) { diff --git a/txn_test.go b/txn_test.go index 9a776c4ef..90c7c2885 100644 --- a/txn_test.go +++ b/txn_test.go @@ -21,6 +21,45 @@ import ( "github.com/dgraph-io/ristretto/v2/z" ) +func TestTxnSimpleTsRead(t *testing.T) { + dir, err := os.MkdirTemp("", "badger-test") + require.NoError(t, err) + defer removeDir(dir) + opts := getTestOptions(dir) + opts.Dir = dir + opts.ValueDir = dir + + opts.managedTxns = true + + db, err := Open(opts) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + for i := 0; i < 10; i++ { + txn := db.NewTransactionAt(uint64(i)+1, true) + k := []byte(fmt.Sprintf("key=%d", 1)) + v := []byte(fmt.Sprintf("val=%d", i)) + require.NoError(t, txn.SetEntry(NewEntry(k, v))) + err = txn.CommitAt(uint64(i)*3+1, nil) + require.NoError(t, err) + } + + for i := 7; i < 10; i++ { + txn := db.NewTransactionAt(uint64(i), false) + item, err := txn.Get([]byte("key=1")) + require.NoError(t, err) + + require.NoError(t, item.Value(func(val []byte) error { + require.Equal(t, []byte("val=2"), val) + return nil + })) + + txn.Discard() + } +} + func TestTxnSimple(t *testing.T) { runBadgerTest(t, nil, func(t *testing.T, db *DB) { txn := db.NewTransaction(true)