From fb78bbfbbb7cbd253fc6c40742d393defc761c1b Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Mon, 4 Oct 2021 16:36:41 -0400 Subject: [PATCH] feat: implement low-level dependency injection container (#9666) ## Description closes #9775 needs #9658 --- ### Author Checklist *All items are required. Please add a note to the item if the item is not applicable and please add links to any relevant follow up issues.* I have... - [ ] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] added `!` to the type prefix if API or client breaking change - [ ] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#pr-targeting)) - [ ] provided a link to the relevant issue or specification - [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/master/docs/building-modules) - [ ] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#testing) - [ ] added a changelog entry to `CHANGELOG.md` - [ ] included comments for [documenting Go code](https://blog.golang.org/godoc) - [ ] updated the relevant documentation or specification - [ ] reviewed "Files changed" and left comments if necessary - [ ] confirmed all CI checks have passed ### Reviewers Checklist *All items are required. Please add a note if the item is not applicable and please add your handle next to the items reviewed if you only reviewed selected items.* I have... - [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title - [ ] confirmed `!` in the type prefix if API or client breaking change - [ ] confirmed all author checklist items have been addressed - [ ] reviewed state machine logic - [ ] reviewed API design and naming - [ ] reviewed documentation is accurate - [ ] reviewed tests and test coverage - [ ] manually tested (if applicable) --- container/config.go | 173 +++++++++++ container/constructor_info.go | 22 -- container/container.go | 406 ++++++++++++++++++++++++++ container/container_test.go | 501 +++++++++++++++++++++++++++++--- container/errors.go | 12 + container/go.mod | 9 +- container/go.sum | 21 ++ container/group.go | 95 ++++++ container/location.go | 118 +++++++- container/one_per_scope.go | 106 +++++++ container/option.go | 169 +++++++++-- container/provider_desc.go | 104 +++++++ container/provider_desc_test.go | 106 +++++++ container/resolver.go | 14 + container/run.go | 28 +- container/scope.go | 18 +- container/scope_dep.go | 57 ++++ container/simple.go | 67 +++++ container/struct_args.go | 189 +++++++++++- container/supply.go | 31 ++ 20 files changed, 2139 insertions(+), 107 deletions(-) create mode 100644 container/config.go delete mode 100644 container/constructor_info.go create mode 100644 container/container.go create mode 100644 container/errors.go create mode 100644 container/group.go create mode 100644 container/one_per_scope.go create mode 100644 container/provider_desc.go create mode 100644 container/provider_desc_test.go create mode 100644 container/resolver.go create mode 100644 container/scope_dep.go create mode 100644 container/simple.go create mode 100644 container/supply.go diff --git a/container/config.go b/container/config.go new file mode 100644 index 0000000000..12437d7c00 --- /dev/null +++ b/container/config.go @@ -0,0 +1,173 @@ +package container + +import ( + "bytes" + "fmt" + "path/filepath" + "reflect" + + "github.com/pkg/errors" + + "github.com/goccy/go-graphviz" + "github.com/goccy/go-graphviz/cgraph" +) + +type config struct { + // logging + loggers []func(string) + indentStr string + + // graphing + graphviz *graphviz.Graphviz + graph *cgraph.Graph + visualizers []func(string) + logVisualizer bool +} + +func newConfig() (*config, error) { + g := graphviz.New() + graph, err := g.Graph() + if err != nil { + return nil, errors.Wrap(err, "error initializing graph") + } + + return &config{ + graphviz: g, + graph: graph, + }, nil +} + +func (c *config) indentLogger() { + c.indentStr = c.indentStr + " " +} + +func (c *config) dedentLogger() { + if len(c.indentStr) > 0 { + c.indentStr = c.indentStr[1:] + } +} + +func (c config) logf(format string, args ...interface{}) { + s := fmt.Sprintf(c.indentStr+format, args...) + for _, logger := range c.loggers { + logger(s) + } +} + +func (c *config) generateGraph() { + buf := &bytes.Buffer{} + err := c.graphviz.Render(c.graph, graphviz.XDOT, buf) + if err != nil { + c.logf("Error rendering DOT graph: %+v", err) + return + } + + dot := buf.String() + if c.logVisualizer { + c.logf("DOT Graph: %s", dot) + } + + for _, v := range c.visualizers { + v(dot) + } + + err = c.graph.Close() + if err != nil { + c.logf("Error closing graph: %+v", err) + return + } + + err = c.graphviz.Close() + if err != nil { + c.logf("Error closing graphviz: %+v", err) + } +} + +func (c *config) addFuncVisualizer(f func(string)) { + c.visualizers = append(c.visualizers, func(dot string) { + f(dot) + }) +} + +func (c *config) enableLogVisualizer() { + c.logVisualizer = true +} + +func (c *config) addFileVisualizer(filename string, format string) { + c.visualizers = append(c.visualizers, func(_ string) { + err := c.graphviz.RenderFilename(c.graph, graphviz.Format(format), filename) + if err != nil { + c.logf("Error saving graphviz file %s with format %s: %+v", filename, format, err) + } else { + path, err := filepath.Abs(filename) + if err == nil { + c.logf("Saved graph of container to %s", path) + } + } + }) +} + +func (c *config) locationGraphNode(location Location, scope Scope) (*cgraph.Node, error) { + graph := c.scopeSubGraph(scope) + node, found, err := c.findOrCreateGraphNode(graph, location.Name()) + if err != nil { + return nil, err + } + + if found { + return node, nil + } + + node.SetShape(cgraph.BoxShape) + node.SetColor("lightgrey") + return node, nil +} + +func (c *config) typeGraphNode(typ reflect.Type) (*cgraph.Node, error) { + node, found, err := c.findOrCreateGraphNode(c.graph, typ.String()) + if err != nil { + return nil, err + } + + if found { + return node, nil + } + + node.SetColor("lightgrey") + return node, err +} + +func (c *config) findOrCreateGraphNode(subGraph *cgraph.Graph, name string) (node *cgraph.Node, found bool, err error) { + node, err = c.graph.Node(name) + if err != nil { + return nil, false, errors.Wrapf(err, "error finding graph node %s", name) + } + + if node != nil { + return node, true, nil + } + + node, err = subGraph.CreateNode(name) + if err != nil { + return nil, false, errors.Wrapf(err, "error creating graph node %s", name) + } + + return node, false, nil +} + +func (c *config) scopeSubGraph(scope Scope) *cgraph.Graph { + graph := c.graph + if scope != nil { + gname := fmt.Sprintf("cluster_%s", scope.Name()) + graph = c.graph.SubGraph(gname, 1) + graph.SetLabel(fmt.Sprintf("Scope: %s", scope.Name())) + } + return graph +} + +func (c *config) addGraphEdge(from *cgraph.Node, to *cgraph.Node) { + _, err := c.graph.CreateEdge("", from, to) + if err != nil { + c.logf("error creating graph edge") + } +} diff --git a/container/constructor_info.go b/container/constructor_info.go deleted file mode 100644 index b3b5e7b500..0000000000 --- a/container/constructor_info.go +++ /dev/null @@ -1,22 +0,0 @@ -package container - -import "reflect" - -// ConstructorInfo defines a special constructor type that is defined by -// reflection. It should be passed as a value to the Provide function. -// Ex: -// option.Provide(ConstructorInfo{ ... }) -type ConstructorInfo struct { - // In defines the in parameter types to Fn. - In []reflect.Type - - // Out defines the out parameter types to Fn. - Out []reflect.Type - - // Fn defines the constructor function. - Fn func([]reflect.Value) []reflect.Value - - // Location defines the source code location to be used for this constructor - // in error messages. - Location Location -} diff --git a/container/container.go b/container/container.go new file mode 100644 index 0000000000..e528726d60 --- /dev/null +++ b/container/container.go @@ -0,0 +1,406 @@ +package container + +import ( + "bytes" + "fmt" + "reflect" + + "github.com/goccy/go-graphviz/cgraph" + "github.com/pkg/errors" +) + +type container struct { + *config + + resolvers map[reflect.Type]resolver + + scopes map[string]Scope + + resolveStack []resolveFrame + callerStack []Location + callerMap map[Location]bool +} + +type resolveFrame struct { + loc Location + typ reflect.Type +} + +func newContainer(cfg *config) *container { + return &container{ + config: cfg, + resolvers: map[reflect.Type]resolver{}, + scopes: map[string]Scope{}, + callerStack: nil, + callerMap: map[Location]bool{}, + } +} + +func (c *container) call(constructor *ProviderDescriptor, scope Scope) ([]reflect.Value, error) { + loc := constructor.Location + graphNode, err := c.locationGraphNode(loc, scope) + if err != nil { + return nil, err + } + markGraphNodeAsFailed(graphNode) + + if c.callerMap[loc] { + return nil, errors.Errorf("cyclic dependency: %s -> %s", loc.Name(), loc.Name()) + } + + c.callerMap[loc] = true + c.callerStack = append(c.callerStack, loc) + + c.logf("Resolving dependencies for %s", loc) + c.indentLogger() + inVals := make([]reflect.Value, len(constructor.Inputs)) + for i, in := range constructor.Inputs { + val, err := c.resolve(in, scope, loc) + if err != nil { + return nil, err + } + inVals[i] = val + } + c.dedentLogger() + c.logf("Calling %s", loc) + + delete(c.callerMap, loc) + c.callerStack = c.callerStack[0 : len(c.callerStack)-1] + + out, err := constructor.Fn(inVals) + if err != nil { + return nil, errors.Wrapf(err, "error calling constructor %s", loc) + } + + markGraphNodeAsUsed(graphNode) + + return out, nil +} + +func (c *container) getResolver(typ reflect.Type) (resolver, error) { + if vr, ok := c.resolvers[typ]; ok { + return vr, nil + } + + elemType := typ + if isAutoGroupSliceType(elemType) || isOnePerScopeMapType(elemType) { + elemType = elemType.Elem() + } + + var typeGraphNode *cgraph.Node + var err error + + if isAutoGroupType(elemType) { + c.logf("Registering resolver for auto-group type %v", elemType) + sliceType := reflect.SliceOf(elemType) + + typeGraphNode, err = c.typeGraphNode(sliceType) + if err != nil { + return nil, err + } + typeGraphNode.SetComment("auto-group") + + r := &groupResolver{ + typ: elemType, + sliceType: sliceType, + graphNode: typeGraphNode, + } + + c.resolvers[elemType] = r + c.resolvers[sliceType] = &sliceGroupResolver{r} + } else if isOnePerScopeType(elemType) { + c.logf("Registering resolver for one-per-scope type %v", elemType) + mapType := reflect.MapOf(stringType, elemType) + + typeGraphNode, err = c.typeGraphNode(mapType) + if err != nil { + return nil, err + } + typeGraphNode.SetComment("one-per-scope") + + r := &onePerScopeResolver{ + typ: elemType, + mapType: mapType, + providers: map[Scope]*simpleProvider{}, + idxMap: map[Scope]int{}, + graphNode: typeGraphNode, + } + + c.resolvers[elemType] = r + c.resolvers[mapType] = &mapOfOnePerScopeResolver{r} + } + + return c.resolvers[typ], nil +} + +func (c *container) addNode(constructor *ProviderDescriptor, scope Scope) (interface{}, error) { + constructorGraphNode, err := c.locationGraphNode(constructor.Location, scope) + if err != nil { + return nil, err + } + + hasScopeParam := false + for _, in := range constructor.Inputs { + typ := in.Type + if typ == scopeType { + hasScopeParam = true + } + + if isAutoGroupType(typ) { + return nil, fmt.Errorf("auto-group type %v can't be used as an input parameter", typ) + } else if isOnePerScopeType(typ) { + return nil, fmt.Errorf("one-per-scope type %v can't be used as an input parameter", typ) + } + + vr, err := c.getResolver(typ) + if err != nil { + return nil, err + } + + var typeGraphNode *cgraph.Node + if vr != nil { + typeGraphNode = vr.typeGraphNode() + } else { + typeGraphNode, err = c.typeGraphNode(typ) + if err != nil { + return nil, err + } + } + + c.addGraphEdge(typeGraphNode, constructorGraphNode) + } + + if scope != nil || !hasScopeParam { + c.logf("Registering %s", constructor.Location.String()) + c.indentLogger() + defer c.dedentLogger() + + sp := &simpleProvider{ + provider: constructor, + scope: scope, + } + + for i, out := range constructor.Outputs { + typ := out.Type + + // one-per-scope maps can't be used as a return type + if isOnePerScopeMapType(typ) { + return nil, fmt.Errorf("%v cannot be used as a return type because %v is a one-per-scope type", + typ, typ.Elem()) + } + + // auto-group slices of auto-group types + if isAutoGroupSliceType(typ) { + typ = typ.Elem() + } + + vr, err := c.getResolver(typ) + if err != nil { + return nil, err + } + + if vr != nil { + c.logf("Found resolver for %v: %T", typ, vr) + err := vr.addNode(sp, i) + if err != nil { + return nil, err + } + } else { + c.logf("Registering resolver for simple type %v", typ) + + typeGraphNode, err := c.typeGraphNode(typ) + if err != nil { + return nil, err + } + + vr = &simpleResolver{ + node: sp, + typ: typ, + graphNode: typeGraphNode, + } + c.resolvers[typ] = vr + } + + c.addGraphEdge(constructorGraphNode, vr.typeGraphNode()) + } + + return sp, nil + } else { + c.logf("Registering scope provider: %s", constructor.Location.String()) + c.indentLogger() + defer c.dedentLogger() + + node := &scopeDepProvider{ + provider: constructor, + calledForScope: map[Scope]bool{}, + valueMap: map[Scope][]reflect.Value{}, + } + + for i, out := range constructor.Outputs { + typ := out.Type + + c.logf("Registering resolver for scoped type %v", typ) + + existing, ok := c.resolvers[typ] + if ok { + return nil, errors.Errorf("duplicate provision of type %v by scoped provider %s\n\talready provided by %s", + typ, constructor.Location, existing.describeLocation()) + } + + typeGraphNode, err := c.typeGraphNode(typ) + if err != nil { + return reflect.Value{}, err + } + + c.resolvers[typ] = &scopeDepResolver{ + typ: typ, + idxInValues: i, + node: node, + valueMap: map[Scope]reflect.Value{}, + graphNode: typeGraphNode, + } + + c.addGraphEdge(constructorGraphNode, typeGraphNode) + } + + return node, nil + } +} + +func (c *container) supply(value reflect.Value, location Location) error { + typ := value.Type() + locGrapNode, err := c.locationGraphNode(location, nil) + if err != nil { + return err + } + markGraphNodeAsUsed(locGrapNode) + + typeGraphNode, err := c.typeGraphNode(typ) + if err != nil { + return err + } + + c.addGraphEdge(locGrapNode, typeGraphNode) + + if existing, ok := c.resolvers[typ]; ok { + return duplicateDefinitionError(typ, location, existing.describeLocation()) + } + + c.resolvers[typ] = &supplyResolver{ + typ: typ, + value: value, + loc: location, + graphNode: typeGraphNode, + } + + return nil +} + +func (c *container) resolve(in ProviderInput, scope Scope, caller Location) (reflect.Value, error) { + c.resolveStack = append(c.resolveStack, resolveFrame{loc: caller, typ: in.Type}) + + typeGraphNode, err := c.typeGraphNode(in.Type) + if err != nil { + return reflect.Value{}, err + } + + if in.Type == scopeType { + if scope == nil { + return reflect.Value{}, errors.Errorf("trying to resolve %T for %s but not inside of any scope", scope, caller) + } + c.logf("Providing Scope %s", scope.Name()) + markGraphNodeAsUsed(typeGraphNode) + return reflect.ValueOf(scope), nil + } + + vr, err := c.getResolver(in.Type) + if err != nil { + return reflect.Value{}, err + } + + if vr == nil { + if in.Optional { + c.logf("Providing zero value for optional dependency %v", in.Type) + return reflect.Zero(in.Type), nil + } + + markGraphNodeAsFailed(typeGraphNode) + return reflect.Value{}, errors.Errorf("can't resolve type %v for %s:\n%s", + in.Type, caller, c.formatResolveStack()) + } + + res, err := vr.resolve(c, scope, caller) + if err != nil { + markGraphNodeAsFailed(typeGraphNode) + return reflect.Value{}, err + } + + markGraphNodeAsUsed(typeGraphNode) + + c.resolveStack = c.resolveStack[:len(c.resolveStack)-1] + + return res, nil +} + +func (c *container) run(invoker interface{}) error { + rctr, err := ExtractProviderDescriptor(invoker) + if err != nil { + return err + } + + if len(rctr.Outputs) > 0 { + return errors.Errorf("invoker function cannot have return values other than error: %s", rctr.Location) + } + + c.logf("Registering invoker") + c.indentLogger() + + node, err := c.addNode(&rctr, nil) + if err != nil { + return err + } + + c.dedentLogger() + + sn, ok := node.(*simpleProvider) + if !ok { + return errors.Errorf("cannot run scoped provider as an invoker") + } + + c.logf("Building container") + _, err = sn.resolveValues(c) + if err != nil { + return err + } + c.logf("Done building container") + + return nil +} + +func (c container) createOrGetScope(name string) Scope { + if s, ok := c.scopes[name]; ok { + return s + } + s := newScope(name) + c.scopes[name] = s + return s +} + +func (c container) formatResolveStack() string { + buf := &bytes.Buffer{} + _, _ = fmt.Fprintf(buf, "\twhile resolving:\n") + n := len(c.resolveStack) + for i := n - 1; i >= 0; i-- { + rk := c.resolveStack[i] + _, _ = fmt.Fprintf(buf, "\t\t%v for %s\n", rk.typ, rk.loc) + } + return buf.String() +} + +func markGraphNodeAsUsed(node *cgraph.Node) { + node.SetColor("black") +} + +func markGraphNodeAsFailed(node *cgraph.Node) { + node.SetColor("red") +} diff --git a/container/container_test.go b/container/container_test.go index 091ab3ab29..9bb072c9cd 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -1,6 +1,9 @@ package container_test import ( + "fmt" + "io/ioutil" + "os" "reflect" "testing" @@ -32,16 +35,20 @@ type Handler struct { Handle func() } +func (Handler) IsOnePerScopeType() {} + type Command struct { Run func() } +func (Command) IsAutoGroupType() {} + func ProvideKVStoreKey(scope container.Scope) KVStoreKey { return KVStoreKey{name: scope.Name()} } -func ProvideModuleKey(scope container.Scope) ModuleKey { - return ModuleKey(scope.Name()) +func ProvideModuleKey(scope container.Scope) (ModuleKey, error) { + return ModuleKey(scope.Name()), nil } func ProvideMsgClientA(_ container.Scope, key ModuleKey) MsgClientA { @@ -57,72 +64,480 @@ func (ModuleA) Provide(key KVStoreKey) (KeeperA, Handler, Command) { type ModuleB struct{} type BDependencies struct { - container.StructArgs + container.In Key KVStoreKey A MsgClientA } type BProvides struct { + container.Out + KeeperB KeeperB - Handler Handler Commands []Command } -func (ModuleB) Provide(dependencies BDependencies) BProvides { +func (ModuleB) Provide(dependencies BDependencies, _ container.Scope) (BProvides, Handler, error) { return BProvides{ KeeperB: KeeperB{ key: dependencies.Key, msgClientA: dependencies.A, }, - Handler: Handler{}, Commands: []Command{{}, {}}, - } + }, Handler{}, nil } -func TestRun(t *testing.T) { - t.Skip("Expecting this test to fail for now") +func TestScenario(t *testing.T) { require.NoError(t, container.Run( - func(handlers map[container.Scope]Handler, commands []Command, a KeeperA, b KeeperB) { - // TODO: - // require one Handler for module a and a scopes - // require 3 commands - // require KeeperA have store key a - // require KeeperB have store key b and MsgClientA - }), - container.AutoGroupTypes(reflect.TypeOf(Command{})), - container.OnePerScopeTypes(reflect.TypeOf(Handler{})), + func(handlers map[string]Handler, commands []Command, a KeeperA, b KeeperB) { + require.Len(t, handlers, 2) + require.Equal(t, Handler{}, handlers["a"]) + require.Equal(t, Handler{}, handlers["b"]) + require.Len(t, commands, 3) + require.Equal(t, KeeperA{ + key: KVStoreKey{name: "a"}, + }, a) + require.Equal(t, KeeperB{ + key: KVStoreKey{name: "b"}, + msgClientA: MsgClientA{ + key: "b", + }, + }, b) + }, + container.Provide( + ProvideKVStoreKey, + ProvideModuleKey, + ProvideMsgClientA, + ), + container.ProvideWithScope("a", wrapMethod0(ModuleA{})), + container.ProvideWithScope("b", wrapMethod0(ModuleB{})), + )) +} + +func wrapMethod0(module interface{}) interface{} { + methodFn := reflect.TypeOf(module).Method(0).Func.Interface() + ctrInfo, err := container.ExtractProviderDescriptor(methodFn) + if err != nil { + panic(err) + } + + ctrInfo.Inputs = ctrInfo.Inputs[1:] + fn := ctrInfo.Fn + ctrInfo.Fn = func(values []reflect.Value) ([]reflect.Value, error) { + return fn(append([]reflect.Value{reflect.ValueOf(module)}, values...)) + } + return ctrInfo +} + +func TestResolveError(t *testing.T) { + require.Error(t, container.Run( + func(x string) {}, container.Provide( - ProvideKVStoreKey, - ProvideModuleKey, - ProvideMsgClientA, + func(x float64) string { return fmt.Sprintf("%f", x) }, + func(x int) float64 { return float64(x) }, + func(x float32) int { return int(x) }, + ), + )) +} + +func TestCyclic(t *testing.T) { + require.Error(t, container.Run( + func(x string) {}, + container.Provide( + func(x int) float64 { return float64(x) }, + func(x float64) (int, string) { return int(x), "hi" }, + ), + )) +} + +func TestErrorOption(t *testing.T) { + err := container.Run(func() {}, container.Error(fmt.Errorf("an error"))) + require.Error(t, err) +} + +func TestBadCtr(t *testing.T) { + _, err := container.ExtractProviderDescriptor(KeeperA{}) + require.Error(t, err) +} + +func TestInvoker(t *testing.T) { + require.NoError(t, container.Run(func() {})) + require.NoError(t, container.Run(func() error { return nil })) + require.Error(t, container.Run(func() error { return fmt.Errorf("error") })) + require.Error(t, container.Run(func() int { return 0 })) +} + +func TestErrorFunc(t *testing.T) { + _, err := container.ExtractProviderDescriptor( + func() (error, int) { return nil, 0 }, + ) + require.Error(t, err) + + _, err = container.ExtractProviderDescriptor( + func() (int, error) { return 0, nil }, + ) + require.NoError(t, err) + + require.Error(t, + container.Run( + func(x int) { + }, + container.Provide(func() (int, error) { + return 0, fmt.Errorf("the error") + }), + )) + + require.Error(t, + container.Run(func() error { + return fmt.Errorf("the error") + }), "the error") +} + +func TestSimple(t *testing.T) { + require.NoError(t, + container.Run( + func(x int) { + require.Equal(t, 1, x) + }, + container.Provide( + func() int { return 1 }, + ), + ), + ) + + require.Error(t, + container.Run(func(int) {}, + container.Provide( + func() int { return 0 }, + func() int { return 1 }, + ), ), - container.ProvideWithScope(container.NewScope("a"), wrapProvideMethod(ModuleA{})), - container.ProvideWithScope(container.NewScope("b"), wrapProvideMethod(ModuleB{})), ) } -func wrapProvideMethod(module interface{}) container.ConstructorInfo { - method := reflect.TypeOf(module).Method(0) - methodTy := method.Type - var in []reflect.Type - var out []reflect.Type +func TestScoped(t *testing.T) { + require.Error(t, + container.Run(func(int) {}, + container.Provide( + func(container.Scope) int { return 0 }, + ), + ), + ) - for i := 1; i < methodTy.NumIn(); i++ { - in = append(in, methodTy.In(i)) - } - for i := 0; i < methodTy.NumOut(); i++ { - out = append(out, methodTy.Out(i)) - } + require.Error(t, + container.Run(func(float64) {}, + container.Provide( + func(container.Scope) int { return 0 }, + func() int { return 1 }, + ), + container.ProvideWithScope("a", + func(x int) float64 { return float64(x) }, + ), + ), + ) - return container.ConstructorInfo{ - In: in, - Out: out, - Fn: func(values []reflect.Value) []reflect.Value { - values = append([]reflect.Value{reflect.ValueOf(module)}, values...) - return method.Func.Call(values) - }, - Location: container.LocationFromPC(method.Func.Pointer()), - } + require.Error(t, + container.Run(func(float64) {}, + container.Provide( + func() int { return 0 }, + func(container.Scope) int { return 1 }, + ), + container.ProvideWithScope("a", + func(x int) float64 { return float64(x) }, + ), + ), + ) + + require.Error(t, + container.Run(func(float64) {}, + container.Provide( + func(container.Scope) int { return 0 }, + func(container.Scope) int { return 1 }, + ), + container.ProvideWithScope("a", + func(x int) float64 { return float64(x) }, + ), + ), + ) + + require.NoError(t, + container.Run(func(float64) {}, + container.Provide( + func(container.Scope) int { return 0 }, + ), + container.ProvideWithScope("a", + func(x int) float64 { return float64(x) }, + ), + ), + ) + + require.Error(t, + container.Run(func(float64) {}, + container.Provide( + func(container.Scope) int { return 0 }, + ), + container.ProvideWithScope("", + func(x int) float64 { return float64(x) }, + ), + ), + ) + + require.NoError(t, + container.Run(func(float64, float32) {}, + container.Provide( + func(container.Scope) int { return 0 }, + ), + container.ProvideWithScope("a", + func(x int) float64 { return float64(x) }, + func(x int) float32 { return float32(x) }, + ), + ), + "use scope dep twice", + ) +} + +type OnePerScopeInt int + +func (OnePerScopeInt) IsOnePerScopeType() {} + +func TestOnePerScope(t *testing.T) { + require.Error(t, + container.Run( + func(OnePerScopeInt) {}, + ), + "bad input type", + ) + + require.NoError(t, + container.Run( + func(x map[string]OnePerScopeInt, y string) { + require.Equal(t, map[string]OnePerScopeInt{ + "a": 3, + "b": 4, + }, x) + require.Equal(t, "7", y) + }, + container.ProvideWithScope("a", + func() OnePerScopeInt { return 3 }, + ), + container.ProvideWithScope("b", + func() OnePerScopeInt { return 4 }, + ), + container.Provide(func(x map[string]OnePerScopeInt) string { + sum := 0 + for _, v := range x { + sum += int(v) + } + return fmt.Sprintf("%d", sum) + }), + ), + ) + + require.Error(t, + container.Run( + func(map[string]OnePerScopeInt) {}, + container.ProvideWithScope("a", + func() OnePerScopeInt { return 0 }, + func() OnePerScopeInt { return 0 }, + ), + ), + "duplicate", + ) + + require.Error(t, + container.Run( + func(map[string]OnePerScopeInt) {}, + container.Provide( + func() OnePerScopeInt { return 0 }, + ), + ), + "out of scope", + ) + + require.Error(t, + container.Run( + func(map[string]OnePerScopeInt) {}, + container.Provide( + func() map[string]OnePerScopeInt { return nil }, + ), + ), + "bad return type", + ) + + require.NoError(t, + container.Run( + func(map[string]OnePerScopeInt) {}, + ), + "no providers", + ) +} + +type AutoGroupInt int + +func (AutoGroupInt) IsAutoGroupType() {} + +func TestAutoGroup(t *testing.T) { + require.NoError(t, + container.Run( + func(xs []AutoGroupInt, sum string) { + require.Len(t, xs, 2) + require.Contains(t, xs, AutoGroupInt(4)) + require.Contains(t, xs, AutoGroupInt(9)) + require.Equal(t, "13", sum) + }, + container.Provide( + func() AutoGroupInt { return 4 }, + func() AutoGroupInt { return 9 }, + func(xs []AutoGroupInt) string { + sum := 0 + for _, x := range xs { + sum += int(x) + } + return fmt.Sprintf("%d", sum) + }, + ), + ), + ) + + require.Error(t, + container.Run( + func(AutoGroupInt) {}, + container.Provide( + func() AutoGroupInt { return 0 }, + ), + ), + "bad input type", + ) + + require.NoError(t, + container.Run( + func([]AutoGroupInt) {}, + ), + "no providers", + ) +} + +func TestSupply(t *testing.T) { + require.NoError(t, + container.Run(func(x int) { + require.Equal(t, 3, x) + }, + container.Supply(3), + ), + ) + + require.Error(t, + container.Run(func(x int) {}, + container.Supply(3), + container.Provide(func() int { return 4 }), + ), + "can't supply then provide", + ) + + require.Error(t, + container.Run(func(x int) {}, + container.Supply(3), + container.Provide(func() int { return 4 }), + ), + "can't provide then supply", + ) + + require.Error(t, + container.Run(func(x int) {}, + container.Supply(3, 4), + ), + "can't supply twice", + ) +} + +type TestInput struct { + container.In + + X int `optional:"true"` + Y float64 +} + +type TestOutput struct { + container.Out + + X string +} + +func TestStructArgs(t *testing.T) { + require.Error(t, container.Run( + func(input TestInput) {}, + )) + + require.NoError(t, container.Run( + func(input TestInput) { + require.Equal(t, 0, input.X) + require.Equal(t, 1.3, input.Y) + }, + container.Supply(1.3), + )) + + require.NoError(t, container.Run( + func(input TestInput) { + require.Equal(t, 1, input.X) + require.Equal(t, 1.3, input.Y) + }, + container.Supply(1.3, 1), + )) + + require.NoError(t, container.Run( + func(x string) { + require.Equal(t, "A", x) + }, + container.Provide(func() (TestOutput, error) { + return TestOutput{X: "A"}, nil + }), + )) + + require.Error(t, container.Run( + func(x string) {}, + container.Provide(func() (TestOutput, error) { + return TestOutput{}, fmt.Errorf("error") + }), + )) +} + +func TestLogging(t *testing.T) { + var logOut string + var dotGraph string + + outfile, err := ioutil.TempFile("", "out") + require.NoError(t, err) + stdout := os.Stdout + os.Stdout = outfile + defer func() { os.Stdout = stdout }() + defer os.Remove(outfile.Name()) + + graphfile, err := ioutil.TempFile("", "graph") + require.NoError(t, err) + defer os.Remove(graphfile.Name()) + + require.NoError(t, container.Run( + func() {}, + container.Logger(func(s string) { + logOut += s + }), + container.Visualizer(func(g string) { + dotGraph = g + }), + container.LogVisualizer(), + container.FileVisualizer(graphfile.Name(), "svg"), + container.StdoutLogger(), + )) + + require.Contains(t, logOut, "digraph") + require.Contains(t, dotGraph, "digraph") + + outfileContents, err := ioutil.ReadFile(outfile.Name()) + require.NoError(t, err) + require.Contains(t, string(outfileContents), "digraph") + + graphfileContents, err := ioutil.ReadFile(graphfile.Name()) + require.NoError(t, err) + require.Contains(t, string(graphfileContents), "= 0 { + idx = i + } + if i := strings.Index(function[idx:], "."); i >= 0 { + idx += i + } + pname, fname = function[:idx], function[idx+1:] + + // The package may be vendored. + if i := strings.Index(pname, _vendor); i > 0 { + pname = pname[i+len(_vendor):] + } + + // Package names are URL-encoded to avoid ambiguity in the case where the + // package name contains ".git". Otherwise, "foo/bar.git.MyFunction" would + // mean that "git" is the top-level function and "MyFunction" is embedded + // inside it. + if unescaped, err := url.QueryUnescape(pname); err == nil { + pname = unescaped + } + + return } diff --git a/container/one_per_scope.go b/container/one_per_scope.go new file mode 100644 index 0000000000..e153884dce --- /dev/null +++ b/container/one_per_scope.go @@ -0,0 +1,106 @@ +package container + +import ( + "fmt" + "reflect" + + "github.com/goccy/go-graphviz/cgraph" + + "github.com/pkg/errors" +) + +// OnePerScopeType marks a type which +// can have up to one value per scope. All of the values for a one-per-scope type T +// and their respective scopes, can be retrieved by declaring an input parameter map[string]T. +type OnePerScopeType interface { + // IsOnePerScopeType is a marker function just indicates that this is a one-per-scope type. + IsOnePerScopeType() +} + +var onePerScopeTypeType = reflect.TypeOf((*OnePerScopeType)(nil)).Elem() + +func isOnePerScopeType(t reflect.Type) bool { + return t.Implements(onePerScopeTypeType) +} + +func isOnePerScopeMapType(typ reflect.Type) bool { + return typ.Kind() == reflect.Map && isOnePerScopeType(typ.Elem()) && typ.Key().Kind() == reflect.String +} + +type onePerScopeResolver struct { + typ reflect.Type + mapType reflect.Type + providers map[Scope]*simpleProvider + idxMap map[Scope]int + resolved bool + values reflect.Value + graphNode *cgraph.Node +} + +type mapOfOnePerScopeResolver struct { + *onePerScopeResolver +} + +func (o *onePerScopeResolver) resolve(_ *container, _ Scope, _ Location) (reflect.Value, error) { + return reflect.Value{}, errors.Errorf("%v is a one-per-scope type and thus can't be used as an input parameter, instead use %v", o.typ, o.mapType) +} + +func (o *onePerScopeResolver) describeLocation() string { + return fmt.Sprintf("one-per-scope type %v", o.typ) +} + +func (o *mapOfOnePerScopeResolver) resolve(c *container, _ Scope, caller Location) (reflect.Value, error) { + // Log + c.logf("Providing one-per-scope type map %v to %s from:", o.mapType, caller.Name()) + c.indentLogger() + for scope, node := range o.providers { + c.logf("%s: %s", scope.Name(), node.provider.Location) + } + c.dedentLogger() + + // Resolve + if !o.resolved { + res := reflect.MakeMap(o.mapType) + for scope, node := range o.providers { + values, err := node.resolveValues(c) + if err != nil { + return reflect.Value{}, err + } + idx := o.idxMap[scope] + if len(values) < idx { + return reflect.Value{}, errors.Errorf("expected value of type %T at index %d", o.typ, idx) + } + value := values[idx] + res.SetMapIndex(reflect.ValueOf(scope.Name()), value) + } + + o.values = res + o.resolved = true + } + + return o.values, nil +} + +func (o *onePerScopeResolver) addNode(n *simpleProvider, i int) error { + if n.scope == nil { + return errors.Errorf("cannot define a constructor with one-per-scope dependency %v which isn't provided in a scope", o.typ) + } + + if existing, ok := o.providers[n.scope]; ok { + return errors.Errorf("duplicate provision for one-per-scope type %v in scope %s: %s\n\talready provided by %s", + o.typ, n.scope.Name(), n.provider.Location, existing.provider.Location) + } + + o.providers[n.scope] = n + o.idxMap[n.scope] = i + + return nil +} + +func (o *mapOfOnePerScopeResolver) addNode(s *simpleProvider, _ int) error { + return errors.Errorf("%v is a one-per-scope type and thus %v can't be used as an output parameter in %s", o.typ, o.mapType, s.provider.Location) +} + +func (o onePerScopeResolver) typeGraphNode() *cgraph.Node { + return o.graphNode +} diff --git a/container/option.go b/container/option.go index e138872864..9388ab6f64 100644 --- a/container/option.go +++ b/container/option.go @@ -1,10 +1,17 @@ package container -import "reflect" +import ( + "fmt" + "os" + "reflect" + + "github.com/pkg/errors" +) // Option is a functional option for a container. type Option interface { - isOption() + applyConfig(*config) error + applyContainer(*container) error } // Provide creates a container option which registers the provided dependency @@ -12,39 +19,163 @@ type Option interface { // exception of scoped constructors which are called at most once per scope // (see Scope). func Provide(constructors ...interface{}) Option { - panic("TODO") + return containerOption(func(ctr *container) error { + return provide(ctr, nil, constructors) + }) } // ProvideWithScope creates a container option which registers the provided dependency // injection constructors that are to be run in the provided scope. Each constructor // will be called at most once. -func ProvideWithScope(scope Scope, constructors ...interface{}) Option { - panic("TODO") +func ProvideWithScope(scopeName string, constructors ...interface{}) Option { + return containerOption(func(ctr *container) error { + if scopeName == "" { + return errors.Errorf("expected non-empty scope name") + } + + return provide(ctr, ctr.createOrGetScope(scopeName), constructors) + }) } -// AutoGroupTypes creates an option which registers the provided types as types which -// will automatically get grouped together. For a given type T, T and []T can -// be declared as output parameters for constructors as many times within the container -// as desired. All of the provided values for T can be retrieved by declaring an -// []T input parameter. -func AutoGroupTypes(types ...reflect.Type) Option { - panic("TODO") +func provide(ctr *container, scope Scope, constructors []interface{}) error { + for _, c := range constructors { + rc, err := ExtractProviderDescriptor(c) + if err != nil { + return errors.WithStack(err) + } + _, err = ctr.addNode(&rc, scope) + if err != nil { + return errors.WithStack(err) + } + } + return nil } -// OnePerScopeTypes creates an option which registers the provided types as types which -// can have up to one value per scope. All of the values for a one-per-scope type T -// and their respective scopes, can be retrieved by declaring an input parameter map[Scope]T. -func OnePerScopeTypes(types ...reflect.Type) Option { - panic("TODO") +func Supply(values ...interface{}) Option { + loc := LocationFromCaller(1) + return containerOption(func(ctr *container) error { + for _, v := range values { + err := ctr.supply(reflect.ValueOf(v), loc) + if err != nil { + return errors.WithStack(err) + } + } + return nil + }) +} + +// Logger creates an option which provides a logger function which will +// receive all log messages from the container. +func Logger(logger func(string)) Option { + return configOption(func(c *config) error { + logger("Initializing logger") + c.loggers = append(c.loggers, logger) + return nil + }) +} + +func StdoutLogger() Option { + return Logger(func(s string) { + _, _ = fmt.Fprintln(os.Stdout, s) + }) +} + +// Visualizer creates an option which provides a visualizer function which +// will receive a rendering of the container in the Graphiz DOT format +// whenever the container finishes building or fails due to an error. The +// graph is color-coded to aid debugging. +func Visualizer(visualizer func(dotGraph string)) Option { + return configOption(func(c *config) error { + c.addFuncVisualizer(visualizer) + return nil + }) +} + +func LogVisualizer() Option { + return configOption(func(c *config) error { + c.enableLogVisualizer() + return nil + }) +} + +func FileVisualizer(filename, format string) Option { + return configOption(func(c *config) error { + c.addFileVisualizer(filename, format) + return nil + }) +} + +func Debug() Option { + return Options( + StdoutLogger(), + LogVisualizer(), + FileVisualizer("container_dump.svg", "svg"), + ) } // Error creates an option which causes the dependency injection container to // fail immediately. func Error(err error) Option { - panic("TODO") + return configOption(func(*config) error { + return errors.WithStack(err) + }) } // Options creates an option which bundles together other options. func Options(opts ...Option) Option { - panic("TODO") + return option{ + configOption: func(cfg *config) error { + for _, opt := range opts { + err := opt.applyConfig(cfg) + if err != nil { + return errors.WithStack(err) + } + } + return nil + }, + containerOption: func(ctr *container) error { + for _, opt := range opts { + err := opt.applyContainer(ctr) + if err != nil { + return errors.WithStack(err) + } + } + return nil + }, + } } + +type configOption func(*config) error + +func (c configOption) applyConfig(cfg *config) error { + return c(cfg) +} + +func (c configOption) applyContainer(*container) error { + return nil +} + +type containerOption func(*container) error + +func (c containerOption) applyConfig(*config) error { + return nil +} + +func (c containerOption) applyContainer(ctr *container) error { + return c(ctr) +} + +type option struct { + configOption + containerOption +} + +func (o option) applyConfig(c *config) error { + return o.configOption(c) +} + +func (o option) applyContainer(c *container) error { + return o.containerOption(c) +} + +var _, _, _ Option = (*configOption)(nil), (*containerOption)(nil), option{} diff --git a/container/provider_desc.go b/container/provider_desc.go new file mode 100644 index 0000000000..76e1cd2ca8 --- /dev/null +++ b/container/provider_desc.go @@ -0,0 +1,104 @@ +package container + +import ( + "reflect" + + "github.com/pkg/errors" +) + +// ProviderDescriptor defines a special constructor type that is defined by +// reflection. It should be passed as a value to the Provide function. +// Ex: +// option.Provide(ProviderDescriptor{ ... }) +type ProviderDescriptor struct { + // Inputs defines the in parameter types to Fn. + Inputs []ProviderInput + + // Outputs defines the out parameter types to Fn. + Outputs []ProviderOutput + + // Fn defines the constructor function. + Fn func([]reflect.Value) ([]reflect.Value, error) + + // Location defines the source code location to be used for this constructor + // in error messages. + Location Location +} + +type ProviderInput struct { + Type reflect.Type + Optional bool +} + +type ProviderOutput struct { + Type reflect.Type +} + +func ExtractProviderDescriptor(provider interface{}) (ProviderDescriptor, error) { + rctr, ok := provider.(ProviderDescriptor) + if !ok { + var err error + rctr, err = doExtractProviderDescriptor(provider) + if err != nil { + return ProviderDescriptor{}, err + } + } + + return expandStructArgsConstructor(rctr) +} + +func doExtractProviderDescriptor(ctr interface{}) (ProviderDescriptor, error) { + val := reflect.ValueOf(ctr) + typ := val.Type() + if typ.Kind() != reflect.Func { + return ProviderDescriptor{}, errors.Errorf("expected a Func type, got %v", typ) + } + + loc := LocationFromPC(val.Pointer()) + + if typ.IsVariadic() { + return ProviderDescriptor{}, errors.Errorf("variadic function can't be used as a constructor: %s", loc) + } + + numIn := typ.NumIn() + in := make([]ProviderInput, numIn) + for i := 0; i < numIn; i++ { + in[i] = ProviderInput{ + Type: typ.In(i), + } + } + + errIdx := -1 + numOut := typ.NumOut() + var out []ProviderOutput + for i := 0; i < numOut; i++ { + t := typ.Out(i) + if t == errType { + if i != numOut-1 { + return ProviderDescriptor{}, errors.Errorf("output error parameter is not last parameter in function %s", loc) + } + errIdx = i + } else { + out = append(out, ProviderOutput{Type: t}) + } + } + + return ProviderDescriptor{ + Inputs: in, + Outputs: out, + Fn: func(values []reflect.Value) ([]reflect.Value, error) { + res := val.Call(values) + if errIdx >= 0 { + err := res[errIdx] + if !err.IsZero() { + return nil, err.Interface().(error) + } + return res[0:errIdx], nil + } + return res, nil + }, + Location: loc, + }, nil +} + +var errType = reflect.TypeOf((*error)(nil)).Elem() diff --git a/container/provider_desc_test.go b/container/provider_desc_test.go new file mode 100644 index 0000000000..b7dcae8964 --- /dev/null +++ b/container/provider_desc_test.go @@ -0,0 +1,106 @@ +package container_test + +import ( + "reflect" + "testing" + + "github.com/cosmos/cosmos-sdk/container" +) + +type StructIn struct { + container.In + X int + Y float64 `optional:"true"` +} + +type BadOptional struct { + container.In + X int `optional:"foo"` +} + +type StructOut struct { + container.Out + X string + Y []byte +} + +func TestExtractConstructorInfo(t *testing.T) { + var ( + intType = reflect.TypeOf(0) + int16Type = reflect.TypeOf(int16(0)) + int32Type = reflect.TypeOf(int32(0)) + float32Type = reflect.TypeOf(float32(0.0)) + float64Type = reflect.TypeOf(0.0) + stringType = reflect.TypeOf("") + byteTyp = reflect.TypeOf(byte(0)) + bytesTyp = reflect.TypeOf([]byte{}) + ) + + tests := []struct { + name string + ctr interface{} + wantIn []container.ProviderInput + wantOut []container.ProviderOutput + wantErr bool + }{ + { + "simple args", + func(x int, y float64) (string, []byte) { return "", nil }, + []container.ProviderInput{{Type: intType}, {Type: float64Type}}, + []container.ProviderOutput{{Type: stringType}, {Type: bytesTyp}}, + false, + }, + { + "simple args with error", + func(x int, y float64) (string, []byte, error) { return "", nil, nil }, + []container.ProviderInput{{Type: intType}, {Type: float64Type}}, + []container.ProviderOutput{{Type: stringType}, {Type: bytesTyp}}, + false, + }, + { + "struct in and out", + func(_ float32, _ StructIn, _ byte) (int16, StructOut, int32, error) { + return int16(0), StructOut{}, int32(0), nil + }, + []container.ProviderInput{{Type: float32Type}, {Type: intType}, {Type: float64Type, Optional: true}, {Type: byteTyp}}, + []container.ProviderOutput{{Type: int16Type}, {Type: stringType}, {Type: bytesTyp}, {Type: int32Type}}, + false, + }, + { + "error bad position", + func() (error, int) { return nil, 0 }, + nil, + nil, + true, + }, + { + "bad optional", + func(_ BadOptional) int { return 0 }, + nil, + nil, + true, + }, + { + "variadic", + func(...float64) int { return 0 }, + nil, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := container.ExtractProviderDescriptor(tt.ctr) + if (err != nil) != tt.wantErr { + t.Errorf("ExtractProviderDescriptor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got.Inputs, tt.wantIn) { + t.Errorf("ExtractProviderDescriptor() got = %v, want %v", got.Inputs, tt.wantIn) + } + if !reflect.DeepEqual(got.Outputs, tt.wantOut) { + t.Errorf("ExtractProviderDescriptor() got = %v, want %v", got.Outputs, tt.wantOut) + } + }) + } +} diff --git a/container/resolver.go b/container/resolver.go new file mode 100644 index 0000000000..4549f6ba3e --- /dev/null +++ b/container/resolver.go @@ -0,0 +1,14 @@ +package container + +import ( + "reflect" + + "github.com/goccy/go-graphviz/cgraph" +) + +type resolver interface { + addNode(*simpleProvider, int) error + resolve(*container, Scope, Location) (reflect.Value, error) + describeLocation() string + typeGraphNode() *cgraph.Node +} diff --git a/container/run.go b/container/run.go index 9c410b76a2..0e680d45fb 100644 --- a/container/run.go +++ b/container/run.go @@ -1,7 +1,5 @@ package container -import "fmt" - // Run runs the provided invoker function with values provided by the provided // options. It is the single entry point for building and running a dependency // injection container. Invoker should be a function taking one or more @@ -10,5 +8,29 @@ import "fmt" // Ex: // Run(func (x int) error { println(x) }, Provide(func() int { return 1 })) func Run(invoker interface{}, opts ...Option) error { - return fmt.Errorf("not implemented") + opt := Options(opts...) + + cfg, err := newConfig() + if err != nil { + return err + } + + defer cfg.generateGraph() // always generate graph on exit + + err = opt.applyConfig(cfg) + if err != nil { + return err + } + + cfg.logf("Registering providers") + cfg.indentLogger() + ctr := newContainer(cfg) + err = opt.applyContainer(ctr) + if err != nil { + cfg.logf("Failed registering providers because of: %+v", err) + return err + } + cfg.dedentLogger() + + return ctr.run(invoker) } diff --git a/container/scope.go b/container/scope.go index f854201ef6..52a087e507 100644 --- a/container/scope.go +++ b/container/scope.go @@ -1,14 +1,18 @@ package container +import ( + "reflect" +) + // Scope is a special type used to define a provider scope. // // Special scoped constructors can be used with Provide by declaring a -// constructor with its first input parameter of type Scope. These constructors +// constructor with an input parameter of type Scope. These constructors // should construct an unique value for each dependency based on scope and will // be called at most once per scope. // // Constructors passed to ProvideWithScope can also declare an input parameter -// of type Scope to retrieve their scope. +// of type Scope to retrieve their scope but these constructors will be called at most once. type Scope interface { isScope() @@ -18,7 +22,7 @@ type Scope interface { // NewScope creates a new scope with the provided name. Only one scope with a // given name can be created per container. -func NewScope(name string) Scope { +func newScope(name string) Scope { return &scope{name: name} } @@ -26,8 +30,12 @@ type scope struct { name string } -func (s *scope) isScope() {} - func (s *scope) Name() string { return s.name } + +func (s *scope) isScope() {} + +var scopeType = reflect.TypeOf((*Scope)(nil)).Elem() + +var stringType = reflect.TypeOf("") diff --git a/container/scope_dep.go b/container/scope_dep.go new file mode 100644 index 0000000000..e136581403 --- /dev/null +++ b/container/scope_dep.go @@ -0,0 +1,57 @@ +package container + +import ( + "reflect" + + "github.com/goccy/go-graphviz/cgraph" +) + +type scopeDepProvider struct { + provider *ProviderDescriptor + calledForScope map[Scope]bool + valueMap map[Scope][]reflect.Value +} + +type scopeDepResolver struct { + typ reflect.Type + idxInValues int + node *scopeDepProvider + valueMap map[Scope]reflect.Value + graphNode *cgraph.Node +} + +func (s scopeDepResolver) describeLocation() string { + return s.node.provider.Location.String() +} + +func (s scopeDepResolver) resolve(ctr *container, scope Scope, caller Location) (reflect.Value, error) { + // Log + ctr.logf("Providing %v from %s to %s", s.typ, s.node.provider.Location, caller.Name()) + + // Resolve + if val, ok := s.valueMap[scope]; ok { + return val, nil + } + + if !s.node.calledForScope[scope] { + values, err := ctr.call(s.node.provider, scope) + if err != nil { + return reflect.Value{}, err + } + + s.node.valueMap[scope] = values + s.node.calledForScope[scope] = true + } + + value := s.node.valueMap[scope][s.idxInValues] + s.valueMap[scope] = value + return value, nil +} + +func (s scopeDepResolver) addNode(p *simpleProvider, _ int) error { + return duplicateDefinitionError(s.typ, p.provider.Location, s.node.provider.Location.String()) +} + +func (s scopeDepResolver) typeGraphNode() *cgraph.Node { + return s.graphNode +} diff --git a/container/simple.go b/container/simple.go new file mode 100644 index 0000000000..61bc3780a8 --- /dev/null +++ b/container/simple.go @@ -0,0 +1,67 @@ +package container + +import ( + "reflect" + + "github.com/goccy/go-graphviz/cgraph" +) + +type simpleProvider struct { + provider *ProviderDescriptor + called bool + values []reflect.Value + scope Scope +} + +type simpleResolver struct { + node *simpleProvider + idxInValues int + resolved bool + typ reflect.Type + value reflect.Value + graphNode *cgraph.Node +} + +func (s *simpleResolver) describeLocation() string { + return s.node.provider.Location.String() +} + +func (s *simpleProvider) resolveValues(ctr *container) ([]reflect.Value, error) { + if !s.called { + values, err := ctr.call(s.provider, s.scope) + if err != nil { + return nil, err + } + s.values = values + s.called = true + } + + return s.values, nil +} + +func (s *simpleResolver) resolve(c *container, _ Scope, caller Location) (reflect.Value, error) { + // Log + c.logf("Providing %v from %s to %s", s.typ, s.node.provider.Location, caller.Name()) + + // Resolve + if !s.resolved { + values, err := s.node.resolveValues(c) + if err != nil { + return reflect.Value{}, err + } + + value := values[s.idxInValues] + s.value = value + s.resolved = true + } + + return s.value, nil +} + +func (s simpleResolver) addNode(p *simpleProvider, _ int) error { + return duplicateDefinitionError(s.typ, p.provider.Location, s.node.provider.Location.String()) +} + +func (s simpleResolver) typeGraphNode() *cgraph.Node { + return s.graphNode +} diff --git a/container/struct_args.go b/container/struct_args.go index 8c1ae72e82..49df400c40 100644 --- a/container/struct_args.go +++ b/container/struct_args.go @@ -1,11 +1,184 @@ package container -// StructArgs is a type which can be embedded in another struct to alert the -// container that the fields of the struct are dependency inputs/outputs. That -// is, the container will not look to resolve a value with StructArgs embedded -// directly, but will instead use the struct's fields to resolve or populate -// dependencies. Types with embedded StructArgs can be used in both the input -// and output parameter positions. -type StructArgs struct{} +import ( + "reflect" -func (StructArgs) isStructArgs() {} + "github.com/pkg/errors" +) + +// In can be embedded in another struct to inform the container that the +// fields of the struct should be treated as dependency inputs. +// This allows a struct to be used to specify dependencies rather than +// positional parameters. +// +// Fields of the struct may support the following tags: +// optional if set to true, the dependency is optional and will +// be set to its default value if not found, rather than causing +// an error +type In struct{} + +func (In) isIn() {} + +type isIn interface{ isIn() } + +var isInType = reflect.TypeOf((*isIn)(nil)).Elem() + +// Out can be embedded in another struct to inform the container that the +// fields of the struct should be treated as dependency outputs. +// This allows a struct to be used to specify outputs rather than +// positional return values. +type Out struct{} + +func (Out) isOut() {} + +type isOut interface{ isOut() } + +var isOutType = reflect.TypeOf((*isOut)(nil)).Elem() + +func expandStructArgsConstructor(constructor ProviderDescriptor) (ProviderDescriptor, error) { + var foundStructArgs bool + var newIn []ProviderInput + + for _, in := range constructor.Inputs { + if in.Type.AssignableTo(isInType) { + foundStructArgs = true + inTypes, err := structArgsInTypes(in.Type) + if err != nil { + return ProviderDescriptor{}, err + } + newIn = append(newIn, inTypes...) + } else { + newIn = append(newIn, in) + } + } + + var newOut []ProviderOutput + for _, out := range constructor.Outputs { + if out.Type.AssignableTo(isOutType) { + foundStructArgs = true + newOut = append(newOut, structArgsOutTypes(out.Type)...) + } else { + newOut = append(newOut, out) + } + } + + if foundStructArgs { + return ProviderDescriptor{ + Inputs: newIn, + Outputs: newOut, + Fn: expandStructArgsFn(constructor), + Location: constructor.Location, + }, nil + } + + return constructor, nil +} + +func expandStructArgsFn(constructor ProviderDescriptor) func(inputs []reflect.Value) ([]reflect.Value, error) { + fn := constructor.Fn + inParams := constructor.Inputs + outParams := constructor.Outputs + return func(inputs []reflect.Value) ([]reflect.Value, error) { + j := 0 + inputs1 := make([]reflect.Value, len(inParams)) + for i, in := range inParams { + if in.Type.AssignableTo(isInType) { + v, n := buildIn(in.Type, inputs[j:]) + inputs1[i] = v + j += n + } else { + inputs1[i] = inputs[j] + j++ + } + } + + outputs, err := fn(inputs1) + if err != nil { + return nil, err + } + + var outputs1 []reflect.Value + for i, out := range outParams { + if out.Type.AssignableTo(isOutType) { + outputs1 = append(outputs1, extractFromOut(out.Type, outputs[i])...) + } else { + outputs1 = append(outputs1, outputs[i]) + } + } + + return outputs1, nil + } +} + +func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) { + n := typ.NumField() + var res []ProviderInput + for i := 0; i < n; i++ { + f := typ.Field(i) + if f.Type.AssignableTo(isInType) { + continue + } + + var optional bool + optTag, found := f.Tag.Lookup("optional") + if found { + if optTag == "true" { + optional = true + } else { + return nil, errors.Errorf("bad optional tag %q (should be \"true\") in %v", optTag, typ) + } + } + + res = append(res, ProviderInput{ + Type: f.Type, + Optional: optional, + }) + } + return res, nil +} + +func structArgsOutTypes(typ reflect.Type) []ProviderOutput { + n := typ.NumField() + var res []ProviderOutput + for i := 0; i < n; i++ { + f := typ.Field(i) + if f.Type.AssignableTo(isOutType) { + continue + } + + res = append(res, ProviderOutput{ + Type: f.Type, + }) + } + return res +} + +func buildIn(typ reflect.Type, values []reflect.Value) (reflect.Value, int) { + numFields := typ.NumField() + j := 0 + res := reflect.New(typ) + for i := 0; i < numFields; i++ { + f := typ.Field(i) + if f.Type.AssignableTo(isInType) { + continue + } + + res.Elem().Field(i).Set(values[j]) + j++ + } + return res.Elem(), j +} + +func extractFromOut(typ reflect.Type, value reflect.Value) []reflect.Value { + numFields := typ.NumField() + var res []reflect.Value + for i := 0; i < numFields; i++ { + f := typ.Field(i) + if f.Type.AssignableTo(isOutType) { + continue + } + + res = append(res, value.Field(i)) + } + return res +} diff --git a/container/supply.go b/container/supply.go new file mode 100644 index 0000000000..eec99f0c99 --- /dev/null +++ b/container/supply.go @@ -0,0 +1,31 @@ +package container + +import ( + "reflect" + + "github.com/goccy/go-graphviz/cgraph" +) + +type supplyResolver struct { + typ reflect.Type + value reflect.Value + loc Location + graphNode *cgraph.Node +} + +func (s supplyResolver) describeLocation() string { + return s.loc.String() +} + +func (s supplyResolver) addNode(provider *simpleProvider, _ int) error { + return duplicateDefinitionError(s.typ, provider.provider.Location, s.loc.String()) +} + +func (s supplyResolver) resolve(c *container, _ Scope, caller Location) (reflect.Value, error) { + c.logf("Supplying %v from %s to %s", s.typ, s.loc, caller.Name()) + return s.value, nil +} + +func (s supplyResolver) typeGraphNode() *cgraph.Node { + return s.graphNode +}