refactor: Start DRY'ing filtered paginate code (#16099)

Co-authored-by: testinginprod <98415576+testinginprod@users.noreply.github.com>
This commit is contained in:
Dev Ojha 2023-05-15 21:40:58 +02:00 committed by GitHub
parent e34a3e0559
commit 660e906eb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 174 deletions

View File

@ -53,9 +53,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
predicateFunc func(key K, value V) (include bool, err error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
if pageReq == nil {
pageReq = &PageRequest{}
}
pageReq = initPageRequestDefaults(pageReq)
offset := pageReq.Offset
key := pageReq.Key
@ -67,11 +65,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}
if limit == 0 {
limit = DefaultLimit
countTotal = true
}
var (
results []collections.KeyValue[K, V]
pageRes *PageResponse

View File

@ -22,55 +22,33 @@ func FilteredPaginate(
pageRequest *PageRequest,
onResult func(key, value []byte, accumulate bool) (bool, error),
) (*PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}
pageRequest = initPageRequestDefaults(pageRequest)
offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}
if limit == 0 {
limit = DefaultLimit
var (
numHits uint64
nextKey []byte
err error
)
// count total results when the limit is zero/not supplied
countTotal = true
}
if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()
if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}
if iterator.Error() != nil {
return nil, iterator.Error()
}
hit, err := onResult(iterator.Key(), iterator.Value(), true)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}
if hit {
numHits++
}
}
return &PageResponse{
@ -78,50 +56,81 @@ func FilteredPaginate(
}, nil
}
iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()
end := offset + limit
var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }
for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, iterator.Error()
}
accumulate := numHits >= offset && numHits < end
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}
if hit {
numHits++
}
if numHits == end+1 {
if nextKey == nil {
nextKey = iterator.Key()
}
if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}
res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}
return res, nil
}
func processResult(iterator types.Iterator, numHits uint64, onResult func(key, value []byte, accumulate bool) (bool, error), accumulateFn func(numHits uint64) bool) (uint64, error) {
if iterator.Error() != nil {
return numHits, iterator.Error()
}
accumulate := accumulateFn(numHits)
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
if err != nil {
return numHits, err
}
if hit {
numHits++
}
return numHits, nil
}
func genericProcessResult[T, F proto.Message](iterator types.Iterator, numHits uint64, onResult func(key []byte, value T) (F, error), accumulateFn func(numHits uint64) bool,
constructor func() T, cdc codec.BinaryCodec, results []F,
) ([]F, uint64, error) {
if iterator.Error() != nil {
return results, numHits, iterator.Error()
}
protoMsg := constructor()
err := cdc.Unmarshal(iterator.Value(), protoMsg)
if err != nil {
return results, numHits, err
}
val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return results, numHits, err
}
if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if accumulateFn(numHits) {
results = append(results, val)
}
numHits++
}
return results, numHits, nil
}
// GenericFilteredPaginate does pagination of all the results in the PrefixStore based on the
// provided PageRequest. `onResult` should be used to filter or transform the results.
// `c` is a constructor function that needs to return a new instance of the type T (this is to
@ -137,64 +146,34 @@ func GenericFilteredPaginate[T, F proto.Message](
onResult func(key []byte, value T) (F, error),
constructor func() T,
) ([]F, *PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}
offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
pageRequest = initPageRequestDefaults(pageRequest)
results := []F{}
if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return results, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}
if limit == 0 {
limit = DefaultLimit
var (
numHits uint64
nextKey []byte
err error
)
// count total results when the limit is zero/not supplied
countTotal = true
}
if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()
if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}
if iterator.Error() != nil {
return nil, nil, iterator.Error()
}
protoMsg := constructor()
err := cdc.Unmarshal(iterator.Value(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}
val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return nil, nil, err
}
if proto.Size(val) != 0 {
results = append(results, val)
numHits++
}
}
return results, &PageResponse{
@ -202,54 +181,28 @@ func GenericFilteredPaginate[T, F proto.Message](
}, nil
}
iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()
end := offset + limit
var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }
for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, nil, iterator.Error()
}
protoMsg := constructor()
err := cdc.Unmarshal(iterator.Value(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}
val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return nil, nil, err
}
if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if numHits >= offset && numHits < end {
results = append(results, val)
}
numHits++
}
if numHits == end+1 {
if nextKey == nil {
nextKey = iterator.Key()
}
if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}
res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}

View File

@ -54,38 +54,21 @@ func Paginate(
pageRequest *PageRequest,
onResult func(key, value []byte) error,
) (*PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}
pageRequest = initPageRequestDefaults(pageRequest)
offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}
if limit == 0 {
limit = DefaultLimit
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()
// count total results when the limit is zero/not supplied
countTotal = true
}
if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var count uint64
var nextKey []byte
var count uint64
var nextKey []byte
if len(pageRequest.Key) != 0 {
for ; iterator.Valid(); iterator.Next() {
if count == limit {
if count == pageRequest.Limit {
nextKey = iterator.Key()
break
}
@ -105,18 +88,12 @@ func Paginate(
}, nil
}
iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()
end := offset + limit
var count uint64
var nextKey []byte
end := pageRequest.Offset + pageRequest.Limit
for ; iterator.Valid(); iterator.Next() {
count++
if count <= offset {
if count <= pageRequest.Offset {
continue
}
if count <= end {
@ -127,7 +104,7 @@ func Paginate(
} else if count == end+1 {
nextKey = iterator.Key()
if !countTotal {
if !pageRequest.CountTotal {
break
}
}
@ -137,7 +114,7 @@ func Paginate(
}
res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = count
}
@ -159,3 +136,25 @@ func getIterator(prefixStore types.KVStore, start []byte, reverse bool) db.Itera
}
return prefixStore.Iterator(start, nil)
}
// initPageRequestDefaults initializes a PageRequest's defaults when those are not set.
func initPageRequestDefaults(pageRequest *PageRequest) *PageRequest {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}
pageRequestCopy := *pageRequest
if len(pageRequestCopy.Key) == 0 {
pageRequestCopy.Key = nil
}
if pageRequestCopy.Limit == 0 {
pageRequestCopy.Limit = DefaultLimit
// count total results when the limit is zero/not supplied
pageRequestCopy.CountTotal = true
}
return &pageRequestCopy
}