feat(bank/v2): Introduce global send restriction (#21925)
This commit is contained in:
parent
90f362d782
commit
0102077fb2
@ -14,4 +14,10 @@ message Module {
|
||||
|
||||
// authority defines the custom module authority. If not set, defaults to the governance module.
|
||||
string authority = 1;
|
||||
|
||||
// restrictions_order specifies the order of send restrictions and should be
|
||||
// a list of module names which provide a send restriction instance. If no
|
||||
// order is provided, then restrictions will be applied in alphabetical order
|
||||
// of module names.
|
||||
repeated string restrictions_order = 2;
|
||||
}
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
package bankv2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"sort"
|
||||
|
||||
"cosmossdk.io/core/address"
|
||||
"cosmossdk.io/core/appmodule"
|
||||
"cosmossdk.io/depinject"
|
||||
@ -22,6 +27,7 @@ func init() {
|
||||
appconfig.RegisterModule(
|
||||
&moduletypes.Module{},
|
||||
appconfig.Provide(ProvideModule),
|
||||
appconfig.Invoke(InvokeSetSendRestrictions),
|
||||
)
|
||||
}
|
||||
|
||||
@ -61,3 +67,39 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
|
||||
Module: m,
|
||||
}
|
||||
}
|
||||
|
||||
func InvokeSetSendRestrictions(
|
||||
config *moduletypes.Module,
|
||||
keeper keeper.Keeper,
|
||||
restrictions map[string]types.SendRestrictionFn,
|
||||
) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modules := slices.Collect(maps.Keys(restrictions))
|
||||
order := config.RestrictionsOrder
|
||||
if len(order) == 0 {
|
||||
order = modules
|
||||
sort.Strings(order)
|
||||
}
|
||||
|
||||
if len(order) != len(modules) {
|
||||
return fmt.Errorf("len(restrictions order: %v) != len(restriction modules: %v)", order, modules)
|
||||
}
|
||||
|
||||
if len(modules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, module := range order {
|
||||
restriction, ok := restrictions[module]
|
||||
if !ok {
|
||||
return fmt.Errorf("can't find send restriction for module %s", module)
|
||||
}
|
||||
|
||||
keeper.AppendGlobalSendRestriction(restriction)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -28,18 +28,21 @@ type Keeper struct {
|
||||
params collections.Item[types.Params]
|
||||
balances *collections.IndexedMap[collections.Pair[[]byte, string], math.Int, BalancesIndexes]
|
||||
supply collections.Map[string, math.Int]
|
||||
|
||||
sendRestriction *sendRestriction
|
||||
}
|
||||
|
||||
func NewKeeper(authority []byte, addressCodec address.Codec, env appmodulev2.Environment, cdc codec.BinaryCodec) *Keeper {
|
||||
sb := collections.NewSchemaBuilder(env.KVStoreService)
|
||||
|
||||
k := &Keeper{
|
||||
Environment: env,
|
||||
authority: authority,
|
||||
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
|
||||
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
|
||||
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
|
||||
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
|
||||
Environment: env,
|
||||
authority: authority,
|
||||
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
|
||||
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
|
||||
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
|
||||
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
|
||||
sendRestriction: newSendRestriction(),
|
||||
}
|
||||
|
||||
schema, err := sb.Build()
|
||||
@ -94,7 +97,10 @@ func (k Keeper) SendCoins(ctx context.Context, from, to []byte, amt sdk.Coins) e
|
||||
}
|
||||
|
||||
var err error
|
||||
// TODO: Send restriction
|
||||
to, err = k.sendRestriction.apply(ctx, from, to, amt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = k.subUnlockedCoins(ctx, from, amt)
|
||||
if err != nil {
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
package keeper_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -184,3 +186,53 @@ func (suite *KeeperTestSuite) TestSendCoins_Module_To_Module() {
|
||||
mintBarBalance := suite.bankKeeper.GetBalance(ctx, mintAcc.GetAddress(), barDenom)
|
||||
require.Equal(mintBarBalance.Amount, math.NewInt(0))
|
||||
}
|
||||
|
||||
func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() {
|
||||
ctx := suite.ctx
|
||||
require := suite.Require()
|
||||
balances := sdk.NewCoins(newFooCoin(100), newBarCoin(50))
|
||||
sendAmt := sdk.NewCoins(newFooCoin(10), newBarCoin(10))
|
||||
|
||||
require.NoError(banktestutil.FundAccount(ctx, suite.bankKeeper, accAddrs[0], balances))
|
||||
|
||||
// Add first restriction
|
||||
addrRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) {
|
||||
if bytes.Equal(from, to) {
|
||||
return nil, fmt.Errorf("Can not send to same address")
|
||||
}
|
||||
return to, nil
|
||||
}
|
||||
suite.bankKeeper.AppendGlobalSendRestriction(addrRestrictFunc)
|
||||
|
||||
err := suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[0], sendAmt)
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), "Can not send to same address")
|
||||
|
||||
// Add second restriction
|
||||
amtRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) {
|
||||
if len(amount) > 1 {
|
||||
return nil, fmt.Errorf("Allow only one denom per one send")
|
||||
}
|
||||
return to, nil
|
||||
}
|
||||
suite.bankKeeper.AppendGlobalSendRestriction(amtRestrictFunc)
|
||||
|
||||
// Pass the 1st but failt at the 2nd
|
||||
err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sendAmt)
|
||||
require.Error(err)
|
||||
require.Contains(err.Error(), "Allow only one denom per one send")
|
||||
|
||||
// Pass both 2 restrictions
|
||||
err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sdk.NewCoins(newFooCoin(10)))
|
||||
require.NoError(err)
|
||||
|
||||
// Check balances
|
||||
acc0FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], fooDenom)
|
||||
require.Equal(acc0FooBalance.Amount, math.NewInt(90))
|
||||
acc0BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], barDenom)
|
||||
require.Equal(acc0BarBalance.Amount, math.NewInt(50))
|
||||
acc1FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], fooDenom)
|
||||
require.Equal(acc1FooBalance.Amount, math.NewInt(10))
|
||||
acc1BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], barDenom)
|
||||
require.Equal(acc1BarBalance.Amount, math.ZeroInt())
|
||||
}
|
||||
|
||||
62
x/bank/v2/keeper/restriction.go
Normal file
62
x/bank/v2/keeper/restriction.go
Normal file
@ -0,0 +1,62 @@
|
||||
package keeper
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cosmossdk.io/x/bank/v2/types"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
)
|
||||
|
||||
// sendRestriction is a struct that houses a SendRestrictionFn.
|
||||
// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver.
|
||||
type sendRestriction struct {
|
||||
fn types.SendRestrictionFn
|
||||
}
|
||||
|
||||
// newSendRestriction creates a new sendRestriction with nil send restriction.
|
||||
func newSendRestriction() *sendRestriction {
|
||||
return &sendRestriction{
|
||||
fn: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// append adds the provided restriction to this, to be run after the existing function.
|
||||
func (r *sendRestriction) append(restriction types.SendRestrictionFn) {
|
||||
r.fn = r.fn.Then(restriction)
|
||||
}
|
||||
|
||||
// prepend adds the provided restriction to this, to be run before the existing function.
|
||||
func (r *sendRestriction) prepend(restriction types.SendRestrictionFn) {
|
||||
r.fn = restriction.Then(r.fn)
|
||||
}
|
||||
|
||||
// clear removes the send restriction (sets it to nil).
|
||||
func (r *sendRestriction) clear() {
|
||||
r.fn = nil
|
||||
}
|
||||
|
||||
var _ types.SendRestrictionFn = (*sendRestriction)(nil).apply
|
||||
|
||||
// apply applies the send restriction if there is one. If not, it's a no-op.
|
||||
func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
|
||||
if r == nil || r.fn == nil {
|
||||
return toAddr, nil
|
||||
}
|
||||
return r.fn(ctx, fromAddr, toAddr, amt)
|
||||
}
|
||||
|
||||
// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions.
|
||||
func (k Keeper) AppendGlobalSendRestriction(restriction types.SendRestrictionFn) {
|
||||
k.sendRestriction.append(restriction)
|
||||
}
|
||||
|
||||
// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions.
|
||||
func (k Keeper) PrependGlobalSendRestriction(restriction types.SendRestrictionFn) {
|
||||
k.sendRestriction.prepend(restriction)
|
||||
}
|
||||
|
||||
// ClearSendRestriction removes the send restriction (if there is one).
|
||||
func (k Keeper) ClearGlobalSendRestriction() {
|
||||
k.sendRestriction.clear()
|
||||
}
|
||||
@ -27,6 +27,11 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
|
||||
type Module struct {
|
||||
// authority defines the custom module authority. If not set, defaults to the governance module.
|
||||
Authority string `protobuf:"bytes,1,opt,name=authority,proto3" json:"authority,omitempty"`
|
||||
// restrictions_order specifies the order of send restrictions and should be
|
||||
// a list of module names which provide a send restriction instance. If no
|
||||
// order is provided, then restrictions will be applied in alphabetical order
|
||||
// of module names.
|
||||
RestrictionsOrder []string `protobuf:"bytes,2,rep,name=restrictions_order,json=restrictionsOrder,proto3" json:"restrictions_order,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Module) Reset() { *m = Module{} }
|
||||
@ -69,6 +74,13 @@ func (m *Module) GetAuthority() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Module) GetRestrictionsOrder() []string {
|
||||
if m != nil {
|
||||
return m.RestrictionsOrder
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Module)(nil), "cosmos.bank.module.v2.Module")
|
||||
}
|
||||
@ -78,19 +90,21 @@ func init() {
|
||||
}
|
||||
|
||||
var fileDescriptor_34a109a905e2a25b = []byte{
|
||||
// 184 bytes of a gzipped FileDescriptorProto
|
||||
// 219 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x52, 0x4a, 0xce, 0x2f, 0xce,
|
||||
0xcd, 0x2f, 0xd6, 0x4f, 0x4a, 0xcc, 0xcb, 0xd6, 0xcf, 0xcd, 0x4f, 0x29, 0xcd, 0x49, 0xd5, 0x2f,
|
||||
0x33, 0x82, 0xb2, 0xf4, 0x0a, 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0x44, 0x21, 0x6a, 0xf4, 0x40, 0x6a,
|
||||
0xf4, 0xa0, 0x32, 0x65, 0x46, 0x52, 0x0a, 0x50, 0xad, 0x89, 0x05, 0x05, 0xfa, 0x65, 0x86, 0x89,
|
||||
0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0xdc, 0xb8, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19,
|
||||
0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0x4a, 0xb9, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19,
|
||||
0x2e, 0xce, 0xc4, 0xd2, 0x92, 0x8c, 0xfc, 0xa2, 0xcc, 0x92, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d,
|
||||
0xce, 0x20, 0x84, 0x80, 0x95, 0xdc, 0xae, 0x03, 0xd3, 0x6e, 0x31, 0x4a, 0x70, 0x89, 0x41, 0x4c,
|
||||
0x2c, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0x38, 0xaa, 0xcc, 0xc8, 0xc9, 0xf6, 0xc4,
|
||||
0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1,
|
||||
0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, 0x2f, 0xa9, 0x2c,
|
||||
0x48, 0x2d, 0x86, 0x3a, 0x26, 0x89, 0x0d, 0xec, 0x1a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
|
||||
0x5e, 0xfc, 0x10, 0x2c, 0xec, 0x00, 0x00, 0x00,
|
||||
0xce, 0x20, 0x84, 0x80, 0x90, 0x2e, 0x97, 0x50, 0x51, 0x6a, 0x71, 0x49, 0x51, 0x66, 0x72, 0x49,
|
||||
0x66, 0x7e, 0x5e, 0x71, 0x7c, 0x7e, 0x51, 0x4a, 0x6a, 0x91, 0x04, 0x93, 0x02, 0xb3, 0x06, 0x67,
|
||||
0x90, 0x20, 0xb2, 0x8c, 0x3f, 0x48, 0xc2, 0x4a, 0x6e, 0xd7, 0x81, 0x69, 0xb7, 0x18, 0x25, 0xb8,
|
||||
0xc4, 0x20, 0x0e, 0x28, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0xf8, 0xa1, 0xcc, 0xc8,
|
||||
0xc9, 0xf6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0,
|
||||
0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0,
|
||||
0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0x86, 0xba, 0x3d, 0x89, 0x0d, 0xec, 0x78, 0x63, 0x40, 0x00, 0x00,
|
||||
0x00, 0xff, 0xff, 0x69, 0x6d, 0xb0, 0x10, 0x1b, 0x01, 0x00, 0x00,
|
||||
}
|
||||
|
||||
func (m *Module) Marshal() (dAtA []byte, err error) {
|
||||
@ -113,6 +127,15 @@ func (m *Module) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if len(m.RestrictionsOrder) > 0 {
|
||||
for iNdEx := len(m.RestrictionsOrder) - 1; iNdEx >= 0; iNdEx-- {
|
||||
i -= len(m.RestrictionsOrder[iNdEx])
|
||||
copy(dAtA[i:], m.RestrictionsOrder[iNdEx])
|
||||
i = encodeVarintModule(dAtA, i, uint64(len(m.RestrictionsOrder[iNdEx])))
|
||||
i--
|
||||
dAtA[i] = 0x12
|
||||
}
|
||||
}
|
||||
if len(m.Authority) > 0 {
|
||||
i -= len(m.Authority)
|
||||
copy(dAtA[i:], m.Authority)
|
||||
@ -144,6 +167,12 @@ func (m *Module) Size() (n int) {
|
||||
if l > 0 {
|
||||
n += 1 + l + sovModule(uint64(l))
|
||||
}
|
||||
if len(m.RestrictionsOrder) > 0 {
|
||||
for _, s := range m.RestrictionsOrder {
|
||||
l = len(s)
|
||||
n += 1 + l + sovModule(uint64(l))
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@ -214,6 +243,38 @@ func (m *Module) Unmarshal(dAtA []byte) error {
|
||||
}
|
||||
m.Authority = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field RestrictionsOrder", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowModule
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return ErrInvalidLengthModule
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthModule
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.RestrictionsOrder = append(m.RestrictionsOrder, string(dAtA[iNdEx:postIndex]))
|
||||
iNdEx = postIndex
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipModule(dAtA[iNdEx:])
|
||||
|
||||
57
x/bank/v2/types/restrictions.go
Normal file
57
x/bank/v2/types/restrictions.go
Normal file
@ -0,0 +1,57 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
)
|
||||
|
||||
// A SendRestrictionFn can restrict sends and/or provide a new receiver address.
|
||||
type SendRestrictionFn func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) (newToAddr []byte, err error)
|
||||
|
||||
// IsOnePerModuleType implements the depinject.OnePerModuleType interface.
|
||||
func (SendRestrictionFn) IsOnePerModuleType() {}
|
||||
|
||||
var _ SendRestrictionFn = NoOpSendRestrictionFn
|
||||
|
||||
// NoOpSendRestrictionFn is a no-op SendRestrictionFn.
|
||||
func NoOpSendRestrictionFn(_ context.Context, _, toAddr []byte, _ sdk.Coins) ([]byte, error) {
|
||||
return toAddr, nil
|
||||
}
|
||||
|
||||
// Then creates a composite restriction that runs this one then the provided second one.
|
||||
func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn {
|
||||
return ComposeSendRestrictions(r, second)
|
||||
}
|
||||
|
||||
// ComposeSendRestrictions combines multiple SendRestrictionFn into one.
|
||||
// nil entries are ignored.
|
||||
// If all entries are nil, nil is returned.
|
||||
// If exactly one entry is not nil, it is returned.
|
||||
// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given.
|
||||
// The composition runs each send restriction until an error is encountered and returns that error,
|
||||
// otherwise it returns the toAddr of the last send restriction.
|
||||
func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn {
|
||||
toRun := make([]SendRestrictionFn, 0, len(restrictions))
|
||||
for _, r := range restrictions {
|
||||
if r != nil {
|
||||
toRun = append(toRun, r)
|
||||
}
|
||||
}
|
||||
switch len(toRun) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return toRun[0]
|
||||
}
|
||||
return func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
|
||||
var err error
|
||||
for _, r := range toRun {
|
||||
toAddr, err = r(ctx, fromAddr, toAddr, amt)
|
||||
if err != nil {
|
||||
return toAddr, err
|
||||
}
|
||||
}
|
||||
return toAddr, err
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user