cosmos-sdk/indexer/postgres/insert_update.go

117 lines
2.8 KiB
Go

package postgres
import (
"context"
"fmt"
"io"
"strings"
)
// insertUpdate inserts or updates the row with the provided key and value.
func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error {
exists, err := tm.exists(ctx, conn, key)
if err != nil {
return err
}
buf := new(strings.Builder)
var params []interface{}
if exists {
if len(tm.typ.ValueFields) == 0 {
// special case where there are no value fields, so we can't update anything
return nil
}
params, err = tm.updateSql(buf, key, value)
} else {
params, err = tm.insertSql(buf, key, value)
}
if err != nil {
return err
}
sqlStr := buf.String()
if tm.options.logger != nil {
tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params)
}
_, err = conn.ExecContext(ctx, sqlStr, params...)
return err
}
// insertSql generates an INSERT statement and binding parameters for the provided key and value.
func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
keyParams, keyCols, err := tm.bindKeyParams(key)
if err != nil {
return nil, err
}
valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}
var allParams []interface{}
allParams = append(allParams, keyParams...)
allParams = append(allParams, valueParams...)
allCols := make([]string, 0, len(keyCols)+len(valueCols))
allCols = append(allCols, keyCols...)
allCols = append(allCols, valueCols...)
var paramBindings []string
for i := 1; i <= len(allCols); i++ {
paramBindings = append(paramBindings, fmt.Sprintf("$%d", i))
}
_, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(),
strings.Join(allCols, ", "),
strings.Join(paramBindings, ", "),
)
return allParams, err
}
// updateSql generates an UPDATE statement and binding parameters for the provided key and value.
func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName())
if err != nil {
return nil, err
}
valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}
paramIdx := 1
for i, col := range valueCols {
if i > 0 {
_, err = fmt.Fprintf(w, ", ")
if err != nil {
return nil, err
}
}
_, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx)
if err != nil {
return nil, err
}
paramIdx++
}
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
_, err = fmt.Fprintf(w, ", _deleted = FALSE")
if err != nil {
return nil, err
}
}
_, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx)
if err != nil {
return nil, err
}
allParams := append(valueParams, keyParams...)
_, err = fmt.Fprintf(w, ";")
return allParams, err
}