117 lines
2.8 KiB
Go
117 lines
2.8 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
)
|
|
|
|
// insertUpdate inserts or updates the row with the provided key and value.
|
|
func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error {
|
|
exists, err := tm.exists(ctx, conn, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
buf := new(strings.Builder)
|
|
var params []interface{}
|
|
if exists {
|
|
if len(tm.typ.ValueFields) == 0 {
|
|
// special case where there are no value fields, so we can't update anything
|
|
return nil
|
|
}
|
|
|
|
params, err = tm.updateSql(buf, key, value)
|
|
} else {
|
|
params, err = tm.insertSql(buf, key, value)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sqlStr := buf.String()
|
|
if tm.options.logger != nil {
|
|
tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params)
|
|
}
|
|
_, err = conn.ExecContext(ctx, sqlStr, params...)
|
|
return err
|
|
}
|
|
|
|
// insertSql generates an INSERT statement and binding parameters for the provided key and value.
|
|
func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
|
|
keyParams, keyCols, err := tm.bindKeyParams(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
valueParams, valueCols, err := tm.bindValueParams(value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var allParams []interface{}
|
|
allParams = append(allParams, keyParams...)
|
|
allParams = append(allParams, valueParams...)
|
|
|
|
allCols := make([]string, 0, len(keyCols)+len(valueCols))
|
|
allCols = append(allCols, keyCols...)
|
|
allCols = append(allCols, valueCols...)
|
|
|
|
var paramBindings []string
|
|
for i := 1; i <= len(allCols); i++ {
|
|
paramBindings = append(paramBindings, fmt.Sprintf("$%d", i))
|
|
}
|
|
|
|
_, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(),
|
|
strings.Join(allCols, ", "),
|
|
strings.Join(paramBindings, ", "),
|
|
)
|
|
return allParams, err
|
|
}
|
|
|
|
// updateSql generates an UPDATE statement and binding parameters for the provided key and value.
|
|
func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
|
|
_, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
valueParams, valueCols, err := tm.bindValueParams(value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
paramIdx := 1
|
|
for i, col := range valueCols {
|
|
if i > 0 {
|
|
_, err = fmt.Fprintf(w, ", ")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
_, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
paramIdx++
|
|
}
|
|
|
|
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
|
|
_, err = fmt.Fprintf(w, ", _deleted = FALSE")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
_, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
allParams := append(valueParams, keyParams...)
|
|
_, err = fmt.Fprintf(w, ";")
|
|
return allParams, err
|
|
}
|