docs: demonstrate how to define hooks using depinject (backport #21892) (#22132)

Co-authored-by: Eric Mokaya <4112301+ziscky@users.noreply.github.com>
This commit is contained in:
mergify[bot] 2024-10-04 17:42:25 +02:00 committed by GitHub
parent 0768b9b0bd
commit 10c19a2592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 169 additions and 5 deletions

View File

@ -0,0 +1,57 @@
---
sidebar_position: 1
---
# Hooks
Hooks are functions that are called before and/or after certain events in the module's lifecycle.
## Defining Hooks
1. Define the hook interface and a wrapper implementing `depinject.OnePerModuleType`:
```go reference
https://github.com/cosmos/cosmos-sdk/blob/71c603a2a5a103df00f216d78ec8b108ed64ae28/testutil/x/counter/types/expected_keepers.go#L5-L12
```
2. Add a `CounterHooks` field to the keeper:
```go reference
https://github.com/cosmos/cosmos-sdk/blob/71c603a2a5a103df00f216d78ec8b108ed64ae28/testutil/x/counter/keeper/keeper.go#L25
```
3. Create a `depinject` invoker function
```go reference
https://github.com/cosmos/cosmos-sdk/blob/71c603a2a5a103df00f216d78ec8b108ed64ae28/testutil/x/counter/depinject.go#L53-L75
```
4. Inject the hooks during app initialization:
```go
appConfig = appconfig.Compose(&appv1alpha1.Config{
Modules: []*appv1alpha1.ModuleConfig{
// ....
{
Name: types.ModuleName,
Config: appconfig.WrapAny(&types.Module{}),
},
}
})
appConfig = depinject.Configs(
AppConfig(),
runtime.DefaultServiceBindings(),
depinject.Supply(
logger,
viper,
map[string]types.CounterHooksWrapper{
"counter": types.CounterHooksWrapper{&types.Hooks{}},
},
))
```
## Examples in the SDK
For examples of hooks implementation in the Cosmos SDK, refer to the [Epochs Hooks documentation](https://docs.cosmos.network/main/build/modules/epochs#hooks) and [Distribution Hooks Documentation](https://docs.cosmos.network/main/build/modules/distribution#hooks).

View File

@ -1,6 +1,10 @@
package counter
import (
"fmt"
"maps"
"slices"
"cosmossdk.io/core/appmodule"
"cosmossdk.io/depinject"
"cosmossdk.io/depinject/appconfig"
@ -18,6 +22,7 @@ func init() {
appconfig.RegisterModule(
&types.Module{},
appconfig.Provide(ProvideModule),
appconfig.Invoke(InvokeSetHooks),
)
}
@ -31,7 +36,7 @@ type ModuleInputs struct {
type ModuleOutputs struct {
depinject.Out
Keeper keeper.Keeper
Keeper *keeper.Keeper
Module appmodule.AppModule
}
@ -44,3 +49,27 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
Module: m,
}
}
func InvokeSetHooks(keeper *keeper.Keeper, counterHooks map[string]types.CounterHooksWrapper) error {
if keeper == nil {
return fmt.Errorf("keeper is nil")
}
if counterHooks == nil {
return fmt.Errorf("counterHooks is nil")
}
// Default ordering is lexical by module name.
// Explicit ordering can be added to the module config if required.
modNames := slices.Sorted(maps.Keys(counterHooks))
var multiHooks types.MultiCounterHooks
for _, modName := range modNames {
hook, ok := counterHooks[modName]
if !ok {
return fmt.Errorf("can't find hooks for module %s", modName)
}
multiHooks = append(multiHooks, hook)
}
keeper.SetHooks(multiHooks)
return nil
}

View File

@ -0,0 +1,14 @@
package keeper
import (
"context"
)
type Hooks struct {
AfterCounterIncreased bool
}
func (h *Hooks) AfterIncreaseCount(ctx context.Context, n int64) error {
h.AfterCounterIncreased = true
return nil
}

View File

@ -21,11 +21,13 @@ type Keeper struct {
appmodule.Environment
CountStore collections.Item[int64]
hooks types.CounterHooks
}
func NewKeeper(env appmodule.Environment) Keeper {
func NewKeeper(env appmodule.Environment) *Keeper {
sb := collections.NewSchemaBuilder(env.KVStoreService)
return Keeper{
return &Keeper{
Environment: env,
CountStore: collections.NewItem(sb, collections.NewPrefix(0), "count", collections.Int64Value),
}
@ -67,6 +69,10 @@ func (k Keeper) IncreaseCount(ctx context.Context, msg *types.MsgIncreaseCounter
return nil, err
}
if err := k.Hooks().AfterIncreaseCount(ctx, num+msg.Count); err != nil {
return nil, err
}
if err := k.EventService.EventManager(ctx).EmitKV(
"increase_counter",
event.NewAttribute("signer", msg.Signer),
@ -78,3 +84,23 @@ func (k Keeper) IncreaseCount(ctx context.Context, msg *types.MsgIncreaseCounter
NewCount: num + msg.Count,
}, nil
}
// Hooks gets the hooks for counter Keeper
func (k *Keeper) Hooks() types.CounterHooks {
if k.hooks == nil {
// return a no-op implementation if no hooks are set
return types.MultiCounterHooks{}
}
return k.hooks
}
// SetHooks sets the hooks for counter
func (k *Keeper) SetHooks(gh types.CounterHooks) *Keeper {
if k.hooks != nil {
panic("cannot set governance hooks twice")
}
k.hooks = gh
return k
}

View File

@ -17,7 +17,7 @@ var (
// AppModule implements an application module
type AppModule struct {
keeper keeper.Keeper
keeper *keeper.Keeper
}
// IsAppModule implements the appmodule.AppModule interface.
@ -31,7 +31,7 @@ func (am AppModule) RegisterServices(registrar grpc.ServiceRegistrar) error {
}
// NewAppModule creates a new AppModule object
func NewAppModule(keeper keeper.Keeper) AppModule {
func NewAppModule(keeper *keeper.Keeper) AppModule {
return AppModule{
keeper: keeper,
}

View File

@ -0,0 +1,12 @@
package types
import "context"
type CounterHooks interface {
AfterIncreaseCount(ctx context.Context, newCount int64) error
}
type CounterHooksWrapper struct{ CounterHooks }
// IsOnePerModuleType implements the depinject.OnePerModuleType interface.
func (CounterHooksWrapper) IsOnePerModuleType() {}

View File

@ -0,0 +1,26 @@
package types
import (
"context"
"errors"
)
var _ CounterHooks = MultiCounterHooks{}
// MultiCounterHooks is a slice of hooks to be called in sequence.
type MultiCounterHooks []CounterHooks
// NewMultiCounterHooks returns a MultiCounterHooks from a list of CounterHooks
func NewMultiCounterHooks(hooks ...CounterHooks) MultiCounterHooks {
return hooks
}
// AfterIncreaseCount calls AfterIncreaseCount on all hooks and collects the errors if any.
func (ch MultiCounterHooks) AfterIncreaseCount(ctx context.Context, newCount int64) error {
var errs error
for i := range ch {
errs = errors.Join(errs, ch[i].AfterIncreaseCount(ctx, newCount))
}
return errs
}