diff --git a/cmd/populate_blocks/main.go b/cmd/populate_blocks/main.go index 23c3d8da..bdeff355 100644 --- a/cmd/populate_blocks/main.go +++ b/cmd/populate_blocks/main.go @@ -8,7 +8,6 @@ import ( "github.com/8thlight/vulcanizedb/cmd" "github.com/8thlight/vulcanizedb/pkg/geth" "github.com/8thlight/vulcanizedb/pkg/history" - "github.com/8thlight/vulcanizedb/pkg/repositories" ) func main() { @@ -16,9 +15,8 @@ func main() { startingBlockNumber := flag.Int("starting-number", -1, "First block to fill from") flag.Parse() config := cmd.LoadConfig(*environment) - blockchain := geth.NewGethBlockchain(config.Client.IPCPath) - repository := repositories.NewPostgres(config.Database) + repository := cmd.LoadPostgres(config.Database) numberOfBlocksCreated := history.PopulateBlocks(blockchain, repository, int64(*startingBlockNumber)) fmt.Printf("Populated %d blocks", numberOfBlocksCreated) } diff --git a/cmd/run/main.go b/cmd/run/main.go index 7b9c094c..2a10fd4f 100644 --- a/cmd/run/main.go +++ b/cmd/run/main.go @@ -10,15 +10,13 @@ import ( "github.com/8thlight/vulcanizedb/pkg/core" "github.com/8thlight/vulcanizedb/pkg/geth" "github.com/8thlight/vulcanizedb/pkg/observers" - "github.com/8thlight/vulcanizedb/pkg/repositories" ) func main() { environment := flag.String("environment", "", "Environment name") flag.Parse() config := cmd.LoadConfig(*environment) - - repository := repositories.NewPostgres(config.Database) + repository := cmd.LoadPostgres(config.Database) fmt.Printf("Creating Geth Blockchain to: %s\n", config.Client.IPCPath) listener := blockchain_listener.NewBlockchainListener( geth.NewGethBlockchain(config.Client.IPCPath), diff --git a/cmd/show_contract_summary/main.go b/cmd/show_contract_summary/main.go index 98d6d55b..1bb5af9f 100644 --- a/cmd/show_contract_summary/main.go +++ b/cmd/show_contract_summary/main.go @@ -9,7 +9,6 @@ import ( "github.com/8thlight/vulcanizedb/cmd" "github.com/8thlight/vulcanizedb/pkg/geth" - "github.com/8thlight/vulcanizedb/pkg/repositories" "github.com/8thlight/vulcanizedb/pkg/watched_contracts" ) @@ -18,9 +17,8 @@ func main() { contractHash := flag.String("contract-hash", "", "Contract hash to show summary") flag.Parse() config := cmd.LoadConfig(*environment) - blockchain := geth.NewGethBlockchain(config.Client.IPCPath) - repository := repositories.NewPostgres(config.Database) + repository := cmd.LoadPostgres(config.Database) contractSummary, err := watched_contracts.NewSummary(blockchain, repository, *contractHash) if err != nil { log.Fatalln(err) diff --git a/cmd/subscribe_contract/main.go b/cmd/subscribe_contract/main.go index 9d15b3fb..04fa4552 100644 --- a/cmd/subscribe_contract/main.go +++ b/cmd/subscribe_contract/main.go @@ -5,7 +5,6 @@ import ( "github.com/8thlight/vulcanizedb/cmd" "github.com/8thlight/vulcanizedb/pkg/core" - "github.com/8thlight/vulcanizedb/pkg/repositories" ) func main() { @@ -13,6 +12,6 @@ func main() { contractHash := flag.String("contract-hash", "", "contract-hash=x1234") flag.Parse() config := cmd.LoadConfig(*environment) - repository := repositories.NewPostgres(config.Database) + repository := cmd.LoadPostgres(config.Database) repository.CreateWatchedContract(core.WatchedContract{Hash: *contractHash}) } diff --git a/cmd/utils.go b/cmd/utils.go index 0a7a0a0e..58df140e 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -4,6 +4,7 @@ import ( "log" "github.com/8thlight/vulcanizedb/pkg/config" + "github.com/8thlight/vulcanizedb/pkg/repositories" ) func LoadConfig(environment string) config.Config { @@ -13,3 +14,11 @@ func LoadConfig(environment string) config.Config { } return *cfg } + +func LoadPostgres(database config.Database) repositories.Postgres { + repository, err := repositories.NewPostgres(database) + if err != nil { + log.Fatalf("Error loading postgres\n%v", err) + } + return repository +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 34543aa3..dfbf62e9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,7 +1,6 @@ package config import ( - "log" "os" "fmt" @@ -50,8 +49,9 @@ func parseConfigFile(filePath string) (*Config, error) { if err != nil { return nil, err } else { - if _, err := toml.DecodeFile(filePath, &cfg); err != nil { - log.Fatal(err) + _, err := toml.DecodeFile(filePath, &cfg) + if err != nil { + return nil, err } return &cfg, err } diff --git a/pkg/repositories/postgres.go b/pkg/repositories/postgres.go index 417168e2..7d0d4d36 100644 --- a/pkg/repositories/postgres.go +++ b/pkg/repositories/postgres.go @@ -2,7 +2,6 @@ package repositories import ( "database/sql" - "log" "context" @@ -19,16 +18,17 @@ type Postgres struct { } var ( - ErrDBInsertFailed = errors.New("postgres: insert failed") + ErrDBInsertFailed = errors.New("postgres: insert failed") + ErrDBConnectionFailed = errors.New("postgres: db connection failed") ) -func NewPostgres(databaseConfig config.Database) Postgres { +func NewPostgres(databaseConfig config.Database) (Postgres, error) { connectString := config.DbConnectionString(databaseConfig) db, err := sqlx.Connect("postgres", connectString) if err != nil { - log.Fatalf("Error connecting to DB: %v\n", err) + return Postgres{}, ErrDBConnectionFailed } - return Postgres{Db: db} + return Postgres{Db: db}, nil } func (repository Postgres) CreateWatchedContract(contract core.WatchedContract) error { @@ -42,11 +42,8 @@ func (repository Postgres) CreateWatchedContract(contract core.WatchedContract) func (repository Postgres) IsWatchedContract(contractHash string) bool { var exists bool - err := repository.Db.QueryRow( + repository.Db.QueryRow( `SELECT exists(SELECT 1 FROM watched_contracts WHERE contract_hash=$1) FROM watched_contracts`, contractHash).Scan(&exists) - if err != nil && err != sql.ErrNoRows { - log.Fatalf("error checking if row exists %v", err) - } return exists } diff --git a/pkg/repositories/postgres_test.go b/pkg/repositories/postgres_test.go index adfcef6d..e50c8728 100644 --- a/pkg/repositories/postgres_test.go +++ b/pkg/repositories/postgres_test.go @@ -26,7 +26,7 @@ var _ = Describe("Postgres repository", func() { testing.AssertRepositoryBehavior(func() repositories.Repository { cfg, _ := config.NewConfig("private") - repository := repositories.NewPostgres(cfg.Database) + repository, _ := repositories.NewPostgres(cfg.Database) testing.ClearData(repository) return repository }) @@ -40,7 +40,7 @@ var _ = Describe("Postgres repository", func() { Transactions: []core.Transaction{}, } cfg, _ := config.NewConfig("private") - repository := repositories.NewPostgres(cfg.Database) + repository, _ := repositories.NewPostgres(cfg.Database) err := repository.CreateBlock(badBlock) savedBlock := repository.FindBlockByNumber(123) @@ -49,6 +49,12 @@ var _ = Describe("Postgres repository", func() { Expect(savedBlock).To(BeNil()) }) + It("throws error when can't connect to the database", func() { + invalidDatabase := config.Database{} + _, err := repositories.NewPostgres(invalidDatabase) + Expect(err).To(Equal(repositories.ErrDBConnectionFailed)) + }) + It("does not commit block or transactions if transaction is invalid", func() { //badHash violates db To field length badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) @@ -58,7 +64,7 @@ var _ = Describe("Postgres repository", func() { Transactions: []core.Transaction{badTransaction}, } cfg, _ := config.NewConfig("private") - repository := repositories.NewPostgres(cfg.Database) + repository, _ := repositories.NewPostgres(cfg.Database) err := repository.CreateBlock(block) savedBlock := repository.FindBlockByNumber(123)