From d025cf09f801be261af3e46f0fff692ebbdc9492 Mon Sep 17 00:00:00 2001 From: Matt Kocubinski Date: Thu, 7 Jul 2022 16:40:32 -0500 Subject: [PATCH] fix(depinject): move non-thread safe write (#12484) * fix(depinject): move non-thread safe write * remove whitespace * Push invoker descriptor mutation down one more layer Co-authored-by: Aleksandr Bezobchuk --- core/appmodule/option.go | 2 +- depinject/config.go | 2 +- depinject/container.go | 6 ------ depinject/provider_desc.go | 20 ++++++++++++++++++++ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/core/appmodule/option.go b/core/appmodule/option.go index 53121e703c..c55201bfd3 100644 --- a/core/appmodule/option.go +++ b/core/appmodule/option.go @@ -40,7 +40,7 @@ func Provide(providers ...interface{}) Option { func Invoke(invokers ...interface{}) Option { return funcOption(func(initializer *internal.ModuleInitializer) error { for _, invoker := range invokers { - desc, err := depinject.ExtractProviderDescriptor(invoker) + desc, err := depinject.ExtractInvokerDescriptor(invoker) if err != nil { return err } diff --git a/depinject/config.go b/depinject/config.go index 54deb9e8d2..0584c27b47 100644 --- a/depinject/config.go +++ b/depinject/config.go @@ -75,7 +75,7 @@ func InvokeInModule(moduleName string, invokers ...interface{}) Config { func invoke(ctr *container, key *moduleKey, invokers []interface{}) error { for _, c := range invokers { - rc, err := ExtractProviderDescriptor(c) + rc, err := ExtractInvokerDescriptor(c) if err != nil { return errors.WithStack(err) } diff --git a/depinject/container.go b/depinject/container.go index 08fa30363b..189405ece4 100644 --- a/depinject/container.go +++ b/depinject/container.go @@ -365,12 +365,6 @@ func (c *container) addInvoker(provider *ProviderDescriptor, key *moduleKey) err return fmt.Errorf("invoker function %s should not return any outputs", provider.Location) } - // make all inputs optional - for i, input := range provider.Inputs { - input.Optional = true - provider.Inputs[i] = input - } - c.invokers = append(c.invokers, invoker{ fn: provider, modKey: key, diff --git a/depinject/provider_desc.go b/depinject/provider_desc.go index bf39e2a318..16a8f3c797 100644 --- a/depinject/provider_desc.go +++ b/depinject/provider_desc.go @@ -47,6 +47,26 @@ func ExtractProviderDescriptor(provider interface{}) (ProviderDescriptor, error) return expandStructArgsProvider(rctr) } +func ExtractInvokerDescriptor(provider interface{}) (ProviderDescriptor, error) { + rctr, ok := provider.(ProviderDescriptor) + if !ok { + var err error + rctr, err = doExtractProviderDescriptor(provider) + + // mark all inputs as optional + for i, input := range rctr.Inputs { + input.Optional = true + rctr.Inputs[i] = input + } + + if err != nil { + return ProviderDescriptor{}, err + } + } + + return expandStructArgsProvider(rctr) +} + func doExtractProviderDescriptor(ctr interface{}) (ProviderDescriptor, error) { val := reflect.ValueOf(ctr) typ := val.Type()