feat(indexer/postgres): add insert/update/delete functionality (#21186)

This commit is contained in:
Aaron Craelius 2024-09-04 08:06:49 -04:00 committed by GitHub
parent 4b78f15f65
commit 292d7b49c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 951 additions and 1 deletions

View File

@ -0,0 +1,61 @@
package postgres
import (
"context"
"fmt"
"io"
"strings"
)
// delete deletes the row with the provided key from the table.
func (tm *objectIndexer) delete(ctx context.Context, conn dbConn, key interface{}) error {
buf := new(strings.Builder)
var params []interface{}
var err error
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
params, err = tm.retainDeleteSqlAndParams(buf, key)
} else {
params, err = tm.deleteSqlAndParams(buf, key)
}
if err != nil {
return err
}
sqlStr := buf.String()
tm.options.logger.Info("Delete", "sql", sqlStr, "params", params)
_, err = conn.ExecContext(ctx, sqlStr, params...)
return err
}
// deleteSqlAndParams generates a DELETE statement and binding parameters for the provided key.
func (tm *objectIndexer) deleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "DELETE FROM %q", tm.tableName())
if err != nil {
return nil, err
}
_, keyParams, err := tm.whereSqlAndParams(w, key, 1)
if err != nil {
return nil, err
}
_, err = fmt.Fprintf(w, ";")
return keyParams, err
}
// retainDeleteSqlAndParams generates an UPDATE statement to set the _deleted column to true for the provided key
// which is used when the table is set to retain deletions mode.
func (tm *objectIndexer) retainDeleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "UPDATE %q SET _deleted = TRUE", tm.tableName())
if err != nil {
return nil, err
}
_, keyParams, err := tm.whereSqlAndParams(w, key, 1)
if err != nil {
return nil, err
}
_, err = fmt.Fprintf(w, ";")
return keyParams, err
}

View File

@ -72,6 +72,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) {
opts := options{
disableRetainDeletions: config.DisableRetainDeletions,
logger: params.Logger,
addressCodec: params.AddressCodec,
}
idx := &indexerImpl{
@ -85,6 +86,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) {
return indexer.InitResult{
Listener: idx.listener(),
View: idx,
}, nil
}

View File

@ -0,0 +1,116 @@
package postgres
import (
"context"
"fmt"
"io"
"strings"
)
// insertUpdate inserts or updates the row with the provided key and value.
func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error {
exists, err := tm.exists(ctx, conn, key)
if err != nil {
return err
}
buf := new(strings.Builder)
var params []interface{}
if exists {
if len(tm.typ.ValueFields) == 0 {
// special case where there are no value fields, so we can't update anything
return nil
}
params, err = tm.updateSql(buf, key, value)
} else {
params, err = tm.insertSql(buf, key, value)
}
if err != nil {
return err
}
sqlStr := buf.String()
if tm.options.logger != nil {
tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params)
}
_, err = conn.ExecContext(ctx, sqlStr, params...)
return err
}
// insertSql generates an INSERT statement and binding parameters for the provided key and value.
func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
keyParams, keyCols, err := tm.bindKeyParams(key)
if err != nil {
return nil, err
}
valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}
var allParams []interface{}
allParams = append(allParams, keyParams...)
allParams = append(allParams, valueParams...)
allCols := make([]string, 0, len(keyCols)+len(valueCols))
allCols = append(allCols, keyCols...)
allCols = append(allCols, valueCols...)
var paramBindings []string
for i := 1; i <= len(allCols); i++ {
paramBindings = append(paramBindings, fmt.Sprintf("$%d", i))
}
_, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(),
strings.Join(allCols, ", "),
strings.Join(paramBindings, ", "),
)
return allParams, err
}
// updateSql generates an UPDATE statement and binding parameters for the provided key and value.
func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName())
if err != nil {
return nil, err
}
valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}
paramIdx := 1
for i, col := range valueCols {
if i > 0 {
_, err = fmt.Fprintf(w, ", ")
if err != nil {
return nil, err
}
}
_, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx)
if err != nil {
return nil, err
}
paramIdx++
}
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
_, err = fmt.Fprintf(w, ", _deleted = FALSE")
if err != nil {
return nil, err
}
}
_, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx)
if err != nil {
return nil, err
}
allParams := append(valueParams, keyParams...)
_, err = fmt.Fprintf(w, ";")
return allParams, err
}

View File

@ -25,6 +25,34 @@ func (i *indexerImpl) listener() appdata.Listener {
_, err := i.tx.Exec("INSERT INTO block (number) VALUES ($1)", data.Height)
return err
},
OnObjectUpdate: func(data appdata.ObjectUpdateData) error {
module := data.ModuleName
mod, ok := i.modules[module]
if !ok {
return fmt.Errorf("module %s not initialized", module)
}
for _, update := range data.Updates {
if i.logger != nil {
i.logger.Debug("OnObjectUpdate", "module", module, "type", update.TypeName, "key", update.Key, "delete", update.Delete, "value", update.Value)
}
tm, ok := mod.tables[update.TypeName]
if !ok {
return fmt.Errorf("object type %s not found in schema for module %s", update.TypeName, module)
}
var err error
if update.Delete {
err = tm.delete(i.ctx, i.tx, update.Key)
} else {
err = tm.insertUpdate(i.ctx, i.tx, update.Key, update.Value)
}
if err != nil {
return err
}
}
return nil
},
Commit: func(data appdata.CommitData) (func() error, error) {
err := i.tx.Commit()
if err != nil {

View File

@ -1,6 +1,9 @@
package postgres
import "cosmossdk.io/schema/logutil"
import (
"cosmossdk.io/schema/addressutil"
"cosmossdk.io/schema/logutil"
)
// options are the options for module and object indexers.
type options struct {
@ -9,4 +12,7 @@ type options struct {
// logger is the logger for the indexer to use. It may be nil.
logger logutil.Logger
// addressCodec is the codec for encoding and decoding addresses. It is expected to be non-nil.
addressCodec addressutil.AddressCodec
}

116
indexer/postgres/params.go Normal file
View File

@ -0,0 +1,116 @@
package postgres
import (
"fmt"
"time"
"cosmossdk.io/schema"
)
// bindKeyParams binds the key to the key columns.
func (tm *objectIndexer) bindKeyParams(key interface{}) ([]interface{}, []string, error) {
n := len(tm.typ.KeyFields)
if n == 0 {
// singleton, set _id = 1
return []interface{}{1}, []string{"_id"}, nil
} else if n == 1 {
return tm.bindParams(tm.typ.KeyFields, []interface{}{key})
} else {
key, ok := key.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("expected key to be a slice")
}
return tm.bindParams(tm.typ.KeyFields, key)
}
}
func (tm *objectIndexer) bindValueParams(value interface{}) (params []interface{}, valueCols []string, err error) {
n := len(tm.typ.ValueFields)
if n == 0 {
return nil, nil, nil
} else if valueUpdates, ok := value.(schema.ValueUpdates); ok {
var e error
var fields []schema.Field
var params []interface{}
if err := valueUpdates.Iterate(func(name string, value interface{}) bool {
field, ok := tm.valueFields[name]
if !ok {
e = fmt.Errorf("unknown column %q", name)
return false
}
fields = append(fields, field)
params = append(params, value)
return true
}); err != nil {
return nil, nil, err
}
if e != nil {
return nil, nil, e
}
return tm.bindParams(fields, params)
} else if n == 1 {
return tm.bindParams(tm.typ.ValueFields, []interface{}{value})
} else {
values, ok := value.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("expected values to be a slice")
}
return tm.bindParams(tm.typ.ValueFields, values)
}
}
func (tm *objectIndexer) bindParams(fields []schema.Field, values []interface{}) ([]interface{}, []string, error) {
names := make([]string, 0, len(fields))
params := make([]interface{}, 0, len(fields))
for i, field := range fields {
if i >= len(values) {
return nil, nil, fmt.Errorf("missing value for field %q", field.Name)
}
param, err := tm.bindParam(field, values[i])
if err != nil {
return nil, nil, err
}
name, err := tm.updatableColumnName(field)
if err != nil {
return nil, nil, err
}
names = append(names, name)
params = append(params, param)
}
return params, names, nil
}
func (tm *objectIndexer) bindParam(field schema.Field, value interface{}) (param interface{}, err error) {
param = value
if value == nil {
if !field.Nullable {
return nil, fmt.Errorf("expected non-null value for field %q", field.Name)
}
} else if field.Kind == schema.TimeKind {
t, ok := value.(time.Time)
if !ok {
return nil, fmt.Errorf("expected time.Time value for field %q, got %T", field.Name, value)
}
param = t.UnixNano()
} else if field.Kind == schema.DurationKind {
t, ok := value.(time.Duration)
if !ok {
return nil, fmt.Errorf("expected time.Duration value for field %q, got %T", field.Name, value)
}
param = int64(t)
} else if field.Kind == schema.AddressKind {
param, err = tm.options.addressCodec.BytesToString(value.([]byte))
if err != nil {
return nil, fmt.Errorf("address encoding failed for field %q: %w", field.Name, err)
}
}
return
}

299
indexer/postgres/select.go Normal file
View File

@ -0,0 +1,299 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"strings"
"time"
"cosmossdk.io/schema"
)
// Count returns the number of rows in the table.
func (tm *objectIndexer) count(ctx context.Context, conn dbConn) (int, error) {
sqlStr := fmt.Sprintf("SELECT COUNT(*) FROM %q;", tm.tableName())
if tm.options.logger != nil {
tm.options.logger.Debug("Count", "sql", sqlStr)
}
row := conn.QueryRowContext(ctx, sqlStr)
var count int
err := row.Scan(&count)
return count, err
}
// exists checks if a row with the provided key exists in the table.
func (tm *objectIndexer) exists(ctx context.Context, conn dbConn, key interface{}) (bool, error) {
buf := new(strings.Builder)
params, err := tm.existsSqlAndParams(buf, key)
if err != nil {
return false, err
}
return tm.checkExists(ctx, conn, buf.String(), params)
}
// checkExists checks if a row exists in the table.
func (tm *objectIndexer) checkExists(ctx context.Context, conn dbConn, sqlStr string, params []interface{}) (bool, error) {
if tm.options.logger != nil {
tm.options.logger.Debug("Check exists", "sql", sqlStr, "params", params)
}
var res interface{}
err := conn.QueryRowContext(ctx, sqlStr, params...).Scan(&res)
switch err {
case nil:
return true, nil
case sql.ErrNoRows:
return false, nil
default:
return false, err
}
}
// existsSqlAndParams generates a SELECT statement to check if a row with the provided key exists in the table.
func (tm *objectIndexer) existsSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "SELECT 1 FROM %q", tm.tableName())
if err != nil {
return nil, err
}
_, keyParams, err := tm.whereSqlAndParams(w, key, 1)
if err != nil {
return nil, err
}
_, err = fmt.Fprintf(w, ";")
return keyParams, err
}
func (tm *objectIndexer) get(ctx context.Context, conn dbConn, key interface{}) (schema.ObjectUpdate, bool, error) {
buf := new(strings.Builder)
params, err := tm.getSqlAndParams(buf, key)
if err != nil {
return schema.ObjectUpdate{}, false, err
}
sqlStr := buf.String()
if tm.options.logger != nil {
tm.options.logger.Debug("Get", "sql", sqlStr, "params", params)
}
row := conn.QueryRowContext(ctx, sqlStr, params...)
return tm.readRow(row)
}
func (tm *objectIndexer) selectAllSql(w io.Writer) error {
err := tm.selectAllClause(w)
if err != nil {
return err
}
_, err = fmt.Fprintf(w, ";")
return err
}
func (tm *objectIndexer) getSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
err := tm.selectAllClause(w)
if err != nil {
return nil, err
}
keyParams, keyCols, err := tm.bindKeyParams(key)
if err != nil {
return nil, err
}
_, keyParams, err = tm.whereSql(w, keyParams, keyCols, 1)
if err != nil {
return nil, err
}
_, err = fmt.Fprintf(w, ";")
return keyParams, err
}
func (tm *objectIndexer) selectAllClause(w io.Writer) error {
allFields := make([]string, 0, len(tm.typ.KeyFields)+len(tm.typ.ValueFields))
for _, field := range tm.typ.KeyFields {
colName, err := tm.updatableColumnName(field)
if err != nil {
return err
}
allFields = append(allFields, colName)
}
for _, field := range tm.typ.ValueFields {
colName, err := tm.updatableColumnName(field)
if err != nil {
return err
}
allFields = append(allFields, colName)
}
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
allFields = append(allFields, "_deleted")
}
_, err := fmt.Fprintf(w, "SELECT %s FROM %q", strings.Join(allFields, ", "), tm.tableName())
if err != nil {
return err
}
return nil
}
func (tm *objectIndexer) readRow(row interface{ Scan(...interface{}) error }) (schema.ObjectUpdate, bool, error) {
var res []interface{}
for _, f := range tm.typ.KeyFields {
res = append(res, tm.colBindValue(f))
}
for _, f := range tm.typ.ValueFields {
res = append(res, tm.colBindValue(f))
}
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
res = append(res, new(bool))
}
err := row.Scan(res...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return schema.ObjectUpdate{}, false, err
}
return schema.ObjectUpdate{}, false, err
}
var keys []interface{}
for _, field := range tm.typ.KeyFields {
x, err := tm.readCol(field, res[0])
if err != nil {
return schema.ObjectUpdate{}, false, err
}
keys = append(keys, x)
res = res[1:]
}
var key interface{} = keys
if len(keys) == 1 {
key = keys[0]
}
var values []interface{}
for _, field := range tm.typ.ValueFields {
x, err := tm.readCol(field, res[0])
if err != nil {
return schema.ObjectUpdate{}, false, err
}
values = append(values, x)
res = res[1:]
}
var value interface{} = values
if len(values) == 1 {
value = values[0]
}
update := schema.ObjectUpdate{
TypeName: tm.typ.Name,
Key: key,
Value: value,
}
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
deleted := res[0].(*bool)
if *deleted {
update.Delete = true
}
}
return update, true, nil
}
func (tm *objectIndexer) colBindValue(field schema.Field) interface{} {
switch field.Kind {
case schema.BytesKind:
return new(interface{})
default:
return new(sql.NullString)
}
}
func (tm *objectIndexer) readCol(field schema.Field, value interface{}) (interface{}, error) {
switch field.Kind {
case schema.BytesKind:
// for bytes types we either get []byte or nil
value = *value.(*interface{})
return value, nil
default:
}
nullStr := *value.(*sql.NullString)
if field.Nullable {
if !nullStr.Valid {
return nil, nil
}
}
str := nullStr.String
switch field.Kind {
case schema.StringKind, schema.EnumKind, schema.IntegerStringKind, schema.DecimalStringKind:
return str, nil
case schema.Uint8Kind:
value, err := strconv.ParseUint(str, 10, 8)
return uint8(value), err
case schema.Uint16Kind:
value, err := strconv.ParseUint(str, 10, 16)
return uint16(value), err
case schema.Uint32Kind:
value, err := strconv.ParseUint(str, 10, 32)
return uint32(value), err
case schema.Uint64Kind:
value, err := strconv.ParseUint(str, 10, 64)
return value, err
case schema.Int8Kind:
value, err := strconv.ParseInt(str, 10, 8)
return int8(value), err
case schema.Int16Kind:
value, err := strconv.ParseInt(str, 10, 16)
return int16(value), err
case schema.Int32Kind:
value, err := strconv.ParseInt(str, 10, 32)
return int32(value), err
case schema.Int64Kind:
value, err := strconv.ParseInt(str, 10, 64)
return value, err
case schema.Float32Kind:
value, err := strconv.ParseFloat(str, 32)
return float32(value), err
case schema.Float64Kind:
value, err := strconv.ParseFloat(str, 64)
return value, err
case schema.BoolKind:
value, err := strconv.ParseBool(str)
return value, err
case schema.JSONKind:
return json.RawMessage(str), nil
case schema.TimeKind:
value, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return nil, err
}
return time.Unix(0, value), nil
case schema.DurationKind:
value, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return nil, err
}
return time.Duration(value), nil
case schema.AddressKind:
return tm.options.addressCodec.StringToBytes(str)
default:
return value, nil
}
}

View File

@ -5,6 +5,7 @@ go 1.23
require (
cosmossdk.io/indexer/postgres v0.0.0-00010101000000-000000000000
cosmossdk.io/schema v0.1.1
cosmossdk.io/schema/testing v0.0.0
github.com/fergusstrange/embedded-postgres v1.29.0
github.com/hashicorp/consul/sdk v0.16.1
github.com/jackc/pgx/v5 v5.6.0
@ -13,6 +14,7 @@ require (
)
require (
github.com/cockroachdb/apd/v3 v3.2.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@ -22,14 +24,18 @@ require (
github.com/lib/pq v1.10.9 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/tidwall/btree v1.7.0 // indirect
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.23.0 // indirect
golang.org/x/text v0.17.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
pgregory.net/rapid v1.1.0 // indirect
)
replace cosmossdk.io/indexer/postgres => ../.
replace cosmossdk.io/schema => ../../../schema
replace cosmossdk.io/schema/testing => ../../../schema/testing

View File

@ -1,3 +1,5 @@
github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg=
github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
@ -32,6 +34,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo=
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos=
go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA=
@ -52,3 +56,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw=
pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=

View File

@ -0,0 +1,99 @@
package tests
import (
"context"
"os"
"strings"
"testing"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/hashicorp/consul/sdk/freeport"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/require"
"cosmossdk.io/indexer/postgres"
"cosmossdk.io/schema/addressutil"
"cosmossdk.io/schema/indexer"
indexertesting "cosmossdk.io/schema/testing"
"cosmossdk.io/schema/testing/appdatasim"
"cosmossdk.io/schema/testing/statesim"
)
func TestPostgresIndexer(t *testing.T) {
t.Run("RetainDeletions", func(t *testing.T) {
testPostgresIndexer(t, true)
})
t.Run("NoRetainDeletions", func(t *testing.T) {
testPostgresIndexer(t, false)
})
}
func testPostgresIndexer(t *testing.T, retainDeletions bool) {
tempDir, err := os.MkdirTemp("", "postgres-indexer-test")
require.NoError(t, err)
dbPort := freeport.GetOne(t)
pgConfig := embeddedpostgres.DefaultConfig().
Port(uint32(dbPort)).
DataPath(tempDir)
dbUrl := pgConfig.GetConnectionURL()
pg := embeddedpostgres.NewDatabase(pgConfig)
require.NoError(t, pg.Start())
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(func() {
cancel()
require.NoError(t, pg.Stop())
err := os.RemoveAll(tempDir)
require.NoError(t, err)
})
cfg, err := postgresConfigToIndexerConfig(postgres.Config{
DatabaseURL: dbUrl,
DisableRetainDeletions: !retainDeletions,
})
require.NoError(t, err)
debugLog := &strings.Builder{}
pgIndexer, err := postgres.StartIndexer(indexer.InitParams{
Config: cfg,
Context: ctx,
Logger: &prettyLogger{debugLog},
AddressCodec: addressutil.HexAddressCodec{},
})
require.NoError(t, err)
sim, err := appdatasim.NewSimulator(appdatasim.Options{
Listener: pgIndexer.Listener,
AppSchema: indexertesting.ExampleAppSchema,
StateSimOptions: statesim.Options{
CanRetainDeletions: retainDeletions,
},
})
require.NoError(t, err)
blockDataGen := sim.BlockDataGenN(10, 100)
numBlocks := 200
if testing.Short() {
numBlocks = 10
}
for i := 0; i < numBlocks; i++ {
// using Example generates a deterministic data set based
// on a seed so that regression tests can be created OR rapid.Check can
// be used for fully random property-based testing
blockData := blockDataGen.Example(i)
// process the generated block data with the simulator which will also
// send it to the indexer
require.NoError(t, sim.ProcessBlockData(blockData), debugLog.String())
// compare the expected state in the simulator to the actual state in the indexer and expect the diff to be empty
require.Empty(t, appdatasim.DiffAppData(sim, pgIndexer.View), debugLog.String())
// reset the debug log after each successful block so that it doesn't get too long when debugging
debugLog.Reset()
}
}

151
indexer/postgres/view.go Normal file
View File

@ -0,0 +1,151 @@
package postgres
import (
"context"
"database/sql"
"strings"
"cosmossdk.io/schema"
"cosmossdk.io/schema/view"
)
var _ view.AppData = &indexerImpl{}
func (i *indexerImpl) AppState() view.AppState {
return i
}
func (i *indexerImpl) BlockNum() (uint64, error) {
var blockNum int64
err := i.tx.QueryRow("SELECT coalesce(max(number), 0) FROM block").Scan(&blockNum)
if err != nil {
return 0, err
}
return uint64(blockNum), nil
}
type moduleView struct {
moduleIndexer
ctx context.Context
conn dbConn
}
func (i *indexerImpl) GetModule(moduleName string) (view.ModuleState, error) {
mod, ok := i.modules[moduleName]
if !ok {
return nil, nil
}
return &moduleView{
moduleIndexer: *mod,
ctx: i.ctx,
conn: i.tx,
}, nil
}
func (i *indexerImpl) Modules(f func(modState view.ModuleState, err error) bool) {
for _, mod := range i.modules {
if !f(&moduleView{
moduleIndexer: *mod,
ctx: i.ctx,
conn: i.tx,
}, nil) {
return
}
}
}
func (i *indexerImpl) NumModules() (int, error) {
return len(i.modules), nil
}
func (m *moduleView) ModuleName() string {
return m.moduleName
}
func (m *moduleView) ModuleSchema() schema.ModuleSchema {
return m.schema
}
func (m *moduleView) GetObjectCollection(objectType string) (view.ObjectCollection, error) {
obj, ok := m.tables[objectType]
if !ok {
return nil, nil
}
return &objectView{
objectIndexer: *obj,
ctx: m.ctx,
conn: m.conn,
}, nil
}
func (m *moduleView) ObjectCollections(f func(value view.ObjectCollection, err error) bool) {
for _, obj := range m.tables {
if !f(&objectView{
objectIndexer: *obj,
ctx: m.ctx,
conn: m.conn,
}, nil) {
return
}
}
}
func (m *moduleView) NumObjectCollections() (int, error) {
return len(m.tables), nil
}
type objectView struct {
objectIndexer
ctx context.Context
conn dbConn
}
func (tm *objectView) ObjectType() schema.ObjectType {
return tm.typ
}
func (tm *objectView) GetObject(key interface{}) (update schema.ObjectUpdate, found bool, err error) {
return tm.get(tm.ctx, tm.conn, key)
}
func (tm *objectView) AllState(f func(schema.ObjectUpdate, error) bool) {
buf := new(strings.Builder)
err := tm.selectAllSql(buf)
if err != nil {
panic(err)
}
sqlStr := buf.String()
if tm.options.logger != nil {
tm.options.logger.Debug("Select", "sql", sqlStr)
}
rows, err := tm.conn.QueryContext(tm.ctx, sqlStr)
if err != nil {
panic(err)
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
panic(err)
}
}(rows)
for rows.Next() {
update, found, err := tm.readRow(rows)
if err == nil && !found {
err = sql.ErrNoRows
}
if !f(update, err) {
return
}
}
}
func (tm *objectView) Len() (int, error) {
n, err := tm.count(tm.ctx, tm.conn)
if err != nil {
return 0, err
}
return n, nil
}

60
indexer/postgres/where.go Normal file
View File

@ -0,0 +1,60 @@
package postgres
import (
"fmt"
"io"
)
// whereSqlAndParams generates a WHERE clause for the provided key and returns the parameters.
func (tm *objectIndexer) whereSqlAndParams(w io.Writer, key interface{}, startParamIdx int) (endParamIdx int, keyParams []interface{}, err error) {
var keyCols []string
keyParams, keyCols, err = tm.bindKeyParams(key)
if err != nil {
return
}
endParamIdx, keyParams, err = tm.whereSql(w, keyParams, keyCols, startParamIdx)
return
}
// whereSql generates a WHERE clause for the provided columns and returns the parameters.
func (tm *objectIndexer) whereSql(w io.Writer, params []interface{}, cols []string, startParamIdx int) (endParamIdx int, resParams []interface{}, err error) {
_, err = fmt.Fprintf(w, " WHERE ")
if err != nil {
return 0, nil, err
}
endParamIdx = startParamIdx
for i, col := range cols {
if i > 0 {
_, err = fmt.Fprintf(w, " AND ")
if err != nil {
return 0, nil, err
}
}
_, err = fmt.Fprintf(w, "%s ", col)
if err != nil {
return 0, nil, err
}
if params[i] == nil {
_, err = fmt.Fprintf(w, "IS NULL")
if err != nil {
return 0, nil, err
}
} else {
_, err = fmt.Fprintf(w, "= $%d", endParamIdx)
if err != nil {
return 0, nil, err
}
resParams = append(resParams, params[i])
endParamIdx++
}
}
return endParamIdx, resParams, nil
}