feat(bank/v2): Introduce global send restriction (#21925)

This commit is contained in:
Hieu Vu 2024-10-02 13:03:04 +07:00 committed by GitHub
parent 90f362d782
commit 0102077fb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 301 additions and 15 deletions

View File

@ -14,4 +14,10 @@ message Module {
// authority defines the custom module authority. If not set, defaults to the governance module.
string authority = 1;
// restrictions_order specifies the order of send restrictions and should be
// a list of module names which provide a send restriction instance. If no
// order is provided, then restrictions will be applied in alphabetical order
// of module names.
repeated string restrictions_order = 2;
}

View File

@ -1,6 +1,11 @@
package bankv2
import (
"fmt"
"maps"
"slices"
"sort"
"cosmossdk.io/core/address"
"cosmossdk.io/core/appmodule"
"cosmossdk.io/depinject"
@ -22,6 +27,7 @@ func init() {
appconfig.RegisterModule(
&moduletypes.Module{},
appconfig.Provide(ProvideModule),
appconfig.Invoke(InvokeSetSendRestrictions),
)
}
@ -61,3 +67,39 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
Module: m,
}
}
func InvokeSetSendRestrictions(
config *moduletypes.Module,
keeper keeper.Keeper,
restrictions map[string]types.SendRestrictionFn,
) error {
if config == nil {
return nil
}
modules := slices.Collect(maps.Keys(restrictions))
order := config.RestrictionsOrder
if len(order) == 0 {
order = modules
sort.Strings(order)
}
if len(order) != len(modules) {
return fmt.Errorf("len(restrictions order: %v) != len(restriction modules: %v)", order, modules)
}
if len(modules) == 0 {
return nil
}
for _, module := range order {
restriction, ok := restrictions[module]
if !ok {
return fmt.Errorf("can't find send restriction for module %s", module)
}
keeper.AppendGlobalSendRestriction(restriction)
}
return nil
}

View File

@ -28,18 +28,21 @@ type Keeper struct {
params collections.Item[types.Params]
balances *collections.IndexedMap[collections.Pair[[]byte, string], math.Int, BalancesIndexes]
supply collections.Map[string, math.Int]
sendRestriction *sendRestriction
}
func NewKeeper(authority []byte, addressCodec address.Codec, env appmodulev2.Environment, cdc codec.BinaryCodec) *Keeper {
sb := collections.NewSchemaBuilder(env.KVStoreService)
k := &Keeper{
Environment: env,
authority: authority,
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
Environment: env,
authority: authority,
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
sendRestriction: newSendRestriction(),
}
schema, err := sb.Build()
@ -94,7 +97,10 @@ func (k Keeper) SendCoins(ctx context.Context, from, to []byte, amt sdk.Coins) e
}
var err error
// TODO: Send restriction
to, err = k.sendRestriction.apply(ctx, from, to, amt)
if err != nil {
return err
}
err = k.subUnlockedCoins(ctx, from, amt)
if err != nil {

View File

@ -1,7 +1,9 @@
package keeper_test
import (
"bytes"
"context"
"fmt"
"testing"
"time"
@ -184,3 +186,53 @@ func (suite *KeeperTestSuite) TestSendCoins_Module_To_Module() {
mintBarBalance := suite.bankKeeper.GetBalance(ctx, mintAcc.GetAddress(), barDenom)
require.Equal(mintBarBalance.Amount, math.NewInt(0))
}
func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() {
ctx := suite.ctx
require := suite.Require()
balances := sdk.NewCoins(newFooCoin(100), newBarCoin(50))
sendAmt := sdk.NewCoins(newFooCoin(10), newBarCoin(10))
require.NoError(banktestutil.FundAccount(ctx, suite.bankKeeper, accAddrs[0], balances))
// Add first restriction
addrRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) {
if bytes.Equal(from, to) {
return nil, fmt.Errorf("Can not send to same address")
}
return to, nil
}
suite.bankKeeper.AppendGlobalSendRestriction(addrRestrictFunc)
err := suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[0], sendAmt)
require.Error(err)
require.Contains(err.Error(), "Can not send to same address")
// Add second restriction
amtRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) {
if len(amount) > 1 {
return nil, fmt.Errorf("Allow only one denom per one send")
}
return to, nil
}
suite.bankKeeper.AppendGlobalSendRestriction(amtRestrictFunc)
// Pass the 1st but failt at the 2nd
err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sendAmt)
require.Error(err)
require.Contains(err.Error(), "Allow only one denom per one send")
// Pass both 2 restrictions
err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sdk.NewCoins(newFooCoin(10)))
require.NoError(err)
// Check balances
acc0FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], fooDenom)
require.Equal(acc0FooBalance.Amount, math.NewInt(90))
acc0BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], barDenom)
require.Equal(acc0BarBalance.Amount, math.NewInt(50))
acc1FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], fooDenom)
require.Equal(acc1FooBalance.Amount, math.NewInt(10))
acc1BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], barDenom)
require.Equal(acc1BarBalance.Amount, math.ZeroInt())
}

View File

@ -0,0 +1,62 @@
package keeper
import (
"context"
"cosmossdk.io/x/bank/v2/types"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// sendRestriction is a struct that houses a SendRestrictionFn.
// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver.
type sendRestriction struct {
fn types.SendRestrictionFn
}
// newSendRestriction creates a new sendRestriction with nil send restriction.
func newSendRestriction() *sendRestriction {
return &sendRestriction{
fn: nil,
}
}
// append adds the provided restriction to this, to be run after the existing function.
func (r *sendRestriction) append(restriction types.SendRestrictionFn) {
r.fn = r.fn.Then(restriction)
}
// prepend adds the provided restriction to this, to be run before the existing function.
func (r *sendRestriction) prepend(restriction types.SendRestrictionFn) {
r.fn = restriction.Then(r.fn)
}
// clear removes the send restriction (sets it to nil).
func (r *sendRestriction) clear() {
r.fn = nil
}
var _ types.SendRestrictionFn = (*sendRestriction)(nil).apply
// apply applies the send restriction if there is one. If not, it's a no-op.
func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
if r == nil || r.fn == nil {
return toAddr, nil
}
return r.fn(ctx, fromAddr, toAddr, amt)
}
// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions.
func (k Keeper) AppendGlobalSendRestriction(restriction types.SendRestrictionFn) {
k.sendRestriction.append(restriction)
}
// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions.
func (k Keeper) PrependGlobalSendRestriction(restriction types.SendRestrictionFn) {
k.sendRestriction.prepend(restriction)
}
// ClearSendRestriction removes the send restriction (if there is one).
func (k Keeper) ClearGlobalSendRestriction() {
k.sendRestriction.clear()
}

View File

@ -27,6 +27,11 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Module struct {
// authority defines the custom module authority. If not set, defaults to the governance module.
Authority string `protobuf:"bytes,1,opt,name=authority,proto3" json:"authority,omitempty"`
// restrictions_order specifies the order of send restrictions and should be
// a list of module names which provide a send restriction instance. If no
// order is provided, then restrictions will be applied in alphabetical order
// of module names.
RestrictionsOrder []string `protobuf:"bytes,2,rep,name=restrictions_order,json=restrictionsOrder,proto3" json:"restrictions_order,omitempty"`
}
func (m *Module) Reset() { *m = Module{} }
@ -69,6 +74,13 @@ func (m *Module) GetAuthority() string {
return ""
}
func (m *Module) GetRestrictionsOrder() []string {
if m != nil {
return m.RestrictionsOrder
}
return nil
}
func init() {
proto.RegisterType((*Module)(nil), "cosmos.bank.module.v2.Module")
}
@ -78,19 +90,21 @@ func init() {
}
var fileDescriptor_34a109a905e2a25b = []byte{
// 184 bytes of a gzipped FileDescriptorProto
// 219 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x52, 0x4a, 0xce, 0x2f, 0xce,
0xcd, 0x2f, 0xd6, 0x4f, 0x4a, 0xcc, 0xcb, 0xd6, 0xcf, 0xcd, 0x4f, 0x29, 0xcd, 0x49, 0xd5, 0x2f,
0x33, 0x82, 0xb2, 0xf4, 0x0a, 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0x44, 0x21, 0x6a, 0xf4, 0x40, 0x6a,
0xf4, 0xa0, 0x32, 0x65, 0x46, 0x52, 0x0a, 0x50, 0xad, 0x89, 0x05, 0x05, 0xfa, 0x65, 0x86, 0x89,
0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0xdc, 0xb8, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19,
0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0x4a, 0xb9, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19,
0x2e, 0xce, 0xc4, 0xd2, 0x92, 0x8c, 0xfc, 0xa2, 0xcc, 0x92, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d,
0xce, 0x20, 0x84, 0x80, 0x95, 0xdc, 0xae, 0x03, 0xd3, 0x6e, 0x31, 0x4a, 0x70, 0x89, 0x41, 0x4c,
0x2c, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0x38, 0xaa, 0xcc, 0xc8, 0xc9, 0xf6, 0xc4,
0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1,
0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, 0x2f, 0xa9, 0x2c,
0x48, 0x2d, 0x86, 0x3a, 0x26, 0x89, 0x0d, 0xec, 0x1a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
0x5e, 0xfc, 0x10, 0x2c, 0xec, 0x00, 0x00, 0x00,
0xce, 0x20, 0x84, 0x80, 0x90, 0x2e, 0x97, 0x50, 0x51, 0x6a, 0x71, 0x49, 0x51, 0x66, 0x72, 0x49,
0x66, 0x7e, 0x5e, 0x71, 0x7c, 0x7e, 0x51, 0x4a, 0x6a, 0x91, 0x04, 0x93, 0x02, 0xb3, 0x06, 0x67,
0x90, 0x20, 0xb2, 0x8c, 0x3f, 0x48, 0xc2, 0x4a, 0x6e, 0xd7, 0x81, 0x69, 0xb7, 0x18, 0x25, 0xb8,
0xc4, 0x20, 0x0e, 0x28, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0xf8, 0xa1, 0xcc, 0xc8,
0xc9, 0xf6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0,
0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0,
0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0x86, 0xba, 0x3d, 0x89, 0x0d, 0xec, 0x78, 0x63, 0x40, 0x00, 0x00,
0x00, 0xff, 0xff, 0x69, 0x6d, 0xb0, 0x10, 0x1b, 0x01, 0x00, 0x00,
}
func (m *Module) Marshal() (dAtA []byte, err error) {
@ -113,6 +127,15 @@ func (m *Module) MarshalToSizedBuffer(dAtA []byte) (int, error) {
_ = i
var l int
_ = l
if len(m.RestrictionsOrder) > 0 {
for iNdEx := len(m.RestrictionsOrder) - 1; iNdEx >= 0; iNdEx-- {
i -= len(m.RestrictionsOrder[iNdEx])
copy(dAtA[i:], m.RestrictionsOrder[iNdEx])
i = encodeVarintModule(dAtA, i, uint64(len(m.RestrictionsOrder[iNdEx])))
i--
dAtA[i] = 0x12
}
}
if len(m.Authority) > 0 {
i -= len(m.Authority)
copy(dAtA[i:], m.Authority)
@ -144,6 +167,12 @@ func (m *Module) Size() (n int) {
if l > 0 {
n += 1 + l + sovModule(uint64(l))
}
if len(m.RestrictionsOrder) > 0 {
for _, s := range m.RestrictionsOrder {
l = len(s)
n += 1 + l + sovModule(uint64(l))
}
}
return n
}
@ -214,6 +243,38 @@ func (m *Module) Unmarshal(dAtA []byte) error {
}
m.Authority = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field RestrictionsOrder", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowModule
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthModule
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthModule
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.RestrictionsOrder = append(m.RestrictionsOrder, string(dAtA[iNdEx:postIndex]))
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipModule(dAtA[iNdEx:])

View File

@ -0,0 +1,57 @@
package types
import (
"context"
sdk "github.com/cosmos/cosmos-sdk/types"
)
// A SendRestrictionFn can restrict sends and/or provide a new receiver address.
type SendRestrictionFn func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) (newToAddr []byte, err error)
// IsOnePerModuleType implements the depinject.OnePerModuleType interface.
func (SendRestrictionFn) IsOnePerModuleType() {}
var _ SendRestrictionFn = NoOpSendRestrictionFn
// NoOpSendRestrictionFn is a no-op SendRestrictionFn.
func NoOpSendRestrictionFn(_ context.Context, _, toAddr []byte, _ sdk.Coins) ([]byte, error) {
return toAddr, nil
}
// Then creates a composite restriction that runs this one then the provided second one.
func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn {
return ComposeSendRestrictions(r, second)
}
// ComposeSendRestrictions combines multiple SendRestrictionFn into one.
// nil entries are ignored.
// If all entries are nil, nil is returned.
// If exactly one entry is not nil, it is returned.
// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given.
// The composition runs each send restriction until an error is encountered and returns that error,
// otherwise it returns the toAddr of the last send restriction.
func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn {
toRun := make([]SendRestrictionFn, 0, len(restrictions))
for _, r := range restrictions {
if r != nil {
toRun = append(toRun, r)
}
}
switch len(toRun) {
case 0:
return nil
case 1:
return toRun[0]
}
return func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
var err error
for _, r := range toRun {
toAddr, err = r(ctx, fromAddr, toAddr, amt)
if err != nil {
return toAddr, err
}
}
return toAddr, err
}
}