Add a method to return paginated records with filters

This commit is contained in:
Prathamesh Musale 2024-08-29 19:39:19 +05:30 committed by nabarun
parent 2bab2eeb9b
commit 6dd6a19e55
2 changed files with 108 additions and 1 deletions

View File

@ -321,6 +321,53 @@ func (k Keeper) RecordsFromAttributes(
return records, nil
}
// PaginatedRecordsFromAttributes gets a list of records whose attributes match all provided values
// with pagination.
func (k Keeper) PaginatedRecordsFromAttributes(
ctx sdk.Context,
attributes []*registrytypes.QueryRecordsRequest_KeyValueInput,
all bool,
pagination *query.PageRequest,
) ([]registrytypes.Record, *query.PageResponse, error) {
resultRecordIds := []string{}
for i, attr := range attributes {
suffix, err := QueryValueToJSON(attr.Value)
if err != nil {
return nil, nil, err
}
mapKey := collections.Join(attr.Key, string(suffix))
recordIds, err := k.getAttributeMapping(ctx, mapKey)
if err != nil {
return nil, nil, err
}
if i == 0 {
resultRecordIds = recordIds
} else {
resultRecordIds = getIntersection(recordIds, resultRecordIds)
}
}
paginatedResultRecordIds, pageRes := paginate(resultRecordIds, pagination)
records := []registrytypes.Record{}
for _, id := range paginatedResultRecordIds {
record, err := k.GetRecordById(ctx, id)
if err != nil {
return nil, nil, err
}
if record.Deleted {
continue
}
if !all && len(record.Names) == 0 {
continue
}
records = append(records, record)
}
return records, pageRes, nil
}
// TODO not recursive, and only should be if we want to support querying with whole sub-objects,
// which seems unnecessary.
func QueryValueToJSON(input *registrytypes.QueryRecordsRequest_ValueInput) ([]byte, error) {
@ -734,6 +781,38 @@ func (k Keeper) tryTakeRecordRent(ctx sdk.Context, record registrytypes.Record)
return k.SaveRecord(ctx, record)
}
// paginate implements basic pagination over a list of objects
func paginate[T any](data []T, pagination *query.PageRequest) ([]T, *query.PageResponse) {
pageReq := initPageRequestDefaults(pagination)
offset := pageReq.Offset
limit := pageReq.Limit
countTotal := pageReq.CountTotal
totalItems := uint64(len(data))
start := offset
end := offset + limit
if start > totalItems {
if countTotal {
return []T{}, &query.PageResponse{Total: 0}
} else {
return []T{}, nil
}
}
if end > totalItems {
end = totalItems
}
paginatedItems := data[start:end]
if countTotal {
return paginatedItems, &query.PageResponse{Total: end - start}
} else {
return paginatedItems, nil
}
}
func getIntersection(a []string, b []string) []string {
result := []string{}
if len(a) < len(b) {
@ -760,3 +839,26 @@ func contains(arr []string, str string) bool {
}
return false
}
// https://github.com/cosmos/cosmos-sdk/blob/v0.50.3/types/query/pagination.go#L141
// initPageRequestDefaults initializes a PageRequest's defaults when those are not set.
func initPageRequestDefaults(pageRequest *query.PageRequest) *query.PageRequest {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &query.PageRequest{}
}
pageRequestCopy := *pageRequest
if len(pageRequestCopy.Key) == 0 {
pageRequestCopy.Key = nil
}
if pageRequestCopy.Limit == 0 {
pageRequestCopy.Limit = query.DefaultLimit
// count total results when the limit is zero/not supplied
pageRequestCopy.CountTotal = true
}
return &pageRequestCopy
}

View File

@ -43,7 +43,12 @@ func (qs queryServer) Records(c context.Context, req *registrytypes.QueryRecords
var pageResp *query.PageResponse
var err error
if len(attributes) > 0 {
records, err = qs.k.RecordsFromAttributes(ctx, attributes, all)
if req.Pagination != nil {
records, pageResp, err = qs.k.PaginatedRecordsFromAttributes(ctx, attributes, all, req.Pagination)
} else {
records, err = qs.k.RecordsFromAttributes(ctx, attributes, all)
}
if err != nil {
return nil, err
}