Address more PR feedback

This commit is contained in:
Elizabeth Engelman 2019-08-19 16:55:38 -05:00
parent 1b3786338f
commit edc0bdf668
2 changed files with 19 additions and 9 deletions

View File

@ -27,11 +27,11 @@ import (
type AddressRepository struct{} type AddressRepository struct{}
func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int, error) { func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int64, error) {
stringAddressToCommonAddress := common.HexToAddress(address) stringAddressToCommonAddress := common.HexToAddress(address)
hexAddress := stringAddressToCommonAddress.Hex() hexAddress := stringAddressToCommonAddress.Hex()
var addressId int var addressId int64
getErr := db.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) getErr := db.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress)
if getErr == sql.ErrNoRows { if getErr == sql.ErrNoRows {
insertErr := db.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) insertErr := db.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId)
@ -41,11 +41,11 @@ func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (in
return addressId, getErr return addressId, getErr
} }
func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int, error) { func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int64, error) {
stringAddressToCommonAddress := common.HexToAddress(address) stringAddressToCommonAddress := common.HexToAddress(address)
hexAddress := stringAddressToCommonAddress.Hex() hexAddress := stringAddressToCommonAddress.Hex()
var addressId int var addressId int64
getErr := tx.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) getErr := tx.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress)
if getErr == sql.ErrNoRows { if getErr == sql.ErrNoRows {
insertErr := tx.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) insertErr := tx.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId)
@ -55,7 +55,7 @@ func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address st
return addressId, getErr return addressId, getErr
} }
func (AddressRepository) GetAddressById(db *postgres.DB, id int) (string, error) { func (AddressRepository) GetAddressById(db *postgres.DB, id int64) (string, error) {
var address string var address string
getErr := db.Get(&address, `SELECT address FROM public.addresses WHERE id = $1`, id) getErr := db.Get(&address, `SELECT address FROM public.addresses WHERE id = $1`, id)
return address, getErr return address, getErr

View File

@ -41,8 +41,12 @@ var _ = Describe("address lookup", func() {
repo = repositories.AddressRepository{} repo = repositories.AddressRepository{}
}) })
AfterEach(func() {
test_config.CleanTestDB(db)
})
type dbAddress struct { type dbAddress struct {
Id int Id int64
Address string Address string
} }
@ -59,16 +63,17 @@ var _ = Describe("address lookup", func() {
}) })
It("returns the existing record id if the address already exists", func() { It("returns the existing record id if the address already exists", func() {
_, createErr := repo.GetOrCreateAddress(db, address) createId, createErr := repo.GetOrCreateAddress(db, address)
Expect(createErr).NotTo(HaveOccurred()) Expect(createErr).NotTo(HaveOccurred())
_, getErr := repo.GetOrCreateAddress(db, address) getId, getErr := repo.GetOrCreateAddress(db, address)
Expect(getErr).NotTo(HaveOccurred()) Expect(getErr).NotTo(HaveOccurred())
var addressCount int var addressCount int
addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`)
Expect(addressErr).NotTo(HaveOccurred()) Expect(addressErr).NotTo(HaveOccurred())
Expect(addressCount).To(Equal(1)) Expect(addressCount).To(Equal(1))
Expect(createId).To(Equal(getId))
}) })
It("gets upper-cased addresses", func() { It("gets upper-cased addresses", func() {
@ -102,10 +107,15 @@ var _ = Describe("address lookup", func() {
Expect(txErr).NotTo(HaveOccurred()) Expect(txErr).NotTo(HaveOccurred())
}) })
AfterEach(func() {
tx.Rollback()
})
It("creates an address record", func() { It("creates an address record", func() {
addressId, createErr := repo.GetOrCreateAddressInTransaction(tx, address) addressId, createErr := repo.GetOrCreateAddressInTransaction(tx, address)
Expect(createErr).NotTo(HaveOccurred()) Expect(createErr).NotTo(HaveOccurred())
tx.Commit() commitErr := tx.Commit()
Expect(commitErr).NotTo(HaveOccurred())
var actualAddress dbAddress var actualAddress dbAddress
getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`) getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`)