From 4b61c87b553421b4ff5c1962d5fec7aa9b5be0de Mon Sep 17 00:00:00 2001 From: Elizabeth Engelman Date: Fri, 2 Aug 2019 08:52:14 -0500 Subject: [PATCH] Get or create address record in a transaction --- .../repositories/address_repository.go | 18 ++- .../repositories/address_repository_test.go | 150 ++++++++++++------ 2 files changed, 121 insertions(+), 47 deletions(-) diff --git a/pkg/datastore/postgres/repositories/address_repository.go b/pkg/datastore/postgres/repositories/address_repository.go index ff554112..7e9d4724 100644 --- a/pkg/datastore/postgres/repositories/address_repository.go +++ b/pkg/datastore/postgres/repositories/address_repository.go @@ -17,12 +17,13 @@ package repositories import ( "database/sql" "github.com/ethereum/go-ethereum/common" + "github.com/jmoiron/sqlx" "github.com/vulcanize/vulcanizedb/pkg/datastore/postgres" ) -type AddressRepository struct {} +type AddressRepository struct{} -func (repo AddressRepository) CreateOrGetAddress(db *postgres.DB, address string) (int, error) { +func (repo AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int, error) { stringAddressToCommonAddress := common.HexToAddress(address) hexAddress := stringAddressToCommonAddress.Hex() @@ -36,3 +37,16 @@ func (repo AddressRepository) CreateOrGetAddress(db *postgres.DB, address string return addressId, getErr } +func (repo AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int, error) { + stringAddressToCommonAddress := common.HexToAddress(address) + hexAddress := stringAddressToCommonAddress.Hex() + + var addressId int + getErr := tx.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) + if getErr == sql.ErrNoRows { + insertErr := tx.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) + return addressId, insertErr + } + + return addressId, getErr +} diff --git a/pkg/datastore/postgres/repositories/address_repository_test.go b/pkg/datastore/postgres/repositories/address_repository_test.go index db13673e..b370ea18 100644 --- a/pkg/datastore/postgres/repositories/address_repository_test.go +++ b/pkg/datastore/postgres/repositories/address_repository_test.go @@ -15,20 +15,20 @@ package repositories_test import ( + "github.com/jmoiron/sqlx" . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" "github.com/vulcanize/vulcanizedb/pkg/datastore/postgres" "github.com/vulcanize/vulcanizedb/pkg/datastore/postgres/repositories" "github.com/vulcanize/vulcanizedb/pkg/fakes" "github.com/vulcanize/vulcanizedb/test_config" "strings" - - . "github.com/onsi/gomega" ) -var _ = Describe("address repository", func() { +var _ = Describe("address lookup", func() { var ( - db *postgres.DB - repo repositories.AddressRepository + db *postgres.DB + repo repositories.AddressRepository address = fakes.FakeAddress.Hex() ) BeforeEach(func() { @@ -38,52 +38,112 @@ var _ = Describe("address repository", func() { }) type dbAddress struct { - Id int + Id int Address string } - It("creates an address record", func() { - addressId, createErr := repo.CreateOrGetAddress(db, address) - Expect(createErr).NotTo(HaveOccurred()) + Describe("GetOrCreateAddress", func() { + It("creates an address record", func() { + addressId, createErr := repo.GetOrCreateAddress(db, address) + Expect(createErr).NotTo(HaveOccurred()) - var actualAddress dbAddress - getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`) - Expect(getErr).NotTo(HaveOccurred()) - expectedAddress := dbAddress{Id: addressId, Address: address} - Expect(actualAddress).To(Equal(expectedAddress)) + var actualAddress dbAddress + getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`) + Expect(getErr).NotTo(HaveOccurred()) + expectedAddress := dbAddress{Id: addressId, Address: address} + Expect(actualAddress).To(Equal(expectedAddress)) + }) + + It("returns the existing record id if the address already exists", func() { + _, createErr := repo.GetOrCreateAddress(db, address) + Expect(createErr).NotTo(HaveOccurred()) + + _, getErr := repo.GetOrCreateAddress(db, address) + Expect(getErr).NotTo(HaveOccurred()) + + var addressCount int + addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) + Expect(addressErr).NotTo(HaveOccurred()) + }) + + It("gets upper-cased addresses", func() { + upperAddress := strings.ToUpper(address) + upperAddressId, createErr := repo.GetOrCreateAddress(db, upperAddress) + Expect(createErr).NotTo(HaveOccurred()) + + mixedCaseAddressId, getErr := repo.GetOrCreateAddress(db, address) + Expect(getErr).NotTo(HaveOccurred()) + Expect(upperAddressId).To(Equal(mixedCaseAddressId)) + }) + + It("gets lower-cased addresses", func() { + lowerAddress := strings.ToLower(address) + upperAddressId, createErr := repo.GetOrCreateAddress(db, lowerAddress) + Expect(createErr).NotTo(HaveOccurred()) + + mixedCaseAddressId, getErr := repo.GetOrCreateAddress(db, address) + Expect(getErr).NotTo(HaveOccurred()) + Expect(upperAddressId).To(Equal(mixedCaseAddressId)) + }) }) - It("returns the existing record id if the address already exists", func() { - _, createErr := repo.CreateOrGetAddress(db, address) - Expect(createErr).NotTo(HaveOccurred()) + Describe("GetOrCreateAddressInTransaction", func() { + var ( + tx *sqlx.Tx + txErr error + ) + BeforeEach(func() { + tx, txErr = db.Beginx() + Expect(txErr).NotTo(HaveOccurred()) + }) - _, getErr := repo.CreateOrGetAddress(db, address) - Expect(getErr).NotTo(HaveOccurred()) + It("creates an address record", func() { + addressId, createErr := repo.GetOrCreateAddressInTransaction(tx, address) + Expect(createErr).NotTo(HaveOccurred()) + tx.Commit() - var addressCount int - addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) - Expect(addressErr).NotTo(HaveOccurred()) + var actualAddress dbAddress + getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`) + Expect(getErr).NotTo(HaveOccurred()) + expectedAddress := dbAddress{Id: addressId, Address: address} + Expect(actualAddress).To(Equal(expectedAddress)) + }) + + It("returns the existing record id if the address already exists", func() { + _, createErr := repo.GetOrCreateAddressInTransaction(tx, address) + Expect(createErr).NotTo(HaveOccurred()) + + _, getErr := repo.GetOrCreateAddressInTransaction(tx, address) + Expect(getErr).NotTo(HaveOccurred()) + tx.Commit() + + var addressCount int + addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) + Expect(addressErr).NotTo(HaveOccurred()) + }) + + It("gets upper-cased addresses", func() { + upperAddress := strings.ToUpper(address) + upperAddressId, createErr := repo.GetOrCreateAddressInTransaction(tx, upperAddress) + Expect(createErr).NotTo(HaveOccurred()) + + mixedCaseAddressId, getErr := repo.GetOrCreateAddressInTransaction(tx, address) + Expect(getErr).NotTo(HaveOccurred()) + tx.Commit() + + Expect(upperAddressId).To(Equal(mixedCaseAddressId)) + }) + + It("gets lower-cased addresses", func() { + lowerAddress := strings.ToLower(address) + upperAddressId, createErr := repo.GetOrCreateAddressInTransaction(tx, lowerAddress) + Expect(createErr).NotTo(HaveOccurred()) + + mixedCaseAddressId, getErr := repo.GetOrCreateAddressInTransaction(tx, address) + Expect(getErr).NotTo(HaveOccurred()) + tx.Commit() + + Expect(upperAddressId).To(Equal(mixedCaseAddressId)) + }) }) - - It("gets upper-cased addresses", func() { - //insert it as all upper - upperAddress := strings.ToUpper(address) - upperAddressId, createErr := repo.CreateOrGetAddress(db, upperAddress) - Expect(createErr).NotTo(HaveOccurred()) - - mixedCaseAddressId, getErr := repo.CreateOrGetAddress(db, address) - Expect(getErr).NotTo(HaveOccurred()) - Expect(upperAddressId).To(Equal(mixedCaseAddressId)) - }) - - It("gets lower-cased addresses", func() { - //insert it as all upper - lowerAddress := strings.ToLower(address) - upperAddressId, createErr := repo.CreateOrGetAddress(db, lowerAddress) - Expect(createErr).NotTo(HaveOccurred()) - - mixedCaseAddressId, getErr := repo.CreateOrGetAddress(db, address) - Expect(getErr).NotTo(HaveOccurred()) - Expect(upperAddressId).To(Equal(mixedCaseAddressId)) - }) -}) \ No newline at end of file +})