refactor: Start DRY'ing filtered paginate code (#16099)
Co-authored-by: testinginprod <98415576+testinginprod@users.noreply.github.com>
This commit is contained in:
parent
e34a3e0559
commit
660e906eb8
@ -53,9 +53,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
|
||||
predicateFunc func(key K, value V) (include bool, err error),
|
||||
opts ...func(opt *CollectionsPaginateOptions[K]),
|
||||
) ([]collections.KeyValue[K, V], *PageResponse, error) {
|
||||
if pageReq == nil {
|
||||
pageReq = &PageRequest{}
|
||||
}
|
||||
pageReq = initPageRequestDefaults(pageReq)
|
||||
|
||||
offset := pageReq.Offset
|
||||
key := pageReq.Key
|
||||
@ -67,11 +65,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
|
||||
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
limit = DefaultLimit
|
||||
countTotal = true
|
||||
}
|
||||
|
||||
var (
|
||||
results []collections.KeyValue[K, V]
|
||||
pageRes *PageResponse
|
||||
|
||||
@ -22,55 +22,33 @@ func FilteredPaginate(
|
||||
pageRequest *PageRequest,
|
||||
onResult func(key, value []byte, accumulate bool) (bool, error),
|
||||
) (*PageResponse, error) {
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if pageRequest == nil {
|
||||
pageRequest = &PageRequest{}
|
||||
}
|
||||
pageRequest = initPageRequestDefaults(pageRequest)
|
||||
|
||||
offset := pageRequest.Offset
|
||||
key := pageRequest.Key
|
||||
limit := pageRequest.Limit
|
||||
countTotal := pageRequest.CountTotal
|
||||
reverse := pageRequest.Reverse
|
||||
|
||||
if offset > 0 && key != nil {
|
||||
if pageRequest.Offset > 0 && pageRequest.Key != nil {
|
||||
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
limit = DefaultLimit
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
err error
|
||||
)
|
||||
|
||||
// count total results when the limit is zero/not supplied
|
||||
countTotal = true
|
||||
}
|
||||
|
||||
if len(key) != 0 {
|
||||
iterator := getIterator(prefixStore, key, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
)
|
||||
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
if len(pageRequest.Key) != 0 {
|
||||
accumulateFn := func(_ uint64) bool { return true }
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if numHits == limit {
|
||||
if numHits == pageRequest.Limit {
|
||||
nextKey = iterator.Key()
|
||||
break
|
||||
}
|
||||
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
|
||||
hit, err := onResult(iterator.Key(), iterator.Value(), true)
|
||||
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hit {
|
||||
numHits++
|
||||
}
|
||||
}
|
||||
|
||||
return &PageResponse{
|
||||
@ -78,50 +56,81 @@ func FilteredPaginate(
|
||||
}, nil
|
||||
}
|
||||
|
||||
iterator := getIterator(prefixStore, nil, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
end := offset + limit
|
||||
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
)
|
||||
end := pageRequest.Offset + pageRequest.Limit
|
||||
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if iterator.Error() != nil {
|
||||
return nil, iterator.Error()
|
||||
}
|
||||
|
||||
accumulate := numHits >= offset && numHits < end
|
||||
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
|
||||
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hit {
|
||||
numHits++
|
||||
}
|
||||
|
||||
if numHits == end+1 {
|
||||
if nextKey == nil {
|
||||
nextKey = iterator.Key()
|
||||
}
|
||||
|
||||
if !countTotal {
|
||||
if !pageRequest.CountTotal {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := &PageResponse{NextKey: nextKey}
|
||||
if countTotal {
|
||||
if pageRequest.CountTotal {
|
||||
res.Total = numHits
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func processResult(iterator types.Iterator, numHits uint64, onResult func(key, value []byte, accumulate bool) (bool, error), accumulateFn func(numHits uint64) bool) (uint64, error) {
|
||||
if iterator.Error() != nil {
|
||||
return numHits, iterator.Error()
|
||||
}
|
||||
|
||||
accumulate := accumulateFn(numHits)
|
||||
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
|
||||
if err != nil {
|
||||
return numHits, err
|
||||
}
|
||||
|
||||
if hit {
|
||||
numHits++
|
||||
}
|
||||
|
||||
return numHits, nil
|
||||
}
|
||||
|
||||
func genericProcessResult[T, F proto.Message](iterator types.Iterator, numHits uint64, onResult func(key []byte, value T) (F, error), accumulateFn func(numHits uint64) bool,
|
||||
constructor func() T, cdc codec.BinaryCodec, results []F,
|
||||
) ([]F, uint64, error) {
|
||||
if iterator.Error() != nil {
|
||||
return results, numHits, iterator.Error()
|
||||
}
|
||||
|
||||
protoMsg := constructor()
|
||||
|
||||
err := cdc.Unmarshal(iterator.Value(), protoMsg)
|
||||
if err != nil {
|
||||
return results, numHits, err
|
||||
}
|
||||
|
||||
val, err := onResult(iterator.Key(), protoMsg)
|
||||
if err != nil {
|
||||
return results, numHits, err
|
||||
}
|
||||
|
||||
if proto.Size(val) != 0 {
|
||||
// Previously this was the "accumulate" flag
|
||||
if accumulateFn(numHits) {
|
||||
results = append(results, val)
|
||||
}
|
||||
numHits++
|
||||
}
|
||||
|
||||
return results, numHits, nil
|
||||
}
|
||||
|
||||
// GenericFilteredPaginate does pagination of all the results in the PrefixStore based on the
|
||||
// provided PageRequest. `onResult` should be used to filter or transform the results.
|
||||
// `c` is a constructor function that needs to return a new instance of the type T (this is to
|
||||
@ -137,64 +146,34 @@ func GenericFilteredPaginate[T, F proto.Message](
|
||||
onResult func(key []byte, value T) (F, error),
|
||||
constructor func() T,
|
||||
) ([]F, *PageResponse, error) {
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if pageRequest == nil {
|
||||
pageRequest = &PageRequest{}
|
||||
}
|
||||
|
||||
offset := pageRequest.Offset
|
||||
key := pageRequest.Key
|
||||
limit := pageRequest.Limit
|
||||
countTotal := pageRequest.CountTotal
|
||||
reverse := pageRequest.Reverse
|
||||
pageRequest = initPageRequestDefaults(pageRequest)
|
||||
results := []F{}
|
||||
|
||||
if offset > 0 && key != nil {
|
||||
if pageRequest.Offset > 0 && pageRequest.Key != nil {
|
||||
return results, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
limit = DefaultLimit
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
err error
|
||||
)
|
||||
|
||||
// count total results when the limit is zero/not supplied
|
||||
countTotal = true
|
||||
}
|
||||
|
||||
if len(key) != 0 {
|
||||
iterator := getIterator(prefixStore, key, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
)
|
||||
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
if len(pageRequest.Key) != 0 {
|
||||
accumulateFn := func(_ uint64) bool { return true }
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if numHits == limit {
|
||||
if numHits == pageRequest.Limit {
|
||||
nextKey = iterator.Key()
|
||||
break
|
||||
}
|
||||
|
||||
if iterator.Error() != nil {
|
||||
return nil, nil, iterator.Error()
|
||||
}
|
||||
|
||||
protoMsg := constructor()
|
||||
|
||||
err := cdc.Unmarshal(iterator.Value(), protoMsg)
|
||||
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
val, err := onResult(iterator.Key(), protoMsg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if proto.Size(val) != 0 {
|
||||
results = append(results, val)
|
||||
numHits++
|
||||
}
|
||||
}
|
||||
|
||||
return results, &PageResponse{
|
||||
@ -202,54 +181,28 @@ func GenericFilteredPaginate[T, F proto.Message](
|
||||
}, nil
|
||||
}
|
||||
|
||||
iterator := getIterator(prefixStore, nil, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
end := offset + limit
|
||||
|
||||
var (
|
||||
numHits uint64
|
||||
nextKey []byte
|
||||
)
|
||||
end := pageRequest.Offset + pageRequest.Limit
|
||||
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
if iterator.Error() != nil {
|
||||
return nil, nil, iterator.Error()
|
||||
}
|
||||
|
||||
protoMsg := constructor()
|
||||
|
||||
err := cdc.Unmarshal(iterator.Value(), protoMsg)
|
||||
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
val, err := onResult(iterator.Key(), protoMsg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if proto.Size(val) != 0 {
|
||||
// Previously this was the "accumulate" flag
|
||||
if numHits >= offset && numHits < end {
|
||||
results = append(results, val)
|
||||
}
|
||||
numHits++
|
||||
}
|
||||
|
||||
if numHits == end+1 {
|
||||
if nextKey == nil {
|
||||
nextKey = iterator.Key()
|
||||
}
|
||||
|
||||
if !countTotal {
|
||||
if !pageRequest.CountTotal {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := &PageResponse{NextKey: nextKey}
|
||||
if countTotal {
|
||||
if pageRequest.CountTotal {
|
||||
res.Total = numHits
|
||||
}
|
||||
|
||||
|
||||
@ -54,38 +54,21 @@ func Paginate(
|
||||
pageRequest *PageRequest,
|
||||
onResult func(key, value []byte) error,
|
||||
) (*PageResponse, error) {
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if pageRequest == nil {
|
||||
pageRequest = &PageRequest{}
|
||||
}
|
||||
pageRequest = initPageRequestDefaults(pageRequest)
|
||||
|
||||
offset := pageRequest.Offset
|
||||
key := pageRequest.Key
|
||||
limit := pageRequest.Limit
|
||||
countTotal := pageRequest.CountTotal
|
||||
reverse := pageRequest.Reverse
|
||||
|
||||
if offset > 0 && key != nil {
|
||||
if pageRequest.Offset > 0 && pageRequest.Key != nil {
|
||||
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
limit = DefaultLimit
|
||||
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
// count total results when the limit is zero/not supplied
|
||||
countTotal = true
|
||||
}
|
||||
|
||||
if len(key) != 0 {
|
||||
iterator := getIterator(prefixStore, key, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
var count uint64
|
||||
var nextKey []byte
|
||||
var count uint64
|
||||
var nextKey []byte
|
||||
|
||||
if len(pageRequest.Key) != 0 {
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
|
||||
if count == limit {
|
||||
if count == pageRequest.Limit {
|
||||
nextKey = iterator.Key()
|
||||
break
|
||||
}
|
||||
@ -105,18 +88,12 @@ func Paginate(
|
||||
}, nil
|
||||
}
|
||||
|
||||
iterator := getIterator(prefixStore, nil, reverse)
|
||||
defer iterator.Close()
|
||||
|
||||
end := offset + limit
|
||||
|
||||
var count uint64
|
||||
var nextKey []byte
|
||||
end := pageRequest.Offset + pageRequest.Limit
|
||||
|
||||
for ; iterator.Valid(); iterator.Next() {
|
||||
count++
|
||||
|
||||
if count <= offset {
|
||||
if count <= pageRequest.Offset {
|
||||
continue
|
||||
}
|
||||
if count <= end {
|
||||
@ -127,7 +104,7 @@ func Paginate(
|
||||
} else if count == end+1 {
|
||||
nextKey = iterator.Key()
|
||||
|
||||
if !countTotal {
|
||||
if !pageRequest.CountTotal {
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -137,7 +114,7 @@ func Paginate(
|
||||
}
|
||||
|
||||
res := &PageResponse{NextKey: nextKey}
|
||||
if countTotal {
|
||||
if pageRequest.CountTotal {
|
||||
res.Total = count
|
||||
}
|
||||
|
||||
@ -159,3 +136,25 @@ func getIterator(prefixStore types.KVStore, start []byte, reverse bool) db.Itera
|
||||
}
|
||||
return prefixStore.Iterator(start, nil)
|
||||
}
|
||||
|
||||
// initPageRequestDefaults initializes a PageRequest's defaults when those are not set.
|
||||
func initPageRequestDefaults(pageRequest *PageRequest) *PageRequest {
|
||||
// if the PageRequest is nil, use default PageRequest
|
||||
if pageRequest == nil {
|
||||
pageRequest = &PageRequest{}
|
||||
}
|
||||
|
||||
pageRequestCopy := *pageRequest
|
||||
if len(pageRequestCopy.Key) == 0 {
|
||||
pageRequestCopy.Key = nil
|
||||
}
|
||||
|
||||
if pageRequestCopy.Limit == 0 {
|
||||
pageRequestCopy.Limit = DefaultLimit
|
||||
|
||||
// count total results when the limit is zero/not supplied
|
||||
pageRequestCopy.CountTotal = true
|
||||
}
|
||||
|
||||
return &pageRequestCopy
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user