lotus/curiosrc/proof/treed_build.go

293 lines
6.5 KiB
Go
Raw Normal View History

package proof
import (
"io"
"math/bits"
"os"
"runtime"
"sync"
"time"
2024-01-12 10:03:37 +00:00
"github.com/hashicorp/go-multierror"
"github.com/ipfs/go-cid"
pool "github.com/libp2p/go-buffer-pool"
"github.com/minio/sha256-simd"
"golang.org/x/xerrors"
commcid "github.com/filecoin-project/go-fil-commcid"
"github.com/filecoin-project/go-state-types/abi"
2024-01-22 16:42:29 +00:00
"github.com/filecoin-project/lotus/storage/sealer/fr32"
)
const nodeSize = 32
const threadChunkSize = 1 << 20
func hashChunk(data [][]byte) {
l1Nodes := len(data[0]) / nodeSize / 2
d := sha256.New()
2023-12-23 14:42:24 +00:00
sumBuf := make([]byte, nodeSize)
for i := 0; i < l1Nodes; i++ {
levels := bits.TrailingZeros(^uint(i)) + 1
inNode := i * 2 // at level 0
outNode := i
for l := 0; l < levels; l++ {
d.Reset()
inNodeData := data[l][inNode*nodeSize : (inNode+2)*nodeSize]
d.Write(inNodeData)
2023-12-23 14:42:24 +00:00
copy(data[l+1][outNode*nodeSize:(outNode+1)*nodeSize], d.Sum(sumBuf[:0]))
// set top bits to 00
data[l+1][outNode*nodeSize+nodeSize-1] &= 0x3f
inNode--
inNode >>= 1
outNode >>= 1
}
}
}
2024-02-17 13:51:02 +00:00
func BuildTreeD(data io.Reader, unpaddedData bool, outPath string, size abi.PaddedPieceSize) (_ cid.Cid, err error) {
out, err := os.Create(outPath)
if err != nil {
return cid.Undef, err
}
2024-02-17 13:51:02 +00:00
defer func() {
cerr := out.Close()
if err != nil {
// remove the file, it's probably bad
rerr := os.Remove(outPath)
if rerr != nil {
err = multierror.Append(err, rerr)
}
}
2024-02-17 13:51:02 +00:00
if cerr != nil {
err = multierror.Append(err, cerr)
}
}()
outSize := treeSize(size)
// allocate space for the tree
err = out.Truncate(int64(outSize))
if err != nil {
return cid.Undef, err
}
// setup buffers
maxThreads := int64(size) / threadChunkSize
2023-12-23 14:42:24 +00:00
if maxThreads > int64(runtime.NumCPU())*15/10 {
maxThreads = int64(runtime.NumCPU()) * 15 / 10
}
if maxThreads < 1 {
maxThreads = 1
}
// allocate buffers
var bufLk sync.Mutex
workerBuffers := make([][][]byte, maxThreads) // [worker][level][levelSize]
for i := range workerBuffers {
workerBuffer := make([][]byte, 1)
bottomBufSize := int64(threadChunkSize)
if bottomBufSize > int64(size) {
bottomBufSize = int64(size)
}
2023-12-23 14:42:24 +00:00
workerBuffer[0] = pool.Get(int(bottomBufSize))
// append levels until we get to a 32 byte level
for len(workerBuffer[len(workerBuffer)-1]) > 32 {
2023-12-23 14:42:24 +00:00
newLevel := pool.Get(len(workerBuffer[len(workerBuffer)-1]) / 2)
workerBuffer = append(workerBuffer, newLevel)
}
workerBuffers[i] = workerBuffer
}
// prepare apex buffer
var apexBuf [][]byte
2023-12-23 14:14:41 +00:00
{
apexBottomSize := uint64(size) / uint64(len(workerBuffers[0][0]))
if apexBottomSize == 0 {
apexBottomSize = 1
}
apexBuf = make([][]byte, 1)
2023-12-23 14:42:24 +00:00
apexBuf[0] = pool.Get(int(apexBottomSize * nodeSize))
for len(apexBuf[len(apexBuf)-1]) > 32 {
2023-12-23 14:42:24 +00:00
newLevel := pool.Get(len(apexBuf[len(apexBuf)-1]) / 2)
apexBuf = append(apexBuf, newLevel)
}
}
2023-12-23 14:42:24 +00:00
// defer free pool buffers
defer func() {
for _, workerBuffer := range workerBuffers {
for _, level := range workerBuffer {
pool.Put(level)
}
}
for _, level := range apexBuf {
pool.Put(level)
}
}()
// start processing
var processed uint64
var workWg sync.WaitGroup
var errLock sync.Mutex
var oerr error
for processed < uint64(size) {
// get a buffer
bufLk.Lock()
if len(workerBuffers) == 0 {
bufLk.Unlock()
time.Sleep(50 * time.Microsecond)
continue
}
// pop last
workBuffer := workerBuffers[len(workerBuffers)-1]
workerBuffers = workerBuffers[:len(workerBuffers)-1]
bufLk.Unlock()
// before reading check that we didn't get a write error
errLock.Lock()
if oerr != nil {
errLock.Unlock()
return cid.Undef, oerr
}
errLock.Unlock()
// read data into the bottom level
// note: the bottom level will never be too big; data is power of two
// size, and if it's smaller than a single buffer, we only have one
// smaller buffer
2024-01-22 16:33:59 +00:00
processedSize := uint64(len(workBuffer[0]))
if unpaddedData {
workBuffer[0] = workBuffer[0][:abi.PaddedPieceSize(len(workBuffer[0])).Unpadded()]
}
_, err := io.ReadFull(data, workBuffer[0])
if err != nil && err != io.EOF {
return cid.Undef, err
}
// start processing
workWg.Add(1)
go func(startOffset uint64) {
2024-02-21 12:33:49 +00:00
defer workWg.Done()
2024-01-22 16:33:59 +00:00
if unpaddedData {
paddedBuf := pool.Get(int(abi.UnpaddedPieceSize(len(workBuffer[0])).Padded()))
fr32.PadSingle(workBuffer[0], paddedBuf)
pool.Put(workBuffer[0])
workBuffer[0] = paddedBuf
}
hashChunk(workBuffer)
2023-12-23 14:14:41 +00:00
// persist apex
{
apexHash := workBuffer[len(workBuffer)-1]
2023-12-23 14:14:41 +00:00
hashPos := startOffset / uint64(len(workBuffer[0])) * nodeSize
copy(apexBuf[0][hashPos:hashPos+nodeSize], apexHash)
}
// write results
offsetInLayer := startOffset
for layer, layerData := range workBuffer {
// layerOff is outSize:bits[most significant bit - layer]
layerOff := layerOffset(uint64(size), layer)
dataOff := offsetInLayer + layerOff
offsetInLayer /= 2
_, werr := out.WriteAt(layerData, int64(dataOff))
if werr != nil {
errLock.Lock()
oerr = multierror.Append(oerr, werr)
errLock.Unlock()
return
}
}
// return buffer
bufLk.Lock()
workerBuffers = append(workerBuffers, workBuffer)
bufLk.Unlock()
}(processed)
2024-01-22 16:33:59 +00:00
processed += processedSize
}
workWg.Wait()
if oerr != nil {
return cid.Undef, oerr
}
2023-12-23 14:14:41 +00:00
threadLayers := bits.Len(uint(len(workerBuffers[0][0])) / nodeSize)
if len(apexBuf) > 0 {
// hash the apex
hashChunk(apexBuf)
// write apex
for apexLayer, layerData := range apexBuf {
2023-12-23 14:14:41 +00:00
if apexLayer == 0 {
continue
}
layer := apexLayer + threadLayers - 1
layerOff := layerOffset(uint64(size), layer)
_, werr := out.WriteAt(layerData, int64(layerOff))
if werr != nil {
return cid.Undef, xerrors.Errorf("write apex: %w", werr)
}
}
}
var commp [32]byte
2023-12-23 14:14:41 +00:00
copy(commp[:], apexBuf[len(apexBuf)-1])
commCid, err := commcid.DataCommitmentV1ToCID(commp[:])
if err != nil {
return cid.Undef, err
}
return commCid, nil
}
func treeSize(data abi.PaddedPieceSize) uint64 {
bytesToAlloc := uint64(data)
// append bytes until we get to nodeSize
for todo := bytesToAlloc; todo > nodeSize; todo /= 2 {
bytesToAlloc += todo / 2
}
return bytesToAlloc
}
func layerOffset(size uint64, layer int) uint64 {
2024-02-17 17:05:54 +00:00
allOnes := uint64(0xffff_ffff_ffff_ffff)
// get 'layer' bits set to 1
layerOnes := allOnes >> uint64(64-layer)
// shift layerOnes to the left such that the highest bit is at the same position as the highest bit in size (which is power-of-two)
sizeBitPos := bits.Len64(size) - 1
layerOnes <<= sizeBitPos - (layer - 1)
return layerOnes
}