Merge pull request #65 from 8thlight/config

Refactor config to return an error instead of aborting
This commit is contained in:
Matt K 2017-11-09 12:44:21 -06:00 committed by GitHub
commit 54ed808a61
8 changed files with 60 additions and 24 deletions

View File

@ -17,6 +17,14 @@ func parseEnvironment(context *do.Context) string {
return environment return environment
} }
func loadConfig(environment string) config.Config {
cfg, err := config.NewConfig(environment)
if err != nil {
log.Fatalf("Error loading config\n%v", err)
}
return *cfg
}
func tasks(p *do.Project) { func tasks(p *do.Project) {
p.Task("run", nil, func(context *do.Context) { p.Task("run", nil, func(context *do.Context) {
@ -37,7 +45,7 @@ func tasks(p *do.Project) {
p.Task("migrate", nil, func(context *do.Context) { p.Task("migrate", nil, func(context *do.Context) {
environment := parseEnvironment(context) environment := parseEnvironment(context)
cfg := config.NewConfig(environment) cfg := loadConfig(environment)
connectString := config.DbConnectionString(cfg.Database) connectString := config.DbConnectionString(cfg.Database)
migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations up", connectString) migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations up", connectString)
dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name) dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name)
@ -47,7 +55,7 @@ func tasks(p *do.Project) {
p.Task("rollback", nil, func(context *do.Context) { p.Task("rollback", nil, func(context *do.Context) {
environment := parseEnvironment(context) environment := parseEnvironment(context)
cfg := config.NewConfig(environment) cfg := loadConfig(environment)
connectString := config.DbConnectionString(cfg.Database) connectString := config.DbConnectionString(cfg.Database)
migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations down 1", connectString) migrate := fmt.Sprintf("migrate -database '%s' -path ./db/migrations down 1", connectString)
dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name) dumpSchema := fmt.Sprintf("pg_dump -O -s %s > ./db/schema.sql", cfg.Database.Name)

View File

@ -2,7 +2,6 @@ package main
import ( import (
"flag" "flag"
"log" "log"
"fmt" "fmt"
@ -18,7 +17,10 @@ func main() {
environment := flag.String("environment", "", "Environment name") environment := flag.String("environment", "", "Environment name")
startingBlockNumber := flag.Int("starting-number", -1, "First block to fill from") startingBlockNumber := flag.Int("starting-number", -1, "First block to fill from")
flag.Parse() flag.Parse()
cfg := config.NewConfig(*environment) cfg, err := config.NewConfig(*environment)
if err != nil {
log.Fatalf("Error loading config\n%v", err)
}
blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath) blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath)
connectString := config.DbConnectionString(cfg.Database) connectString := config.DbConnectionString(cfg.Database)

View File

@ -18,7 +18,10 @@ import (
func main() { func main() {
environment := flag.String("environment", "", "Environment name") environment := flag.String("environment", "", "Environment name")
flag.Parse() flag.Parse()
cfg := config.NewConfig(*environment) cfg, err := config.NewConfig(*environment)
if err != nil {
log.Fatalf("Error loading config\n%v", err)
}
fmt.Println("Client Path ", cfg.Client.IPCPath) fmt.Println("Client Path ", cfg.Client.IPCPath)
blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath) blockchain := geth.NewGethBlockchain(cfg.Client.IPCPath)

View File

@ -18,7 +18,7 @@ var _ = Describe("Reading from the Geth blockchain", func() {
BeforeEach(func() { BeforeEach(func() {
observer = fakes.NewFakeBlockchainObserver() observer = fakes.NewFakeBlockchainObserver()
cfg := config.NewConfig("private") cfg, _ := config.NewConfig("private")
blockchain = geth.NewGethBlockchain(cfg.Client.IPCPath) blockchain = geth.NewGethBlockchain(cfg.Client.IPCPath)
observers := []core.BlockchainObserver{observer} observers := []core.BlockchainObserver{observer}
listener = blockchain_listener.NewBlockchainListener(blockchain, observers) listener = blockchain_listener.NewBlockchainListener(blockchain, observers)

View File

@ -11,6 +11,8 @@ import (
"path" "path"
"runtime" "runtime"
"errors"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
) )
@ -19,14 +21,22 @@ type Config struct {
Client Client Client Client
} }
func NewConfig(environment string) Config { var NewErrConfigFileNotFound = func(environment string) error {
return errors.New(fmt.Sprintf("No configuration found for environment: %v", environment))
}
func NewConfig(environment string) (*Config, error) {
filenameWithExtension := fmt.Sprintf("%s.toml", environment) filenameWithExtension := fmt.Sprintf("%s.toml", environment)
absolutePath := filepath.Join(ProjectRoot(), "pkg", "config", "environments", filenameWithExtension) absolutePath := filepath.Join(ProjectRoot(), "pkg", "config", "environments", filenameWithExtension)
config := parseConfigFile(absolutePath) config, err := parseConfigFile(absolutePath)
if !filepath.IsAbs(config.Client.IPCPath) { if err != nil {
config.Client.IPCPath = filepath.Join(ProjectRoot(), config.Client.IPCPath) return nil, NewErrConfigFileNotFound(environment)
} else {
if !filepath.IsAbs(config.Client.IPCPath) {
config.Client.IPCPath = filepath.Join(ProjectRoot(), config.Client.IPCPath)
}
return config, nil
} }
return config
} }
func ProjectRoot() string { func ProjectRoot() string {
@ -34,15 +44,15 @@ func ProjectRoot() string {
return path.Join(path.Dir(filename), "..", "..") return path.Join(path.Dir(filename), "..", "..")
} }
func parseConfigFile(configfile string) Config { func parseConfigFile(filePath string) (*Config, error) {
var cfg Config var cfg Config
_, err := os.Stat(configfile) _, err := os.Stat(filePath)
if err != nil { if err != nil {
log.Fatal("Config file is missing: ", configfile) return nil, err
} else {
if _, err := toml.DecodeFile(filePath, &cfg); err != nil {
log.Fatal(err)
}
return &cfg, err
} }
if _, err := toml.DecodeFile(configfile, &cfg); err != nil {
log.Fatal(err)
}
return cfg
} }

View File

@ -11,8 +11,9 @@ import (
var _ = Describe("Loading the config", func() { var _ = Describe("Loading the config", func() {
It("reads the private config using the environment", func() { It("reads the private config using the environment", func() {
privateConfig := config.NewConfig("private") privateConfig, err := config.NewConfig("private")
Expect(err).To(BeNil())
Expect(privateConfig.Database.Hostname).To(Equal("localhost")) Expect(privateConfig.Database.Hostname).To(Equal("localhost"))
Expect(privateConfig.Database.Name).To(Equal("vulcanize_private")) Expect(privateConfig.Database.Name).To(Equal("vulcanize_private"))
Expect(privateConfig.Database.Port).To(Equal(5432)) Expect(privateConfig.Database.Port).To(Equal(5432))
@ -20,4 +21,11 @@ var _ = Describe("Loading the config", func() {
Expect(privateConfig.Client.IPCPath).To(Equal(expandedPath)) Expect(privateConfig.Client.IPCPath).To(Equal(expandedPath))
}) })
It("returns an error when there is no matching config file", func() {
config, err := config.NewConfig("bad-config")
Expect(config).To(BeNil())
Expect(err).NotTo(BeNil())
})
}) })

View File

@ -6,6 +6,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/8thlight/vulcanizedb/pkg/core" "github.com/8thlight/vulcanizedb/pkg/core"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/lib/pq" _ "github.com/lib/pq"

View File

@ -217,7 +217,8 @@ var _ = Describe("Repositories", func() {
Describe("Postgres repository", func() { Describe("Postgres repository", func() {
It("connects to the database", func() { It("connects to the database", func() {
pgConfig := config.DbConnectionString(config.NewConfig("private").Database) cfg, _ := config.NewConfig("private")
pgConfig := config.DbConnectionString(cfg.Database)
db, err := sqlx.Connect("postgres", pgConfig) db, err := sqlx.Connect("postgres", pgConfig)
Expect(err).Should(BeNil()) Expect(err).Should(BeNil())
Expect(db).ShouldNot(BeNil()) Expect(db).ShouldNot(BeNil())
@ -232,7 +233,8 @@ var _ = Describe("Repositories", func() {
Nonce: badNonce, Nonce: badNonce,
Transactions: []core.Transaction{}, Transactions: []core.Transaction{},
} }
pgConfig := config.DbConnectionString(config.NewConfig("private").Database) cfg, _ := config.NewConfig("private")
pgConfig := config.DbConnectionString(cfg.Database)
db, _ := sqlx.Connect("postgres", pgConfig) db, _ := sqlx.Connect("postgres", pgConfig)
Expect(db).ShouldNot(BeNil()) Expect(db).ShouldNot(BeNil())
repository := repositories.NewPostgres(db) repository := repositories.NewPostgres(db)
@ -249,7 +251,8 @@ var _ = Describe("Repositories", func() {
//badHash violates db To field length //badHash violates db To field length
badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100))
badTransaction := core.Transaction{To: badHash} badTransaction := core.Transaction{To: badHash}
pgConfig := config.DbConnectionString(config.NewConfig("private").Database) cfg, _ := config.NewConfig("private")
pgConfig := config.DbConnectionString(cfg.Database)
block := core.Block{ block := core.Block{
Number: 123, Number: 123,
Transactions: []core.Transaction{badTransaction}, Transactions: []core.Transaction{badTransaction},
@ -266,7 +269,8 @@ var _ = Describe("Repositories", func() {
}) })
AssertRepositoryBehavior(func() repositories.Repository { AssertRepositoryBehavior(func() repositories.Repository {
pgConfig := config.DbConnectionString(config.NewConfig("private").Database) cfg, _ := config.NewConfig("private")
pgConfig := config.DbConnectionString(cfg.Database)
db, _ := sqlx.Connect("postgres", pgConfig) db, _ := sqlx.Connect("postgres", pgConfig)
db.MustExec("DELETE FROM transactions") db.MustExec("DELETE FROM transactions")
db.MustExec("DELETE FROM blocks") db.MustExec("DELETE FROM blocks")