diff --git a/types/query/collections_pagination.go b/types/query/collections_pagination.go index 9e8397d9ea..4ef389038d 100644 --- a/types/query/collections_pagination.go +++ b/types/query/collections_pagination.go @@ -53,9 +53,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]]( predicateFunc func(key K, value V) (include bool, err error), opts ...func(opt *CollectionsPaginateOptions[K]), ) ([]collections.KeyValue[K, V], *PageResponse, error) { - if pageReq == nil { - pageReq = &PageRequest{} - } + pageReq = initPageRequestDefaults(pageReq) offset := pageReq.Offset key := pageReq.Key @@ -67,11 +65,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]]( return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both") } - if limit == 0 { - limit = DefaultLimit - countTotal = true - } - var ( results []collections.KeyValue[K, V] pageRes *PageResponse diff --git a/types/query/filtered_pagination.go b/types/query/filtered_pagination.go index dcb3553039..b5730693c2 100644 --- a/types/query/filtered_pagination.go +++ b/types/query/filtered_pagination.go @@ -22,55 +22,33 @@ func FilteredPaginate( pageRequest *PageRequest, onResult func(key, value []byte, accumulate bool) (bool, error), ) (*PageResponse, error) { - // if the PageRequest is nil, use default PageRequest - if pageRequest == nil { - pageRequest = &PageRequest{} - } + pageRequest = initPageRequestDefaults(pageRequest) - offset := pageRequest.Offset - key := pageRequest.Key - limit := pageRequest.Limit - countTotal := pageRequest.CountTotal - reverse := pageRequest.Reverse - - if offset > 0 && key != nil { + if pageRequest.Offset > 0 && pageRequest.Key != nil { return nil, fmt.Errorf("invalid request, either offset or key is expected, got both") } - if limit == 0 { - limit = DefaultLimit + var ( + numHits uint64 + nextKey []byte + err error + ) - // count total results when the limit is zero/not supplied - countTotal = true - } - - if len(key) != 0 { - iterator := getIterator(prefixStore, key, reverse) - defer iterator.Close() - - var ( - numHits uint64 - nextKey []byte - ) + iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse) + defer iterator.Close() + if len(pageRequest.Key) != 0 { + accumulateFn := func(_ uint64) bool { return true } for ; iterator.Valid(); iterator.Next() { - if numHits == limit { + if numHits == pageRequest.Limit { nextKey = iterator.Key() break } - if iterator.Error() != nil { - return nil, iterator.Error() - } - - hit, err := onResult(iterator.Key(), iterator.Value(), true) + numHits, err = processResult(iterator, numHits, onResult, accumulateFn) if err != nil { return nil, err } - - if hit { - numHits++ - } } return &PageResponse{ @@ -78,50 +56,81 @@ func FilteredPaginate( }, nil } - iterator := getIterator(prefixStore, nil, reverse) - defer iterator.Close() - - end := offset + limit - - var ( - numHits uint64 - nextKey []byte - ) + end := pageRequest.Offset + pageRequest.Limit + accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end } for ; iterator.Valid(); iterator.Next() { - if iterator.Error() != nil { - return nil, iterator.Error() - } - - accumulate := numHits >= offset && numHits < end - hit, err := onResult(iterator.Key(), iterator.Value(), accumulate) + numHits, err = processResult(iterator, numHits, onResult, accumulateFn) if err != nil { return nil, err } - - if hit { - numHits++ - } - if numHits == end+1 { if nextKey == nil { nextKey = iterator.Key() } - if !countTotal { + if !pageRequest.CountTotal { break } } } res := &PageResponse{NextKey: nextKey} - if countTotal { + if pageRequest.CountTotal { res.Total = numHits } return res, nil } +func processResult(iterator types.Iterator, numHits uint64, onResult func(key, value []byte, accumulate bool) (bool, error), accumulateFn func(numHits uint64) bool) (uint64, error) { + if iterator.Error() != nil { + return numHits, iterator.Error() + } + + accumulate := accumulateFn(numHits) + hit, err := onResult(iterator.Key(), iterator.Value(), accumulate) + if err != nil { + return numHits, err + } + + if hit { + numHits++ + } + + return numHits, nil +} + +func genericProcessResult[T, F proto.Message](iterator types.Iterator, numHits uint64, onResult func(key []byte, value T) (F, error), accumulateFn func(numHits uint64) bool, + constructor func() T, cdc codec.BinaryCodec, results []F, +) ([]F, uint64, error) { + if iterator.Error() != nil { + return results, numHits, iterator.Error() + } + + protoMsg := constructor() + + err := cdc.Unmarshal(iterator.Value(), protoMsg) + if err != nil { + return results, numHits, err + } + + val, err := onResult(iterator.Key(), protoMsg) + if err != nil { + return results, numHits, err + } + + if proto.Size(val) != 0 { + // Previously this was the "accumulate" flag + if accumulateFn(numHits) { + results = append(results, val) + } + numHits++ + } + + return results, numHits, nil +} + // GenericFilteredPaginate does pagination of all the results in the PrefixStore based on the // provided PageRequest. `onResult` should be used to filter or transform the results. // `c` is a constructor function that needs to return a new instance of the type T (this is to @@ -137,64 +146,34 @@ func GenericFilteredPaginate[T, F proto.Message]( onResult func(key []byte, value T) (F, error), constructor func() T, ) ([]F, *PageResponse, error) { - // if the PageRequest is nil, use default PageRequest - if pageRequest == nil { - pageRequest = &PageRequest{} - } - - offset := pageRequest.Offset - key := pageRequest.Key - limit := pageRequest.Limit - countTotal := pageRequest.CountTotal - reverse := pageRequest.Reverse + pageRequest = initPageRequestDefaults(pageRequest) results := []F{} - if offset > 0 && key != nil { + if pageRequest.Offset > 0 && pageRequest.Key != nil { return results, nil, fmt.Errorf("invalid request, either offset or key is expected, got both") } - if limit == 0 { - limit = DefaultLimit + var ( + numHits uint64 + nextKey []byte + err error + ) - // count total results when the limit is zero/not supplied - countTotal = true - } - - if len(key) != 0 { - iterator := getIterator(prefixStore, key, reverse) - defer iterator.Close() - - var ( - numHits uint64 - nextKey []byte - ) + iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse) + defer iterator.Close() + if len(pageRequest.Key) != 0 { + accumulateFn := func(_ uint64) bool { return true } for ; iterator.Valid(); iterator.Next() { - if numHits == limit { + if numHits == pageRequest.Limit { nextKey = iterator.Key() break } - if iterator.Error() != nil { - return nil, nil, iterator.Error() - } - - protoMsg := constructor() - - err := cdc.Unmarshal(iterator.Value(), protoMsg) + results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results) if err != nil { return nil, nil, err } - - val, err := onResult(iterator.Key(), protoMsg) - if err != nil { - return nil, nil, err - } - - if proto.Size(val) != 0 { - results = append(results, val) - numHits++ - } } return results, &PageResponse{ @@ -202,54 +181,28 @@ func GenericFilteredPaginate[T, F proto.Message]( }, nil } - iterator := getIterator(prefixStore, nil, reverse) - defer iterator.Close() - - end := offset + limit - - var ( - numHits uint64 - nextKey []byte - ) + end := pageRequest.Offset + pageRequest.Limit + accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end } for ; iterator.Valid(); iterator.Next() { - if iterator.Error() != nil { - return nil, nil, iterator.Error() - } - - protoMsg := constructor() - - err := cdc.Unmarshal(iterator.Value(), protoMsg) + results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results) if err != nil { return nil, nil, err } - val, err := onResult(iterator.Key(), protoMsg) - if err != nil { - return nil, nil, err - } - - if proto.Size(val) != 0 { - // Previously this was the "accumulate" flag - if numHits >= offset && numHits < end { - results = append(results, val) - } - numHits++ - } - if numHits == end+1 { if nextKey == nil { nextKey = iterator.Key() } - if !countTotal { + if !pageRequest.CountTotal { break } } } res := &PageResponse{NextKey: nextKey} - if countTotal { + if pageRequest.CountTotal { res.Total = numHits } diff --git a/types/query/pagination.go b/types/query/pagination.go index 61e3c0679e..c7d70c404b 100644 --- a/types/query/pagination.go +++ b/types/query/pagination.go @@ -54,38 +54,21 @@ func Paginate( pageRequest *PageRequest, onResult func(key, value []byte) error, ) (*PageResponse, error) { - // if the PageRequest is nil, use default PageRequest - if pageRequest == nil { - pageRequest = &PageRequest{} - } + pageRequest = initPageRequestDefaults(pageRequest) - offset := pageRequest.Offset - key := pageRequest.Key - limit := pageRequest.Limit - countTotal := pageRequest.CountTotal - reverse := pageRequest.Reverse - - if offset > 0 && key != nil { + if pageRequest.Offset > 0 && pageRequest.Key != nil { return nil, fmt.Errorf("invalid request, either offset or key is expected, got both") } - if limit == 0 { - limit = DefaultLimit + iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse) + defer iterator.Close() - // count total results when the limit is zero/not supplied - countTotal = true - } - - if len(key) != 0 { - iterator := getIterator(prefixStore, key, reverse) - defer iterator.Close() - - var count uint64 - var nextKey []byte + var count uint64 + var nextKey []byte + if len(pageRequest.Key) != 0 { for ; iterator.Valid(); iterator.Next() { - - if count == limit { + if count == pageRequest.Limit { nextKey = iterator.Key() break } @@ -105,18 +88,12 @@ func Paginate( }, nil } - iterator := getIterator(prefixStore, nil, reverse) - defer iterator.Close() - - end := offset + limit - - var count uint64 - var nextKey []byte + end := pageRequest.Offset + pageRequest.Limit for ; iterator.Valid(); iterator.Next() { count++ - if count <= offset { + if count <= pageRequest.Offset { continue } if count <= end { @@ -127,7 +104,7 @@ func Paginate( } else if count == end+1 { nextKey = iterator.Key() - if !countTotal { + if !pageRequest.CountTotal { break } } @@ -137,7 +114,7 @@ func Paginate( } res := &PageResponse{NextKey: nextKey} - if countTotal { + if pageRequest.CountTotal { res.Total = count } @@ -159,3 +136,25 @@ func getIterator(prefixStore types.KVStore, start []byte, reverse bool) db.Itera } return prefixStore.Iterator(start, nil) } + +// initPageRequestDefaults initializes a PageRequest's defaults when those are not set. +func initPageRequestDefaults(pageRequest *PageRequest) *PageRequest { + // if the PageRequest is nil, use default PageRequest + if pageRequest == nil { + pageRequest = &PageRequest{} + } + + pageRequestCopy := *pageRequest + if len(pageRequestCopy.Key) == 0 { + pageRequestCopy.Key = nil + } + + if pageRequestCopy.Limit == 0 { + pageRequestCopy.Limit = DefaultLimit + + // count total results when the limit is zero/not supplied + pageRequestCopy.CountTotal = true + } + + return &pageRequestCopy +}