diff --git a/depinject/config.go b/depinject/config.go index ff947bc627..f43a4bac16 100644 --- a/depinject/config.go +++ b/depinject/config.go @@ -36,7 +36,7 @@ func ProvideInModule(moduleName string, providers ...interface{}) Config { return errors.Errorf("expected non-empty module name") } - return provide(ctr, ctr.createOrGetModuleKey(moduleName), providers) + return provide(ctr, ctr.moduleKeyContext.createOrGetModuleKey(moduleName), providers) }) } @@ -83,7 +83,7 @@ func InvokeInModule(moduleName string, invokers ...interface{}) Config { return errors.Errorf("expected non-empty module name") } - return invoke(ctr, ctr.createOrGetModuleKey(moduleName), invokers) + return invoke(ctr, ctr.moduleKeyContext.createOrGetModuleKey(moduleName), invokers) }) } diff --git a/depinject/container.go b/depinject/container.go index d2998f5858..148a8c28cc 100644 --- a/depinject/container.go +++ b/depinject/container.go @@ -17,7 +17,7 @@ type container struct { interfaceBindings map[string]interfaceBinding invokers []invoker - moduleKeys map[string]*moduleKey + moduleKeyContext *ModuleKeyContext resolveStack []resolveFrame callerStack []Location @@ -48,7 +48,7 @@ func newContainer(cfg *debugConfig) *container { return &container{ debugConfig: cfg, resolvers: map[string]resolver{}, - moduleKeys: map[string]*moduleKey{}, + moduleKeyContext: &ModuleKeyContext{}, interfaceBindings: map[string]interfaceBinding{}, callerStack: nil, callerMap: map[Location]bool{}, @@ -498,15 +498,6 @@ func (c *container) build(loc Location, outputs ...interface{}) error { return nil } -func (c container) createOrGetModuleKey(name string) *moduleKey { - if s, ok := c.moduleKeys[name]; ok { - return s - } - s := &moduleKey{name} - c.moduleKeys[name] = s - return s -} - func (c container) formatResolveStack() string { buf := &bytes.Buffer{} _, _ = fmt.Fprintf(buf, "\twhile resolving:\n") diff --git a/depinject/module_key.go b/depinject/module_key.go index ff8902a442..8fba64f943 100644 --- a/depinject/module_key.go +++ b/depinject/module_key.go @@ -25,10 +25,18 @@ type moduleKey struct { name string } +// Name returns the module key's name. func (k ModuleKey) Name() string { return k.name } +// Equals checks if the module key is equal to another module key. Module keys +// will be equal only if they have the same name and come from the same +// ModuleKeyContext. +func (k ModuleKey) Equals(other ModuleKey) bool { + return k.moduleKey == other.moduleKey +} + var moduleKeyType = reflect.TypeOf(ModuleKey{}) // OwnModuleKey is a type which can be used in a module to retrieve its own @@ -36,3 +44,34 @@ var moduleKeyType = reflect.TypeOf(ModuleKey{}) type OwnModuleKey ModuleKey var ownModuleKeyType = reflect.TypeOf((*OwnModuleKey)(nil)).Elem() + +// ModuleKeyContext defines a context for non-forgeable module keys. +// All module keys with the same name from the same context should be equal +// and module keys with the same name but from different contexts should be +// not equal. +// +// Usage: +// moduleKeyCtx := &ModuleKeyContext{} +// fooKey := moduleKeyCtx.For("foo") +type ModuleKeyContext struct { + moduleKeys map[string]*moduleKey +} + +// For returns a new or existing module key for the given name within the context. +func (c *ModuleKeyContext) For(moduleName string) ModuleKey { + return ModuleKey{c.createOrGetModuleKey(moduleName)} +} + +func (c *ModuleKeyContext) createOrGetModuleKey(moduleName string) *moduleKey { + if c.moduleKeys == nil { + c.moduleKeys = map[string]*moduleKey{} + } + + if k, ok := c.moduleKeys[moduleName]; ok { + return k + } + + k := &moduleKey{moduleName} + c.moduleKeys[moduleName] = k + return k +} diff --git a/depinject/module_key_test.go b/depinject/module_key_test.go new file mode 100644 index 0000000000..5a226c0e36 --- /dev/null +++ b/depinject/module_key_test.go @@ -0,0 +1,25 @@ +package depinject + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestModuleKeyEquals(t *testing.T) { + ctx := &ModuleKeyContext{} + + fooKey := ctx.For("foo") + fooKey2 := ctx.For("foo") + // two foo keys from the same context should be equal + assert.Assert(t, fooKey.Equals(fooKey2)) + + barKey := ctx.For("bar") + // foo and bar keys should be not equal + assert.Assert(t, !fooKey.Equals(barKey)) + + ctx2 := &ModuleKeyContext{} + fooKeyFromAnotherCtx := ctx2.For("foo") + // foo keys from different context should be not equal + assert.Assert(t, !fooKey.Equals(fooKeyFromAnotherCtx)) +}