package miner

import (
	"github.com/filecoin-project/go-state-types/abi"
	"github.com/filecoin-project/lotus/chain/actors/adt"
	cbg "github.com/whyrusleeping/cbor-gen"
)

func DiffPreCommits(pre, cur State) (*PreCommitChanges, error) {
	results := new(PreCommitChanges)

	prep, err := pre.precommits()
	if err != nil {
		return nil, err
	}

	curp, err := cur.precommits()
	if err != nil {
		return nil, err
	}

	err = adt.DiffAdtMap(prep, curp, &preCommitDiffer{results, pre, cur})
	if err != nil {
		return nil, err
	}

	return results, nil
}

type preCommitDiffer struct {
	Results    *PreCommitChanges
	pre, after State
}

func (m *preCommitDiffer) AsKey(key string) (abi.Keyer, error) {
	sector, err := abi.ParseUIntKey(key)
	if err != nil {
		return nil, err
	}
	return abi.UIntKey(sector), nil
}

func (m *preCommitDiffer) Add(key string, val *cbg.Deferred) error {
	sp, err := m.after.decodeSectorPreCommitOnChainInfo(val)
	if err != nil {
		return err
	}
	m.Results.Added = append(m.Results.Added, sp)
	return nil
}

func (m *preCommitDiffer) Modify(key string, from, to *cbg.Deferred) error {
	return nil
}

func (m *preCommitDiffer) Remove(key string, val *cbg.Deferred) error {
	sp, err := m.pre.decodeSectorPreCommitOnChainInfo(val)
	if err != nil {
		return err
	}
	m.Results.Removed = append(m.Results.Removed, sp)
	return nil
}

func DiffSectors(pre, cur State) (*SectorChanges, error) {
	results := new(SectorChanges)

	pres, err := pre.sectors()
	if err != nil {
		return nil, err
	}

	curs, err := cur.sectors()
	if err != nil {
		return nil, err
	}

	err = adt.DiffAdtArray(pres, curs, &sectorDiffer{results, pre, cur})
	if err != nil {
		return nil, err
	}

	return results, nil
}

type sectorDiffer struct {
	Results    *SectorChanges
	pre, after State
}

func (m *sectorDiffer) Add(key uint64, val *cbg.Deferred) error {
	si, err := m.after.decodeSectorOnChainInfo(val)
	if err != nil {
		return err
	}
	m.Results.Added = append(m.Results.Added, si)
	return nil
}

func (m *sectorDiffer) Modify(key uint64, from, to *cbg.Deferred) error {
	siFrom, err := m.pre.decodeSectorOnChainInfo(from)
	if err != nil {
		return err
	}

	siTo, err := m.after.decodeSectorOnChainInfo(to)
	if err != nil {
		return err
	}

	if siFrom.Expiration != siTo.Expiration {
		m.Results.Extended = append(m.Results.Extended, SectorExtensions{
			From: siFrom,
			To:   siTo,
		})
	}
	return nil
}

func (m *sectorDiffer) Remove(key uint64, val *cbg.Deferred) error {
	si, err := m.pre.decodeSectorOnChainInfo(val)
	if err != nil {
		return err
	}
	m.Results.Removed = append(m.Results.Removed, si)
	return nil
}