diff --git a/Gododir/main.go b/Gododir/main.go index b826f067..b20b69be 100644 --- a/Gododir/main.go +++ b/Gododir/main.go @@ -17,6 +17,14 @@ func parseEnvironment(context *do.Context) string { return environment } +func loadConfig(environment string) config.Config { + cfg, err := config.NewConfig(environment) + if err != nil { + log.Fatalf("Error loading config\n%v", err) + } + return *cfg +} + func tasks(p *do.Project) { p.Task("run", nil, func(context *do.Context) { @@ -37,7 +45,7 @@ func tasks(p *do.Project) { p.Task("migrate", nil, func(context *do.Context) { environment := parseEnvironment(context) - cfg := config.NewConfig(environment) + cfg := loadConfig(environment) connectString := config.DbConnectionString(cfg.Database) migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations up", connectString) dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name) @@ -47,7 +55,7 @@ func tasks(p *do.Project) { p.Task("rollback", nil, func(context *do.Context) { environment := parseEnvironment(context) - cfg := config.NewConfig(environment) + cfg := loadConfig(environment) connectString := config.DbConnectionString(cfg.Database) migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations down 1", connectString) dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name) diff --git a/cmd/populate_blocks/main.go b/cmd/populate_blocks/main.go index 37d38699..19250fa7 100644 --- a/cmd/populate_blocks/main.go +++ b/cmd/populate_blocks/main.go @@ -2,7 +2,6 @@ package main import ( "flag" - "log" "fmt" @@ -18,7 +17,10 @@ func main() { environment := flag.String("environment", "", "Environment name") startingBlockNumber := flag.Int("starting-number", -1, "First block to fill from") flag.Parse() - cfg := config.NewConfig(*environment) + cfg, err := config.NewConfig(*environment) + if err != nil { + log.Fatalf("Error loading config\n%v", err) + } blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath) connectString := config.DbConnectionString(cfg.Database) diff --git a/cmd/run/main.go b/cmd/run/main.go index 489107d7..4b63e52f 100644 --- a/cmd/run/main.go +++ b/cmd/run/main.go @@ -18,7 +18,10 @@ import ( func main() { environment := flag.String("environment", "", "Environment name") flag.Parse() - cfg := config.NewConfig(*environment) + cfg, err := config.NewConfig(*environment) + if err != nil { + log.Fatalf("Error loading config\n%v", err) + } fmt.Println("Client Path ", cfg.Client.IPCPath) blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath) diff --git a/integration_test/geth_blockchain_test.go b/integration_test/geth_blockchain_test.go index e1d2979b..2455f79a 100644 --- a/integration_test/geth_blockchain_test.go +++ b/integration_test/geth_blockchain_test.go @@ -18,7 +18,7 @@ var _ = Describe("Reading from the Geth blockchain", func() { BeforeEach(func() { observer = fakes.NewFakeBlockchainObserver() - cfg := config.NewConfig("private") + cfg, _ := config.NewConfig("private") blockchain = geth.NewGethBlockchain(cfg.Client.IPCPath) observers := []core.BlockchainObserver{observer} listener = blockchain_listener.NewBlockchainListener(blockchain, observers) diff --git a/pkg/config/config.go b/pkg/config/config.go index c2a5f100..a6692ab3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -11,6 +11,8 @@ import ( "path" "runtime" + "errors" + "github.com/BurntSushi/toml" ) @@ -19,14 +21,22 @@ type Config struct { Client Client } -func NewConfig(environment string) Config { +var NewErrConfigFileNotFound = func(environment string) error { + return errors.New(fmt.Sprintf("No configuration found for environment: %v", environment)) +} + +func NewConfig(environment string) (*Config, error) { filenameWithExtension := fmt.Sprintf("%s.toml", environment) absolutePath := filepath.Join(ProjectRoot(), "pkg", "config", "environments", filenameWithExtension) - config := parseConfigFile(absolutePath) - if !filepath.IsAbs(config.Client.IPCPath) { - config.Client.IPCPath = filepath.Join(ProjectRoot(), config.Client.IPCPath) + config, err := parseConfigFile(absolutePath) + if err != nil { + return nil, NewErrConfigFileNotFound(environment) + } else { + if !filepath.IsAbs(config.Client.IPCPath) { + config.Client.IPCPath = filepath.Join(ProjectRoot(), config.Client.IPCPath) + } + return config, nil } - return config } func ProjectRoot() string { @@ -34,15 +44,15 @@ func ProjectRoot() string { return path.Join(path.Dir(filename), "..", "..") } -func parseConfigFile(configfile string) Config { +func parseConfigFile(filePath string) (*Config, error) { var cfg Config - _, err := os.Stat(configfile) + _, err := os.Stat(filePath) if err != nil { - log.Fatal("Config file is missing: ", configfile) + return nil, err + } else { + if _, err := toml.DecodeFile(filePath, &cfg); err != nil { + log.Fatal(err) + } + return &cfg, err } - - if _, err := toml.DecodeFile(configfile, &cfg); err != nil { - log.Fatal(err) - } - return cfg } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index c6236937..f7f10e58 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -11,8 +11,9 @@ import ( var _ = Describe("Loading the config", func() { It("reads the private config using the environment", func() { - privateConfig := config.NewConfig("private") + privateConfig, err := config.NewConfig("private") + Expect(err).To(BeNil()) Expect(privateConfig.Database.Hostname).To(Equal("localhost")) Expect(privateConfig.Database.Name).To(Equal("vulcanize_private")) Expect(privateConfig.Database.Port).To(Equal(5432)) @@ -20,4 +21,11 @@ var _ = Describe("Loading the config", func() { Expect(privateConfig.Client.IPCPath).To(Equal(expandedPath)) }) + It("returns an error when there is no matching config file", func() { + config, err := config.NewConfig("bad-config") + + Expect(config).To(BeNil()) + Expect(err).NotTo(BeNil()) + }) + }) diff --git a/pkg/repositories/postgres.go b/pkg/repositories/postgres.go index 9d1ea7c3..57b88b09 100644 --- a/pkg/repositories/postgres.go +++ b/pkg/repositories/postgres.go @@ -6,6 +6,7 @@ import ( "context" "errors" + "github.com/8thlight/vulcanizedb/pkg/core" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" diff --git a/pkg/repositories/repository_test.go b/pkg/repositories/repository_test.go index df7ea8e7..797697b0 100644 --- a/pkg/repositories/repository_test.go +++ b/pkg/repositories/repository_test.go @@ -217,7 +217,8 @@ var _ = Describe("Repositories", func() { Describe("Postgres repository", func() { It("connects to the database", func() { - pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + cfg, _ := config.NewConfig("private") + pgConfig := config.DbConnectionString(cfg.Database) db, err := sqlx.Connect("postgres", pgConfig) Expect(err).Should(BeNil()) Expect(db).ShouldNot(BeNil()) @@ -232,7 +233,8 @@ var _ = Describe("Repositories", func() { Nonce: badNonce, Transactions: []core.Transaction{}, } - pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + cfg, _ := config.NewConfig("private") + pgConfig := config.DbConnectionString(cfg.Database) db, _ := sqlx.Connect("postgres", pgConfig) Expect(db).ShouldNot(BeNil()) repository := repositories.NewPostgres(db) @@ -249,7 +251,8 @@ var _ = Describe("Repositories", func() { //badHash violates db To field length badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) badTransaction := core.Transaction{To: badHash} - pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + cfg, _ := config.NewConfig("private") + pgConfig := config.DbConnectionString(cfg.Database) block := core.Block{ Number: 123, Transactions: []core.Transaction{badTransaction}, @@ -266,7 +269,8 @@ var _ = Describe("Repositories", func() { }) AssertRepositoryBehavior(func() repositories.Repository { - pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + cfg, _ := config.NewConfig("private") + pgConfig := config.DbConnectionString(cfg.Database) db, _ := sqlx.Connect("postgres", pgConfig) db.MustExec("DELETE FROM transactions") db.MustExec("DELETE FROM blocks")