From a6844f5ed40f327ddf1050246e42126caff559af Mon Sep 17 00:00:00 2001 From: Marko Date: Thu, 13 Jun 2024 10:09:02 +0200 Subject: [PATCH] fix: remove recipient amount from map (#20625) --- x/protocolpool/keeper/keeper.go | 19 +++++++++++------ x/protocolpool/keeper/msg_server.go | 16 ++++++++++++++ x/protocolpool/keeper/msg_server_test.go | 27 +++++++++++++++++++----- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/x/protocolpool/keeper/keeper.go b/x/protocolpool/keeper/keeper.go index 5eb2ae4ec6..f9dd892674 100644 --- a/x/protocolpool/keeper/keeper.go +++ b/x/protocolpool/keeper/keeper.go @@ -254,11 +254,15 @@ func (k Keeper) hasPermission(addr []byte) (bool, error) { return bytes.Equal(authAcc, addr), nil } +type recipientFund struct { + RecipientAddr string + Percentage math.Int +} + func (k Keeper) iterateAndUpdateFundsDistribution(ctx context.Context, toDistributeAmount math.Int) error { totalPercentageToBeDistributed := math.ZeroInt() - // Create a map to store keys & values from RecipientFundPercentage during the first iteration - recipientFundMap := make(map[string]math.Int) + recipientFundList := []recipientFund{} // Calculate totalPercentageToBeDistributed and store values err := k.RecipientFundPercentage.Walk(ctx, nil, func(key sdk.AccAddress, value math.Int) (stop bool, err error) { @@ -267,7 +271,10 @@ func (k Keeper) iterateAndUpdateFundsDistribution(ctx context.Context, toDistrib return true, err } totalPercentageToBeDistributed = totalPercentageToBeDistributed.Add(value) - recipientFundMap[addr] = value + recipientFundList = append(recipientFundList, recipientFund{ + RecipientAddr: addr, + Percentage: value, + }) return false, nil }) if err != nil { @@ -287,14 +294,14 @@ func (k Keeper) iterateAndUpdateFundsDistribution(ctx context.Context, toDistrib totalAmountToBeDistributed := toDistributeDec.MulDec(math.LegacyNewDecFromIntWithPrec(totalPercentageToBeDistributed, 2)) totalDistrAmount := totalAmountToBeDistributed.AmountOf(denom) - for keyStr, value := range recipientFundMap { + for _, value := range recipientFundList { // Calculate the funds to be distributed based on the percentage - decValue := math.LegacyNewDecFromIntWithPrec(value, 2) + decValue := math.LegacyNewDecFromIntWithPrec(value.Percentage, 2) percentage := math.LegacyNewDecFromIntWithPrec(totalPercentageToBeDistributed, 2) recipientAmount := totalDistrAmount.Mul(decValue).Quo(percentage) recipientCoins := recipientAmount.TruncateInt() - key, err := k.authKeeper.AddressCodec().StringToBytes(keyStr) + key, err := k.authKeeper.AddressCodec().StringToBytes(value.RecipientAddr) if err != nil { return err } diff --git a/x/protocolpool/keeper/msg_server.go b/x/protocolpool/keeper/msg_server.go index 9111986318..3af5f91df8 100644 --- a/x/protocolpool/keeper/msg_server.go +++ b/x/protocolpool/keeper/msg_server.go @@ -110,6 +110,14 @@ func (k MsgServer) CreateContinuousFund(ctx context.Context, msg *types.MsgCreat return nil, err } + has, err := k.ContinuousFund.Has(ctx, recipient) + if err != nil { + return nil, err + } + if has { + return nil, fmt.Errorf("continuous fund already exists for recipient %s", msg.Recipient) + } + // Validate the message fields err = k.validateContinuousFund(ctx, *msg) if err != nil { @@ -201,6 +209,14 @@ func (k MsgServer) CancelContinuousFund(ctx context.Context, msg *types.MsgCance return nil, fmt.Errorf("failed to remove continuous fund for recipient %s: %w", msg.RecipientAddress, err) } + if err := k.RecipientFundPercentage.Remove(ctx, recipient); err != nil { + return nil, fmt.Errorf("failed to remove recipient fund percentage for recipient %s: %w", msg.RecipientAddress, err) + } + + if err := k.RecipientFundDistribution.Remove(ctx, recipient); err != nil { + return nil, fmt.Errorf("failed to remove recipient fund distribution for recipient %s: %w", msg.RecipientAddress, err) + } + return &types.MsgCancelContinuousFundResponse{ CanceledTime: canceledTime, CanceledHeight: uint64(canceledHeight), diff --git a/x/protocolpool/keeper/msg_server_test.go b/x/protocolpool/keeper/msg_server_test.go index fb9fda3ede..9b17b3cda2 100644 --- a/x/protocolpool/keeper/msg_server_test.go +++ b/x/protocolpool/keeper/msg_server_test.go @@ -810,6 +810,10 @@ func (suite *KeeperTestSuite) TestCancelContinuousFund() { recipient2 := sdk.AccAddress([]byte("recipientAddr2___________________")) recipient2StrAddr, err := codectestutil.CodecOptions{}.GetAddressCodec().BytesToString(recipient2) suite.Require().NoError(err) + recipient3 := sdk.AccAddress([]byte("recipientAddr3___________________")) + recipient3StrAddr, err := codectestutil.CodecOptions{}.GetAddressCodec().BytesToString(recipient3) + suite.Require().NoError(err) + testCases := map[string]struct { preRun func() recipientAddr sdk.AccAddress @@ -908,20 +912,26 @@ func (suite *KeeperTestSuite) TestCancelContinuousFund() { oneMonthInSeconds := int64(30 * 24 * 60 * 60) // Approximate number of seconds in 1 month expiry := suite.environment.HeaderService.HeaderInfo(suite.ctx).Time.Add(time.Duration(oneMonthInSeconds) * time.Second) cf := types.ContinuousFund{ - Recipient: recipientStrAddr, + Recipient: recipient3StrAddr, Percentage: percentage, Expiry: &expiry, } - err = suite.poolKeeper.ContinuousFund.Set(suite.ctx, recipientAddr, cf) + suite.mockWithdrawContinuousFund() + err = suite.poolKeeper.ContinuousFund.Set(suite.ctx, recipient3, cf) + suite.Require().NoError(err) + err = suite.poolKeeper.RecipientFundPercentage.Set(suite.ctx, recipient3, math.ZeroInt()) + suite.Require().NoError(err) + err = suite.poolKeeper.RecipientFundDistribution.Set(suite.ctx, recipient3, math.ZeroInt()) suite.Require().NoError(err) }, - recipientAddr: recipientAddr, + recipientAddr: recipient3, expErr: false, postRun: func() { - _, err := suite.poolKeeper.ContinuousFund.Get(suite.ctx, recipientAddr) + _, err := suite.poolKeeper.ContinuousFund.Get(suite.ctx, recipient3) suite.Require().Error(err) suite.Require().ErrorIs(err, collections.ErrNotFound) }, + withdrawnFunds: sdk.NewCoin(sdk.DefaultBondDenom, math.NewInt(0)), }, } @@ -943,7 +953,14 @@ func (suite *KeeperTestSuite) TestCancelContinuousFund() { suite.Require().Contains(err.Error(), tc.expErrMsg) } else { suite.Require().NoError(err) - suite.Require().Equal(resp.WithdrawnAllocatedFund, tc.withdrawnFunds) + suite.Require().Equal(tc.withdrawnFunds, resp.WithdrawnAllocatedFund) + // All items below should return error as they are removed from the store + _, err := suite.poolKeeper.RecipientFundPercentage.Get(suite.ctx, tc.recipientAddr) + suite.Require().Contains(err.Error(), "collections: not found") + _, err = suite.poolKeeper.ContinuousFund.Get(suite.ctx, tc.recipientAddr) + suite.Require().Contains(err.Error(), "collections: not found") + _, err = suite.poolKeeper.RecipientFundDistribution.Get(suite.ctx, tc.recipientAddr) + suite.Require().Contains(err.Error(), "collections: not found") } if tc.postRun != nil { tc.postRun()