342 lines
9.3 KiB
Go
342 lines
9.3 KiB
Go
package harmonydb
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
logging "github.com/ipfs/go-log/v2"
|
|
"github.com/yugabyte/pgx/v5"
|
|
"github.com/yugabyte/pgx/v5/pgconn"
|
|
"github.com/yugabyte/pgx/v5/pgxpool"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/filecoin-project/lotus/node/config"
|
|
)
|
|
|
|
type ITestID string
|
|
|
|
// ITestNewID see ITestWithID doc
|
|
func ITestNewID() ITestID {
|
|
return ITestID(strconv.Itoa(rand.Intn(99999)))
|
|
}
|
|
|
|
type DB struct {
|
|
pgx *pgxpool.Pool
|
|
cfg *pgxpool.Config
|
|
schema string
|
|
hostnames []string
|
|
BTFPOnce sync.Once
|
|
BTFP atomic.Uintptr // BeginTransactionFramePointer
|
|
}
|
|
|
|
var logger = logging.Logger("harmonydb")
|
|
|
|
// NewFromConfig is a convenience function.
|
|
// In usage:
|
|
//
|
|
// db, err := NewFromConfig(config.HarmonyDB) // in binary init
|
|
func NewFromConfig(cfg config.HarmonyDB) (*DB, error) {
|
|
return New(
|
|
cfg.Hosts,
|
|
cfg.Username,
|
|
cfg.Password,
|
|
cfg.Database,
|
|
cfg.Port,
|
|
"",
|
|
)
|
|
}
|
|
|
|
func NewFromConfigWithITestID(cfg config.HarmonyDB) func(id ITestID) (*DB, error) {
|
|
return func(id ITestID) (*DB, error) {
|
|
return New(
|
|
cfg.Hosts,
|
|
cfg.Username,
|
|
cfg.Password,
|
|
cfg.Database,
|
|
cfg.Port,
|
|
id,
|
|
)
|
|
}
|
|
}
|
|
|
|
// New is to be called once per binary to establish the pool.
|
|
// log() is for errors. It returns an upgraded database's connection.
|
|
// This entry point serves both production and integration tests, so it's more DI.
|
|
func New(hosts []string, username, password, database, port string, itestID ITestID) (*DB, error) {
|
|
itest := string(itestID)
|
|
connString := ""
|
|
if len(hosts) > 0 {
|
|
connString = "host=" + hosts[0] + " "
|
|
}
|
|
for k, v := range map[string]string{"user": username, "password": password, "dbname": database, "port": port} {
|
|
if strings.TrimSpace(v) != "" {
|
|
connString += k + "=" + v + " "
|
|
}
|
|
}
|
|
|
|
schema := "curio"
|
|
if itest != "" {
|
|
schema = "itest_" + itest
|
|
}
|
|
|
|
if err := ensureSchemaExists(connString, schema); err != nil {
|
|
return nil, err
|
|
}
|
|
cfg, err := pgxpool.ParseConfig(connString + "search_path=" + schema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// enable multiple fallback hosts.
|
|
for _, h := range hosts[1:] {
|
|
cfg.ConnConfig.Fallbacks = append(cfg.ConnConfig.Fallbacks, &pgconn.FallbackConfig{Host: h})
|
|
}
|
|
|
|
cfg.ConnConfig.OnNotice = func(conn *pgconn.PgConn, n *pgconn.Notice) {
|
|
logger.Debug("database notice: " + n.Message + ": " + n.Detail)
|
|
DBMeasures.Errors.M(1)
|
|
}
|
|
|
|
db := DB{cfg: cfg, schema: schema, hostnames: hosts} // pgx populated in AddStatsAndConnect
|
|
if err := db.addStatsAndConnect(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &db, db.upgrade()
|
|
}
|
|
|
|
type tracer struct {
|
|
}
|
|
|
|
type ctxkey string
|
|
|
|
const SQL_START = ctxkey("sqlStart")
|
|
const SQL_STRING = ctxkey("sqlString")
|
|
|
|
func (t tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
|
return context.WithValue(context.WithValue(ctx, SQL_START, time.Now()), SQL_STRING, data.SQL)
|
|
}
|
|
func (t tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
|
DBMeasures.Hits.M(1)
|
|
ms := time.Since(ctx.Value(SQL_START).(time.Time)).Milliseconds()
|
|
DBMeasures.TotalWait.M(ms)
|
|
DBMeasures.Waits.Observe(float64(ms))
|
|
if data.Err != nil {
|
|
DBMeasures.Errors.M(1)
|
|
}
|
|
logger.Debugw("SQL run",
|
|
"query", ctx.Value(SQL_STRING).(string),
|
|
"err", data.Err,
|
|
"rowCt", data.CommandTag.RowsAffected(),
|
|
"milliseconds", ms)
|
|
}
|
|
|
|
func (db *DB) GetRoutableIP() (string, error) {
|
|
tx, err := db.pgx.Begin(context.Background())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer func() { _ = tx.Rollback(context.Background()) }()
|
|
local := tx.Conn().PgConn().Conn().LocalAddr()
|
|
addr, ok := local.(*net.TCPAddr)
|
|
if !ok {
|
|
return "", fmt.Errorf("could not get local addr from %v", addr)
|
|
}
|
|
return addr.IP.String(), nil
|
|
}
|
|
|
|
// addStatsAndConnect connects a prometheus logger. Be sure to run this before using the DB.
|
|
func (db *DB) addStatsAndConnect() error {
|
|
|
|
db.cfg.ConnConfig.Tracer = tracer{}
|
|
|
|
hostnameToIndex := map[string]float64{}
|
|
for i, h := range db.hostnames {
|
|
hostnameToIndex[h] = float64(i)
|
|
}
|
|
db.cfg.AfterConnect = func(ctx context.Context, c *pgx.Conn) error {
|
|
s := db.pgx.Stat()
|
|
DBMeasures.OpenConnections.M(int64(s.TotalConns()))
|
|
DBMeasures.WhichHost.Observe(hostnameToIndex[c.Config().Host])
|
|
|
|
//FUTURE place for any connection seasoning
|
|
return nil
|
|
}
|
|
|
|
// Timeout the first connection so we know if the DB is down.
|
|
ctx, ctxClose := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
|
defer ctxClose()
|
|
var err error
|
|
db.pgx, err = pgxpool.NewWithConfig(ctx, db.cfg)
|
|
if err != nil {
|
|
logger.Error(fmt.Sprintf("Unable to connect to database: %v\n", err))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ITestDeleteAll will delete everything created for "this" integration test.
|
|
// This must be called at the end of each integration test.
|
|
func (db *DB) ITestDeleteAll() {
|
|
if !strings.HasPrefix(db.schema, "itest_") {
|
|
fmt.Println("Warning: this should never be called on anything but an itest schema.")
|
|
return
|
|
}
|
|
defer db.pgx.Close()
|
|
_, err := db.pgx.Exec(context.Background(), "DROP SCHEMA "+db.schema+" CASCADE")
|
|
if err != nil {
|
|
fmt.Println("warning: unclean itest shutdown: cannot delete schema: " + err.Error())
|
|
return
|
|
}
|
|
}
|
|
|
|
var schemaREString = "^[A-Za-z0-9_]+$"
|
|
var schemaRE = regexp.MustCompile(schemaREString)
|
|
|
|
func ensureSchemaExists(connString, schema string) error {
|
|
// FUTURE allow using fallback DBs for start-up.
|
|
ctx, cncl := context.WithDeadline(context.Background(), time.Now().Add(3*time.Second))
|
|
p, err := pgx.Connect(ctx, connString)
|
|
defer cncl()
|
|
if err != nil {
|
|
return xerrors.Errorf("unable to connect to db: %s, err: %v", connString, err)
|
|
}
|
|
defer func() { _ = p.Close(context.Background()) }()
|
|
|
|
if len(schema) < 5 || !schemaRE.MatchString(schema) {
|
|
return xerrors.New("schema must be of the form " + schemaREString + "\n Got: " + schema)
|
|
}
|
|
_, err = p.Exec(context.Background(), "CREATE SCHEMA IF NOT EXISTS "+schema)
|
|
if err != nil {
|
|
return xerrors.Errorf("cannot create schema: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//go:embed sql
|
|
var fs embed.FS
|
|
|
|
func (db *DB) upgrade() error {
|
|
// Does the version table exist? if not, make it.
|
|
// NOTE: This cannot change except via the next sql file.
|
|
_, err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base (
|
|
id SERIAL PRIMARY KEY,
|
|
entry CHAR(12),
|
|
applied TIMESTAMP DEFAULT current_timestamp
|
|
)`)
|
|
if err != nil {
|
|
logger.Error("Upgrade failed.")
|
|
return xerrors.Errorf("Cannot create base table %w", err)
|
|
}
|
|
|
|
// __Run scripts in order.__
|
|
|
|
landed := map[string]bool{}
|
|
{
|
|
var landedEntries []struct{ Entry string }
|
|
err = db.Select(context.Background(), &landedEntries, "SELECT entry FROM base")
|
|
if err != nil {
|
|
logger.Error("Cannot read entries: " + err.Error())
|
|
return xerrors.Errorf("cannot read entries: %w", err)
|
|
}
|
|
for _, l := range landedEntries {
|
|
landed[l.Entry[:8]] = true
|
|
}
|
|
}
|
|
dir, err := fs.ReadDir("sql")
|
|
if err != nil {
|
|
logger.Error("Cannot read fs entries: " + err.Error())
|
|
return err
|
|
}
|
|
sort.Slice(dir, func(i, j int) bool { return dir[i].Name() < dir[j].Name() })
|
|
|
|
if len(dir) == 0 {
|
|
logger.Error("No sql files found.")
|
|
}
|
|
for _, e := range dir {
|
|
name := e.Name()
|
|
if !strings.HasSuffix(name, ".sql") {
|
|
logger.Debug("Must have only SQL files here, found: " + name)
|
|
continue
|
|
}
|
|
if landed[name[:8]] {
|
|
logger.Debug("DB Schema " + name + " already applied.")
|
|
continue
|
|
}
|
|
file, err := fs.ReadFile("sql/" + name)
|
|
if err != nil {
|
|
logger.Error("weird embed file read err")
|
|
return err
|
|
}
|
|
|
|
logger.Infow("Upgrading", "file", name, "size", len(file))
|
|
|
|
for _, s := range parseSQLStatements(string(file)) { // Implement the changes.
|
|
if len(strings.TrimSpace(s)) == 0 {
|
|
continue
|
|
}
|
|
_, err = db.pgx.Exec(context.Background(), s)
|
|
if err != nil {
|
|
msg := fmt.Sprintf("Could not upgrade! File %s, Query: %s, Returned: %s", name, s, err.Error())
|
|
logger.Error(msg)
|
|
return xerrors.New(msg) // makes devs lives easier by placing message at the end.
|
|
}
|
|
}
|
|
|
|
// Mark Completed.
|
|
_, err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name[:8])
|
|
if err != nil {
|
|
logger.Error("Cannot update base: " + err.Error())
|
|
return xerrors.Errorf("cannot insert into base: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseSQLStatements(sqlContent string) []string {
|
|
var statements []string
|
|
var currentStatement strings.Builder
|
|
|
|
lines := strings.Split(sqlContent, "\n")
|
|
var inFunction bool
|
|
|
|
for _, line := range lines {
|
|
trimmedLine := strings.TrimSpace(line)
|
|
if trimmedLine == "" || strings.HasPrefix(trimmedLine, "--") {
|
|
// Skip empty lines and comments.
|
|
continue
|
|
}
|
|
|
|
// Detect function blocks starting or ending.
|
|
if strings.Contains(trimmedLine, "$$") {
|
|
inFunction = !inFunction
|
|
}
|
|
|
|
// Add the line to the current statement.
|
|
currentStatement.WriteString(line + "\n")
|
|
|
|
// If we're not in a function and the line ends with a semicolon, or we just closed a function block.
|
|
if (!inFunction && strings.HasSuffix(trimmedLine, ";")) || (strings.Contains(trimmedLine, "$$") && !inFunction) {
|
|
statements = append(statements, currentStatement.String())
|
|
currentStatement.Reset()
|
|
}
|
|
}
|
|
|
|
// Add any remaining statement not followed by a semicolon (should not happen in well-formed SQL but just in case).
|
|
if currentStatement.Len() > 0 {
|
|
statements = append(statements, currentStatement.String())
|
|
}
|
|
|
|
return statements
|
|
}
|