From 75b2244fb41f830904c7606815340c70dace3568 Mon Sep 17 00:00:00 2001 From: Prathamesh Musale Date: Thu, 29 Aug 2024 19:39:19 +0530 Subject: [PATCH] Add a method to return paginated records with filters --- x/registry/keeper/keeper.go | 102 ++++++++++++++++++++++++++++++ x/registry/keeper/query_server.go | 7 +- 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/x/registry/keeper/keeper.go b/x/registry/keeper/keeper.go index a59d70ca..d292b66c 100644 --- a/x/registry/keeper/keeper.go +++ b/x/registry/keeper/keeper.go @@ -302,6 +302,53 @@ func (k Keeper) RecordsFromAttributes( return records, nil } +// PaginatedRecordsFromAttributes gets a list of records whose attributes match all provided values +// with pagination. +func (k Keeper) PaginatedRecordsFromAttributes( + ctx sdk.Context, + attributes []*registrytypes.QueryRecordsRequest_KeyValueInput, + all bool, + pagination *query.PageRequest, +) ([]registrytypes.Record, *query.PageResponse, error) { + resultRecordIds := []string{} + for i, attr := range attributes { + suffix, err := QueryValueToJSON(attr.Value) + if err != nil { + return nil, nil, err + } + mapKey := collections.Join(attr.Key, string(suffix)) + recordIds, err := k.getAttributeMapping(ctx, mapKey) + if err != nil { + return nil, nil, err + } + + if i == 0 { + resultRecordIds = recordIds + } else { + resultRecordIds = getIntersection(recordIds, resultRecordIds) + } + } + + paginatedResultRecordIds, pageRes := paginate(resultRecordIds, pagination) + + records := []registrytypes.Record{} + for _, id := range paginatedResultRecordIds { + record, err := k.GetRecordById(ctx, id) + if err != nil { + return nil, nil, err + } + if record.Deleted { + continue + } + if !all && len(record.Names) == 0 { + continue + } + records = append(records, record) + } + + return records, pageRes, nil +} + // TODO not recursive, and only should be if we want to support querying with whole sub-objects, // which seems unnecessary. func QueryValueToJSON(input *registrytypes.QueryRecordsRequest_ValueInput) ([]byte, error) { @@ -715,6 +762,38 @@ func (k Keeper) tryTakeRecordRent(ctx sdk.Context, record registrytypes.Record) return k.SaveRecord(ctx, record) } +// paginate implements basic pagination over a list of objects +func paginate[T any](data []T, pagination *query.PageRequest) ([]T, *query.PageResponse) { + pageReq := initPageRequestDefaults(pagination) + + offset := pageReq.Offset + limit := pageReq.Limit + countTotal := pageReq.CountTotal + + totalItems := uint64(len(data)) + start := offset + end := offset + limit + + if start > totalItems { + if countTotal { + return []T{}, &query.PageResponse{Total: 0} + } else { + return []T{}, nil + } + } + if end > totalItems { + end = totalItems + } + + paginatedItems := data[start:end] + + if countTotal { + return paginatedItems, &query.PageResponse{Total: end - start} + } else { + return paginatedItems, nil + } +} + func getIntersection(a []string, b []string) []string { result := []string{} if len(a) < len(b) { @@ -741,3 +820,26 @@ func contains(arr []string, str string) bool { } return false } + +// https://github.com/cosmos/cosmos-sdk/blob/v0.50.3/types/query/pagination.go#L141 +// initPageRequestDefaults initializes a PageRequest's defaults when those are not set. +func initPageRequestDefaults(pageRequest *query.PageRequest) *query.PageRequest { + // if the PageRequest is nil, use default PageRequest + if pageRequest == nil { + pageRequest = &query.PageRequest{} + } + + pageRequestCopy := *pageRequest + if len(pageRequestCopy.Key) == 0 { + pageRequestCopy.Key = nil + } + + if pageRequestCopy.Limit == 0 { + pageRequestCopy.Limit = query.DefaultLimit + + // count total results when the limit is zero/not supplied + pageRequestCopy.CountTotal = true + } + + return &pageRequestCopy +} diff --git a/x/registry/keeper/query_server.go b/x/registry/keeper/query_server.go index f978ceaa..724e2454 100644 --- a/x/registry/keeper/query_server.go +++ b/x/registry/keeper/query_server.go @@ -43,7 +43,12 @@ func (qs queryServer) Records(c context.Context, req *registrytypes.QueryRecords var pageResp *query.PageResponse var err error if len(attributes) > 0 { - records, err = qs.k.RecordsFromAttributes(ctx, attributes, all) + if req.Pagination != nil { + records, pageResp, err = qs.k.PaginatedRecordsFromAttributes(ctx, attributes, all, req.Pagination) + } else { + records, err = qs.k.RecordsFromAttributes(ctx, attributes, all) + } + if err != nil { return nil, err }