Refactor config to return an error instead of aborting
This commit is contained in:
parent
61cf7af2ec
commit
aa52088ba7
@ -17,6 +17,14 @@ func parseEnvironment(context *do.Context) string {
|
||||
return environment
|
||||
}
|
||||
|
||||
func loadConfig(environment string) config.Config {
|
||||
cfg, err := config.NewConfig(environment)
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading config\n%v", err)
|
||||
}
|
||||
return *cfg
|
||||
}
|
||||
|
||||
func tasks(p *do.Project) {
|
||||
|
||||
p.Task("run", nil, func(context *do.Context) {
|
||||
@ -37,7 +45,7 @@ func tasks(p *do.Project) {
|
||||
|
||||
p.Task("migrate", nil, func(context *do.Context) {
|
||||
environment := parseEnvironment(context)
|
||||
cfg := config.NewConfig(environment)
|
||||
cfg := loadConfig(environment)
|
||||
connectString := config.DbConnectionString(cfg.Database)
|
||||
migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations up", connectString)
|
||||
dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name)
|
||||
@ -47,7 +55,7 @@ func tasks(p *do.Project) {
|
||||
|
||||
p.Task("rollback", nil, func(context *do.Context) {
|
||||
environment := parseEnvironment(context)
|
||||
cfg := config.NewConfig(environment)
|
||||
cfg := loadConfig(environment)
|
||||
connectString := config.DbConnectionString(cfg.Database)
|
||||
migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations down 1", connectString)
|
||||
dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name)
|
||||
|
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"log"
|
||||
|
||||
"fmt"
|
||||
@ -18,7 +17,10 @@ func main() {
|
||||
environment := flag.String("environment", "", "Environment name")
|
||||
startingBlockNumber := flag.Int("starting-number", -1, "First block to fill from")
|
||||
flag.Parse()
|
||||
cfg := config.NewConfig(*environment)
|
||||
cfg, err := config.NewConfig(*environment)
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading config\n%v", err)
|
||||
}
|
||||
|
||||
blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath)
|
||||
connectString := config.DbConnectionString(cfg.Database)
|
||||
|
@ -18,7 +18,10 @@ import (
|
||||
func main() {
|
||||
environment := flag.String("environment", "", "Environment name")
|
||||
flag.Parse()
|
||||
cfg := config.NewConfig(*environment)
|
||||
cfg, err := config.NewConfig(*environment)
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading config\n%v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Client Path ", cfg.Client.IPCPath)
|
||||
blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath)
|
||||
|
@ -18,7 +18,7 @@ var _ = Describe("Reading from the Geth blockchain", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
observer = fakes.NewFakeBlockchainObserver()
|
||||
cfg := config.NewConfig("private")
|
||||
cfg, _ := config.NewConfig("private")
|
||||
blockchain = geth.NewGethBlockchain(cfg.Client.IPCPath)
|
||||
observers := []core.BlockchainObserver{observer}
|
||||
listener = blockchain_listener.NewBlockchainListener(blockchain, observers)
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"path"
|
||||
"runtime"
|
||||
|
||||
"errors"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
@ -19,14 +21,22 @@ type Config struct {
|
||||
Client Client
|
||||
}
|
||||
|
||||
func NewConfig(environment string) Config {
|
||||
var NewErrConfigFileNotFound = func(environment string) error {
|
||||
return errors.New(fmt.Sprintf("No configuration found for environment: %v", environment))
|
||||
}
|
||||
|
||||
func NewConfig(environment string) (*Config, error) {
|
||||
filenameWithExtension := fmt.Sprintf("%s.toml", environment)
|
||||
absolutePath := filepath.Join(ProjectRoot(), "pkg", "config", "environments", filenameWithExtension)
|
||||
config := parseConfigFile(absolutePath)
|
||||
config, err := parseConfigFile(absolutePath)
|
||||
if err != nil {
|
||||
return nil, NewErrConfigFileNotFound(environment)
|
||||
} else {
|
||||
if !filepath.IsAbs(config.Client.IPCPath) {
|
||||
config.Client.IPCPath = filepath.Join(ProjectRoot(), config.Client.IPCPath)
|
||||
}
|
||||
return config
|
||||
return config, nil
|
||||
}
|
||||
}
|
||||
|
||||
func ProjectRoot() string {
|
||||
@ -34,15 +44,15 @@ func ProjectRoot() string {
|
||||
return path.Join(path.Dir(filename), "..", "..")
|
||||
}
|
||||
|
||||
func parseConfigFile(configfile string) Config {
|
||||
func parseConfigFile(filePath string) (*Config, error) {
|
||||
var cfg Config
|
||||
_, err := os.Stat(configfile)
|
||||
_, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
log.Fatal("Config file is missing: ", configfile)
|
||||
}
|
||||
|
||||
if _, err := toml.DecodeFile(configfile, &cfg); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
if _, err := toml.DecodeFile(filePath, &cfg); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cfg
|
||||
return &cfg, err
|
||||
}
|
||||
}
|
||||
|
@ -11,8 +11,9 @@ import (
|
||||
var _ = Describe("Loading the config", func() {
|
||||
|
||||
It("reads the private config using the environment", func() {
|
||||
privateConfig := config.NewConfig("private")
|
||||
privateConfig, err := config.NewConfig("private")
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
Expect(privateConfig.Database.Hostname).To(Equal("localhost"))
|
||||
Expect(privateConfig.Database.Name).To(Equal("vulcanize_private"))
|
||||
Expect(privateConfig.Database.Port).To(Equal(5432))
|
||||
@ -20,4 +21,11 @@ var _ = Describe("Loading the config", func() {
|
||||
Expect(privateConfig.Client.IPCPath).To(Equal(expandedPath))
|
||||
})
|
||||
|
||||
It("returns an error when there is no matching config file", func() {
|
||||
config, err := config.NewConfig("bad-config")
|
||||
|
||||
Expect(config).To(BeNil())
|
||||
Expect(err).NotTo(BeNil())
|
||||
})
|
||||
|
||||
})
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
|
||||
"errors"
|
||||
|
||||
"github.com/8thlight/vulcanizedb/pkg/core"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -217,7 +217,8 @@ var _ = Describe("Repositories", func() {
|
||||
|
||||
Describe("Postgres repository", func() {
|
||||
It("connects to the database", func() {
|
||||
pgConfig := config.DbConnectionString(config.NewConfig("private").Database)
|
||||
cfg, _ := config.NewConfig("private")
|
||||
pgConfig := config.DbConnectionString(cfg.Database)
|
||||
db, err := sqlx.Connect("postgres", pgConfig)
|
||||
Expect(err).Should(BeNil())
|
||||
Expect(db).ShouldNot(BeNil())
|
||||
@ -232,7 +233,8 @@ var _ = Describe("Repositories", func() {
|
||||
Nonce: badNonce,
|
||||
Transactions: []core.Transaction{},
|
||||
}
|
||||
pgConfig := config.DbConnectionString(config.NewConfig("private").Database)
|
||||
cfg, _ := config.NewConfig("private")
|
||||
pgConfig := config.DbConnectionString(cfg.Database)
|
||||
db, _ := sqlx.Connect("postgres", pgConfig)
|
||||
Expect(db).ShouldNot(BeNil())
|
||||
repository := repositories.NewPostgres(db)
|
||||
@ -249,7 +251,8 @@ var _ = Describe("Repositories", func() {
|
||||
//badHash violates db To field length
|
||||
badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100))
|
||||
badTransaction := core.Transaction{To: badHash}
|
||||
pgConfig := config.DbConnectionString(config.NewConfig("private").Database)
|
||||
cfg, _ := config.NewConfig("private")
|
||||
pgConfig := config.DbConnectionString(cfg.Database)
|
||||
block := core.Block{
|
||||
Number: 123,
|
||||
Transactions: []core.Transaction{badTransaction},
|
||||
@ -266,7 +269,8 @@ var _ = Describe("Repositories", func() {
|
||||
})
|
||||
|
||||
AssertRepositoryBehavior(func() repositories.Repository {
|
||||
pgConfig := config.DbConnectionString(config.NewConfig("private").Database)
|
||||
cfg, _ := config.NewConfig("private")
|
||||
pgConfig := config.DbConnectionString(cfg.Database)
|
||||
db, _ := sqlx.Connect("postgres", pgConfig)
|
||||
db.MustExec("DELETE FROM transactions")
|
||||
db.MustExec("DELETE FROM blocks")
|
||||
|
Loading…
Reference in New Issue
Block a user