From b0edb6d3ef94489419b89e2a146e7f1f23a40229 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Tue, 10 May 2022 09:24:09 -0400 Subject: [PATCH] refactor(container)!: use build instead of run model to get dependencies (#11916) ## Description Closes: #11907 --- ### Author Checklist *All items are required. Please add a note to the item if the item is not applicable and please add links to any relevant follow up issues.* I have... - [ ] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] added `!` to the type prefix if API or client breaking change - [ ] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/main/CONTRIBUTING.md#pr-targeting)) - [ ] provided a link to the relevant issue or specification - [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/main/docs/building-modules) - [ ] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/main/CONTRIBUTING.md#testing) - [ ] added a changelog entry to `CHANGELOG.md` - [ ] included comments for [documenting Go code](https://blog.golang.org/godoc) - [ ] updated the relevant documentation or specification - [ ] reviewed "Files changed" and left comments if necessary - [ ] confirmed all CI checks have passed ### Reviewers Checklist *All items are required. Please add a note if the item is not applicable and please add your handle next to the items reviewed if you only reviewed selected items.* I have... - [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] confirmed `!` in the type prefix if API or client breaking change - [ ] confirmed all author checklist items have been addressed - [ ] reviewed state machine logic - [ ] reviewed API design and naming - [ ] reviewed documentation is accurate - [ ] reviewed tests and test coverage - [ ] manually tested (if applicable) --- container/build.go | 50 +++++ container/container.go | 40 +++- container/container_test.go | 371 ++++++++++++++++++++---------------- container/run.go | 44 ----- container/struct_args.go | 31 +-- go.mod | 3 +- 6 files changed, 309 insertions(+), 230 deletions(-) create mode 100644 container/build.go delete mode 100644 container/run.go diff --git a/container/build.go b/container/build.go new file mode 100644 index 0000000000..92edaaff20 --- /dev/null +++ b/container/build.go @@ -0,0 +1,50 @@ +package container + +// Build builds the container specified by containerOption and extracts the +// requested outputs from the container or returns an error. It is the single +// entry point for building and running a dependency injection container. +// Each of the values specified as outputs must be pointers to types that +// can be provided by the container. +// +// Ex: +// var x int +// Build(Provide(func() int { return 1 }), &x) +func Build(containerOption Option, outputs ...interface{}) error { + loc := LocationFromCaller(1) + return build(loc, nil, containerOption, outputs...) +} + +// BuildDebug is a version of Build which takes an optional DebugOption for +// logging and visualization. +func BuildDebug(debugOpt DebugOption, option Option, outputs ...interface{}) error { + loc := LocationFromCaller(1) + return build(loc, debugOpt, option, outputs...) +} + +func build(loc Location, debugOpt DebugOption, option Option, outputs ...interface{}) error { + cfg, err := newDebugConfig() + if err != nil { + return err + } + + defer cfg.generateGraph() // always generate graph on exit + + if debugOpt != nil { + err = debugOpt.applyConfig(cfg) + if err != nil { + return err + } + } + + cfg.logf("Registering providers") + cfg.indentLogger() + ctr := newContainer(cfg) + err = option.apply(ctr) + if err != nil { + cfg.logf("Failed registering providers because of: %+v", err) + return err + } + cfg.dedentLogger() + + return ctr.build(loc, outputs...) +} diff --git a/container/container.go b/container/container.go index 1866bff401..deb52dc51f 100644 --- a/container/container.go +++ b/container/container.go @@ -364,20 +364,44 @@ func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Locat return res, nil } -func (c *container) run(invoker interface{}) error { - rctr, err := ExtractProviderDescriptor(invoker) +func (c *container) build(loc Location, outputs ...interface{}) error { + var providerIn []ProviderInput + for _, output := range outputs { + typ := reflect.TypeOf(output) + if typ.Kind() != reflect.Pointer { + return fmt.Errorf("output type must be a pointer, %s is invalid", typ) + } + + providerIn = append(providerIn, ProviderInput{Type: typ.Elem()}) + } + + desc := ProviderDescriptor{ + Inputs: providerIn, + Outputs: nil, + Fn: func(values []reflect.Value) ([]reflect.Value, error) { + if len(values) != len(outputs) { + return nil, fmt.Errorf("internal error, unexpected number of values") + } + + for i, output := range outputs { + val := reflect.ValueOf(output) + val.Elem().Set(values[i]) + } + + return nil, nil + }, + Location: loc, + } + + desc, err := expandStructArgsProvider(desc) if err != nil { return err } - if len(rctr.Outputs) > 0 { - return errors.Errorf("invoker function cannot have return values other than error: %s", rctr.Location) - } - - c.logf("Registering invoker") + c.logf("Registering outputs") c.indentLogger() - node, err := c.addNode(&rctr, nil) + node, err := c.addNode(&desc, nil) if err != nil { return err } diff --git a/container/container_test.go b/container/container_test.go index ce8494c52d..6cc8399cfb 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -82,29 +82,40 @@ func (ModuleB) Provide(dependencies BDependencies) (BProvides, Handler, error) { } func TestScenario(t *testing.T) { + var ( + handlers map[string]Handler + commands []Command + a KeeperA + b KeeperB + ) require.NoError(t, - container.Run( - func(handlers map[string]Handler, commands []Command, a KeeperA, b KeeperB) { - require.Len(t, handlers, 2) - require.Equal(t, Handler{}, handlers["a"]) - require.Equal(t, Handler{}, handlers["b"]) - require.Len(t, commands, 3) - require.Equal(t, KeeperA{ - key: KVStoreKey{name: "a"}, - name: "a", - }, a) - require.Equal(t, KeeperB{ - key: KVStoreKey{name: "b"}, - msgClientA: MsgClientA{ - key: "b", - }, - }, b) - }, - container.Provide(ProvideMsgClientA), - container.ProvideInModule("runtime", ProvideKVStoreKey), - container.ProvideInModule("a", wrapMethod0(ModuleA{})), - container.ProvideInModule("b", wrapMethod0(ModuleB{})), + container.Build( + container.Options( + container.Provide(ProvideMsgClientA), + container.ProvideInModule("runtime", ProvideKVStoreKey), + container.ProvideInModule("a", wrapMethod0(ModuleA{})), + container.ProvideInModule("b", wrapMethod0(ModuleB{})), + ), + &handlers, + &commands, + &a, + &b, )) + + require.Len(t, handlers, 2) + require.Equal(t, Handler{}, handlers["a"]) + require.Equal(t, Handler{}, handlers["b"]) + require.Len(t, commands, 3) + require.Equal(t, KeeperA{ + key: KVStoreKey{name: "a"}, + name: "a", + }, a) + require.Equal(t, KeeperB{ + key: KVStoreKey{name: "b"}, + msgClientA: MsgClientA{ + key: "b", + }, + }, b) } func wrapMethod0(module interface{}) interface{} { @@ -123,28 +134,30 @@ func wrapMethod0(module interface{}) interface{} { } func TestResolveError(t *testing.T) { - require.Error(t, container.Run( - func(x string) {}, + var x string + require.Error(t, container.Build( container.Provide( func(x float64) string { return fmt.Sprintf("%f", x) }, func(x int) float64 { return float64(x) }, func(x float32) int { return int(x) }, ), + &x, )) } func TestCyclic(t *testing.T) { - require.Error(t, container.Run( - func(x string) {}, + var x string + require.Error(t, container.Build( container.Provide( func(x int) float64 { return float64(x) }, func(x float64) (int, string) { return int(x), "hi" }, ), + &x, )) } func TestErrorOption(t *testing.T) { - err := container.Run(func() {}, container.Error(fmt.Errorf("an error"))) + err := container.Build(container.Error(fmt.Errorf("an error"))) require.Error(t, err) } @@ -153,11 +166,8 @@ func TestBadCtr(t *testing.T) { require.Error(t, err) } -func TestInvoker(t *testing.T) { - require.NoError(t, container.Run(func() {})) - require.NoError(t, container.Run(func() error { return nil })) - require.Error(t, container.Run(func() error { return fmt.Errorf("error") })) - require.Error(t, container.Run(func() int { return 0 })) +func TestTrivial(t *testing.T) { + require.NoError(t, container.Build(container.Options())) } func TestErrorFunc(t *testing.T) { @@ -171,119 +181,136 @@ func TestErrorFunc(t *testing.T) { ) require.NoError(t, err) + var x int require.Error(t, - container.Run( - func(x int) { - }, + container.Build( container.Provide(func() (int, error) { return 0, fmt.Errorf("the error") }), + &x, )) - - require.Error(t, - container.Run(func() error { - return fmt.Errorf("the error") - }), "the error") } func TestSimple(t *testing.T) { + var x int require.NoError(t, - container.Run( - func(x int) { - require.Equal(t, 1, x) - }, + container.Build( container.Provide( func() int { return 1 }, ), + &x, ), ) require.Error(t, - container.Run(func(int) {}, + container.Build( container.Provide( func() int { return 0 }, func() int { return 1 }, ), + &x, ), ) } func TestModuleScoped(t *testing.T) { + var x int require.Error(t, - container.Run(func(int) {}, + container.Build( container.Provide( func(container.ModuleKey) int { return 0 }, ), + &x, + ), + ) + + var y float64 + require.Error(t, + container.Build( + container.Options( + container.Provide( + func(container.ModuleKey) int { return 0 }, + func() int { return 1 }, + ), + container.ProvideInModule("a", + func(x int) float64 { return float64(x) }, + ), + ), + &y, ), ) require.Error(t, - container.Run(func(float64) {}, - container.Provide( - func(container.ModuleKey) int { return 0 }, - func() int { return 1 }, - ), - container.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + container.Build( + container.Options( + container.Provide( + func() int { return 0 }, + func(container.ModuleKey) int { return 1 }, + ), + container.ProvideInModule("a", + func(x int) float64 { return float64(x) }, + ), ), + &y, ), ) require.Error(t, - container.Run(func(float64) {}, - container.Provide( - func() int { return 0 }, - func(container.ModuleKey) int { return 1 }, - ), - container.ProvideInModule("a", - func(x int) float64 { return float64(x) }, - ), - ), - ) - - require.Error(t, - container.Run(func(float64) {}, - container.Provide( - func(container.ModuleKey) int { return 0 }, - func(container.ModuleKey) int { return 1 }, - ), - container.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + container.Build( + container.Options( + container.Provide( + func(container.ModuleKey) int { return 0 }, + func(container.ModuleKey) int { return 1 }, + ), + container.ProvideInModule("a", + func(x int) float64 { return float64(x) }, + ), ), + &y, ), ) require.NoError(t, - container.Run(func(float64) {}, - container.Provide( - func(container.ModuleKey) int { return 0 }, - ), - container.ProvideInModule("a", - func(x int) float64 { return float64(x) }, + container.Build( + container.Options( + container.Provide( + func(container.ModuleKey) int { return 0 }, + ), + container.ProvideInModule("a", + func(x int) float64 { return float64(x) }, + ), ), + &y, ), ) require.Error(t, - container.Run(func(float64) {}, - container.Provide( - func(container.ModuleKey) int { return 0 }, - ), - container.ProvideInModule("", - func(x int) float64 { return float64(x) }, + container.Build( + container.Options( + container.Provide( + func(container.ModuleKey) int { return 0 }, + ), + container.ProvideInModule("", + func(x int) float64 { return float64(x) }, + ), ), + &y, ), ) + var z float32 require.NoError(t, - container.Run(func(float64, float32) {}, - container.Provide( - func(container.ModuleKey) int { return 0 }, - ), - container.ProvideInModule("a", - func(x int) float64 { return float64(x) }, - func(x int) float32 { return float32(x) }, + container.Build( + container.Options( + container.Provide( + func(container.ModuleKey) int { return 0 }, + ), + container.ProvideInModule("a", + func(x int) float64 { return float64(x) }, + func(x int) float32 { return float32(x) }, + ), ), + &y, &z, ), "use module dep twice", ) @@ -294,72 +321,78 @@ type OnePerModuleInt int func (OnePerModuleInt) IsOnePerModuleType() {} func TestOnePerModule(t *testing.T) { + var x OnePerModuleInt require.Error(t, - container.Run( - func(OnePerModuleInt) {}, - ), + container.Build(container.Options(), &x), "bad input type", ) + var y map[string]OnePerModuleInt + var z string require.NoError(t, - container.Run( - func(x map[string]OnePerModuleInt, y string) { - require.Equal(t, map[string]OnePerModuleInt{ - "a": 3, - "b": 4, - }, x) - require.Equal(t, "7", y) - }, - container.ProvideInModule("a", - func() OnePerModuleInt { return 3 }, + container.Build( + container.Options( + container.ProvideInModule("a", + func() OnePerModuleInt { return 3 }, + ), + container.ProvideInModule("b", + func() OnePerModuleInt { return 4 }, + ), + container.Provide(func(x map[string]OnePerModuleInt) string { + sum := 0 + for _, v := range x { + sum += int(v) + } + return fmt.Sprintf("%d", sum) + }), ), - container.ProvideInModule("b", - func() OnePerModuleInt { return 4 }, - ), - container.Provide(func(x map[string]OnePerModuleInt) string { - sum := 0 - for _, v := range x { - sum += int(v) - } - return fmt.Sprintf("%d", sum) - }), + &y, + &z, ), ) + require.Equal(t, map[string]OnePerModuleInt{ + "a": 3, + "b": 4, + }, y) + require.Equal(t, "7", z) + + var m map[string]OnePerModuleInt require.Error(t, - container.Run( - func(map[string]OnePerModuleInt) {}, + container.Build( container.ProvideInModule("a", func() OnePerModuleInt { return 0 }, func() OnePerModuleInt { return 0 }, ), + &m, ), "duplicate", ) require.Error(t, - container.Run( - func(map[string]OnePerModuleInt) {}, + container.Build( container.Provide( func() OnePerModuleInt { return 0 }, ), + &m, ), "out of scope", ) require.Error(t, - container.Run( - func(map[string]OnePerModuleInt) {}, + container.Build( container.Provide( func() map[string]OnePerModuleInt { return nil }, ), + &m, ), "bad return type", ) require.NoError(t, - container.Run( - func(map[string]OnePerModuleInt) {}, + container.Build( + container.Options(), + &m, ), "no providers", ) @@ -370,14 +403,10 @@ type AutoGroupInt int func (AutoGroupInt) IsAutoGroupType() {} func TestAutoGroup(t *testing.T) { + var xs []AutoGroupInt + var sum string require.NoError(t, - container.Run( - func(xs []AutoGroupInt, sum string) { - require.Len(t, xs, 2) - require.Contains(t, xs, AutoGroupInt(4)) - require.Contains(t, xs, AutoGroupInt(9)) - require.Equal(t, "13", sum) - }, + container.Build( container.Provide( func() AutoGroupInt { return 4 }, func() AutoGroupInt { return 9 }, @@ -389,55 +418,71 @@ func TestAutoGroup(t *testing.T) { return fmt.Sprintf("%d", sum) }, ), + &xs, + &sum, ), ) + require.Len(t, xs, 2) + require.Contains(t, xs, AutoGroupInt(4)) + require.Contains(t, xs, AutoGroupInt(9)) + require.Equal(t, "13", sum) + var z AutoGroupInt require.Error(t, - container.Run( - func(AutoGroupInt) {}, + container.Build( container.Provide( func() AutoGroupInt { return 0 }, ), + &z, ), "bad input type", ) require.NoError(t, - container.Run( - func([]AutoGroupInt) {}, + container.Build( + container.Options(), + &xs, ), "no providers", ) } func TestSupply(t *testing.T) { + var x int require.NoError(t, - container.Run(func(x int) { - require.Equal(t, 3, x) - }, + container.Build( container.Supply(3), + &x, ), ) + require.Equal(t, 3, x) require.Error(t, - container.Run(func(x int) {}, - container.Supply(3), - container.Provide(func() int { return 4 }), + container.Build( + container.Options( + container.Supply(3), + container.Provide(func() int { return 4 }), + ), + &x, ), "can't supply then provide", ) require.Error(t, - container.Run(func(x int) {}, - container.Supply(3), - container.Provide(func() int { return 4 }), + container.Build( + container.Options( + container.Supply(3), + container.Provide(func() int { return 4 }), + ), + &x, ), "can't provide then supply", ) require.Error(t, - container.Run(func(x int) {}, + container.Build( container.Supply(3, 4), + &x, ), "can't supply twice", ) @@ -458,41 +503,39 @@ type TestOutput struct { } func TestStructArgs(t *testing.T) { - require.Error(t, container.Run( - func(input TestInput) {}, - )) + var input TestInput + require.Error(t, container.Build(container.Options(), &input)) - require.NoError(t, container.Run( - func(input TestInput) { - require.Equal(t, 0, input.X) - require.Equal(t, 1.3, input.Y) - }, + require.NoError(t, container.Build( container.Supply(1.3), + &input, )) + require.Equal(t, 0, input.X) + require.Equal(t, 1.3, input.Y) - require.NoError(t, container.Run( - func(input TestInput) { - require.Equal(t, 1, input.X) - require.Equal(t, 1.3, input.Y) - }, + require.NoError(t, container.Build( container.Supply(1.3, 1), + &input, )) + require.Equal(t, 1, input.X) + require.Equal(t, 1.3, input.Y) - require.NoError(t, container.Run( - func(x string, y int64) { - require.Equal(t, "A", x) - require.Equal(t, int64(-10), y) - }, + var x string + var y int64 + require.NoError(t, container.Build( container.Provide(func() (TestOutput, error) { return TestOutput{X: "A", Y: -10}, nil }), + &x, &y, )) + require.Equal(t, "A", x) + require.Equal(t, int64(-10), y) - require.Error(t, container.Run( - func(x string) {}, + require.Error(t, container.Build( container.Provide(func() (TestOutput, error) { return TestOutput{}, fmt.Errorf("error") }), + &x, )) } @@ -511,8 +554,7 @@ func TestLogging(t *testing.T) { require.NoError(t, err) defer os.Remove(graphfile.Name()) - require.NoError(t, container.RunDebug( - func() {}, + require.NoError(t, container.BuildDebug( container.DebugOptions( container.Logger(func(s string) { logOut += s @@ -524,6 +566,7 @@ func TestLogging(t *testing.T) { container.FileVisualizer(graphfile.Name(), "svg"), container.StdoutLogger(), ), + container.Options(), )) require.Contains(t, logOut, "digraph") diff --git a/container/run.go b/container/run.go deleted file mode 100644 index 12d1b9e2f1..0000000000 --- a/container/run.go +++ /dev/null @@ -1,44 +0,0 @@ -package container - -// Run runs the provided invoker function with values provided by the provided -// options. It is the single entry point for building and running a dependency -// injection container. Invoker should be a function taking one or more -// dependencies from the container, optionally returning an error. -// -// Ex: -// Run(func (x int) error { println(x) }, Provide(func() int { return 1 })) -func Run(invoker interface{}, opts ...Option) error { - return RunDebug(invoker, nil, opts...) -} - -// RunDebug is a version of Run which takes an optional DebugOption for -// logging and visualization. -func RunDebug(invoker interface{}, debugOpt DebugOption, opts ...Option) error { - opt := Options(opts...) - - cfg, err := newDebugConfig() - if err != nil { - return err - } - - defer cfg.generateGraph() // always generate graph on exit - - if debugOpt != nil { - err = debugOpt.applyConfig(cfg) - if err != nil { - return err - } - } - - cfg.logf("Registering providers") - cfg.indentLogger() - ctr := newContainer(cfg) - err = opt.apply(ctr) - if err != nil { - cfg.logf("Failed registering providers because of: %+v", err) - return err - } - cfg.dedentLogger() - - return ctr.run(invoker) -} diff --git a/container/struct_args.go b/container/struct_args.go index 985d950d46..e3e52128d1 100644 --- a/container/struct_args.go +++ b/container/struct_args.go @@ -36,12 +36,11 @@ type isOut interface{ isOut() } var isOutType = reflect.TypeOf((*isOut)(nil)).Elem() func expandStructArgsProvider(provider ProviderDescriptor) (ProviderDescriptor, error) { - var foundStructArgs bool + var structArgsInInput bool var newIn []ProviderInput - for _, in := range provider.Inputs { if in.Type.AssignableTo(isInType) { - foundStructArgs = true + structArgsInInput = true inTypes, err := structArgsInTypes(in.Type) if err != nil { return ProviderDescriptor{}, err @@ -52,17 +51,9 @@ func expandStructArgsProvider(provider ProviderDescriptor) (ProviderDescriptor, } } - var newOut []ProviderOutput - for _, out := range provider.Outputs { - if out.Type.AssignableTo(isOutType) { - foundStructArgs = true - newOut = append(newOut, structArgsOutTypes(out.Type)...) - } else { - newOut = append(newOut, out) - } - } + newOut, structArgsInOutput := expandStructArgsOutTypes(provider.Outputs) - if foundStructArgs { + if structArgsInInput || structArgsInOutput { return ProviderDescriptor{ Inputs: newIn, Outputs: newOut, @@ -137,6 +128,20 @@ func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { return res, nil } +func expandStructArgsOutTypes(outputs []ProviderOutput) ([]ProviderOutput, bool) { + foundStructArgs := false + var newOut []ProviderOutput + for _, out := range outputs { + if out.Type.AssignableTo(isOutType) { + foundStructArgs = true + newOut = append(newOut, structArgsOutTypes(out.Type)...) + } else { + newOut = append(newOut, out) + } + } + return newOut, foundStructArgs +} + func structArgsOutTypes(typ reflect.Type) []ProviderOutput { n := typ.NumField() var res []ProviderOutput diff --git a/go.mod b/go.mod index 015429f404..4c742dd79a 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,8 @@ require ( sigs.k8s.io/yaml v1.3.0 ) +require github.com/google/uuid v1.3.0 + require ( cloud.google.com/go v0.100.2 // indirect cloud.google.com/go/compute v1.5.0 // indirect @@ -93,7 +95,6 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.0.1 // indirect github.com/google/orderedcode v0.0.1 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/gax-go/v2 v2.3.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect