From 56ebd3aaa39aa44101a4a5f0042ab23320e16b1c Mon Sep 17 00:00:00 2001 From: Prathamesh Musale Date: Fri, 30 Aug 2024 11:10:52 +0530 Subject: [PATCH] Refactor code --- x/registry/keeper/genesis.go | 2 +- x/registry/keeper/keeper.go | 119 ++++++++++-------------------- x/registry/keeper/query_server.go | 14 +--- 3 files changed, 44 insertions(+), 91 deletions(-) diff --git a/x/registry/keeper/genesis.go b/x/registry/keeper/genesis.go index 8b6c2e74..36ec275d 100644 --- a/x/registry/keeper/genesis.go +++ b/x/registry/keeper/genesis.go @@ -62,7 +62,7 @@ func (k *Keeper) ExportGenesis(ctx sdk.Context) (*registry.GenesisState, error) return nil, err } - records, err := k.ListRecords(ctx) + records, _, err := k.PaginatedListRecords(ctx, nil) if err != nil { return nil, err } diff --git a/x/registry/keeper/keeper.go b/x/registry/keeper/keeper.go index ede73230..e1e14699 100644 --- a/x/registry/keeper/keeper.go +++ b/x/registry/keeper/keeper.go @@ -205,36 +205,35 @@ func (k Keeper) HasRecord(ctx sdk.Context, id string) (bool, error) { return has, nil } -// ListRecords - get all records. -func (k Keeper) ListRecords(ctx sdk.Context) ([]registrytypes.Record, error) { - var records []registrytypes.Record - - err := k.Records.Walk(ctx, nil, func(key string, value registrytypes.Record) (bool, error) { - if err := k.populateRecordNames(ctx, &value); err != nil { - return true, err - } - records = append(records, value) - - return false, nil - }) - if err != nil { - return nil, err - } - - return records, nil -} - -// PaginatedListRecords - get all records with pagination. +// PaginatedListRecords - get all records with optional pagination. func (k Keeper) PaginatedListRecords(ctx sdk.Context, pagination *query.PageRequest) ([]registrytypes.Record, *query.PageResponse, error) { - records, pageResp, err := query.CollectionPaginate(ctx, k.Records, pagination, func(key string, value registrytypes.Record) (registrytypes.Record, error) { - if err := k.populateRecordNames(ctx, &value); err != nil { - return registrytypes.Record{}, err - } + var records []registrytypes.Record + var pageResp *query.PageResponse - return value, nil - }) - if err != nil { - return nil, nil, err + if pagination == nil { + err := k.Records.Walk(ctx, nil, func(key string, value registrytypes.Record) (bool, error) { + if err := k.populateRecordNames(ctx, &value); err != nil { + return true, err + } + records = append(records, value) + + return false, nil + }) + if err != nil { + return nil, nil, err + } + } else { + var err error + records, pageResp, err = query.CollectionPaginate(ctx, k.Records, pagination, func(key string, value registrytypes.Record) (registrytypes.Record, error) { + if err := k.populateRecordNames(ctx, &value); err != nil { + return registrytypes.Record{}, err + } + + return value, nil + }) + if err != nil { + return nil, nil, err + } } return records, pageResp, nil @@ -278,58 +277,18 @@ func (k Keeper) GetRecordsByBondId(ctx sdk.Context, bondId string) ([]registryty return records, nil } -// RecordsFromAttributes gets a list of records whose attributes match all provided values -func (k Keeper) RecordsFromAttributes( - ctx sdk.Context, - attributes []*registrytypes.QueryRecordsRequest_KeyValueInput, - all bool, -) ([]registrytypes.Record, error) { - resultRecordIds := []string{} - for i, attr := range attributes { - suffix, err := QueryValueToJSON(attr.Value) - if err != nil { - return nil, err - } - mapKey := collections.Join(attr.Key, string(suffix)) - recordIds, err := k.getAttributeMapping(ctx, mapKey) - if err != nil { - return nil, err - } - - if i == 0 { - resultRecordIds = recordIds - } else { - resultRecordIds = getIntersection(recordIds, resultRecordIds) - } - } - - records := []registrytypes.Record{} - for _, id := range resultRecordIds { - record, err := k.GetRecordById(ctx, id) - if err != nil { - return nil, err - } - if record.Deleted { - continue - } - if !all && len(record.Names) == 0 { - continue - } - records = append(records, record) - } - - return records, nil -} - // PaginatedRecordsFromAttributes gets a list of records whose attributes match all provided values -// with pagination. +// with optional pagination. func (k Keeper) PaginatedRecordsFromAttributes( ctx sdk.Context, attributes []*registrytypes.QueryRecordsRequest_KeyValueInput, all bool, pagination *query.PageRequest, ) ([]registrytypes.Record, *query.PageResponse, error) { - resultRecordIds := []string{} + var resultRecordIds []string + var pageResp *query.PageResponse + + filteredRecordIds := []string{} for i, attr := range attributes { suffix, err := QueryValueToJSON(attr.Value) if err != nil { @@ -342,16 +301,20 @@ func (k Keeper) PaginatedRecordsFromAttributes( } if i == 0 { - resultRecordIds = recordIds + filteredRecordIds = recordIds } else { - resultRecordIds = getIntersection(recordIds, resultRecordIds) + filteredRecordIds = getIntersection(recordIds, filteredRecordIds) } } - paginatedResultRecordIds, pageRes := paginate(resultRecordIds, pagination) + if pagination != nil { + resultRecordIds, pageResp = paginate(filteredRecordIds, pagination) + } else { + resultRecordIds = filteredRecordIds + } records := []registrytypes.Record{} - for _, id := range paginatedResultRecordIds { + for _, id := range resultRecordIds { record, err := k.GetRecordById(ctx, id) if err != nil { return nil, nil, err @@ -365,7 +328,7 @@ func (k Keeper) PaginatedRecordsFromAttributes( records = append(records, record) } - return records, pageRes, nil + return records, pageResp, nil } // TODO not recursive, and only should be if we want to support querying with whole sub-objects, diff --git a/x/registry/keeper/query_server.go b/x/registry/keeper/query_server.go index 724e2454..b34e69ce 100644 --- a/x/registry/keeper/query_server.go +++ b/x/registry/keeper/query_server.go @@ -43,22 +43,12 @@ func (qs queryServer) Records(c context.Context, req *registrytypes.QueryRecords var pageResp *query.PageResponse var err error if len(attributes) > 0 { - if req.Pagination != nil { - records, pageResp, err = qs.k.PaginatedRecordsFromAttributes(ctx, attributes, all, req.Pagination) - } else { - records, err = qs.k.RecordsFromAttributes(ctx, attributes, all) - } - + records, pageResp, err = qs.k.PaginatedRecordsFromAttributes(ctx, attributes, all, req.Pagination) if err != nil { return nil, err } } else { - if req.Pagination != nil { - records, pageResp, err = qs.k.PaginatedListRecords(ctx, req.Pagination) - } else { - records, err = qs.k.ListRecords(ctx) - } - + records, pageResp, err = qs.k.PaginatedListRecords(ctx, req.Pagination) if err != nil { return nil, err }