From 1373fe83a1a5d0293b4eb879dae4fa1062a5dd98 Mon Sep 17 00:00:00 2001 From: Elizabeth Engelman Date: Thu, 1 Aug 2019 15:44:13 -0500 Subject: [PATCH] Pass a db into GetOrCreateAddress --- .../repositories/address_repository.go | 10 ++++------ .../repositories/address_repository_test.go | 18 +++++++++--------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pkg/datastore/postgres/repositories/address_repository.go b/pkg/datastore/postgres/repositories/address_repository.go index 9927d099..ff554112 100644 --- a/pkg/datastore/postgres/repositories/address_repository.go +++ b/pkg/datastore/postgres/repositories/address_repository.go @@ -20,18 +20,16 @@ import ( "github.com/vulcanize/vulcanizedb/pkg/datastore/postgres" ) -type AddressRepository struct { - *postgres.DB -} +type AddressRepository struct {} -func (repo AddressRepository) CreateOrGetAddress(address string) (int, error) { +func (repo AddressRepository) CreateOrGetAddress(db *postgres.DB, address string) (int, error) { stringAddressToCommonAddress := common.HexToAddress(address) hexAddress := stringAddressToCommonAddress.Hex() var addressId int - getErr := repo.DB.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) + getErr := db.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) if getErr == sql.ErrNoRows { - insertErr := repo.DB.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) + insertErr := db.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) return addressId, insertErr } diff --git a/pkg/datastore/postgres/repositories/address_repository_test.go b/pkg/datastore/postgres/repositories/address_repository_test.go index 417bc9ab..db13673e 100644 --- a/pkg/datastore/postgres/repositories/address_repository_test.go +++ b/pkg/datastore/postgres/repositories/address_repository_test.go @@ -34,7 +34,7 @@ var _ = Describe("address repository", func() { BeforeEach(func() { db = test_config.NewTestDB(test_config.NewTestNode()) test_config.CleanTestDB(db) - repo = repositories.AddressRepository{DB: db} + repo = repositories.AddressRepository{} }) type dbAddress struct { @@ -43,7 +43,7 @@ var _ = Describe("address repository", func() { } It("creates an address record", func() { - addressId, createErr := repo.CreateOrGetAddress(address) + addressId, createErr := repo.CreateOrGetAddress(db, address) Expect(createErr).NotTo(HaveOccurred()) var actualAddress dbAddress @@ -54,24 +54,24 @@ var _ = Describe("address repository", func() { }) It("returns the existing record id if the address already exists", func() { - _, createErr := repo.CreateOrGetAddress(address) + _, createErr := repo.CreateOrGetAddress(db, address) Expect(createErr).NotTo(HaveOccurred()) - _, getErr := repo.CreateOrGetAddress(address) + _, getErr := repo.CreateOrGetAddress(db, address) Expect(getErr).NotTo(HaveOccurred()) var addressCount int - addressErr := repo.DB.Get(&addressCount, `SELECT count(*) FROM public.addresses`) + addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) Expect(addressErr).NotTo(HaveOccurred()) }) It("gets upper-cased addresses", func() { //insert it as all upper upperAddress := strings.ToUpper(address) - upperAddressId, createErr := repo.CreateOrGetAddress(upperAddress) + upperAddressId, createErr := repo.CreateOrGetAddress(db, upperAddress) Expect(createErr).NotTo(HaveOccurred()) - mixedCaseAddressId, getErr := repo.CreateOrGetAddress(address) + mixedCaseAddressId, getErr := repo.CreateOrGetAddress(db, address) Expect(getErr).NotTo(HaveOccurred()) Expect(upperAddressId).To(Equal(mixedCaseAddressId)) }) @@ -79,10 +79,10 @@ var _ = Describe("address repository", func() { It("gets lower-cased addresses", func() { //insert it as all upper lowerAddress := strings.ToLower(address) - upperAddressId, createErr := repo.CreateOrGetAddress(lowerAddress) + upperAddressId, createErr := repo.CreateOrGetAddress(db, lowerAddress) Expect(createErr).NotTo(HaveOccurred()) - mixedCaseAddressId, getErr := repo.CreateOrGetAddress(address) + mixedCaseAddressId, getErr := repo.CreateOrGetAddress(db, address) Expect(getErr).NotTo(HaveOccurred()) Expect(upperAddressId).To(Equal(mixedCaseAddressId)) })