cosmos-sdk/schema/testing/statesim/module.go

97 lines
2.9 KiB
Go

package statesim
import (
"fmt"
"github.com/stretchr/testify/require"
"github.com/tidwall/btree"
"pgregory.net/rapid"
"cosmossdk.io/schema"
"cosmossdk.io/schema/view"
)
// Module is a collection of object collections corresponding to a module's schema for testing purposes.
type Module struct {
name string
moduleSchema schema.ModuleSchema
objectCollections *btree.Map[string, *ObjectCollection]
updateGen *rapid.Generator[schema.ObjectUpdate]
}
// NewModule creates a new Module for the given module schema.
func NewModule(name string, moduleSchema schema.ModuleSchema, options Options) *Module {
objectCollections := &btree.Map[string, *ObjectCollection]{}
var objectTypeNames []string
moduleSchema.ObjectTypes(func(objectType schema.ObjectType) bool {
objectCollection := NewObjectCollection(objectType, options)
objectCollections.Set(objectType.Name, objectCollection)
objectTypeNames = append(objectTypeNames, objectType.Name)
return true
})
objectTypeSelector := rapid.SampledFrom(objectTypeNames)
updateGen := rapid.Custom(func(t *rapid.T) schema.ObjectUpdate {
objectType := objectTypeSelector.Draw(t, "objectType")
objectColl, ok := objectCollections.Get(objectType)
require.True(t, ok)
return objectColl.UpdateGen().Draw(t, "update")
})
return &Module{
name: name,
moduleSchema: moduleSchema,
updateGen: updateGen,
objectCollections: objectCollections,
}
}
// ApplyUpdate applies the given object update to the module.
func (o *Module) ApplyUpdate(update schema.ObjectUpdate) error {
objState, ok := o.objectCollections.Get(update.TypeName)
if !ok {
return fmt.Errorf("object type %s not found in module", update.TypeName)
}
return objState.ApplyUpdate(update)
}
// UpdateGen returns a generator for object updates. The generator is stateful and returns
// a certain number of updates and deletes of existing objects in the module.
func (o *Module) UpdateGen() *rapid.Generator[schema.ObjectUpdate] {
return o.updateGen
}
// ModuleName returns the name of the module.
func (o *Module) ModuleName() string {
return o.name
}
// ModuleSchema returns the module schema for the module.
func (o *Module) ModuleSchema() schema.ModuleSchema {
return o.moduleSchema
}
// GetObjectCollection returns the object collection for the given object type.
func (o *Module) GetObjectCollection(objectType string) (view.ObjectCollection, error) {
obj, ok := o.objectCollections.Get(objectType)
if !ok {
return nil, nil
}
return obj, nil
}
// ObjectCollections iterates over all object collections in the module.
func (o *Module) ObjectCollections(f func(value view.ObjectCollection, err error) bool) {
o.objectCollections.Scan(func(key string, value *ObjectCollection) bool {
return f(value, nil)
})
}
// NumObjectCollections returns the number of object collections in the module.
func (o *Module) NumObjectCollections() (int, error) {
return o.objectCollections.Len(), nil
}