package harmonydb

import (
	"context"
	"errors"
	"runtime"
	"time"

	"github.com/georgysavva/scany/v2/pgxscan"
	"github.com/jackc/pgerrcode"
	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgconn"
	"github.com/samber/lo"
)

var errTx = errors.New("Cannot use a non-transaction func in a transaction")

// rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries.
// In any package, raw strings will satisfy compilation.  Ex:
//
//	harmonydb.Exec("INSERT INTO version (number) VALUES (1)")
//
// This prevents SQL injection attacks where the input contains query fragments.
type rawStringOnly string

// Exec executes changes (INSERT, DELETE,  or UPDATE).
// Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number).
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
	if db.usedInTransaction() {
		return 0, errTx
	}
	res, err := db.pgx.Exec(ctx, string(sql), arguments...)
	return int(res.RowsAffected()), err
}

type Qry interface {
	Next() bool
	Err() error
	Close()
	Scan(...any) error
	Values() ([]any, error)
}

// Query offers Next/Err/Close/Scan/Values/StructScan
type Query struct {
	Qry
}

// Query allows iterating returned values to save memory consumption
// with the downside of needing to `defer q.Close()`. For a simpler interface,
// try Select()
// Next() must be called to advance the row cursor, including the first time:
// Ex:
// q, err := db.Query(ctx, "SELECT id, name FROM users")
// handleError(err)
// defer q.Close()
//
//	for q.Next() {
//		  var id int
//	   var name string
//	   handleError(q.Scan(&id, &name))
//	   fmt.Println(id, name)
//	}
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
	if db.usedInTransaction() {
		return &Query{}, errTx
	}
	q, err := db.pgx.Query(ctx, string(sql), arguments...)
	return &Query{q}, err
}
func (q *Query) StructScan(s any) error {
	return pgxscan.ScanRow(s, q.Qry.(pgx.Rows))
}

type Row interface {
	Scan(...any) error
}

type rowErr struct{}

func (rowErr) Scan(_ ...any) error { return errTx }

// QueryRow gets 1 row using column order matching.
// This is a timesaver for the special case of wanting the first row returned only.
// EX:
//
//	var name, pet string
//	var ID = 123
//	err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
	if db.usedInTransaction() {
		return rowErr{}
	}
	return db.pgx.QueryRow(ctx, string(sql), arguments...)
}

/*
Select multiple rows into a slice using name matching
Ex:

	type user struct {
		Name string
		ID int
		Number string `db:"tel_no"`
	}

	var users []user
	pet := "cat"
	err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
*/
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
	if db.usedInTransaction() {
		return errTx
	}
	return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
}

type Tx struct {
	pgx.Tx
	ctx context.Context
}

// usedInTransaction is a helper to prevent nesting transactions
// & non-transaction calls in transactions. It only checks 20 frames.
// Fast: This memory should all be in CPU Caches.
func (db *DB) usedInTransaction() bool {
	var framePtrs = (&[20]uintptr{})[:]                   // 20 can be stack-local (no alloc)
	framePtrs = framePtrs[:runtime.Callers(3, framePtrs)] // skip past our caller.
	return lo.Contains(framePtrs, db.BTFP.Load())         // Unsafe read @ beginTx overlap, but 'return false' is correct there.
}

type TransactionOptions struct {
	RetrySerializationError            bool
	InitialSerializationErrorRetryWait time.Duration
}

type TransactionOption func(*TransactionOptions)

func OptionRetry() TransactionOption {
	return func(o *TransactionOptions) {
		o.RetrySerializationError = true
	}
}

func OptionSerialRetryTime(d time.Duration) TransactionOption {
	return func(o *TransactionOptions) {
		o.InitialSerializationErrorRetryWait = d
	}
}

// BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur.
// Be sure to test the error for IsErrSerialization() if you want to retry
//
//	when there is a DB serialization error.
//
//go:noinline
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error), opt ...TransactionOption) (didCommit bool, retErr error) {
	db.BTFPOnce.Do(func() {
		fp := make([]uintptr, 20)
		runtime.Callers(1, fp)
		db.BTFP.Store(fp[0])
	})
	if db.usedInTransaction() {
		return false, errTx
	}

	opts := TransactionOptions{
		RetrySerializationError:            false,
		InitialSerializationErrorRetryWait: 10 * time.Millisecond,
	}

	for _, o := range opt {
		o(&opts)
	}

retry:
	comm, err := db.transactionInner(ctx, f)
	if err != nil && opts.RetrySerializationError && IsErrSerialization(err) {
		time.Sleep(opts.InitialSerializationErrorRetryWait)
		opts.InitialSerializationErrorRetryWait *= 2
		goto retry
	}

	return comm, err
}

func (db *DB) transactionInner(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
	tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
	if err != nil {
		return false, err
	}
	var commit bool
	defer func() { // Panic clean-up.
		if !commit {
			if tmp := tx.Rollback(ctx); tmp != nil {
				retErr = tmp
			}
		}
	}()
	commit, err = f(&Tx{tx, ctx})
	if err != nil {
		return false, err
	}
	if commit {
		err = tx.Commit(ctx)
		if err != nil {
			return false, err
		}
		return true, nil
	}
	return false, nil
}

// Exec in a transaction.
func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) {
	res, err := t.Tx.Exec(t.ctx, string(sql), arguments...)
	return int(res.RowsAffected()), err
}

// Query in a transaction.
func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) {
	q, err := t.Tx.Query(t.ctx, string(sql), arguments...)
	return &Query{q}, err
}

// QueryRow in a transaction.
func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row {
	return t.Tx.QueryRow(t.ctx, string(sql), arguments...)
}

// Select in a transaction.
func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
	return pgxscan.Select(t.ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...)
}

func IsErrUniqueContraint(err error) bool {
	var e2 *pgconn.PgError
	return errors.As(err, &e2) && e2.Code == pgerrcode.UniqueViolation
}

func IsErrSerialization(err error) bool {
	var e2 *pgconn.PgError
	return errors.As(err, &e2) && e2.Code == pgerrcode.SerializationFailure
}