feat(collections): Add Clear method on Map and KeySet (#16618)

Co-authored-by: unknown unknown <unknown@unknown>
This commit is contained in:
testinginprod 2023-06-20 15:44:06 +02:00 committed by GitHub
parent e078f1a49e
commit d4f1e88b65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 13 deletions

View File

@ -31,6 +31,10 @@ Ref: https://keepachangelog.com/en/1.0.0/
## [Unreleased]
### Features
* [#16074](https://github.com/cosmos/cosmos-sdk/pull/16607) - Introduces `Clear` method for `Map` and `KeySet`
## [v0.2.0](https://github.com/cosmos/cosmos-sdk/releases/tag/collections%2Fv0.2.0)
### Features

View File

@ -130,39 +130,48 @@ func (r *Range[K]) RangeValues() (start, end *RangeKey[K], order Order, err erro
return r.start, r.end, r.order, nil
}
// iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging.
// a nil Ranger can be seen as an ascending iteration over all the possible keys.
func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) {
// parseRangeInstruction converts a Ranger into start bytes, end bytes and order of a store iteration.
func parseRangeInstruction[K any](prefix []byte, keyCodec codec.KeyCodec[K], r Ranger[K]) ([]byte, []byte, Order, error) {
var (
start *RangeKey[K]
end *RangeKey[K]
order = OrderAscending
err error
)
if r != nil {
start, end, order, err = r.RangeValues()
if err != nil {
return iter, err
return nil, nil, 0, err
}
}
startBytes := m.prefix
startBytes := prefix
if start != nil {
startBytes, err = encodeRangeBound(m.prefix, m.kc, start)
startBytes, err = encodeRangeBound(prefix, keyCodec, start)
if err != nil {
return iter, err
return nil, nil, 0, err
}
}
var endBytes []byte
if end != nil {
endBytes, err = encodeRangeBound(m.prefix, m.kc, end)
endBytes, err = encodeRangeBound(prefix, keyCodec, end)
if err != nil {
return iter, err
return nil, nil, 0, err
}
} else {
endBytes = nextBytesPrefixKey(m.prefix)
endBytes = nextBytesPrefixKey(prefix)
}
return startBytes, endBytes, order, nil
}
// iteratorFromRanger generates an Iterator instance, with the proper prefixing and ranging.
// a nil Ranger can be seen as an ascending iteration over all the possible keys.
func iteratorFromRanger[K, V any](ctx context.Context, m Map[K, V], r Ranger[K]) (iter Iterator[K, V], err error) {
startBytes, endBytes, order, err := parseRangeInstruction(m.prefix, m.kc, r)
if err != nil {
return Iterator[K, V]{}, err
}
return newIterator(ctx, startBytes, endBytes, order, m)
}

View File

@ -57,6 +57,12 @@ func (k KeySet[K]) Walk(ctx context.Context, ranger Ranger[K], walkFunc func(key
return (Map[K, NoValue])(k).Walk(ctx, ranger, func(key K, value NoValue) (bool, error) { return walkFunc(key) })
}
// Clear clears the KeySet using the provided Ranger. Refer to Map.Clear for
// behavioral documentation.
func (k KeySet[K]) Clear(ctx context.Context, ranger Ranger[K]) error {
return (Map[K, NoValue])(k).Clear(ctx, ranger)
}
func (k KeySet[K]) KeyCodec() codec.KeyCodec[K] { return (Map[K, NoValue])(k).KeyCodec() }
func (k KeySet[K]) ValueCodec() codec.ValueCodec[NoValue] { return (Map[K, NoValue])(k).ValueCodec() }

View File

@ -5,7 +5,6 @@ import (
"fmt"
"cosmossdk.io/collections/codec"
"cosmossdk.io/core/store"
)
@ -65,8 +64,7 @@ func (m Map[K, V]) Set(ctx context.Context, key K, value V) error {
}
kvStore := m.sa(ctx)
kvStore.Set(bytesKey, valueBytes)
return nil
return kvStore.Set(bytesKey, valueBytes)
}
// Get returns the value associated with the provided key,
@ -150,6 +148,59 @@ func (m Map[K, V]) Walk(ctx context.Context, ranger Ranger[K], walkFunc func(key
return nil
}
// Clear clears the collection contained within the provided key range.
// A nil ranger equals to clearing the whole collection. In case the collection
// is empty no error will be returned.
// NOTE: this API needs to be used with care, considering that as of today
// cosmos-sdk stores the deletion records to be committed in a memory cache,
// clearing a lot of data might make the node go OOM.
func (m Map[K, V]) Clear(ctx context.Context, ranger Ranger[K]) error {
startBytes, endBytes, _, err := parseRangeInstruction(m.prefix, m.kc, ranger)
if err != nil {
return err
}
return deleteDomain(m.sa(ctx), startBytes, endBytes)
}
const clearBatchSize = 10000
// deleteDomain deletes the domain of an iterator, the key difference
// is that it uses batches to clear the store meaning that it will read
// the keys within the domain close the iterator and then delete them.
func deleteDomain(s store.KVStore, start, end []byte) error {
for {
iter, err := s.Iterator(start, end)
if err != nil {
return err
}
keys := make([][]byte, 0, clearBatchSize)
for ; iter.Valid() && len(keys) < clearBatchSize; iter.Next() {
keys = append(keys, iter.Key())
}
// we close the iterator here instead of deferring
err = iter.Close()
if err != nil {
return err
}
for _, key := range keys {
err = s.Delete(key)
if err != nil {
return err
}
}
// If we've retrieved less than the batchSize, we're done.
if len(keys) < clearBatchSize {
break
}
}
return nil
}
// IterateRaw iterates over the collection. The iteration range is untyped, it uses raw
// bytes. The resulting Iterator is typed.
// A nil start iterates from the first key contained in the collection.

View File

@ -1,6 +1,7 @@
package collections
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -36,6 +37,40 @@ func TestMap(t *testing.T) {
require.False(t, has)
}
func TestMap_Clear(t *testing.T) {
makeTest := func() (context.Context, Map[uint64, uint64]) {
sk, ctx := deps()
m := NewMap(NewSchemaBuilder(sk), NewPrefix(0), "test", Uint64Key, Uint64Value)
for i := uint64(0); i < clearBatchSize*2; i++ {
require.NoError(t, m.Set(ctx, i, i))
}
return ctx, m
}
t.Run("nil ranger", func(t *testing.T) {
ctx, m := makeTest()
err := m.Clear(ctx, nil)
require.NoError(t, err)
_, err = m.Iterate(ctx, nil)
require.ErrorIs(t, err, ErrInvalidIterator)
})
t.Run("custom ranger", func(t *testing.T) {
ctx, m := makeTest()
// delete from 0 to 100
err := m.Clear(ctx, new(Range[uint64]).StartInclusive(0).EndInclusive(100))
require.NoError(t, err)
iter, err := m.Iterate(ctx, nil)
require.NoError(t, err)
keys, err := iter.Keys()
require.NoError(t, err)
require.Len(t, keys, clearBatchSize*2-101)
require.Equal(t, keys[0], uint64(101))
require.Equal(t, keys[len(keys)-1], uint64(clearBatchSize*2-1))
})
}
func TestMap_IterateRaw(t *testing.T) {
sk, ctx := deps()
// safety check to ensure prefix boundaries are not crossed