From 0b9930dd9c303c696d6f3c75315d9bccf9335073 Mon Sep 17 00:00:00 2001 From: Matt Krump Date: Mon, 13 Nov 2017 15:41:32 -0600 Subject: [PATCH] Add transactions for a watched contract --- pkg/core/contract.go | 3 ++- pkg/repositories/in_memory.go | 13 +++++++++++- pkg/repositories/postgres.go | 33 ++++++++++++++++++----------- pkg/repositories/repository_test.go | 29 +++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 14 deletions(-) diff --git a/pkg/core/contract.go b/pkg/core/contract.go index a01b8d46..8547d20d 100644 --- a/pkg/core/contract.go +++ b/pkg/core/contract.go @@ -1,5 +1,6 @@ package core type WatchedContract struct { - Hash string + Hash string + Transactions []Transaction } diff --git a/pkg/repositories/in_memory.go b/pkg/repositories/in_memory.go index 65f281e3..a86df78f 100644 --- a/pkg/repositories/in_memory.go +++ b/pkg/repositories/in_memory.go @@ -20,7 +20,18 @@ func (repository *InMemory) IsWatchedContract(contractHash string) bool { } func (repository *InMemory) FindWatchedContract(contractHash string) *core.WatchedContract { - return repository.watchedContracts[contractHash] + var transactions []core.Transaction + if _, ok := repository.watchedContracts[contractHash]; !ok { + return nil + } + for _, block := range repository.blocks { + for _, transaction := range block.Transactions { + if transaction.To == contractHash { + transactions = append(transactions, transaction) + } + } + } + return &core.WatchedContract{Hash: contractHash, Transactions: transactions} } func (repository *InMemory) MissingBlockNumbers(startingBlockNumber int64, endingBlockNumber int64) []int64 { diff --git a/pkg/repositories/postgres.go b/pkg/repositories/postgres.go index f2d145df..cb9e0646 100644 --- a/pkg/repositories/postgres.go +++ b/pkg/repositories/postgres.go @@ -43,7 +43,7 @@ func (repository Postgres) CreateWatchedContract(contract core.WatchedContract) func (repository Postgres) IsWatchedContract(contractHash string) bool { var exists bool err := repository.Db.QueryRow( - `SELECT exists(select 1 from watched_contracts where contract_hash=$1) FROM watched_contracts`, contractHash).Scan(&exists) + `SELECT exists(SELECT 1 FROM watched_contracts WHERE contract_hash=$1) FROM watched_contracts`, contractHash).Scan(&exists) if err != nil && err != sql.ErrNoRows { log.Fatalf("error checking if row exists %v", err) } @@ -53,13 +53,8 @@ func (repository Postgres) IsWatchedContract(contractHash string) bool { func (repository Postgres) FindWatchedContract(contractHash string) *core.WatchedContract { var savedContracts []core.WatchedContract contractRows, _ := repository.Db.Query( - `select contract_hash from watched_contracts where contract_hash=$1`, contractHash) - for contractRows.Next() { - var savedContractHash string - contractRows.Scan(&savedContractHash) - savedContract := core.WatchedContract{Hash: savedContractHash} - savedContracts = append(savedContracts, savedContract) - } + `SELECT contract_hash FROM watched_contracts WHERE contract_hash=$1`, contractHash) + savedContracts = repository.loadContract(contractRows) if len(savedContracts) > 0 { return &savedContracts[0] } else { @@ -104,7 +99,7 @@ func (repository Postgres) FindBlockByNumber(blockNumber int64) *core.Block { func (repository Postgres) BlockCount() int { var count int - repository.Db.Get(&count, "SELECT COUNT(*) FROM blocks") + repository.Db.Get(&count, `SELECT COUNT(*) FROM blocks`) return count } @@ -158,7 +153,8 @@ func (repository Postgres) loadBlock(blockRows *sql.Rows) core.Block { var gasUsed float64 var uncleHash string blockRows.Scan(&blockId, &blockNumber, &gasLimit, &gasUsed, &blockTime, &difficulty, &blockHash, &blockNonce, &blockParentHash, &blockSize, &uncleHash) - transactions := repository.loadTransactions(blockId) + transactionRows, _ := repository.Db.Query(`SELECT tx_hash, tx_nonce, tx_to, tx_from, tx_gaslimit, tx_gasprice, tx_value FROM transactions WHERE block_id = $1`, blockId) + transactions := repository.loadTransactions(transactionRows) return core.Block{ Difficulty: difficulty, GasLimit: int64(gasLimit), @@ -173,8 +169,8 @@ func (repository Postgres) loadBlock(blockRows *sql.Rows) core.Block { UncleHash: uncleHash, } } -func (repository Postgres) loadTransactions(blockId int64) []core.Transaction { - transactionRows, _ := repository.Db.Query(`SELECT tx_hash, tx_nonce, tx_to, tx_from, tx_gaslimit, tx_gasprice, tx_value FROM transactions`) + +func (repository Postgres) loadTransactions(transactionRows *sql.Rows) []core.Transaction { var transactions []core.Transaction for transactionRows.Next() { var hash string @@ -198,3 +194,16 @@ func (repository Postgres) loadTransactions(blockId int64) []core.Transaction { } return transactions } + +func (repository Postgres) loadContract(contractRows *sql.Rows) []core.WatchedContract { + var savedContracts []core.WatchedContract + for contractRows.Next() { + var savedContractHash string + contractRows.Scan(&savedContractHash) + transactionRows, _ := repository.Db.Query(`SELECT tx_hash, tx_nonce, tx_to, tx_from, tx_gaslimit, tx_gasprice, tx_value FROM transactions WHERE tx_to = $1`, savedContractHash) + transactions := repository.loadTransactions(transactionRows) + savedContract := core.WatchedContract{Hash: savedContractHash, Transactions: transactions} + savedContracts = append(savedContracts, savedContract) + } + return savedContracts +} diff --git a/pkg/repositories/repository_test.go b/pkg/repositories/repository_test.go index f09dd326..55b31f1b 100644 --- a/pkg/repositories/repository_test.go +++ b/pkg/repositories/repository_test.go @@ -227,6 +227,35 @@ var _ = Describe("Repositories", func() { watchedContract := repository.FindWatchedContract("x123") Expect(watchedContract).To(BeNil()) }) + + It("returns empty array when no transactions 'To' a watched contract", func() { + repository.CreateWatchedContract(core.WatchedContract{Hash: "x123"}) + watchedContract := repository.FindWatchedContract("x123") + Expect(watchedContract).ToNot(BeNil()) + Expect(watchedContract.Transactions).To(BeEmpty()) + + }) + + It("returns transactions 'To' a watched contract", func() { + block := core.Block{ + Number: 123, + Transactions: []core.Transaction{ + {Hash: "TRANSACTION1", To: "x123"}, + {Hash: "TRANSACTION2", To: "x345"}, + {Hash: "TRANSACTION3", To: "x123"}, + }, + } + repository.CreateBlock(block) + + repository.CreateWatchedContract(core.WatchedContract{Hash: "x123"}) + watchedContract := repository.FindWatchedContract("x123") + Expect(watchedContract).ToNot(BeNil()) + Expect(watchedContract.Transactions).To( + Equal([]core.Transaction{ + {Hash: "TRANSACTION1", To: "x123"}, + {Hash: "TRANSACTION3", To: "x123"}, + })) + }) }) }