vat.fold: add repository tests for MarkHeaderChecked

This commit is contained in:
David Terry 2018-10-10 13:41:26 +03:00
parent 3778d7ac06
commit 1273acb733
3 changed files with 76 additions and 14 deletions

View File

@ -24,7 +24,7 @@ import (
type MockVatFoldRepository struct { type MockVatFoldRepository struct {
createErr error createErr error
markHeaderCheckedErr error markHeaderCheckedErr error
markHeaderCheckedPassedHeaderID int64 MarkHeaderCheckedPassedHeaderID int64
missingHeaders []core.Header missingHeaders []core.Header
missingHeadersErr error missingHeadersErr error
PassedStartingBlockNumber int64 PassedStartingBlockNumber int64
@ -40,7 +40,7 @@ func (repository *MockVatFoldRepository) Create(headerID int64, models []vat_fol
} }
func (repository *MockVatFoldRepository) MarkHeaderChecked(headerID int64) error { func (repository *MockVatFoldRepository) MarkHeaderChecked(headerID int64) error {
repository.markHeaderCheckedPassedHeaderID = headerID repository.MarkHeaderCheckedPassedHeaderID = headerID
return repository.markHeaderCheckedErr return repository.markHeaderCheckedErr
} }
@ -67,5 +67,5 @@ func (repository *MockVatFoldRepository) SetCreateError(e error) {
} }
func (repository *MockVatFoldRepository) AssertMarkHeaderCheckedCalledWith(i int64) { func (repository *MockVatFoldRepository) AssertMarkHeaderCheckedCalledWith(i int64) {
Expect(repository.markHeaderCheckedPassedHeaderID).To(Equal(i)) Expect(repository.MarkHeaderCheckedPassedHeaderID).To(Equal(i))
} }

View File

@ -26,17 +26,17 @@ type Repository interface {
} }
type VatFoldRepository struct { type VatFoldRepository struct {
db *postgres.DB DB *postgres.DB
} }
func NewVatFoldRepository(db *postgres.DB) VatFoldRepository { func NewVatFoldRepository(db *postgres.DB) VatFoldRepository {
return VatFoldRepository{ return VatFoldRepository{
db: db, DB: db,
} }
} }
func (repository VatFoldRepository) Create(headerID int64, models []VatFoldModel) error { func (repository VatFoldRepository) Create(headerID int64, models []VatFoldModel) error {
tx, err := repository.db.Begin() tx, err := repository.DB.Begin()
if err != nil { if err != nil {
return err return err
} }
@ -46,6 +46,13 @@ func (repository VatFoldRepository) Create(headerID int64, models []VatFoldModel
VALUES($1, $2, $3, $4::NUMERIC, $5, $6)`, VALUES($1, $2, $3, $4::NUMERIC, $5, $6)`,
headerID, model.Ilk, model.Urn, model.Rate, model.TransactionIndex, model.Raw, headerID, model.Ilk, model.Urn, model.Rate, model.TransactionIndex, model.Raw,
) )
_, err = tx.Exec(
`INSERT INTO public.checked_headers (header_id, vat_fold_checked)
VALUES($1, $2)
ON CONFLICT (header_id) DO
UPDATE SET vat_fold_checked = $2`,
headerID, true,
)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
@ -56,16 +63,19 @@ func (repository VatFoldRepository) Create(headerID int64, models []VatFoldModel
} }
func (repository VatFoldRepository) MarkHeaderChecked(headerID int64) error { func (repository VatFoldRepository) MarkHeaderChecked(headerID int64) error {
_, err := repository.db.Exec(`INSERT INTO public.checked_headers (header_id, vat_fold_checked) _, err := repository.DB.Exec(
`INSERT INTO public.checked_headers (header_id, vat_fold_checked)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (header_id) DO ON CONFLICT (header_id) DO
UPDATE SET vat_fold_checked = $2`, headerID, true) UPDATE SET vat_fold_checked = $2`,
headerID, true,
)
return err return err
} }
func (repository VatFoldRepository) MissingHeaders(startingBlockNumber, endingBlockNumber int64) ([]core.Header, error) { func (repository VatFoldRepository) MissingHeaders(startingBlockNumber, endingBlockNumber int64) ([]core.Header, error) {
var result []core.Header var result []core.Header
err := repository.db.Select( err := repository.DB.Select(
&result, &result,
`SELECT headers.id, headers.block_number FROM headers `SELECT headers.id, headers.block_number FROM headers
LEFT JOIN checked_headers on headers.id = header_id LEFT JOIN checked_headers on headers.id = header_id
@ -75,7 +85,7 @@ func (repository VatFoldRepository) MissingHeaders(startingBlockNumber, endingBl
AND headers.eth_node_fingerprint = $3`, AND headers.eth_node_fingerprint = $3`,
startingBlockNumber, startingBlockNumber,
endingBlockNumber, endingBlockNumber,
repository.db.Node.ID, repository.DB.Node.ID,
) )
return result, err return result, err
} }

View File

@ -28,7 +28,7 @@ import (
"github.com/vulcanize/vulcanizedb/test_config" "github.com/vulcanize/vulcanizedb/test_config"
) )
var _ = Describe("", func() { var _ = Describe("Vat.fold repository", func() {
Describe("Create", func() { Describe("Create", func() {
@ -80,6 +80,58 @@ var _ = Describe("", func() {
}) })
}) })
Describe("MarkHeaderChecked", func() {
var db *postgres.DB
var headerID int64
var repository vat_fold.VatFoldRepository
type CheckedHeaderResult struct {
VatFoldChecked bool `db:"vat_fold_checked"`
}
BeforeEach(func() {
node := test_config.NewTestNode()
db = test_config.NewTestDB(node)
test_config.CleanTestDB(db)
headerRepository := repositories.NewHeaderRepository(db)
id, err := headerRepository.CreateOrUpdateHeader(core.Header{})
Expect(err).NotTo(HaveOccurred())
headerID = id
repository = vat_fold.NewVatFoldRepository(db)
Expect(err).NotTo(HaveOccurred())
})
It("creates a new checked header record", func() {
err := repository.MarkHeaderChecked(headerID)
Expect(err).NotTo(HaveOccurred())
var checkedHeaderResult = CheckedHeaderResult{}
err = db.Get(&checkedHeaderResult, `SELECT vat_fold_checked FROM checked_headers WHERE header_id = $1`, headerID)
Expect(err).NotTo(HaveOccurred())
Expect(checkedHeaderResult.VatFoldChecked).To(BeTrue())
})
It("updates an existing checked header", func() {
_, err := repository.DB.Exec(`INSERT INTO checked_headers (header_id) VALUES($1)`, headerID)
Expect(err).NotTo(HaveOccurred())
var checkedHeaderResult CheckedHeaderResult
err = db.Get(&checkedHeaderResult, `SELECT vat_fold_checked FROM checked_headers WHERE header_id = $1`, headerID)
Expect(err).NotTo(HaveOccurred())
Expect(checkedHeaderResult.VatFoldChecked).To(BeFalse())
err = repository.MarkHeaderChecked(headerID)
Expect(err).NotTo(HaveOccurred())
err = db.Get(&checkedHeaderResult, `SELECT vat_fold_checked FROM checked_headers WHERE header_id = $1`, headerID)
Expect(err).NotTo(HaveOccurred())
Expect(checkedHeaderResult.VatFoldChecked).To(BeTrue())
})
})
Describe("MissingHeaders", func() { Describe("MissingHeaders", func() {
It("returns headers that haven't been checked", func() { It("returns headers that haven't been checked", func() {