diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 7d80659b0..4e638166f 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,6 +1,6 @@ Hi there, -please note that this is an issue tracker reserved for bug reports and feature requests. +Please note that this is an issue tracker reserved for bug reports and feature requests. For general questions please use the gitter channel or the Ethereum stack exchange at https://ethereum.stackexchange.com. diff --git a/.travis.yml b/.travis.yml index 3a40ff583..4acd00bc9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ jobs: - stage: lint os: linux dist: xenial - go: 1.12.x + go: 1.13.x env: - lint git: @@ -18,15 +18,15 @@ jobs: - stage: build os: linux dist: xenial - go: 1.10.x + go: 1.11.x script: - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES + - go run build/ci.go install + - go run build/ci.go test -coverage $TEST_PACKAGES - stage: build os: linux dist: xenial - go: 1.11.x + go: 1.12.x script: - go run build/ci.go install - go run build/ci.go test -coverage $TEST_PACKAGES @@ -35,14 +35,14 @@ jobs: - stage: build os: linux dist: xenial - go: 1.12.x + go: 1.13.x script: - go run build/ci.go install - go run build/ci.go test -coverage $TEST_PACKAGES - stage: build os: osx - go: 1.12.x + go: 1.13.x script: - echo "Increase the maximum number of open file descriptors on macOS" - NOFILE=20480 @@ -61,7 +61,7 @@ jobs: if: type = push os: linux dist: xenial - go: 1.12.x + go: 1.13.x env: - ubuntu-ppa git: @@ -75,9 +75,12 @@ jobs: - fakeroot - python-bzrlib - python-paramiko + cache: + directories: + - $HOME/.gobundle script: - echo '|1|7SiYPr9xl3uctzovOTj4gMwAC1M=|t6ReES75Bo/PxlOPJ6/GsGbTrM0= ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA0aKz5UTUndYgIGG7dQBV+HaeuEZJ2xPHo2DS2iSKvUL4xNMSAY4UguNW+pX56nAQmZKIZZ8MaEvSj6zMEDiq6HFfn5JcTlM80UwlnyKe8B8p7Nk06PPQLrnmQt5fh0HmEcZx+JU9TZsfCHPnX7MNz4ELfZE6cFsclClrKim3BHUIGq//t93DllB+h4O9LHjEUsQ1Sr63irDLSutkLJD6RXchjROXkNirlcNVHH/jwLWR5RcYilNX7S5bIkK8NlWPjsn/8Ua5O7I9/YoE97PpO6i73DTGLh5H9JN/SITwCKBkgSDWUt61uPK3Y11Gty7o2lWsBjhBUm2Y38CBsoGmBw==' >> ~/.ssh/known_hosts - - go run build/ci.go debsrc -upload ethereum/ethereum -sftp-user geth-ci -signer "Go Ethereum Linux Builder " + - go run build/ci.go debsrc -upload ethereum/ethereum -sftp-user geth-ci -signer "Go Ethereum Linux Builder " -goversion 1.13.4 -gohash 95dbeab442ee2746b9acf0934c8e2fc26414a0565c008631b04addb8c02e7624 -gobundle $HOME/.gobundle/go.tar.gz # This builder does the Linux Azure uploads - stage: build @@ -85,7 +88,7 @@ jobs: os: linux dist: xenial sudo: required - go: 1.12.x + go: 1.13.x env: - azure-linux git: @@ -121,7 +124,7 @@ jobs: dist: xenial services: - docker - go: 1.12.x + go: 1.13.x env: - azure-linux-mips git: @@ -167,7 +170,7 @@ jobs: git: submodules: false # avoid cloning ethereum/tests before_install: - - curl https://dl.google.com/go/go1.12.linux-amd64.tar.gz | tar -xz + - curl https://dl.google.com/go/go1.13.linux-amd64.tar.gz | tar -xz - export PATH=`pwd`/go/bin:$PATH - export GOROOT=`pwd`/go - export GOPATH=$HOME/go @@ -185,7 +188,7 @@ jobs: - stage: build if: type = push os: osx - go: 1.12.x + go: 1.13.x env: - azure-osx - azure-ios @@ -216,7 +219,7 @@ jobs: if: type = cron os: linux dist: xenial - go: 1.12.x + go: 1.13.x env: - azure-purge git: diff --git a/Dockerfile b/Dockerfile index b9dcffb7c..114e76205 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build Geth in a stock Go builder container -FROM golang:1.12-alpine as builder +FROM golang:1.13-alpine as builder RUN apk add --no-cache make gcc musl-dev linux-headers git diff --git a/Dockerfile.alltools b/Dockerfile.alltools index 721b79de3..2f661ba01 100644 --- a/Dockerfile.alltools +++ b/Dockerfile.alltools @@ -1,5 +1,5 @@ # Build Geth in a stock Go builder container -FROM golang:1.12-alpine as builder +FROM golang:1.13-alpine as builder RUN apk add --no-cache make gcc musl-dev linux-headers git diff --git a/Makefile b/Makefile index 4bf52f5c9..5d4a82de8 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ .PHONY: geth-darwin geth-darwin-386 geth-darwin-amd64 .PHONY: geth-windows geth-windows-386 geth-windows-amd64 -GOBIN = $(shell pwd)/build/bin +GOBIN = ./build/bin GO ?= latest geth: diff --git a/README.md b/README.md index 87e2328af..92a7125b4 100644 --- a/README.md +++ b/README.md @@ -233,8 +233,8 @@ aware of and agree upon. This consists of a small JSON file (e.g. call it `genes The above fields should be fine for most purposes, although we'd recommend changing the `nonce` to some random value so you prevent unknown remote nodes from being able -to connect to you. If you'd like to pre-fund some accounts for easier testing, you can -populate the `alloc` field with account configs: +to connect to you. If you'd like to pre-fund some accounts for easier testing, create +the accounts and populate the `alloc` field with their addresses. ```json "alloc": { @@ -303,7 +303,7 @@ ones either). To start a `geth` instance for mining, run it with all your usual by: ```shell -$ geth --mine --minerthreads=1 --etherbase=0x0000000000000000000000000000000000000000 +$ geth --mine --miner.threads=1 --etherbase=0x0000000000000000000000000000000000000000 ``` Which will start mining blocks and transactions on a single CPU thread, crediting all diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 7831a5ed3..603e956b9 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -75,9 +75,6 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { // Unpack output in v according to the abi specification func (abi ABI) Unpack(v interface{}, name string, data []byte) (err error) { - if len(data) == 0 { - return fmt.Errorf("abi: unmarshalling empty output") - } // since there can't be naming collisions with contracts and events, // we need to decide whether we're calling a method or an event if method, ok := abi.Methods[name]; ok { @@ -94,9 +91,6 @@ func (abi ABI) Unpack(v interface{}, name string, data []byte) (err error) { // UnpackIntoMap unpacks a log into the provided map[string]interface{} func (abi ABI) UnpackIntoMap(v map[string]interface{}, name string, data []byte) (err error) { - if len(data) == 0 { - return fmt.Errorf("abi: unmarshalling empty output") - } // since there can't be naming collisions with contracts and events, // we need to decide whether we're calling a method or an event if method, ok := abi.Methods[name]; ok { diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 7a795e052..ca19c5801 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -57,7 +57,7 @@ const jsondata2 = ` ]` func TestReader(t *testing.T) { - Uint256, _ := NewType("uint256", nil) + Uint256, _ := NewType("uint256", "", nil) exp := ABI{ Methods: map[string]Method{ "balance": { @@ -172,7 +172,7 @@ func TestTestSlice(t *testing.T) { } func TestMethodSignature(t *testing.T) { - String, _ := NewType("string", nil) + String, _ := NewType("string", "", nil) m := Method{"foo", "foo", false, []Argument{{"bar", String, false}, {"baz", String, false}}, nil} exp := "foo(string,string)" if m.Sig() != exp { @@ -184,7 +184,7 @@ func TestMethodSignature(t *testing.T) { t.Errorf("expected ids to match %x != %x", m.ID(), idexp) } - uintt, _ := NewType("uint256", nil) + uintt, _ := NewType("uint256", "", nil) m = Method{"foo", "foo", false, []Argument{{"bar", uintt, false}}, nil} exp = "foo(uint256)" if m.Sig() != exp { @@ -192,7 +192,7 @@ func TestMethodSignature(t *testing.T) { } // Method with tuple arguments - s, _ := NewType("tuple", []ArgumentMarshaling{ + s, _ := NewType("tuple", "", []ArgumentMarshaling{ {Name: "a", Type: "int256"}, {Name: "b", Type: "int256[]"}, {Name: "c", Type: "tuple[]", Components: []ArgumentMarshaling{ @@ -602,9 +602,9 @@ func TestBareEvents(t *testing.T) { { "type" : "event", "name" : "tuple", "inputs" : [{ "indexed":false, "name":"t", "type":"tuple", "components":[{"name":"a", "type":"uint256"}] }, { "indexed":true, "name":"arg1", "type":"address" }] } ]` - arg0, _ := NewType("uint256", nil) - arg1, _ := NewType("address", nil) - tuple, _ := NewType("tuple", []ArgumentMarshaling{{Name: "a", Type: "uint256"}}) + arg0, _ := NewType("uint256", "", nil) + arg1, _ := NewType("address", "", nil) + tuple, _ := NewType("tuple", "", []ArgumentMarshaling{{Name: "a", Type: "uint256"}}) expectedEvents := map[string]struct { Anonymous bool diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index 4dae58653..f8ec11b9f 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -34,10 +34,11 @@ type Argument struct { type Arguments []Argument type ArgumentMarshaling struct { - Name string - Type string - Components []ArgumentMarshaling - Indexed bool + Name string + Type string + InternalType string + Components []ArgumentMarshaling + Indexed bool } // UnmarshalJSON implements json.Unmarshaler interface @@ -48,7 +49,7 @@ func (argument *Argument) UnmarshalJSON(data []byte) error { return fmt.Errorf("argument json err: %v", err) } - argument.Type, err = NewType(arg.Type, arg.Components) + argument.Type, err = NewType(arg.Type, arg.InternalType, arg.Components) if err != nil { return err } @@ -88,6 +89,13 @@ func (arguments Arguments) isTuple() bool { // Unpack performs the operation hexdata -> Go format func (arguments Arguments) Unpack(v interface{}, data []byte) error { + if len(data) == 0 { + if len(arguments) != 0 { + return fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") + } else { + return nil // Nothing to unmarshal, return + } + } // make sure the passed value is arguments pointer if reflect.Ptr != reflect.ValueOf(v).Kind() { return fmt.Errorf("abi: Unpack(non-pointer %T)", v) @@ -104,11 +112,17 @@ func (arguments Arguments) Unpack(v interface{}, data []byte) error { // UnpackIntoMap performs the operation hexdata -> mapping of argument name to argument value func (arguments Arguments) UnpackIntoMap(v map[string]interface{}, data []byte) error { + if len(data) == 0 { + if len(arguments) != 0 { + return fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") + } else { + return nil // Nothing to unmarshal, return + } + } marshalledValues, err := arguments.UnpackValues(data) if err != nil { return err } - return arguments.unpackIntoMap(v, marshalledValues) } diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index f74a0af21..499b4bda0 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -218,7 +218,7 @@ func (c *BoundContract) transact(opts *TransactOpts, contract *common.Address, i } } // If the contract surely has code (or code is not needed), estimate the transaction - msg := ethereum.CallMsg{From: opts.From, To: contract, Value: value, Data: input} + msg := ethereum.CallMsg{From: opts.From, To: contract, GasPrice: gasPrice, Value: value, Data: input} gasLimit, err = c.transactor.EstimateGas(ensureContext(opts.Context), msg) if err != nil { return nil, fmt.Errorf("failed to estimate gas needed: %v", err) diff --git a/accounts/abi/bind/bind.go b/accounts/abi/bind/bind.go index dc51e2a7e..7bda997a6 100644 --- a/accounts/abi/bind/bind.go +++ b/accounts/abi/bind/bind.go @@ -86,7 +86,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string] if input.Name == "" { normalized.Inputs[j].Name = fmt.Sprintf("arg%d", j) } - if _, exist := structs[input.Type.String()]; input.Type.T == abi.TupleTy && !exist { + if hasStruct(input.Type) { bindStructType[lang](input.Type, structs) } } @@ -96,7 +96,7 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string] if output.Name != "" { normalized.Outputs[j].Name = capitalise(output.Name) } - if _, exist := structs[output.Type.String()]; output.Type.T == abi.TupleTy && !exist { + if hasStruct(output.Type) { bindStructType[lang](output.Type, structs) } } @@ -119,14 +119,11 @@ func Bind(types []string, abis []string, bytecodes []string, fsigs []map[string] normalized.Inputs = make([]abi.Argument, len(original.Inputs)) copy(normalized.Inputs, original.Inputs) for j, input := range normalized.Inputs { - // Indexed fields are input, non-indexed ones are outputs - if input.Indexed { - if input.Name == "" { - normalized.Inputs[j].Name = fmt.Sprintf("arg%d", j) - } - if _, exist := structs[input.Type.String()]; input.Type.T == abi.TupleTy && !exist { - bindStructType[lang](input.Type, structs) - } + if input.Name == "" { + normalized.Inputs[j].Name = fmt.Sprintf("arg%d", j) + } + if hasStruct(input.Type) { + bindStructType[lang](input.Type, structs) } } // Append the event to the accumulator list @@ -244,7 +241,7 @@ func bindBasicTypeGo(kind abi.Type) string { func bindTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { switch kind.T { case abi.TupleTy: - return structs[kind.String()].Name + return structs[kind.TupleRawName+kind.String()].Name case abi.ArrayTy: return fmt.Sprintf("[%d]", kind.Size) + bindTypeGo(*kind.Elem, structs) case abi.SliceTy: @@ -321,7 +318,7 @@ func pluralizeJavaType(typ string) string { func bindTypeJava(kind abi.Type, structs map[string]*tmplStruct) string { switch kind.T { case abi.TupleTy: - return structs[kind.String()].Name + return structs[kind.TupleRawName+kind.String()].Name case abi.ArrayTy, abi.SliceTy: return pluralizeJavaType(bindTypeJava(*kind.Elem, structs)) default: @@ -340,6 +337,13 @@ var bindTopicType = map[Lang]func(kind abi.Type, structs map[string]*tmplStruct) // funcionality as for simple types, but dynamic types get converted to hashes. func bindTopicTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { bound := bindTypeGo(kind, structs) + + // todo(rjl493456442) according solidity documentation, indexed event + // parameters that are not value types i.e. arrays and structs are not + // stored directly but instead a keccak256-hash of an encoding is stored. + // + // We only convert stringS and bytes to hash, still need to deal with + // array(both fixed-size and dynamic-size) and struct. if bound == "string" || bound == "[]byte" { bound = "common.Hash" } @@ -350,6 +354,13 @@ func bindTopicTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { // funcionality as for simple types, but dynamic types get converted to hashes. func bindTopicTypeJava(kind abi.Type, structs map[string]*tmplStruct) string { bound := bindTypeJava(kind, structs) + + // todo(rjl493456442) according solidity documentation, indexed event + // parameters that are not value types i.e. arrays and structs are not + // stored directly but instead a keccak256-hash of an encoding is stored. + // + // We only convert stringS and bytes to hash, still need to deal with + // array(both fixed-size and dynamic-size) and struct. if bound == "String" || bound == "byte[]" { bound = "Hash" } @@ -369,7 +380,14 @@ var bindStructType = map[Lang]func(kind abi.Type, structs map[string]*tmplStruct func bindStructTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { switch kind.T { case abi.TupleTy: - if s, exist := structs[kind.String()]; exist { + // We compose raw struct name and canonical parameter expression + // together here. The reason is before solidity v0.5.11, kind.TupleRawName + // is empty, so we use canonical parameter expression to distinguish + // different struct definition. From the consideration of backward + // compatibility, we concat these two together so that if kind.TupleRawName + // is not empty, it can have unique id. + id := kind.TupleRawName + kind.String() + if s, exist := structs[id]; exist { return s.Name } var fields []*tmplField @@ -377,8 +395,11 @@ func bindStructTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { field := bindStructTypeGo(*elem, structs) fields = append(fields, &tmplField{Type: field, Name: capitalise(kind.TupleRawNames[i]), SolKind: *elem}) } - name := fmt.Sprintf("Struct%d", len(structs)) - structs[kind.String()] = &tmplStruct{ + name := kind.TupleRawName + if name == "" { + name = fmt.Sprintf("Struct%d", len(structs)) + } + structs[id] = &tmplStruct{ Name: name, Fields: fields, } @@ -398,7 +419,14 @@ func bindStructTypeGo(kind abi.Type, structs map[string]*tmplStruct) string { func bindStructTypeJava(kind abi.Type, structs map[string]*tmplStruct) string { switch kind.T { case abi.TupleTy: - if s, exist := structs[kind.String()]; exist { + // We compose raw struct name and canonical parameter expression + // together here. The reason is before solidity v0.5.11, kind.TupleRawName + // is empty, so we use canonical parameter expression to distinguish + // different struct definition. From the consideration of backward + // compatibility, we concat these two together so that if kind.TupleRawName + // is not empty, it can have unique id. + id := kind.TupleRawName + kind.String() + if s, exist := structs[id]; exist { return s.Name } var fields []*tmplField @@ -406,8 +434,11 @@ func bindStructTypeJava(kind abi.Type, structs map[string]*tmplStruct) string { field := bindStructTypeJava(*elem, structs) fields = append(fields, &tmplField{Type: field, Name: decapitalise(kind.TupleRawNames[i]), SolKind: *elem}) } - name := fmt.Sprintf("Class%d", len(structs)) - structs[kind.String()] = &tmplStruct{ + name := kind.TupleRawName + if name == "" { + name = fmt.Sprintf("Class%d", len(structs)) + } + structs[id] = &tmplStruct{ Name: name, Fields: fields, } @@ -497,6 +528,21 @@ func structured(args abi.Arguments) bool { return true } +// hasStruct returns an indicator whether the given type is struct, struct slice +// or struct array. +func hasStruct(t abi.Type) bool { + switch t.T { + case abi.SliceTy: + return hasStruct(*t.Elem) + case abi.ArrayTy: + return hasStruct(*t.Elem) + case abi.TupleTy: + return true + default: + return false + } +} + // resolveArgName converts a raw argument representation into a user friendly format. func resolveArgName(arg abi.Argument, structs map[string]*tmplStruct) string { var ( @@ -512,7 +558,7 @@ loop: case abi.ArrayTy: prefix += fmt.Sprintf("[%d]", typ.Size) default: - embedded = typ.String() + embedded = typ.TupleRawName + typ.String() break loop } typ = typ.Elem diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index 7dca3547c..1db568283 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -1085,7 +1085,10 @@ var bindTests = []struct { contract Tuple { struct S { uint a; uint[] b; T[] c; } struct T { uint x; uint y; } + struct P { uint8 x; uint8 y; } + struct Q { uint16 x; uint16 y; } event TupleEvent(S a, T[2][] b, T[][2] c, S[] d, uint[] e); + event TupleEvent2(P[]); function func1(S memory a, T[2][] memory b, T[][2] memory c, S[] memory d, uint[] memory e) public pure returns (S memory, T[2][] memory, T[][2] memory, S[] memory, uint[] memory) { return (a, b, c, d, e); @@ -1093,12 +1096,12 @@ var bindTests = []struct { function func2(S memory a, T[2][] memory b, T[][2] memory c, S[] memory d, uint[] memory e) public { emit TupleEvent(a, b, c, d, e); } + function func3(Q[] memory) public pure {} // call function, nothing to return } - `, - []string{`608060405234801561001057600080fd5b50610eb2806100206000396000f3fe608060405234801561001057600080fd5b50600436106100365760003560e01c8063443c79b41461003b578063d0062cdd1461006f575b600080fd5b61005560048036036100509190810190610624565b61008b565b604051610066959493929190610b28565b60405180910390f35b61008960048036036100849190810190610624565b6100bc565b005b610093610102565b606061009d610123565b6060808989898989945094509450945094509550955095509550959050565b7f18d6e66efa53739ca6d13626f35ebc700b31cced3eddb50c70bbe9c082c6cd0085858585856040516100f3959493929190610b28565b60405180910390a15050505050565b60405180606001604052806000815260200160608152602001606081525090565b60405180604001604052806002905b60608152602001906001900390816101325790505090565b600082601f83011261015b57600080fd5b813561016e61016982610bcb565b610b9e565b9150818183526020840193506020810190508385608084028201111561019357600080fd5b60005b838110156101c357816101a988826102a6565b845260208401935060808301925050600181019050610196565b5050505092915050565b600082601f8301126101de57600080fd5b60026101f16101ec82610bf3565b610b9e565b9150818360005b83811015610228578135860161020e888261031a565b8452602084019350602083019250506001810190506101f8565b5050505092915050565b600082601f83011261024357600080fd5b813561025661025182610c15565b610b9e565b9150818183526020840193506020810190508360005b8381101561029c578135860161028288826104a3565b84526020840193506020830192505060018101905061026c565b5050505092915050565b600082601f8301126102b757600080fd5b60026102ca6102c582610c3d565b610b9e565b915081838560408402820111156102e057600080fd5b60005b8381101561031057816102f688826105c3565b8452602084019350604083019250506001810190506102e3565b5050505092915050565b600082601f83011261032b57600080fd5b813561033e61033982610c5f565b610b9e565b9150818183526020840193506020810190508385604084028201111561036357600080fd5b60005b83811015610393578161037988826105c3565b845260208401935060408301925050600181019050610366565b5050505092915050565b600082601f8301126103ae57600080fd5b81356103c16103bc82610c87565b610b9e565b915081818352602084019350602081019050838560208402820111156103e657600080fd5b60005b8381101561041657816103fc888261060f565b8452602084019350602083019250506001810190506103e9565b5050505092915050565b600082601f83011261043157600080fd5b813561044461043f82610caf565b610b9e565b9150818183526020840193506020810190508385602084028201111561046957600080fd5b60005b83811015610499578161047f888261060f565b84526020840193506020830192505060018101905061046c565b5050505092915050565b6000606082840312156104b557600080fd5b6104bf6060610b9e565b905060006104cf8482850161060f565b600083015250602082013567ffffffffffffffff8111156104ef57600080fd5b6104fb8482850161039d565b602083015250604082013567ffffffffffffffff81111561051b57600080fd5b6105278482850161031a565b60408301525092915050565b60006060828403121561054557600080fd5b61054f6060610b9e565b9050600061055f8482850161060f565b600083015250602082013567ffffffffffffffff81111561057f57600080fd5b61058b8482850161039d565b602083015250604082013567ffffffffffffffff8111156105ab57600080fd5b6105b78482850161031a565b60408301525092915050565b6000604082840312156105d557600080fd5b6105df6040610b9e565b905060006105ef8482850161060f565b60008301525060206106038482850161060f565b60208301525092915050565b60008135905061061e81610e58565b92915050565b600080600080600060a0868803121561063c57600080fd5b600086013567ffffffffffffffff81111561065657600080fd5b61066288828901610533565b955050602086013567ffffffffffffffff81111561067f57600080fd5b61068b8882890161014a565b945050604086013567ffffffffffffffff8111156106a857600080fd5b6106b4888289016101cd565b935050606086013567ffffffffffffffff8111156106d157600080fd5b6106dd88828901610232565b925050608086013567ffffffffffffffff8111156106fa57600080fd5b61070688828901610420565b9150509295509295909350565b600061071f83836108cb565b60808301905092915050565b60006107378383610922565b905092915050565b600061074b8383610a93565b905092915050565b600061075f8383610aea565b60408301905092915050565b60006107778383610b19565b60208301905092915050565b600061078e82610d3b565b6107988185610de3565b93506107a383610cd7565b8060005b838110156107d45781516107bb8882610713565b97506107c683610d88565b9250506001810190506107a7565b5085935050505092915050565b60006107ec82610d46565b6107f68185610df4565b93508360208202850161080885610ce7565b8060005b858110156108445784840389528151610825858261072b565b945061083083610d95565b925060208a0199505060018101905061080c565b50829750879550505050505092915050565b600061086182610d51565b61086b8185610dff565b93508360208202850161087d85610cf1565b8060005b858110156108b9578484038952815161089a858261073f565b94506108a583610da2565b925060208a01995050600181019050610881565b50829750879550505050505092915050565b6108d481610d5c565b6108de8184610e10565b92506108e982610d01565b8060005b8381101561091a5781516109018782610753565b965061090c83610daf565b9250506001810190506108ed565b505050505050565b600061092d82610d67565b6109378185610e1b565b935061094283610d0b565b8060005b8381101561097357815161095a8882610753565b975061096583610dbc565b925050600181019050610946565b5085935050505092915050565b600061098b82610d7d565b6109958185610e3d565b93506109a083610d2b565b8060005b838110156109d15781516109b8888261076b565b97506109c383610dd6565b9250506001810190506109a4565b5085935050505092915050565b60006109e982610d72565b6109f38185610e2c565b93506109fe83610d1b565b8060005b83811015610a2f578151610a16888261076b565b9750610a2183610dc9565b925050600181019050610a02565b5085935050505092915050565b6000606083016000830151610a546000860182610b19565b5060208301518482036020860152610a6c82826109de565b91505060408301518482036040860152610a868282610922565b9150508091505092915050565b6000606083016000830151610aab6000860182610b19565b5060208301518482036020860152610ac382826109de565b91505060408301518482036040860152610add8282610922565b9150508091505092915050565b604082016000820151610b006000850182610b19565b506020820151610b136020850182610b19565b50505050565b610b2281610e4e565b82525050565b600060a0820190508181036000830152610b428188610a3c565b90508181036020830152610b568187610783565b90508181036040830152610b6a81866107e1565b90508181036060830152610b7e8185610856565b90508181036080830152610b928184610980565b90509695505050505050565b6000604051905081810181811067ffffffffffffffff82111715610bc157600080fd5b8060405250919050565b600067ffffffffffffffff821115610be257600080fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610c0a57600080fd5b602082029050919050565b600067ffffffffffffffff821115610c2c57600080fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610c5457600080fd5b602082029050919050565b600067ffffffffffffffff821115610c7657600080fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610c9e57600080fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610cc657600080fd5b602082029050602081019050919050565b6000819050602082019050919050565b6000819050919050565b6000819050602082019050919050565b6000819050919050565b6000819050602082019050919050565b6000819050602082019050919050565b6000819050602082019050919050565b600081519050919050565b600060029050919050565b600081519050919050565b600060029050919050565b600081519050919050565b600081519050919050565b600081519050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b600082825260208201905092915050565b600081905092915050565b600082825260208201905092915050565b600081905092915050565b600082825260208201905092915050565b600082825260208201905092915050565b600082825260208201905092915050565b6000819050919050565b610e6181610e4e565b8114610e6c57600080fd5b5056fea365627a7a72305820405a6336d8c302cee779de6788527018e5a2393892328fbf12b96065df2de00a6c6578706572696d656e74616cf564736f6c634300050a0040`}, + []string{`60806040523480156100115760006000fd5b50610017565b6110b2806100266000396000f3fe60806040523480156100115760006000fd5b50600436106100465760003560e01c8063443c79b41461004c578063d0062cdd14610080578063e4d9a43b1461009c57610046565b60006000fd5b610066600480360361006191908101906107b8565b6100b8565b604051610077959493929190610ccb565b60405180910390f35b61009a600480360361009591908101906107b8565b6100ef565b005b6100b660048036036100b19190810190610775565b610136565b005b6100c061013a565b60606100ca61015e565b606060608989898989945094509450945094506100e2565b9550955095509550959050565b7f18d6e66efa53739ca6d13626f35ebc700b31cced3eddb50c70bbe9c082c6cd008585858585604051610126959493929190610ccb565b60405180910390a15b5050505050565b5b50565b60405180606001604052806000815260200160608152602001606081526020015090565b60405180604001604052806002905b606081526020019060019003908161016d57905050905661106e565b600082601f830112151561019d5760006000fd5b81356101b06101ab82610d6f565b610d41565b915081818352602084019350602081019050838560808402820111156101d65760006000fd5b60005b8381101561020757816101ec888261037a565b8452602084019350608083019250505b6001810190506101d9565b5050505092915050565b600082601f83011215156102255760006000fd5b600261023861023382610d98565b610d41565b9150818360005b83811015610270578135860161025588826103f3565b8452602084019350602083019250505b60018101905061023f565b5050505092915050565b600082601f830112151561028e5760006000fd5b81356102a161029c82610dbb565b610d41565b915081818352602084019350602081019050838560408402820111156102c75760006000fd5b60005b838110156102f857816102dd888261058b565b8452602084019350604083019250505b6001810190506102ca565b5050505092915050565b600082601f83011215156103165760006000fd5b813561032961032482610de4565b610d41565b9150818183526020840193506020810190508360005b83811015610370578135860161035588826105d8565b8452602084019350602083019250505b60018101905061033f565b5050505092915050565b600082601f830112151561038e5760006000fd5b60026103a161039c82610e0d565b610d41565b915081838560408402820111156103b85760006000fd5b60005b838110156103e957816103ce88826106fe565b8452602084019350604083019250505b6001810190506103bb565b5050505092915050565b600082601f83011215156104075760006000fd5b813561041a61041582610e30565b610d41565b915081818352602084019350602081019050838560408402820111156104405760006000fd5b60005b83811015610471578161045688826106fe565b8452602084019350604083019250505b600181019050610443565b5050505092915050565b600082601f830112151561048f5760006000fd5b81356104a261049d82610e59565b610d41565b915081818352602084019350602081019050838560208402820111156104c85760006000fd5b60005b838110156104f957816104de8882610760565b8452602084019350602083019250505b6001810190506104cb565b5050505092915050565b600082601f83011215156105175760006000fd5b813561052a61052582610e82565b610d41565b915081818352602084019350602081019050838560208402820111156105505760006000fd5b60005b8381101561058157816105668882610760565b8452602084019350602083019250505b600181019050610553565b5050505092915050565b60006040828403121561059e5760006000fd5b6105a86040610d41565b905060006105b88482850161074b565b60008301525060206105cc8482850161074b565b60208301525092915050565b6000606082840312156105eb5760006000fd5b6105f56060610d41565b9050600061060584828501610760565b600083015250602082013567ffffffffffffffff8111156106265760006000fd5b6106328482850161047b565b602083015250604082013567ffffffffffffffff8111156106535760006000fd5b61065f848285016103f3565b60408301525092915050565b60006060828403121561067e5760006000fd5b6106886060610d41565b9050600061069884828501610760565b600083015250602082013567ffffffffffffffff8111156106b95760006000fd5b6106c58482850161047b565b602083015250604082013567ffffffffffffffff8111156106e65760006000fd5b6106f2848285016103f3565b60408301525092915050565b6000604082840312156107115760006000fd5b61071b6040610d41565b9050600061072b84828501610760565b600083015250602061073f84828501610760565b60208301525092915050565b60008135905061075a8161103a565b92915050565b60008135905061076f81611054565b92915050565b6000602082840312156107885760006000fd5b600082013567ffffffffffffffff8111156107a35760006000fd5b6107af8482850161027a565b91505092915050565b6000600060006000600060a086880312156107d35760006000fd5b600086013567ffffffffffffffff8111156107ee5760006000fd5b6107fa8882890161066b565b955050602086013567ffffffffffffffff8111156108185760006000fd5b61082488828901610189565b945050604086013567ffffffffffffffff8111156108425760006000fd5b61084e88828901610211565b935050606086013567ffffffffffffffff81111561086c5760006000fd5b61087888828901610302565b925050608086013567ffffffffffffffff8111156108965760006000fd5b6108a288828901610503565b9150509295509295909350565b60006108bb8383610a6a565b60808301905092915050565b60006108d38383610ac2565b905092915050565b60006108e78383610c36565b905092915050565b60006108fb8383610c8d565b60408301905092915050565b60006109138383610cbc565b60208301905092915050565b600061092a82610f0f565b6109348185610fb7565b935061093f83610eab565b8060005b8381101561097157815161095788826108af565b975061096283610f5c565b9250505b600181019050610943565b5085935050505092915050565b600061098982610f1a565b6109938185610fc8565b9350836020820285016109a585610ebb565b8060005b858110156109e257848403895281516109c285826108c7565b94506109cd83610f69565b925060208a019950505b6001810190506109a9565b50829750879550505050505092915050565b60006109ff82610f25565b610a098185610fd3565b935083602082028501610a1b85610ec5565b8060005b85811015610a585784840389528151610a3885826108db565b9450610a4383610f76565b925060208a019950505b600181019050610a1f565b50829750879550505050505092915050565b610a7381610f30565b610a7d8184610fe4565b9250610a8882610ed5565b8060005b83811015610aba578151610aa087826108ef565b9650610aab83610f83565b9250505b600181019050610a8c565b505050505050565b6000610acd82610f3b565b610ad78185610fef565b9350610ae283610edf565b8060005b83811015610b14578151610afa88826108ef565b9750610b0583610f90565b9250505b600181019050610ae6565b5085935050505092915050565b6000610b2c82610f51565b610b368185611011565b9350610b4183610eff565b8060005b83811015610b73578151610b598882610907565b9750610b6483610faa565b9250505b600181019050610b45565b5085935050505092915050565b6000610b8b82610f46565b610b958185611000565b9350610ba083610eef565b8060005b83811015610bd2578151610bb88882610907565b9750610bc383610f9d565b9250505b600181019050610ba4565b5085935050505092915050565b6000606083016000830151610bf76000860182610cbc565b5060208301518482036020860152610c0f8282610b80565b91505060408301518482036040860152610c298282610ac2565b9150508091505092915050565b6000606083016000830151610c4e6000860182610cbc565b5060208301518482036020860152610c668282610b80565b91505060408301518482036040860152610c808282610ac2565b9150508091505092915050565b604082016000820151610ca36000850182610cbc565b506020820151610cb66020850182610cbc565b50505050565b610cc581611030565b82525050565b600060a0820190508181036000830152610ce58188610bdf565b90508181036020830152610cf9818761091f565b90508181036040830152610d0d818661097e565b90508181036060830152610d2181856109f4565b90508181036080830152610d358184610b21565b90509695505050505050565b6000604051905081810181811067ffffffffffffffff82111715610d655760006000fd5b8060405250919050565b600067ffffffffffffffff821115610d875760006000fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610db05760006000fd5b602082029050919050565b600067ffffffffffffffff821115610dd35760006000fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610dfc5760006000fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610e255760006000fd5b602082029050919050565b600067ffffffffffffffff821115610e485760006000fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610e715760006000fd5b602082029050602081019050919050565b600067ffffffffffffffff821115610e9a5760006000fd5b602082029050602081019050919050565b6000819050602082019050919050565b6000819050919050565b6000819050602082019050919050565b6000819050919050565b6000819050602082019050919050565b6000819050602082019050919050565b6000819050602082019050919050565b600081519050919050565b600060029050919050565b600081519050919050565b600060029050919050565b600081519050919050565b600081519050919050565b600081519050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b6000602082019050919050565b600082825260208201905092915050565b600081905092915050565b600082825260208201905092915050565b600081905092915050565b600082825260208201905092915050565b600082825260208201905092915050565b600082825260208201905092915050565b600061ffff82169050919050565b6000819050919050565b61104381611022565b811415156110515760006000fd5b50565b61105d81611030565b8114151561106b5760006000fd5b50565bfea365627a7a72315820d78c6ba7ee332581e6c4d9daa5fc07941841230f7ce49edf6e05b1b63853e8746c6578706572696d656e74616cf564736f6c634300050c0040`}, []string{` - [{"constant":true,"inputs":[{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"a","type":"tuple"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"b","type":"tuple[2][]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[][2]"},{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"d","type":"tuple[]"},{"name":"e","type":"uint256[]"}],"name":"func1","outputs":[{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"","type":"tuple"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"","type":"tuple[2][]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"","type":"tuple[][2]"},{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"","type":"tuple[]"},{"name":"","type":"uint256[]"}],"payable":false,"stateMutability":"pure","type":"function"},{"constant":false,"inputs":[{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"a","type":"tuple"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"b","type":"tuple[2][]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[][2]"},{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"name":"d","type":"tuple[]"},{"name":"e","type":"uint256[]"}],"name":"func2","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"anonymous":false,"inputs":[{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"indexed":false,"name":"a","type":"tuple"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"indexed":false,"name":"b","type":"tuple[2][]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"indexed":false,"name":"c","type":"tuple[][2]"},{"components":[{"name":"a","type":"uint256"},{"name":"b","type":"uint256[]"},{"components":[{"name":"x","type":"uint256"},{"name":"y","type":"uint256"}],"name":"c","type":"tuple[]"}],"indexed":false,"name":"d","type":"tuple[]"},{"indexed":false,"name":"e","type":"uint256[]"}],"name":"TupleEvent","type":"event"}] +[{"anonymous":false,"inputs":[{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"indexed":false,"internalType":"struct Tuple.S","name":"a","type":"tuple"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"indexed":false,"internalType":"struct Tuple.T[2][]","name":"b","type":"tuple[2][]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"indexed":false,"internalType":"struct Tuple.T[][2]","name":"c","type":"tuple[][2]"},{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"indexed":false,"internalType":"struct Tuple.S[]","name":"d","type":"tuple[]"},{"indexed":false,"internalType":"uint256[]","name":"e","type":"uint256[]"}],"name":"TupleEvent","type":"event"},{"anonymous":false,"inputs":[{"components":[{"internalType":"uint8","name":"x","type":"uint8"},{"internalType":"uint8","name":"y","type":"uint8"}],"indexed":false,"internalType":"struct Tuple.P[]","name":"","type":"tuple[]"}],"name":"TupleEvent2","type":"event"},{"constant":true,"inputs":[{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S","name":"a","type":"tuple"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[2][]","name":"b","type":"tuple[2][]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[][2]","name":"c","type":"tuple[][2]"},{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S[]","name":"d","type":"tuple[]"},{"internalType":"uint256[]","name":"e","type":"uint256[]"}],"name":"func1","outputs":[{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S","name":"","type":"tuple"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[2][]","name":"","type":"tuple[2][]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[][2]","name":"","type":"tuple[][2]"},{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S[]","name":"","type":"tuple[]"},{"internalType":"uint256[]","name":"","type":"uint256[]"}],"payable":false,"stateMutability":"pure","type":"function"},{"constant":false,"inputs":[{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S","name":"a","type":"tuple"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[2][]","name":"b","type":"tuple[2][]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[][2]","name":"c","type":"tuple[][2]"},{"components":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256[]","name":"b","type":"uint256[]"},{"components":[{"internalType":"uint256","name":"x","type":"uint256"},{"internalType":"uint256","name":"y","type":"uint256"}],"internalType":"struct Tuple.T[]","name":"c","type":"tuple[]"}],"internalType":"struct Tuple.S[]","name":"d","type":"tuple[]"},{"internalType":"uint256[]","name":"e","type":"uint256[]"}],"name":"func2","outputs":[],"payable":false,"stateMutability":"nonpayable","type":"function"},{"constant":true,"inputs":[{"components":[{"internalType":"uint16","name":"x","type":"uint16"},{"internalType":"uint16","name":"y","type":"uint16"}],"internalType":"struct Tuple.Q[]","name":"","type":"tuple[]"}],"name":"func3","outputs":[],"payable":false,"stateMutability":"pure","type":"function"}] `}, ` "math/big" @@ -1129,10 +1132,10 @@ var bindTests = []struct { } } - a := Struct1{ + a := TupleS{ A: big.NewInt(1), B: []*big.Int{big.NewInt(2), big.NewInt(3)}, - C: []Struct0{ + C: []TupleT{ { X: big.NewInt(4), Y: big.NewInt(5), @@ -1144,7 +1147,7 @@ var bindTests = []struct { }, } - b := [][2]Struct0{ + b := [][2]TupleT{ { { X: big.NewInt(8), @@ -1157,7 +1160,7 @@ var bindTests = []struct { }, } - c := [2][]Struct0{ + c := [2][]TupleT{ { { X: big.NewInt(12), @@ -1176,7 +1179,7 @@ var bindTests = []struct { }, } - d := []Struct1{a} + d := []TupleS{a} e := []*big.Int{big.NewInt(18), big.NewInt(19)} ret1, ret2, ret3, ret4, ret5, err := contract.Func1(nil, a, b, c, d, e) @@ -1207,6 +1210,11 @@ var bindTests = []struct { check(iter.Event.C, c, "field3 mismatch") check(iter.Event.D, d, "field4 mismatch") check(iter.Event.E, e, "field5 mismatch") + + err = contract.Func3(nil, nil) + if err != nil { + t.Fatalf("failed to call function which has no return, err %v", err) + } `, nil, nil, diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go index 4ec65474b..5b35b1bad 100644 --- a/accounts/abi/bind/template.go +++ b/accounts/abi/bind/template.go @@ -65,7 +65,7 @@ type tmplField struct { // tmplStruct is a wrapper around an abi.tuple contains a auto-generated // struct name. type tmplStruct struct { - Name string // Auto-generated struct name(We can't obtain the raw struct name through abi) + Name string // Auto-generated struct name(before solidity v0.5.11) or raw name. Fields []*tmplField // Struct fields definition depends on the binding language. } @@ -483,7 +483,7 @@ var ( // Parse{{.Normalized.Name}} is a log parse operation binding the contract event 0x{{printf "%x" .Original.ID}}. // - // Solidity: {{.Original.String}} + // Solidity: {{formatevent .Original $structs}} func (_{{$contract.Type}} *{{$contract.Type}}Filterer) Parse{{.Normalized.Name}}(log types.Log) (*{{$contract.Type}}{{.Normalized.Name}}, error) { event := new({{$contract.Type}}{{.Normalized.Name}}) if err := _{{$contract.Type}}.contract.UnpackLog(event, "{{.Original.Name}}", log); err != nil { diff --git a/accounts/abi/bind/topics.go b/accounts/abi/bind/topics.go index c7657b4a4..e27fa5484 100644 --- a/accounts/abi/bind/topics.go +++ b/accounts/abi/bind/topics.go @@ -80,15 +80,19 @@ func makeTopics(query ...[]interface{}) ([][]common.Hash, error) { copy(topic[:], hash[:]) default: + // todo(rjl493456442) according solidity documentation, indexed event + // parameters that are not value types i.e. arrays and structs are not + // stored directly but instead a keccak256-hash of an encoding is stored. + // + // We only convert stringS and bytes to hash, still need to deal with + // array(both fixed-size and dynamic-size) and struct. + // Attempt to generate the topic from funky types val := reflect.ValueOf(rule) - switch { - // static byte array case val.Kind() == reflect.Array && reflect.TypeOf(rule).Elem().Kind() == reflect.Uint8: reflect.Copy(reflect.ValueOf(topic[:val.Len()]), val) - default: return nil, fmt.Errorf("unsupported indexed type: %T", rule) } @@ -162,6 +166,7 @@ func parseTopics(out interface{}, fields abi.Arguments, topics []common.Hash) er default: // Ran out of plain primitive types, try custom types + switch field.Type() { case reflectHash: // Also covers all dynamic types field.Set(reflect.ValueOf(topics[0])) @@ -178,11 +183,9 @@ func parseTopics(out interface{}, fields abi.Arguments, topics []common.Hash) er default: // Ran out of custom types, try the crazies switch { - // static byte array case arg.Type.T == abi.FixedBytesTy: reflect.Copy(field, reflect.ValueOf(topics[0][:arg.Type.Size])) - default: return fmt.Errorf("unsupported indexed type: %v", arg.Type) } diff --git a/accounts/abi/bind/topics_test.go b/accounts/abi/bind/topics_test.go index ac865e5b4..f18e2d1bd 100644 --- a/accounts/abi/bind/topics_test.go +++ b/accounts/abi/bind/topics_test.go @@ -59,7 +59,7 @@ func TestParseTopics(t *testing.T) { type bytesStruct struct { StaticBytes [5]byte } - bytesType, _ := abi.NewType("bytes5", nil) + bytesType, _ := abi.NewType("bytes5", "", nil) type args struct { createObj func() interface{} resultObj func() interface{} diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index f43e39056..cf649b480 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -613,7 +613,7 @@ func TestPack(t *testing.T) { "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), // tuple[1].A[1] }, } { - typ, err := NewType(test.typ, test.components) + typ, err := NewType(test.typ, "", test.components) if err != nil { t.Fatalf("%v failed. Unexpected parse error: %v", i, err) } diff --git a/accounts/abi/type.go b/accounts/abi/type.go index 597d31439..4792283ee 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -53,6 +53,7 @@ type Type struct { stringKind string // holds the unparsed string for deriving signatures // Tuple relative fields + TupleRawName string // Raw struct name defined in source code, may be empty. TupleElems []*Type // Type information of all tuple fields TupleRawNames []string // Raw field name of all tuple fields } @@ -63,7 +64,7 @@ var ( ) // NewType creates a new reflection type of abi type given in t. -func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { +func NewType(t string, internalType string, components []ArgumentMarshaling) (typ Type, err error) { // check that array brackets are equal if they exist if strings.Count(t, "[") != strings.Count(t, "]") { return Type{}, fmt.Errorf("invalid arg type in abi") @@ -73,9 +74,14 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { // if there are brackets, get ready to go into slice/array mode and // recursively create the type if strings.Count(t, "[") != 0 { - i := strings.LastIndex(t, "[") + // Note internalType can be empty here. + subInternal := internalType + if i := strings.LastIndex(internalType, "["); i != -1 { + subInternal = subInternal[:i] + } // recursively embed the type - embeddedType, err := NewType(t[:i], components) + i := strings.LastIndex(t, "[") + embeddedType, err := NewType(t[:i], subInternal, components) if err != nil { return Type{}, err } @@ -173,7 +179,7 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { ) expression += "(" for idx, c := range components { - cType, err := NewType(c.Type, c.Components) + cType, err := NewType(c.Type, c.InternalType, c.Components) if err != nil { return Type{}, err } @@ -199,6 +205,17 @@ func NewType(t string, components []ArgumentMarshaling) (typ Type, err error) { typ.TupleRawNames = names typ.T = TupleTy typ.stringKind = expression + + const structPrefix = "struct " + // After solidity 0.5.10, a new field of abi "internalType" + // is introduced. From that we can obtain the struct name + // user defined in the source code. + if internalType != "" && strings.HasPrefix(internalType, structPrefix) { + // Foo.Bar type definition is not allowed in golang, + // convert the format to FooBar + typ.TupleRawName = strings.Replace(internalType[len(structPrefix):], ".", "", -1) + } + case "function": typ.Kind = reflect.Array typ.T = FunctionTy diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go index 5023456ae..a2c78dc2e 100644 --- a/accounts/abi/type_test.go +++ b/accounts/abi/type_test.go @@ -106,7 +106,7 @@ func TestTypeRegexp(t *testing.T) { } for _, tt := range tests { - typ, err := NewType(tt.blob, tt.components) + typ, err := NewType(tt.blob, "", tt.components) if err != nil { t.Errorf("type %q: failed to parse type string: %v", tt.blob, err) } @@ -281,7 +281,7 @@ func TestTypeCheck(t *testing.T) { B *big.Int }{{big.NewInt(0), big.NewInt(0)}, {big.NewInt(0), big.NewInt(0)}}, ""}, } { - typ, err := NewType(test.typ, test.components) + typ, err := NewType(test.typ, "", test.components) if err != nil && len(test.err) == 0 { t.Fatal("unexpected parse error:", err) } else if err != nil && len(test.err) != 0 { diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index c85b86d8c..dfea8db67 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -51,6 +51,7 @@ func (test unpackTest) checkError(err error) error { } var unpackTests = []unpackTest{ + // Bools { def: `[{ "type": "bool" }]`, enc: "0000000000000000000000000000000000000000000000000000000000000001", @@ -73,6 +74,7 @@ var unpackTests = []unpackTest{ want: false, err: "abi: improperly encoded boolean value", }, + // Integers { def: `[{"type": "uint32"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000001", @@ -122,11 +124,13 @@ var unpackTests = []unpackTest{ enc: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", want: big.NewInt(-1), }, + // Address { def: `[{"type": "address"}]`, enc: "0000000000000000000000000100000000000000000000000000000000000000", want: common.Address{1}, }, + // Bytes { def: `[{"type": "bytes32"}]`, enc: "0100000000000000000000000000000000000000000000000000000000000000", @@ -154,23 +158,39 @@ var unpackTests = []unpackTest{ enc: "0100000000000000000000000000000000000000000000000000000000000000", want: [32]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, + // Functions { def: `[{"type": "function"}]`, enc: "0100000000000000000000000000000000000000000000000000000000000000", want: [24]byte{1}, }, - // slices + // Slice and Array { def: `[{"type": "uint8[]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: []uint8{1, 2}, }, + { + def: `[{"type": "uint8[]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", + want: []uint8{}, + }, + { + def: `[{"type": "uint256[]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", + want: []*big.Int{}, + }, { def: `[{"type": "uint8[2]"}]`, enc: "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: [2]uint8{1, 2}, }, // multi dimensional, if these pass, all types that don't require length prefix should pass + { + def: `[{"type": "uint8[][]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", + want: [][]uint8{}, + }, { def: `[{"type": "uint8[][]"}]`, enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", @@ -186,11 +206,21 @@ var unpackTests = []unpackTest{ enc: "0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", want: [2][2]uint8{{1, 2}, {1, 2}}, }, + { + def: `[{"type": "uint8[][2]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + want: [2][]uint8{{}, {}}, + }, { def: `[{"type": "uint8[][2]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001", want: [2][]uint8{{1}, {1}}, }, + { + def: `[{"type": "uint8[2][]"}]`, + enc: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", + want: [][2]uint8{}, + }, { def: `[{"type": "uint8[2][]"}]`, enc: "0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002", @@ -420,7 +450,7 @@ func TestUnpack(t *testing.T) { } encb, err := hex.DecodeString(test.enc) if err != nil { - t.Fatalf("invalid hex: %s" + test.enc) + t.Fatalf("invalid hex %s: %v", test.enc, err) } outptr := reflect.New(reflect.TypeOf(test.want)) err = abi.Unpack(outptr.Interface(), "method", encb) diff --git a/appveyor.yml b/appveyor.yml index 473ee1b78..0f230bac1 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -23,8 +23,8 @@ environment: install: - git submodule update --init - rmdir C:\go /s /q - - appveyor DownloadFile https://dl.google.com/go/go1.12.9.windows-%GETH_ARCH%.zip - - 7z x go1.12.9.windows-%GETH_ARCH%.zip -y -oC:\ > NUL + - appveyor DownloadFile https://dl.google.com/go/go1.13.4.windows-%GETH_ARCH%.zip + - 7z x go1.13.4.windows-%GETH_ARCH%.zip -y -oC:\ > NUL - go version - gcc --version diff --git a/build/ci-notes.md b/build/ci-notes.md index 13e1fd230..edd9adc1c 100644 --- a/build/ci-notes.md +++ b/build/ci-notes.md @@ -22,19 +22,18 @@ variables `PPA_SIGNING_KEY` and `PPA_SSH_KEY` on Travis. We want to build go-ethereum with the most recent version of Go, irrespective of the Go version that is available in the main Ubuntu repository. In order to make this possible, -our PPA depends on the ~gophers/ubuntu/archive PPA. Our source package build-depends on -golang-1.11, which is co-installable alongside the regular golang package. PPA dependencies -can be edited at https://launchpad.net/%7Eethereum/+archive/ubuntu/ethereum/+edit-dependencies +we bundle the entire Go sources into our own source archive and start the built job by +compiling Go and then using that to build go-ethereum. On Trusty we have a special case +requiring the `~gophers/ubuntu/archive` PPA since Trusty can't even build Go itself. PPA +deps are set at https://launchpad.net/%7Eethereum/+archive/ubuntu/ethereum/+edit-dependencies ## Building Packages Locally (for testing) You need to run Ubuntu to do test packaging. -Add the gophers PPA and install Go 1.11 and Debian packaging tools: +Install any version of Go and Debian packaging tools: - $ sudo apt-add-repository ppa:gophers/ubuntu/archive - $ sudo apt-get update - $ sudo apt-get install build-essential golang-1.11 devscripts debhelper python-bzrlib python-paramiko + $ sudo apt-get install build-essential golang-go devscripts debhelper python-bzrlib python-paramiko Create the source packages: @@ -42,10 +41,10 @@ Create the source packages: Then go into the source package directory for your running distribution and build the package: - $ cd dist/ethereum-unstable-1.6.0+xenial + $ cd dist/ethereum-unstable-1.9.6+bionic $ dpkg-buildpackage Built packages are placed in the dist/ directory. $ cd .. - $ dpkg-deb -c geth-unstable_1.6.0+xenial_amd64.deb + $ dpkg-deb -c geth-unstable_1.9.6+bionic_amd64.deb diff --git a/build/ci.go b/build/ci.go index d4e2814ec..ac5c72b6b 100644 --- a/build/ci.go +++ b/build/ci.go @@ -58,6 +58,7 @@ import ( "strings" "time" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/internal/build" "github.com/ethereum/go-ethereum/params" ) @@ -138,7 +139,18 @@ var ( // Note: zesty is unsupported because it was officially deprecated on Launchpad. // Note: artful is unsupported because it was officially deprecated on Launchpad. // Note: cosmic is unsupported because it was officially deprecated on Launchpad. - debDistros = []string{"trusty", "xenial", "bionic", "disco", "eoan"} + debDistroGoBoots = map[string]string{ + "trusty": "golang-1.11", + "xenial": "golang-go", + "bionic": "golang-go", + "disco": "golang-go", + "eoan": "golang-go", + } + + debGoBootPaths = map[string]string{ + "golang-1.11": "/usr/lib/go-1.11", + "golang-go": "/usr/lib/go", + } ) var GOBIN, _ = filepath.Abs(filepath.Join("build", "bin")) @@ -214,7 +226,6 @@ func doInstall(cmdline []string) { if flag.NArg() > 0 { packages = flag.Args() } - packages = build.ExpandPackagesNoVendor(packages) if *arch == "" || *arch == runtime.GOARCH { goinstall := goTool("install", buildFlags(env)...) @@ -311,13 +322,12 @@ func doTest(cmdline []string) { if len(flag.CommandLine.Args()) > 0 { packages = flag.CommandLine.Args() } - packages = build.ExpandPackagesNoVendor(packages) // Run the actual tests. // Test a single package at a time. CI builders are slow // and some tests run into timeouts under load. gotest := goTool("test", buildFlags(env)...) - gotest.Args = append(gotest.Args, "-p", "1", "-timeout", "5m") + gotest.Args = append(gotest.Args, "-p", "1", "-timeout", "5m", "--short") if *coverage { gotest.Args = append(gotest.Args, "-covermode=atomic", "-cover") } @@ -461,11 +471,14 @@ func maybeSkipArchive(env build.Environment) { // Debian Packaging func doDebianSource(cmdline []string) { var ( - signer = flag.String("signer", "", `Signing key name, also used as package author`) - upload = flag.String("upload", "", `Where to upload the source package (usually "ethereum/ethereum")`) - sshUser = flag.String("sftp-user", "", `Username for SFTP upload (usually "geth-ci")`) - workdir = flag.String("workdir", "", `Output directory for packages (uses temp dir if unset)`) - now = time.Now() + goversion = flag.String("goversion", "", `Go version to build with (will be included in the source package)`) + gobundle = flag.String("gobundle", "/tmp/go.tar.gz", `Filesystem path to cache the downloaded Go bundles at`) + gohash = flag.String("gohash", "", `SHA256 checksum of the Go sources requested to build with`) + signer = flag.String("signer", "", `Signing key name, also used as package author`) + upload = flag.String("upload", "", `Where to upload the source package (usually "ethereum/ethereum")`) + sshUser = flag.String("sftp-user", "", `Username for SFTP upload (usually "geth-ci")`) + workdir = flag.String("workdir", "", `Output directory for packages (uses temp dir if unset)`) + now = time.Now() ) flag.CommandLine.Parse(cmdline) *workdir = makeWorkdir(*workdir) @@ -478,12 +491,25 @@ func doDebianSource(cmdline []string) { gpg.Stdin = bytes.NewReader(key) build.MustRun(gpg) } - + // Download and verify the Go source package + if err := build.EnsureGoSources(*goversion, hexutil.MustDecode("0x"+*gohash), *gobundle); err != nil { + log.Fatalf("Failed to ensure Go source package: %v", err) + } // Create Debian packages and upload them for _, pkg := range debPackages { - for _, distro := range debDistros { - meta := newDebMetadata(distro, *signer, env, now, pkg.Name, pkg.Version, pkg.Executables) + for distro, goboot := range debDistroGoBoots { + // Prepare the debian package with the go-ethereum sources + meta := newDebMetadata(distro, goboot, *signer, env, now, pkg.Name, pkg.Version, pkg.Executables) pkgdir := stageDebianSource(*workdir, meta) + + // Ship the Go sources along so we have a proper thing to build with + if err := build.ExtractTarballArchive(*gobundle, pkgdir); err != nil { + log.Fatalf("Failed to extract Go sources: %v", err) + } + if err := os.Rename(filepath.Join(pkgdir, "go"), filepath.Join(pkgdir, ".go")); err != nil { + log.Fatalf("Failed to rename Go source folder: %v", err) + } + // Run the packaging and upload to the PPA debuild := exec.Command("debuild", "-S", "-sa", "-us", "-uc", "-d", "-Zxz") debuild.Dir = pkgdir build.MustRun(debuild) @@ -563,7 +589,9 @@ type debPackage struct { } type debMetadata struct { - Env build.Environment + Env build.Environment + GoBootPackage string + GoBootPath string PackageName string @@ -592,19 +620,21 @@ func (d debExecutable) Package() string { return d.BinaryName } -func newDebMetadata(distro, author string, env build.Environment, t time.Time, name string, version string, exes []debExecutable) debMetadata { +func newDebMetadata(distro, goboot, author string, env build.Environment, t time.Time, name string, version string, exes []debExecutable) debMetadata { if author == "" { // No signing key, use default author. author = "Ethereum Builds " } return debMetadata{ - PackageName: name, - Env: env, - Author: author, - Distro: distro, - Version: version, - Time: t.Format(time.RFC1123Z), - Executables: exes, + GoBootPackage: goboot, + GoBootPath: debGoBootPaths[goboot], + PackageName: name, + Env: env, + Author: author, + Distro: distro, + Version: version, + Time: t.Format(time.RFC1123Z), + Executables: exes, } } @@ -669,7 +699,6 @@ func stageDebianSource(tmpdir string, meta debMetadata) (pkgdir string) { if err := os.Mkdir(pkgdir, 0755); err != nil { log.Fatal(err) } - // Copy the source code. build.MustRunCommand("git", "checkout-index", "-a", "--prefix", pkgdir+string(filepath.Separator)) @@ -687,7 +716,6 @@ func stageDebianSource(tmpdir string, meta debMetadata) (pkgdir string) { build.Render("build/deb/"+meta.PackageName+"/deb.install", install, 0644, exe) build.Render("build/deb/"+meta.PackageName+"/deb.docs", docs, 0644, exe) } - return pkgdir } diff --git a/build/deb/ethereum/deb.control b/build/deb/ethereum/deb.control index 5b3ff9354..501a32cb4 100644 --- a/build/deb/ethereum/deb.control +++ b/build/deb/ethereum/deb.control @@ -2,7 +2,7 @@ Source: {{.Name}} Section: science Priority: extra Maintainer: {{.Author}} -Build-Depends: debhelper (>= 8.0.0), golang-1.11 +Build-Depends: debhelper (>= 8.0.0), {{.GoBootPackage}} Standards-Version: 3.9.5 Homepage: https://ethereum.org Vcs-Git: git://github.com/ethereum/go-ethereum.git diff --git a/build/deb/ethereum/deb.rules b/build/deb/ethereum/deb.rules index 5280e0e55..1370a52f1 100644 --- a/build/deb/ethereum/deb.rules +++ b/build/deb/ethereum/deb.rules @@ -6,9 +6,11 @@ # Launchpad rejects Go's access to $HOME/.cache, use custom folder export GOCACHE=/tmp/go-build +export GOROOT_BOOTSTRAP={{.GoBootPath}} override_dh_auto_build: - build/env.sh /usr/lib/go-1.11/bin/go run build/ci.go install -git-commit={{.Env.Commit}} -git-branch={{.Env.Branch}} -git-tag={{.Env.Tag}} -buildnum={{.Env.Buildnum}} -pull-request={{.Env.IsPullRequest}} + (cd .go/src && ./make.bash) + build/env.sh .go/bin/go run build/ci.go install -git-commit={{.Env.Commit}} -git-branch={{.Env.Branch}} -git-tag={{.Env.Tag}} -buildnum={{.Env.Buildnum}} -pull-request={{.Env.IsPullRequest}} override_dh_auto_test: diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index 2f9bba111..f6e2a14c3 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -70,7 +70,9 @@ func main() { if err = crypto.SaveECDSA(*genKey, nodeKey); err != nil { utils.Fatalf("%v", err) } - return + if !*writeAddr { + return + } case *nodeKeyFile == "" && *nodeKeyHex == "": utils.Fatalf("Use -nodekey or -nodekeyhex to specify a private key") case *nodeKeyFile != "" && *nodeKeyHex != "": diff --git a/cmd/clef/main.go b/cmd/clef/main.go index 88d9eaaa5..d34f8c28d 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -404,6 +404,27 @@ func initialize(c *cli.Context) error { return nil } +// ipcEndpoint resolves an IPC endpoint based on a configured value, taking into +// account the set data folders as well as the designated platform we're currently +// running on. +func ipcEndpoint(ipcPath, datadir string) string { + // On windows we can only use plain top-level pipes + if runtime.GOOS == "windows" { + if strings.HasPrefix(ipcPath, `\\.\pipe\`) { + return ipcPath + } + return `\\.\pipe\` + ipcPath + } + // Resolve names into the data directory full paths otherwise + if filepath.Base(ipcPath) == ipcPath { + if datadir == "" { + return filepath.Join(os.TempDir(), ipcPath) + } + return filepath.Join(datadir, ipcPath) + } + return ipcPath +} + func signer(c *cli.Context) error { // If we have some unrecognized command, bail out if args := c.Args(); len(args) > 0 { @@ -532,12 +553,8 @@ func signer(c *cli.Context) error { }() } if !c.GlobalBool(utils.IPCDisabledFlag.Name) { - if c.IsSet(utils.IPCPathFlag.Name) { - ipcapiURL = c.GlobalString(utils.IPCPathFlag.Name) - } else { - ipcapiURL = filepath.Join(configDir, "clef.ipc") - } - + givenPath := c.GlobalString(utils.IPCPathFlag.Name) + ipcapiURL = ipcEndpoint(filepath.Join(givenPath, "clef.ipc"), configDir) listener, _, err := rpc.StartIPCEndpoint(ipcapiURL, rpcAPI) if err != nil { utils.Fatalf("Could not start IPC api: %v", err) @@ -547,7 +564,6 @@ func signer(c *cli.Context) error { listener.Close() log.Info("IPC endpoint closed", "url", ipcapiURL) }() - } if c.GlobalBool(testFlag.Name) { diff --git a/cmd/devp2p/crawl.go b/cmd/devp2p/crawl.go new file mode 100644 index 000000000..92aaad72a --- /dev/null +++ b/cmd/devp2p/crawl.go @@ -0,0 +1,152 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +type crawler struct { + input nodeSet + output nodeSet + disc *discover.UDPv4 + iters []enode.Iterator + inputIter enode.Iterator + ch chan *enode.Node + closed chan struct{} + + // settings + revalidateInterval time.Duration +} + +func newCrawler(input nodeSet, disc *discover.UDPv4, iters ...enode.Iterator) *crawler { + c := &crawler{ + input: input, + output: make(nodeSet, len(input)), + disc: disc, + iters: iters, + inputIter: enode.IterNodes(input.nodes()), + ch: make(chan *enode.Node), + closed: make(chan struct{}), + } + c.iters = append(c.iters, c.inputIter) + // Copy input to output initially. Any nodes that fail validation + // will be dropped from output during the run. + for id, n := range input { + c.output[id] = n + } + return c +} + +func (c *crawler) run(timeout time.Duration) nodeSet { + var ( + timeoutTimer = time.NewTimer(timeout) + timeoutCh <-chan time.Time + doneCh = make(chan enode.Iterator, len(c.iters)) + liveIters = len(c.iters) + ) + for _, it := range c.iters { + go c.runIterator(doneCh, it) + } + +loop: + for { + select { + case n := <-c.ch: + c.updateNode(n) + case it := <-doneCh: + if it == c.inputIter { + // Enable timeout when we're done revalidating the input nodes. + log.Info("Revalidation of input set is done", "len", len(c.input)) + if timeout > 0 { + timeoutCh = timeoutTimer.C + } + } + if liveIters--; liveIters == 0 { + break loop + } + case <-timeoutCh: + break loop + } + } + + close(c.closed) + for _, it := range c.iters { + it.Close() + } + for ; liveIters > 0; liveIters-- { + <-doneCh + } + return c.output +} + +func (c *crawler) runIterator(done chan<- enode.Iterator, it enode.Iterator) { + defer func() { done <- it }() + for it.Next() { + select { + case c.ch <- it.Node(): + case <-c.closed: + return + } + } +} + +func (c *crawler) updateNode(n *enode.Node) { + node, ok := c.output[n.ID()] + + // Skip validation of recently-seen nodes. + if ok && time.Since(node.LastCheck) < c.revalidateInterval { + return + } + + // Request the node record. + nn, err := c.disc.RequestENR(n) + node.LastCheck = truncNow() + if err != nil { + if node.Score == 0 { + // Node doesn't implement EIP-868. + log.Debug("Skipping node", "id", n.ID()) + return + } + node.Score /= 2 + } else { + node.N = nn + node.Seq = nn.Seq() + node.Score++ + if node.FirstResponse.IsZero() { + node.FirstResponse = node.LastCheck + } + node.LastResponse = node.LastCheck + } + + // Store/update node in output set. + if node.Score <= 0 { + log.Info("Removing node", "id", n.ID()) + delete(c.output, n.ID()) + } else { + log.Info("Updating node", "id", n.ID(), "seq", n.Seq(), "score", node.Score) + c.output[n.ID()] = node + } +} + +func truncNow() time.Time { + return time.Now().UTC().Truncate(1 * time.Second) +} diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 1e56687a6..9525bec66 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -19,10 +19,10 @@ package main import ( "fmt" "net" - "sort" "strings" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" @@ -38,36 +38,59 @@ var ( discv4PingCommand, discv4RequestRecordCommand, discv4ResolveCommand, + discv4ResolveJSONCommand, + discv4CrawlCommand, }, } discv4PingCommand = cli.Command{ - Name: "ping", - Usage: "Sends ping to a node", - Action: discv4Ping, + Name: "ping", + Usage: "Sends ping to a node", + Action: discv4Ping, + ArgsUsage: "", } discv4RequestRecordCommand = cli.Command{ - Name: "requestenr", - Usage: "Requests a node record using EIP-868 enrRequest", - Action: discv4RequestRecord, + Name: "requestenr", + Usage: "Requests a node record using EIP-868 enrRequest", + Action: discv4RequestRecord, + ArgsUsage: "", } discv4ResolveCommand = cli.Command{ - Name: "resolve", - Usage: "Finds a node in the DHT", - Action: discv4Resolve, - Flags: []cli.Flag{bootnodesFlag}, + Name: "resolve", + Usage: "Finds a node in the DHT", + Action: discv4Resolve, + ArgsUsage: "", + Flags: []cli.Flag{bootnodesFlag}, + } + discv4ResolveJSONCommand = cli.Command{ + Name: "resolve-json", + Usage: "Re-resolves nodes in a nodes.json file", + Action: discv4ResolveJSON, + Flags: []cli.Flag{bootnodesFlag}, + ArgsUsage: "", + } + discv4CrawlCommand = cli.Command{ + Name: "crawl", + Usage: "Updates a nodes.json file with random nodes found in the DHT", + Action: discv4Crawl, + Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, } ) -var bootnodesFlag = cli.StringFlag{ - Name: "bootnodes", - Usage: "Comma separated nodes used for bootstrapping", -} +var ( + bootnodesFlag = cli.StringFlag{ + Name: "bootnodes", + Usage: "Comma separated nodes used for bootstrapping", + } + crawlTimeoutFlag = cli.DurationFlag{ + Name: "timeout", + Usage: "Time limit for the crawl.", + Value: 30 * time.Minute, + } +) func discv4Ping(ctx *cli.Context) error { - n, disc, err := getNodeArgAndStartV4(ctx) - if err != nil { - return err - } + n := getNodeArg(ctx) + disc := startV4(ctx) defer disc.Close() start := time.Now() @@ -79,10 +102,8 @@ func discv4Ping(ctx *cli.Context) error { } func discv4RequestRecord(ctx *cli.Context) error { - n, disc, err := getNodeArgAndStartV4(ctx) - if err != nil { - return err - } + n := getNodeArg(ctx) + disc := startV4(ctx) defer disc.Close() respN, err := disc.RequestENR(n) @@ -94,33 +115,61 @@ func discv4RequestRecord(ctx *cli.Context) error { } func discv4Resolve(ctx *cli.Context) error { - n, disc, err := getNodeArgAndStartV4(ctx) - if err != nil { - return err - } + n := getNodeArg(ctx) + disc := startV4(ctx) defer disc.Close() fmt.Println(disc.Resolve(n).String()) return nil } -func getNodeArgAndStartV4(ctx *cli.Context) (*enode.Node, *discover.UDPv4, error) { - if ctx.NArg() != 1 { - return nil, nil, fmt.Errorf("missing node as command-line argument") +func discv4ResolveJSON(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") } - n, err := parseNode(ctx.Args()[0]) - if err != nil { - return nil, nil, err + nodesFile := ctx.Args().Get(0) + inputSet := make(nodeSet) + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) } - var bootnodes []*enode.Node - if commandHasFlag(ctx, bootnodesFlag) { - bootnodes, err = parseBootnodes(ctx) + + // Add extra nodes from command line arguments. + var nodeargs []*enode.Node + for i := 1; i < ctx.NArg(); i++ { + n, err := parseNode(ctx.Args().Get(i)) if err != nil { - return nil, nil, err + exit(err) } + nodeargs = append(nodeargs, n) } - disc, err := startV4(bootnodes) - return n, disc, err + + // Run the crawler. + disc := startV4(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, enode.IterNodes(nodeargs)) + c.revalidateInterval = 0 + output := c.run(0) + writeNodesJSON(nodesFile, output) + return nil +} + +func discv4Crawl(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + nodesFile := ctx.Args().First() + var inputSet nodeSet + if common.FileExist(nodesFile) { + inputSet = loadNodesJSON(nodesFile) + } + + disc := startV4(ctx) + defer disc.Close() + c := newCrawler(inputSet, disc, disc.RandomNodes()) + c.revalidateInterval = 10 * time.Minute + output := c.run(ctx.Duration(crawlTimeoutFlag.Name)) + writeNodesJSON(nodesFile, output) + return nil } func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { @@ -139,28 +188,39 @@ func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { return nodes, nil } -// commandHasFlag returns true if the current command supports the given flag. -func commandHasFlag(ctx *cli.Context, flag cli.Flag) bool { - flags := ctx.FlagNames() - sort.Strings(flags) - i := sort.SearchStrings(flags, flag.GetName()) - return i != len(flags) && flags[i] == flag.GetName() +// startV4 starts an ephemeral discovery V4 node. +func startV4(ctx *cli.Context) *discover.UDPv4 { + socket, ln, cfg, err := listen() + if err != nil { + exit(err) + } + if commandHasFlag(ctx, bootnodesFlag) { + bn, err := parseBootnodes(ctx) + if err != nil { + exit(err) + } + cfg.Bootnodes = bn + } + disc, err := discover.ListenV4(socket, ln, cfg) + if err != nil { + exit(err) + } + return disc } -// startV4 starts an ephemeral discovery V4 node. -func startV4(bootnodes []*enode.Node) (*discover.UDPv4, error) { +func listen() (*net.UDPConn, *enode.LocalNode, discover.Config, error) { var cfg discover.Config - cfg.Bootnodes = bootnodes cfg.PrivateKey, _ = crypto.GenerateKey() db, _ := enode.OpenDB("") ln := enode.NewLocalNode(db, cfg.PrivateKey) socket, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{0, 0, 0, 0}}) if err != nil { - return nil, err + db.Close() + return nil, nil, cfg, err } addr := socket.LocalAddr().(*net.UDPAddr) ln.SetFallbackIP(net.IP{127, 0, 0, 1}) ln.SetFallbackUDP(addr.Port) - return discover.ListenUDP(socket, ln, cfg) + return socket, ln, cfg, nil } diff --git a/cmd/devp2p/dns_cloudflare.go b/cmd/devp2p/dns_cloudflare.go new file mode 100644 index 000000000..83279168c --- /dev/null +++ b/cmd/devp2p/dns_cloudflare.go @@ -0,0 +1,163 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "strings" + + "github.com/cloudflare/cloudflare-go" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/dnsdisc" + "gopkg.in/urfave/cli.v1" +) + +var ( + cloudflareTokenFlag = cli.StringFlag{ + Name: "token", + Usage: "CloudFlare API token", + EnvVar: "CLOUDFLARE_API_TOKEN", + } + cloudflareZoneIDFlag = cli.StringFlag{ + Name: "zoneid", + Usage: "CloudFlare Zone ID (optional)", + } +) + +type cloudflareClient struct { + *cloudflare.API + zoneID string +} + +// newCloudflareClient sets up a CloudFlare API client from command line flags. +func newCloudflareClient(ctx *cli.Context) *cloudflareClient { + token := ctx.String(cloudflareTokenFlag.Name) + if token == "" { + exit(fmt.Errorf("need cloudflare API token to proceed")) + } + api, err := cloudflare.NewWithAPIToken(token) + if err != nil { + exit(fmt.Errorf("can't create Cloudflare client: %v", err)) + } + return &cloudflareClient{ + API: api, + zoneID: ctx.String(cloudflareZoneIDFlag.Name), + } +} + +// deploy uploads the given tree to CloudFlare DNS. +func (c *cloudflareClient) deploy(name string, t *dnsdisc.Tree) error { + if err := c.checkZone(name); err != nil { + return err + } + records := t.ToTXT(name) + return c.uploadRecords(name, records) +} + +// checkZone verifies permissions on the CloudFlare DNS Zone for name. +func (c *cloudflareClient) checkZone(name string) error { + if c.zoneID == "" { + log.Info(fmt.Sprintf("Finding CloudFlare zone ID for %s", name)) + id, err := c.ZoneIDByName(name) + if err != nil { + return err + } + c.zoneID = id + } + log.Info(fmt.Sprintf("Checking Permissions on zone %s", c.zoneID)) + zone, err := c.ZoneDetails(c.zoneID) + if err != nil { + return err + } + if !strings.HasSuffix(name, "."+zone.Name) { + return fmt.Errorf("CloudFlare zone name %q does not match name %q to be deployed", zone.Name, name) + } + needPerms := map[string]bool{"#zone:edit": false, "#zone:read": false} + for _, perm := range zone.Permissions { + if _, ok := needPerms[perm]; ok { + needPerms[perm] = true + } + } + for _, ok := range needPerms { + if !ok { + return fmt.Errorf("wrong permissions on zone %s: %v", c.zoneID, needPerms) + } + } + return nil +} + +// uploadRecords updates the TXT records at a particular subdomain. All non-root records +// will have a TTL of "infinity" and all existing records not in the new map will be +// nuked! +func (c *cloudflareClient) uploadRecords(name string, records map[string]string) error { + // Convert all names to lowercase. + lrecords := make(map[string]string, len(records)) + for name, r := range records { + lrecords[strings.ToLower(name)] = r + } + records = lrecords + + log.Info(fmt.Sprintf("Retrieving existing TXT records on %s", name)) + entries, err := c.DNSRecords(c.zoneID, cloudflare.DNSRecord{Type: "TXT"}) + if err != nil { + return err + } + existing := make(map[string]cloudflare.DNSRecord) + for _, entry := range entries { + if !strings.HasSuffix(entry.Name, name) { + continue + } + existing[strings.ToLower(entry.Name)] = entry + } + + // Iterate over the new records and inject anything missing. + for path, val := range records { + old, exists := existing[path] + if !exists { + // Entry is unknown, push a new one to Cloudflare. + log.Info(fmt.Sprintf("Creating %s = %q", path, val)) + ttl := 1 + if path != name { + ttl = 2147483647 // Max TTL permitted by Cloudflare + } + _, err = c.CreateDNSRecord(c.zoneID, cloudflare.DNSRecord{Type: "TXT", Name: path, Content: val, TTL: ttl}) + } else if old.Content != val { + // Entry already exists, only change its content. + log.Info(fmt.Sprintf("Updating %s from %q to %q", path, old.Content, val)) + old.Content = val + err = c.UpdateDNSRecord(c.zoneID, old.ID, old) + } else { + log.Info(fmt.Sprintf("Skipping %s = %q", path, val)) + } + if err != nil { + return fmt.Errorf("failed to publish %s: %v", path, err) + } + } + + // Iterate over the old records and delete anything stale. + for path, entry := range existing { + if _, ok := records[path]; ok { + continue + } + // Stale entry, nuke it. + log.Info(fmt.Sprintf("Deleting %s = %q", path, entry.Content)) + if err := c.DeleteDNSRecord(c.zoneID, entry.ID); err != nil { + return fmt.Errorf("failed to delete %s: %v", path, err) + } + } + return nil +} diff --git a/cmd/devp2p/dnscmd.go b/cmd/devp2p/dnscmd.go new file mode 100644 index 000000000..eb15764b0 --- /dev/null +++ b/cmd/devp2p/dnscmd.go @@ -0,0 +1,361 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "crypto/ecdsa" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "time" + + "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/console" + "github.com/ethereum/go-ethereum/p2p/dnsdisc" + "github.com/ethereum/go-ethereum/p2p/enode" + cli "gopkg.in/urfave/cli.v1" +) + +var ( + dnsCommand = cli.Command{ + Name: "dns", + Usage: "DNS Discovery Commands", + Subcommands: []cli.Command{ + dnsSyncCommand, + dnsSignCommand, + dnsTXTCommand, + dnsCloudflareCommand, + }, + } + dnsSyncCommand = cli.Command{ + Name: "sync", + Usage: "Download a DNS discovery tree", + ArgsUsage: " [ ]", + Action: dnsSync, + Flags: []cli.Flag{dnsTimeoutFlag}, + } + dnsSignCommand = cli.Command{ + Name: "sign", + Usage: "Sign a DNS discovery tree", + ArgsUsage: " ", + Action: dnsSign, + Flags: []cli.Flag{dnsDomainFlag, dnsSeqFlag}, + } + dnsTXTCommand = cli.Command{ + Name: "to-txt", + Usage: "Create a DNS TXT records for a discovery tree", + ArgsUsage: " ", + Action: dnsToTXT, + } + dnsCloudflareCommand = cli.Command{ + Name: "to-cloudflare", + Usage: "Deploy DNS TXT records to cloudflare", + ArgsUsage: "", + Action: dnsToCloudflare, + Flags: []cli.Flag{cloudflareTokenFlag, cloudflareZoneIDFlag}, + } +) + +var ( + dnsTimeoutFlag = cli.DurationFlag{ + Name: "timeout", + Usage: "Timeout for DNS lookups", + } + dnsDomainFlag = cli.StringFlag{ + Name: "domain", + Usage: "Domain name of the tree", + } + dnsSeqFlag = cli.UintFlag{ + Name: "seq", + Usage: "New sequence number of the tree", + } +) + +// dnsSync performs dnsSyncCommand. +func dnsSync(ctx *cli.Context) error { + var ( + c = dnsClient(ctx) + url = ctx.Args().Get(0) + outdir = ctx.Args().Get(1) + ) + domain, _, err := dnsdisc.ParseURL(url) + if err != nil { + return err + } + if outdir == "" { + outdir = domain + } + + t, err := c.SyncTree(url) + if err != nil { + return err + } + def := treeToDefinition(url, t) + def.Meta.LastModified = time.Now() + writeTreeMetadata(outdir, def) + writeTreeNodes(outdir, def) + return nil +} + +func dnsSign(ctx *cli.Context) error { + if ctx.NArg() < 2 { + return fmt.Errorf("need tree definition directory and key file as arguments") + } + var ( + defdir = ctx.Args().Get(0) + keyfile = ctx.Args().Get(1) + def = loadTreeDefinition(defdir) + domain = directoryName(defdir) + ) + if def.Meta.URL != "" { + d, _, err := dnsdisc.ParseURL(def.Meta.URL) + if err != nil { + return fmt.Errorf("invalid 'url' field: %v", err) + } + domain = d + } + if ctx.IsSet(dnsDomainFlag.Name) { + domain = ctx.String(dnsDomainFlag.Name) + } + if ctx.IsSet(dnsSeqFlag.Name) { + def.Meta.Seq = ctx.Uint(dnsSeqFlag.Name) + } else { + def.Meta.Seq++ // Auto-bump sequence number if not supplied via flag. + } + t, err := dnsdisc.MakeTree(def.Meta.Seq, def.Nodes, def.Meta.Links) + if err != nil { + return err + } + + key := loadSigningKey(keyfile) + url, err := t.Sign(key, domain) + if err != nil { + return fmt.Errorf("can't sign: %v", err) + } + + def = treeToDefinition(url, t) + def.Meta.LastModified = time.Now() + writeTreeMetadata(defdir, def) + return nil +} + +func directoryName(dir string) string { + abs, err := filepath.Abs(dir) + if err != nil { + exit(err) + } + return filepath.Base(abs) +} + +// dnsToTXT peforms dnsTXTCommand. +func dnsToTXT(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need tree definition directory as argument") + } + output := ctx.Args().Get(1) + if output == "" { + output = "-" // default to stdout + } + domain, t, err := loadTreeDefinitionForExport(ctx.Args().Get(0)) + if err != nil { + return err + } + writeTXTJSON(output, t.ToTXT(domain)) + return nil +} + +// dnsToCloudflare peforms dnsCloudflareCommand. +func dnsToCloudflare(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need tree definition directory as argument") + } + domain, t, err := loadTreeDefinitionForExport(ctx.Args().Get(0)) + if err != nil { + return err + } + client := newCloudflareClient(ctx) + return client.deploy(domain, t) +} + +// loadSigningKey loads a private key in Ethereum keystore format. +func loadSigningKey(keyfile string) *ecdsa.PrivateKey { + keyjson, err := ioutil.ReadFile(keyfile) + if err != nil { + exit(fmt.Errorf("failed to read the keyfile at '%s': %v", keyfile, err)) + } + password, _ := console.Stdin.PromptPassword("Please enter the password for '" + keyfile + "': ") + key, err := keystore.DecryptKey(keyjson, password) + if err != nil { + exit(fmt.Errorf("error decrypting key: %v", err)) + } + return key.PrivateKey +} + +// dnsClient configures the DNS discovery client from command line flags. +func dnsClient(ctx *cli.Context) *dnsdisc.Client { + var cfg dnsdisc.Config + if commandHasFlag(ctx, dnsTimeoutFlag) { + cfg.Timeout = ctx.Duration(dnsTimeoutFlag.Name) + } + c, _ := dnsdisc.NewClient(cfg) // cannot fail because no URLs given + return c +} + +// There are two file formats for DNS node trees on disk: +// +// The 'TXT' format is a single JSON file containing DNS TXT records +// as a JSON object where the keys are names and the values are objects +// containing the value of the record. +// +// The 'definition' format is a directory containing two files: +// +// enrtree-info.json -- contains sequence number & links to other trees +// nodes.json -- contains the nodes as a JSON array. +// +// This format exists because it's convenient to edit. nodes.json can be generated +// in multiple ways: it may be written by a DHT crawler or compiled by a human. + +type dnsDefinition struct { + Meta dnsMetaJSON + Nodes []*enode.Node +} + +type dnsMetaJSON struct { + URL string `json:"url,omitempty"` + Seq uint `json:"seq"` + Sig string `json:"signature,omitempty"` + Links []string `json:"links"` + LastModified time.Time `json:"lastModified"` +} + +func treeToDefinition(url string, t *dnsdisc.Tree) *dnsDefinition { + meta := dnsMetaJSON{ + URL: url, + Seq: t.Seq(), + Sig: t.Signature(), + Links: t.Links(), + } + if meta.Links == nil { + meta.Links = []string{} + } + return &dnsDefinition{Meta: meta, Nodes: t.Nodes()} +} + +// loadTreeDefinition loads a directory in 'definition' format. +func loadTreeDefinition(directory string) *dnsDefinition { + metaFile, nodesFile := treeDefinitionFiles(directory) + var def dnsDefinition + err := common.LoadJSON(metaFile, &def.Meta) + if err != nil && !os.IsNotExist(err) { + exit(err) + } + if def.Meta.Links == nil { + def.Meta.Links = []string{} + } + // Check link syntax. + for _, link := range def.Meta.Links { + if _, _, err := dnsdisc.ParseURL(link); err != nil { + exit(fmt.Errorf("invalid link %q: %v", link, err)) + } + } + // Check/convert nodes. + nodes := loadNodesJSON(nodesFile) + if err := nodes.verify(); err != nil { + exit(err) + } + def.Nodes = nodes.nodes() + return &def +} + +// loadTreeDefinitionForExport loads a DNS tree and ensures it is signed. +func loadTreeDefinitionForExport(dir string) (domain string, t *dnsdisc.Tree, err error) { + metaFile, _ := treeDefinitionFiles(dir) + def := loadTreeDefinition(dir) + if def.Meta.URL == "" { + return "", nil, fmt.Errorf("missing 'url' field in %v", metaFile) + } + domain, pubkey, err := dnsdisc.ParseURL(def.Meta.URL) + if err != nil { + return "", nil, fmt.Errorf("invalid 'url' field in %v: %v", metaFile, err) + } + if t, err = dnsdisc.MakeTree(def.Meta.Seq, def.Nodes, def.Meta.Links); err != nil { + return "", nil, err + } + if err := ensureValidTreeSignature(t, pubkey, def.Meta.Sig); err != nil { + return "", nil, err + } + return domain, t, nil +} + +// ensureValidTreeSignature checks that sig is valid for tree and assigns it as the +// tree's signature if valid. +func ensureValidTreeSignature(t *dnsdisc.Tree, pubkey *ecdsa.PublicKey, sig string) error { + if sig == "" { + return fmt.Errorf("missing signature, run 'devp2p dns sign' first") + } + if err := t.SetSignature(pubkey, sig); err != nil { + return fmt.Errorf("invalid signature on tree, run 'devp2p dns sign' to update it") + } + return nil +} + +// writeTreeMetadata writes a DNS node tree metadata file to the given directory. +func writeTreeMetadata(directory string, def *dnsDefinition) { + metaJSON, err := json.MarshalIndent(&def.Meta, "", jsonIndent) + if err != nil { + exit(err) + } + if err := os.Mkdir(directory, 0744); err != nil && !os.IsExist(err) { + exit(err) + } + metaFile, _ := treeDefinitionFiles(directory) + if err := ioutil.WriteFile(metaFile, metaJSON, 0644); err != nil { + exit(err) + } +} + +func writeTreeNodes(directory string, def *dnsDefinition) { + ns := make(nodeSet, len(def.Nodes)) + ns.add(def.Nodes...) + _, nodesFile := treeDefinitionFiles(directory) + writeNodesJSON(nodesFile, ns) +} + +func treeDefinitionFiles(directory string) (string, string) { + meta := filepath.Join(directory, "enrtree-info.json") + nodes := filepath.Join(directory, "nodes.json") + return meta, nodes +} + +// writeTXTJSON writes TXT records in JSON format. +func writeTXTJSON(file string, txt map[string]string) { + txtJSON, err := json.MarshalIndent(txt, "", jsonIndent) + if err != nil { + exit(err) + } + if file == "-" { + os.Stdout.Write(txtJSON) + fmt.Println() + return + } + if err := ioutil.WriteFile(file, txtJSON, 0644); err != nil { + exit(err) + } +} diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go index 4532ab968..6faa65093 100644 --- a/cmd/devp2p/main.go +++ b/cmd/devp2p/main.go @@ -20,8 +20,10 @@ import ( "fmt" "os" "path/filepath" + "sort" "github.com/ethereum/go-ethereum/internal/debug" + "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" "gopkg.in/urfave/cli.v1" ) @@ -57,12 +59,39 @@ func init() { app.Commands = []cli.Command{ enrdumpCommand, discv4Command, + dnsCommand, + nodesetCommand, } } func main() { - if err := app.Run(os.Args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } + exit(app.Run(os.Args)) +} + +// commandHasFlag returns true if the current command supports the given flag. +func commandHasFlag(ctx *cli.Context, flag cli.Flag) bool { + flags := ctx.FlagNames() + sort.Strings(flags) + i := sort.SearchStrings(flags, flag.GetName()) + return i != len(flags) && flags[i] == flag.GetName() +} + +// getNodeArg handles the common case of a single node descriptor argument. +func getNodeArg(ctx *cli.Context) *enode.Node { + if ctx.NArg() != 1 { + exit("missing node as command-line argument") + } + n, err := parseNode(ctx.Args()[0]) + if err != nil { + exit(err) + } + return n +} + +func exit(err interface{}) { + if err == nil { + os.Exit(0) + } + fmt.Fprintln(os.Stderr, err) + os.Exit(1) } diff --git a/cmd/devp2p/nodeset.go b/cmd/devp2p/nodeset.go new file mode 100644 index 000000000..2d86c3f65 --- /dev/null +++ b/cmd/devp2p/nodeset.go @@ -0,0 +1,102 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "sort" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const jsonIndent = " " + +// nodeSet is the nodes.json file format. It holds a set of node records +// as a JSON object. +type nodeSet map[enode.ID]nodeJSON + +type nodeJSON struct { + Seq uint64 `json:"seq"` + N *enode.Node `json:"record"` + + // The score tracks how many liveness checks were performed. It is incremented by one + // every time the node passes a check, and halved every time it doesn't. + Score int `json:"score,omitempty"` + // These two track the time of last successful contact. + FirstResponse time.Time `json:"firstResponse,omitempty"` + LastResponse time.Time `json:"lastResponse,omitempty"` + // This one tracks the time of our last attempt to contact the node. + LastCheck time.Time `json:"lastCheck,omitempty"` +} + +func loadNodesJSON(file string) nodeSet { + var nodes nodeSet + if err := common.LoadJSON(file, &nodes); err != nil { + exit(err) + } + return nodes +} + +func writeNodesJSON(file string, nodes nodeSet) { + nodesJSON, err := json.MarshalIndent(nodes, "", jsonIndent) + if err != nil { + exit(err) + } + if file == "-" { + os.Stdout.Write(nodesJSON) + return + } + if err := ioutil.WriteFile(file, nodesJSON, 0644); err != nil { + exit(err) + } +} + +func (ns nodeSet) nodes() []*enode.Node { + result := make([]*enode.Node, 0, len(ns)) + for _, n := range ns { + result = append(result, n.N) + } + // Sort by ID. + sort.Slice(result, func(i, j int) bool { + return bytes.Compare(result[i].ID().Bytes(), result[j].ID().Bytes()) < 0 + }) + return result +} + +func (ns nodeSet) add(nodes ...*enode.Node) { + for _, n := range nodes { + ns[n.ID()] = nodeJSON{Seq: n.Seq(), N: n} + } +} + +func (ns nodeSet) verify() error { + for id, n := range ns { + if n.N.ID() != id { + return fmt.Errorf("invalid node %v: ID does not match ID %v in record", id, n.N.ID()) + } + if n.N.Seq() != n.Seq { + return fmt.Errorf("invalid node %v: 'seq' does not match seq %d from record", id, n.N.Seq()) + } + } + return nil +} diff --git a/cmd/devp2p/nodesetcmd.go b/cmd/devp2p/nodesetcmd.go new file mode 100644 index 000000000..de8e6d45e --- /dev/null +++ b/cmd/devp2p/nodesetcmd.go @@ -0,0 +1,193 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "gopkg.in/urfave/cli.v1" +) + +var ( + nodesetCommand = cli.Command{ + Name: "nodeset", + Usage: "Node set tools", + Subcommands: []cli.Command{ + nodesetInfoCommand, + nodesetFilterCommand, + }, + } + nodesetInfoCommand = cli.Command{ + Name: "info", + Usage: "Shows statistics about a node set", + Action: nodesetInfo, + ArgsUsage: "", + } + nodesetFilterCommand = cli.Command{ + Name: "filter", + Usage: "Filters a node set", + Action: nodesetFilter, + ArgsUsage: " filters..", + + SkipFlagParsing: true, + } +) + +func nodesetInfo(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + + ns := loadNodesJSON(ctx.Args().First()) + fmt.Printf("Set contains %d nodes.\n", len(ns)) + return nil +} + +func nodesetFilter(ctx *cli.Context) error { + if ctx.NArg() < 1 { + return fmt.Errorf("need nodes file as argument") + } + ns := loadNodesJSON(ctx.Args().First()) + filter, err := andFilter(ctx.Args().Tail()) + if err != nil { + return err + } + + result := make(nodeSet) + for id, n := range ns { + if filter(n) { + result[id] = n + } + } + writeNodesJSON("-", result) + return nil +} + +type nodeFilter func(nodeJSON) bool + +type nodeFilterC struct { + narg int + fn func([]string) (nodeFilter, error) +} + +var filterFlags = map[string]nodeFilterC{ + "-ip": {1, ipFilter}, + "-min-age": {1, minAgeFilter}, + "-eth-network": {1, ethFilter}, + "-les-server": {0, lesFilter}, +} + +func parseFilters(args []string) ([]nodeFilter, error) { + var filters []nodeFilter + for len(args) > 0 { + fc, ok := filterFlags[args[0]] + if !ok { + return nil, fmt.Errorf("invalid filter %q", args[0]) + } + if len(args) < fc.narg { + return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args)) + } + filter, err := fc.fn(args[1:]) + if err != nil { + return nil, fmt.Errorf("%s: %v", args[0], err) + } + filters = append(filters, filter) + args = args[fc.narg+1:] + } + return filters, nil +} + +func andFilter(args []string) (nodeFilter, error) { + checks, err := parseFilters(args) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { + for _, filter := range checks { + if !filter(n) { + return false + } + } + return true + } + return f, nil +} + +func ipFilter(args []string) (nodeFilter, error) { + _, cidr, err := net.ParseCIDR(args[0]) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) } + return f, nil +} + +func minAgeFilter(args []string) (nodeFilter, error) { + minage, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + f := func(n nodeJSON) bool { + age := n.LastResponse.Sub(n.FirstResponse) + return age >= minage + } + return f, nil +} + +func ethFilter(args []string) (nodeFilter, error) { + var filter forkid.Filter + switch args[0] { + case "mainnet": + filter = forkid.NewStaticFilter(params.MainnetChainConfig, params.MainnetGenesisHash) + case "rinkeby": + filter = forkid.NewStaticFilter(params.RinkebyChainConfig, params.RinkebyGenesisHash) + case "goerli": + filter = forkid.NewStaticFilter(params.GoerliChainConfig, params.GoerliGenesisHash) + case "ropsten": + filter = forkid.NewStaticFilter(params.TestnetChainConfig, params.TestnetGenesisHash) + default: + return nil, fmt.Errorf("unknown network %q", args[0]) + } + + f := func(n nodeJSON) bool { + var eth struct { + ForkID forkid.ID + _ []rlp.RawValue `rlp:"tail"` + } + if n.N.Load(enr.WithEntry("eth", ð)) != nil { + return false + } + return filter(eth.ForkID) == nil + } + return f, nil +} + +func lesFilter(args []string) (nodeFilter, error) { + f := func(n nodeJSON) bool { + var les struct { + _ []rlp.RawValue `rlp:"tail"` + } + return n.N.Load(enr.WithEntry("les", &les)) == nil + } + return f, nil +} diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 318aa222a..3cbdcad01 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -17,6 +17,7 @@ package main import ( + "bytes" "encoding/json" "fmt" "io/ioutil" @@ -145,6 +146,7 @@ func runCmd(ctx *cli.Context) error { } else { hexcode = []byte(codeFlag) } + hexcode = bytes.TrimSpace(hexcode) if len(hexcode)%2 != 0 { fmt.Printf("Invalid input length for hex data (%d)\n", len(hexcode)) os.Exit(1) @@ -198,6 +200,8 @@ func runCmd(ctx *cli.Context) error { if chainConfig != nil { runtimeConfig.ChainConfig = chainConfig + } else { + runtimeConfig.ChainConfig = params.AllEthashProtocolChanges } tstart := time.Now() var leftOverGas uint64 diff --git a/cmd/geth/consolecmd_test.go b/cmd/geth/consolecmd_test.go index 436045119..33c83b7ed 100644 --- a/cmd/geth/consolecmd_test.go +++ b/cmd/geth/consolecmd_test.go @@ -87,7 +87,7 @@ func TestIPCAttachWelcome(t *testing.T) { "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--etherbase", coinbase, "--shh", "--ipcpath", ipc) - time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open + waitForEndpoint(t, ipc, 3*time.Second) testAttachWelcome(t, geth, "ipc:"+ipc, ipcAPIs) geth.Interrupt() @@ -101,8 +101,9 @@ func TestHTTPAttachWelcome(t *testing.T) { "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--etherbase", coinbase, "--rpc", "--rpcport", port) - time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open - testAttachWelcome(t, geth, "http://localhost:"+port, httpAPIs) + endpoint := "http://127.0.0.1:" + port + waitForEndpoint(t, endpoint, 3*time.Second) + testAttachWelcome(t, geth, endpoint, httpAPIs) geth.Interrupt() geth.ExpectExit() @@ -116,8 +117,9 @@ func TestWSAttachWelcome(t *testing.T) { "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--etherbase", coinbase, "--ws", "--wsport", port) - time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open - testAttachWelcome(t, geth, "ws://localhost:"+port, httpAPIs) + endpoint := "ws://127.0.0.1:" + port + waitForEndpoint(t, endpoint, 3*time.Second) + testAttachWelcome(t, geth, endpoint, httpAPIs) geth.Interrupt() geth.ExpectExit() diff --git a/cmd/geth/retesteth.go b/cmd/geth/retesteth.go index 9469c9f5f..b6aa3706b 100644 --- a/cmd/geth/retesteth.go +++ b/cmd/geth/retesteth.go @@ -508,7 +508,7 @@ func (api *RetestethAPI) mineBlock() error { statedb.Prepare(tx.Hash(), common.Hash{}, txCount) snap := statedb.Snapshot() - receipt, _, err := core.ApplyTransaction( + receipt, err := core.ApplyTransaction( api.chainConfig, api.blockchain, &api.author, diff --git a/cmd/geth/run_test.go b/cmd/geth/run_test.go index da82facac..f7b735b84 100644 --- a/cmd/geth/run_test.go +++ b/cmd/geth/run_test.go @@ -17,13 +17,16 @@ package main import ( + "context" "fmt" "io/ioutil" "os" "testing" + "time" "github.com/docker/docker/pkg/reexec" "github.com/ethereum/go-ethereum/internal/cmdtest" + "github.com/ethereum/go-ethereum/rpc" ) func tmpdir(t *testing.T) string { @@ -96,3 +99,28 @@ func runGeth(t *testing.T, args ...string) *testgeth { return tt } + +// waitForEndpoint attempts to connect to an RPC endpoint until it succeeds. +func waitForEndpoint(t *testing.T, endpoint string, timeout time.Duration) { + probe := func() bool { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + c, err := rpc.DialContext(ctx, endpoint) + if c != nil { + _, err = c.SupportedModules() + c.Close() + } + return err == nil + } + + start := time.Now() + for { + if probe() { + return + } + if time.Since(start) > timeout { + t.Fatal("endpoint", endpoint, "did not open within", timeout) + } + time.Sleep(200 * time.Millisecond) + } +} diff --git a/cmd/puppeth/genesis.go b/cmd/puppeth/genesis.go index ae7675cd9..44ad6c4cd 100644 --- a/cmd/puppeth/genesis.go +++ b/cmd/puppeth/genesis.go @@ -36,25 +36,27 @@ import ( type alethGenesisSpec struct { SealEngine string `json:"sealEngine"` Params struct { - AccountStartNonce math2.HexOrDecimal64 `json:"accountStartNonce"` - MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` - HomesteadForkBlock hexutil.Uint64 `json:"homesteadForkBlock"` - DaoHardforkBlock math2.HexOrDecimal64 `json:"daoHardforkBlock"` - EIP150ForkBlock hexutil.Uint64 `json:"EIP150ForkBlock"` - EIP158ForkBlock hexutil.Uint64 `json:"EIP158ForkBlock"` - ByzantiumForkBlock hexutil.Uint64 `json:"byzantiumForkBlock"` - ConstantinopleForkBlock hexutil.Uint64 `json:"constantinopleForkBlock"` - MinGasLimit hexutil.Uint64 `json:"minGasLimit"` - MaxGasLimit hexutil.Uint64 `json:"maxGasLimit"` - TieBreakingGas bool `json:"tieBreakingGas"` - GasLimitBoundDivisor math2.HexOrDecimal64 `json:"gasLimitBoundDivisor"` - MinimumDifficulty *hexutil.Big `json:"minimumDifficulty"` - DifficultyBoundDivisor *math2.HexOrDecimal256 `json:"difficultyBoundDivisor"` - DurationLimit *math2.HexOrDecimal256 `json:"durationLimit"` - BlockReward *hexutil.Big `json:"blockReward"` - NetworkID hexutil.Uint64 `json:"networkID"` - ChainID hexutil.Uint64 `json:"chainID"` - AllowFutureBlocks bool `json:"allowFutureBlocks"` + AccountStartNonce math2.HexOrDecimal64 `json:"accountStartNonce"` + MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` + HomesteadForkBlock *hexutil.Big `json:"homesteadForkBlock,omitempty"` + DaoHardforkBlock math2.HexOrDecimal64 `json:"daoHardforkBlock"` + EIP150ForkBlock *hexutil.Big `json:"EIP150ForkBlock,omitempty"` + EIP158ForkBlock *hexutil.Big `json:"EIP158ForkBlock,omitempty"` + ByzantiumForkBlock *hexutil.Big `json:"byzantiumForkBlock,omitempty"` + ConstantinopleForkBlock *hexutil.Big `json:"constantinopleForkBlock,omitempty"` + ConstantinopleFixForkBlock *hexutil.Big `json:"constantinopleFixForkBlock,omitempty"` + IstanbulForkBlock *hexutil.Big `json:"istanbulForkBlock,omitempty"` + MinGasLimit hexutil.Uint64 `json:"minGasLimit"` + MaxGasLimit hexutil.Uint64 `json:"maxGasLimit"` + TieBreakingGas bool `json:"tieBreakingGas"` + GasLimitBoundDivisor math2.HexOrDecimal64 `json:"gasLimitBoundDivisor"` + MinimumDifficulty *hexutil.Big `json:"minimumDifficulty"` + DifficultyBoundDivisor *math2.HexOrDecimal256 `json:"difficultyBoundDivisor"` + DurationLimit *math2.HexOrDecimal256 `json:"durationLimit"` + BlockReward *hexutil.Big `json:"blockReward"` + NetworkID hexutil.Uint64 `json:"networkID"` + ChainID hexutil.Uint64 `json:"chainID"` + AllowFutureBlocks bool `json:"allowFutureBlocks"` } `json:"params"` Genesis struct { @@ -74,7 +76,7 @@ type alethGenesisSpec struct { // alethGenesisSpecAccount is the prefunded genesis account and/or precompiled // contract definition. type alethGenesisSpecAccount struct { - Balance *math2.HexOrDecimal256 `json:"balance"` + Balance *math2.HexOrDecimal256 `json:"balance,omitempty"` Nonce uint64 `json:"nonce,omitempty"` Precompiled *alethGenesisSpecBuiltin `json:"precompiled,omitempty"` } @@ -82,7 +84,7 @@ type alethGenesisSpecAccount struct { // alethGenesisSpecBuiltin is the precompiled contract definition. type alethGenesisSpecBuiltin struct { Name string `json:"name,omitempty"` - StartingBlock hexutil.Uint64 `json:"startingBlock,omitempty"` + StartingBlock *hexutil.Big `json:"startingBlock,omitempty"` Linear *alethGenesisSpecLinearPricing `json:"linear,omitempty"` } @@ -106,21 +108,33 @@ func newAlethGenesisSpec(network string, genesis *core.Genesis) (*alethGenesisSp spec.Params.AccountStartNonce = 0 spec.Params.TieBreakingGas = false spec.Params.AllowFutureBlocks = false + + // Dao hardfork block is a special one. The fork block is listed as 0 in the + // config but aleth will sync with ETC clients up until the actual dao hard + // fork block. spec.Params.DaoHardforkBlock = 0 - spec.Params.HomesteadForkBlock = (hexutil.Uint64)(genesis.Config.HomesteadBlock.Uint64()) - spec.Params.EIP150ForkBlock = (hexutil.Uint64)(genesis.Config.EIP150Block.Uint64()) - spec.Params.EIP158ForkBlock = (hexutil.Uint64)(genesis.Config.EIP158Block.Uint64()) - - // Byzantium + if num := genesis.Config.HomesteadBlock; num != nil { + spec.Params.HomesteadForkBlock = (*hexutil.Big)(num) + } + if num := genesis.Config.EIP150Block; num != nil { + spec.Params.EIP150ForkBlock = (*hexutil.Big)(num) + } + if num := genesis.Config.EIP158Block; num != nil { + spec.Params.EIP158ForkBlock = (*hexutil.Big)(num) + } if num := genesis.Config.ByzantiumBlock; num != nil { - spec.setByzantium(num) + spec.Params.ByzantiumForkBlock = (*hexutil.Big)(num) } - // Constantinople if num := genesis.Config.ConstantinopleBlock; num != nil { - spec.setConstantinople(num) + spec.Params.ConstantinopleForkBlock = (*hexutil.Big)(num) + } + if num := genesis.Config.PetersburgBlock; num != nil { + spec.Params.ConstantinopleFixForkBlock = (*hexutil.Big)(num) + } + if num := genesis.Config.IstanbulBlock; num != nil { + spec.Params.IstanbulForkBlock = (*hexutil.Big)(num) } - spec.Params.NetworkID = (hexutil.Uint64)(genesis.Config.ChainID.Uint64()) spec.Params.ChainID = (hexutil.Uint64)(genesis.Config.ChainID.Uint64()) spec.Params.MaximumExtraDataSize = (hexutil.Uint64)(params.MaximumExtraDataSize) @@ -157,15 +171,32 @@ func newAlethGenesisSpec(network string, genesis *core.Genesis) (*alethGenesisSp Linear: &alethGenesisSpecLinearPricing{Base: 15, Word: 3}}) if genesis.Config.ByzantiumBlock != nil { spec.setPrecompile(5, &alethGenesisSpecBuiltin{Name: "modexp", - StartingBlock: (hexutil.Uint64)(genesis.Config.ByzantiumBlock.Uint64())}) + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock)}) spec.setPrecompile(6, &alethGenesisSpecBuiltin{Name: "alt_bn128_G1_add", - StartingBlock: (hexutil.Uint64)(genesis.Config.ByzantiumBlock.Uint64()), + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Linear: &alethGenesisSpecLinearPricing{Base: 500}}) spec.setPrecompile(7, &alethGenesisSpecBuiltin{Name: "alt_bn128_G1_mul", - StartingBlock: (hexutil.Uint64)(genesis.Config.ByzantiumBlock.Uint64()), + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Linear: &alethGenesisSpecLinearPricing{Base: 40000}}) spec.setPrecompile(8, &alethGenesisSpecBuiltin{Name: "alt_bn128_pairing_product", - StartingBlock: (hexutil.Uint64)(genesis.Config.ByzantiumBlock.Uint64())}) + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock)}) + } + if genesis.Config.IstanbulBlock != nil { + if genesis.Config.ByzantiumBlock == nil { + return nil, errors.New("invalid genesis, istanbul fork is enabled while byzantium is not") + } + spec.setPrecompile(6, &alethGenesisSpecBuiltin{ + Name: "alt_bn128_G1_add", + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock), + }) // Aleth hardcoded the gas policy + spec.setPrecompile(7, &alethGenesisSpecBuiltin{ + Name: "alt_bn128_G1_mul", + StartingBlock: (*hexutil.Big)(genesis.Config.ByzantiumBlock), + }) // Aleth hardcoded the gas policy + spec.setPrecompile(9, &alethGenesisSpecBuiltin{ + Name: "blake2_compression", + StartingBlock: (*hexutil.Big)(genesis.Config.IstanbulBlock), + }) } return spec, nil } @@ -196,14 +227,6 @@ func (spec *alethGenesisSpec) setAccount(address common.Address, account core.Ge } -func (spec *alethGenesisSpec) setByzantium(num *big.Int) { - spec.Params.ByzantiumForkBlock = hexutil.Uint64(num.Uint64()) -} - -func (spec *alethGenesisSpec) setConstantinople(num *big.Int) { - spec.Params.ConstantinopleForkBlock = hexutil.Uint64(num.Uint64()) -} - // parityChainSpec is the chain specification format used by Parity. type parityChainSpec struct { Name string `json:"name"` @@ -223,29 +246,33 @@ type parityChainSpec struct { } `json:"engine"` Params struct { - AccountStartNonce hexutil.Uint64 `json:"accountStartNonce"` - MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` - MinGasLimit hexutil.Uint64 `json:"minGasLimit"` - GasLimitBoundDivisor math2.HexOrDecimal64 `json:"gasLimitBoundDivisor"` - NetworkID hexutil.Uint64 `json:"networkID"` - ChainID hexutil.Uint64 `json:"chainID"` - MaxCodeSize hexutil.Uint64 `json:"maxCodeSize"` - MaxCodeSizeTransition hexutil.Uint64 `json:"maxCodeSizeTransition"` - EIP98Transition hexutil.Uint64 `json:"eip98Transition"` - EIP150Transition hexutil.Uint64 `json:"eip150Transition"` - EIP160Transition hexutil.Uint64 `json:"eip160Transition"` - EIP161abcTransition hexutil.Uint64 `json:"eip161abcTransition"` - EIP161dTransition hexutil.Uint64 `json:"eip161dTransition"` - EIP155Transition hexutil.Uint64 `json:"eip155Transition"` - EIP140Transition hexutil.Uint64 `json:"eip140Transition"` - EIP211Transition hexutil.Uint64 `json:"eip211Transition"` - EIP214Transition hexutil.Uint64 `json:"eip214Transition"` - EIP658Transition hexutil.Uint64 `json:"eip658Transition"` - EIP145Transition hexutil.Uint64 `json:"eip145Transition"` - EIP1014Transition hexutil.Uint64 `json:"eip1014Transition"` - EIP1052Transition hexutil.Uint64 `json:"eip1052Transition"` - EIP1283Transition hexutil.Uint64 `json:"eip1283Transition"` - EIP1283DisableTransition hexutil.Uint64 `json:"eip1283DisableTransition"` + AccountStartNonce hexutil.Uint64 `json:"accountStartNonce"` + MaximumExtraDataSize hexutil.Uint64 `json:"maximumExtraDataSize"` + MinGasLimit hexutil.Uint64 `json:"minGasLimit"` + GasLimitBoundDivisor math2.HexOrDecimal64 `json:"gasLimitBoundDivisor"` + NetworkID hexutil.Uint64 `json:"networkID"` + ChainID hexutil.Uint64 `json:"chainID"` + MaxCodeSize hexutil.Uint64 `json:"maxCodeSize"` + MaxCodeSizeTransition hexutil.Uint64 `json:"maxCodeSizeTransition"` + EIP98Transition hexutil.Uint64 `json:"eip98Transition"` + EIP150Transition hexutil.Uint64 `json:"eip150Transition"` + EIP160Transition hexutil.Uint64 `json:"eip160Transition"` + EIP161abcTransition hexutil.Uint64 `json:"eip161abcTransition"` + EIP161dTransition hexutil.Uint64 `json:"eip161dTransition"` + EIP155Transition hexutil.Uint64 `json:"eip155Transition"` + EIP140Transition hexutil.Uint64 `json:"eip140Transition"` + EIP211Transition hexutil.Uint64 `json:"eip211Transition"` + EIP214Transition hexutil.Uint64 `json:"eip214Transition"` + EIP658Transition hexutil.Uint64 `json:"eip658Transition"` + EIP145Transition hexutil.Uint64 `json:"eip145Transition"` + EIP1014Transition hexutil.Uint64 `json:"eip1014Transition"` + EIP1052Transition hexutil.Uint64 `json:"eip1052Transition"` + EIP1283Transition hexutil.Uint64 `json:"eip1283Transition"` + EIP1283DisableTransition hexutil.Uint64 `json:"eip1283DisableTransition"` + EIP1283ReenableTransition hexutil.Uint64 `json:"eip1283ReenableTransition"` + EIP1344Transition hexutil.Uint64 `json:"eip1344Transition"` + EIP1884Transition hexutil.Uint64 `json:"eip1884Transition"` + EIP2028Transition hexutil.Uint64 `json:"eip2028Transition"` } `json:"params"` Genesis struct { @@ -278,17 +305,22 @@ type parityChainSpecAccount struct { // parityChainSpecBuiltin is the precompiled contract definition. type parityChainSpecBuiltin struct { - Name string `json:"name,omitempty"` - ActivateAt math2.HexOrDecimal64 `json:"activate_at,omitempty"` - Pricing *parityChainSpecPricing `json:"pricing,omitempty"` + Name string `json:"name"` // Each builtin should has it own name + Pricing *parityChainSpecPricing `json:"pricing"` // Each builtin should has it own price strategy + ActivateAt *hexutil.Big `json:"activate_at,omitempty"` // ActivateAt can't be omitted if empty, default means no fork + EIP1108Transition *hexutil.Big `json:"eip1108_transition,omitempty"` // EIP1108Transition can't be omitted if empty, default means no fork } // parityChainSpecPricing represents the different pricing models that builtin // contracts might advertise using. type parityChainSpecPricing struct { - Linear *parityChainSpecLinearPricing `json:"linear,omitempty"` - ModExp *parityChainSpecModExpPricing `json:"modexp,omitempty"` - AltBnPairing *parityChainSpecAltBnPairingPricing `json:"alt_bn128_pairing,omitempty"` + Linear *parityChainSpecLinearPricing `json:"linear,omitempty"` + ModExp *parityChainSpecModExpPricing `json:"modexp,omitempty"` + AltBnPairing *parityChainSpecAltBnPairingPricing `json:"alt_bn128_pairing,omitempty"` + AltBnConstOperation *parityChainSpecAltBnConstOperationPricing `json:"alt_bn128_const_operations,omitempty"` + + // Blake2F is the price per round of Blake2 compression + Blake2F *parityChainSpecBlakePricing `json:"blake2_f,omitempty"` } type parityChainSpecLinearPricing struct { @@ -300,9 +332,20 @@ type parityChainSpecModExpPricing struct { Divisor uint64 `json:"divisor"` } +type parityChainSpecAltBnConstOperationPricing struct { + Price uint64 `json:"price"` + EIP1108TransitionPrice uint64 `json:"eip1108_transition_price,omitempty"` // Before Istanbul fork, this field is nil +} + type parityChainSpecAltBnPairingPricing struct { - Base uint64 `json:"base"` - Pair uint64 `json:"pair"` + Base uint64 `json:"base"` + Pair uint64 `json:"pair"` + EIP1108TransitionBase uint64 `json:"eip1108_transition_base,omitempty"` // Before Istanbul fork, this field is nil + EIP1108TransitionPair uint64 `json:"eip1108_transition_pair,omitempty"` // Before Istanbul fork, this field is nil +} + +type parityChainSpecBlakePricing struct { + GasPerRound uint64 `json:"gas_per_round"` } // newParityChainSpec converts a go-ethereum genesis block into a Parity specific @@ -352,7 +395,10 @@ func newParityChainSpec(network string, genesis *core.Genesis, bootnodes []strin if num := genesis.Config.PetersburgBlock; num != nil { spec.setConstantinopleFix(num) } - + // Istanbul + if num := genesis.Config.IstanbulBlock; num != nil { + spec.setIstanbul(num) + } spec.Params.MaximumExtraDataSize = (hexutil.Uint64)(params.MaximumExtraDataSize) spec.Params.MinGasLimit = (hexutil.Uint64)(params.MinGasLimit) spec.Params.GasLimitBoundDivisor = (math2.HexOrDecimal64)(params.GasLimitBoundDivisor) @@ -398,18 +444,34 @@ func newParityChainSpec(network string, genesis *core.Genesis, bootnodes []strin Name: "identity", Pricing: &parityChainSpecPricing{Linear: &parityChainSpecLinearPricing{Base: 15, Word: 3}}, }) if genesis.Config.ByzantiumBlock != nil { - blnum := math2.HexOrDecimal64(genesis.Config.ByzantiumBlock.Uint64()) spec.setPrecompile(5, &parityChainSpecBuiltin{ - Name: "modexp", ActivateAt: blnum, Pricing: &parityChainSpecPricing{ModExp: &parityChainSpecModExpPricing{Divisor: 20}}, + Name: "modexp", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Pricing: &parityChainSpecPricing{ModExp: &parityChainSpecModExpPricing{Divisor: 20}}, }) spec.setPrecompile(6, &parityChainSpecBuiltin{ - Name: "alt_bn128_add", ActivateAt: blnum, Pricing: &parityChainSpecPricing{Linear: &parityChainSpecLinearPricing{Base: 500}}, + Name: "alt_bn128_add", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Pricing: &parityChainSpecPricing{AltBnConstOperation: &parityChainSpecAltBnConstOperationPricing{Price: 500}}, }) spec.setPrecompile(7, &parityChainSpecBuiltin{ - Name: "alt_bn128_mul", ActivateAt: blnum, Pricing: &parityChainSpecPricing{Linear: &parityChainSpecLinearPricing{Base: 40000}}, + Name: "alt_bn128_mul", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Pricing: &parityChainSpecPricing{AltBnConstOperation: &parityChainSpecAltBnConstOperationPricing{Price: 40000}}, }) spec.setPrecompile(8, &parityChainSpecBuiltin{ - Name: "alt_bn128_pairing", ActivateAt: blnum, Pricing: &parityChainSpecPricing{AltBnPairing: &parityChainSpecAltBnPairingPricing{Base: 100000, Pair: 80000}}, + Name: "alt_bn128_pairing", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), Pricing: &parityChainSpecPricing{AltBnPairing: &parityChainSpecAltBnPairingPricing{Base: 100000, Pair: 80000}}, + }) + } + if genesis.Config.IstanbulBlock != nil { + if genesis.Config.ByzantiumBlock == nil { + return nil, errors.New("invalid genesis, istanbul fork is enabled while byzantium is not") + } + spec.setPrecompile(6, &parityChainSpecBuiltin{ + Name: "alt_bn128_add", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), EIP1108Transition: (*hexutil.Big)(genesis.Config.IstanbulBlock), Pricing: &parityChainSpecPricing{AltBnConstOperation: &parityChainSpecAltBnConstOperationPricing{Price: 500, EIP1108TransitionPrice: 150}}, + }) + spec.setPrecompile(7, &parityChainSpecBuiltin{ + Name: "alt_bn128_mul", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), EIP1108Transition: (*hexutil.Big)(genesis.Config.IstanbulBlock), Pricing: &parityChainSpecPricing{AltBnConstOperation: &parityChainSpecAltBnConstOperationPricing{Price: 40000, EIP1108TransitionPrice: 6000}}, + }) + spec.setPrecompile(8, &parityChainSpecBuiltin{ + Name: "alt_bn128_pairing", ActivateAt: (*hexutil.Big)(genesis.Config.ByzantiumBlock), EIP1108Transition: (*hexutil.Big)(genesis.Config.IstanbulBlock), Pricing: &parityChainSpecPricing{AltBnPairing: &parityChainSpecAltBnPairingPricing{Base: 100000, Pair: 80000, EIP1108TransitionBase: 45000, EIP1108TransitionPair: 34000}}, + }) + spec.setPrecompile(9, &parityChainSpecBuiltin{ + Name: "blake2_f", ActivateAt: (*hexutil.Big)(genesis.Config.IstanbulBlock), Pricing: &parityChainSpecPricing{Blake2F: &parityChainSpecBlakePricing{GasPerRound: 1}}, }) } return spec, nil @@ -451,6 +513,15 @@ func (spec *parityChainSpec) setConstantinopleFix(num *big.Int) { spec.Params.EIP1283DisableTransition = hexutil.Uint64(num.Uint64()) } +func (spec *parityChainSpec) setIstanbul(num *big.Int) { + // spec.Params.EIP152Transition = hexutil.Uint64(num.Uint64()) + // spec.Params.EIP1108Transition = hexutil.Uint64(num.Uint64()) + spec.Params.EIP1344Transition = hexutil.Uint64(num.Uint64()) + spec.Params.EIP1884Transition = hexutil.Uint64(num.Uint64()) + spec.Params.EIP2028Transition = hexutil.Uint64(num.Uint64()) + spec.Params.EIP1283ReenableTransition = hexutil.Uint64(num.Uint64()) +} + // pyEthereumGenesisSpec represents the genesis specification format used by the // Python Ethereum implementation. type pyEthereumGenesisSpec struct { diff --git a/cmd/puppeth/genesis_test.go b/cmd/puppeth/genesis_test.go index f128da24f..1fd1b35c4 100644 --- a/cmd/puppeth/genesis_test.go +++ b/cmd/puppeth/genesis_test.go @@ -76,7 +76,7 @@ func TestParitySturebyConverter(t *testing.T) { if err := json.Unmarshal(blob, &genesis); err != nil { t.Fatalf("failed parsing genesis: %v", err) } - spec, err := newParityChainSpec("Stureby", &genesis, []string{}) + spec, err := newParityChainSpec("stureby", &genesis, []string{}) if err != nil { t.Fatalf("failed creating chainspec: %v", err) } diff --git a/cmd/puppeth/testdata/stureby_aleth.json b/cmd/puppeth/testdata/stureby_aleth.json index 1ef1d8ae1..d18ba3854 100644 --- a/cmd/puppeth/testdata/stureby_aleth.json +++ b/cmd/puppeth/testdata/stureby_aleth.json @@ -1,112 +1,113 @@ { - "sealEngine":"Ethash", - "params":{ - "accountStartNonce":"0x00", - "maximumExtraDataSize":"0x20", - "homesteadForkBlock":"0x2710", - "daoHardforkBlock":"0x00", - "EIP150ForkBlock":"0x3a98", - "EIP158ForkBlock":"0x59d8", - "byzantiumForkBlock":"0x7530", - "constantinopleForkBlock":"0x9c40", - "minGasLimit":"0x1388", - "maxGasLimit":"0x7fffffffffffffff", - "tieBreakingGas":false, - "gasLimitBoundDivisor":"0x0400", - "minimumDifficulty":"0x20000", - "difficultyBoundDivisor":"0x0800", - "durationLimit":"0x0d", - "blockReward":"0x4563918244F40000", - "networkID":"0x4cb2e", - "chainID":"0x4cb2e", - "allowFutureBlocks":false + "sealEngine": "Ethash", + "params": { + "accountStartNonce": "0x0", + "maximumExtraDataSize": "0x20", + "homesteadForkBlock": "0x2710", + "daoHardforkBlock": "0x0", + "EIP150ForkBlock": "0x3a98", + "EIP158ForkBlock": "0x59d8", + "byzantiumForkBlock": "0x7530", + "constantinopleForkBlock": "0x9c40", + "constantinopleFixForkBlock": "0x9c40", + "istanbulForkBlock": "0xc350", + "minGasLimit": "0x1388", + "maxGasLimit": "0x7fffffffffffffff", + "tieBreakingGas": false, + "gasLimitBoundDivisor": "0x400", + "minimumDifficulty": "0x20000", + "difficultyBoundDivisor": "0x800", + "durationLimit": "0xd", + "blockReward": "0x4563918244f40000", + "networkID": "0x4cb2e", + "chainID": "0x4cb2e", + "allowFutureBlocks": false }, - "genesis":{ - "nonce":"0x0000000000000000", - "difficulty":"0x20000", - "mixHash":"0x0000000000000000000000000000000000000000000000000000000000000000", - "author":"0x0000000000000000000000000000000000000000", - "timestamp":"0x59a4e76d", - "parentHash":"0x0000000000000000000000000000000000000000000000000000000000000000", - "extraData":"0x0000000000000000000000000000000000000000000000000000000b4dc0ffee", - "gasLimit":"0x47b760" + "genesis": { + "nonce": "0x0000000000000000", + "difficulty": "0x20000", + "mixHash": "0x0000000000000000000000000000000000000000000000000000000000000000", + "author": "0x0000000000000000000000000000000000000000", + "timestamp": "0x59a4e76d", + "parentHash": "0x0000000000000000000000000000000000000000000000000000000000000000", + "extraData": "0x0000000000000000000000000000000000000000000000000000000b4dc0ffee", + "gasLimit": "0x47b760" }, - "accounts":{ - "0000000000000000000000000000000000000001":{ - "balance":"1", - "precompiled":{ - "name":"ecrecover", - "linear":{ - "base":3000, - "word":0 + "accounts": { + "0000000000000000000000000000000000000001": { + "balance": "0x1", + "precompiled": { + "name": "ecrecover", + "linear": { + "base": 3000, + "word": 0 } } }, - "0000000000000000000000000000000000000002":{ - "balance":"1", - "precompiled":{ - "name":"sha256", - "linear":{ - "base":60, - "word":12 + "0000000000000000000000000000000000000002": { + "balance": "0x1", + "precompiled": { + "name": "sha256", + "linear": { + "base": 60, + "word": 12 } } }, - "0000000000000000000000000000000000000003":{ - "balance":"1", - "precompiled":{ - "name":"ripemd160", - "linear":{ - "base":600, - "word":120 + "0000000000000000000000000000000000000003": { + "balance": "0x1", + "precompiled": { + "name": "ripemd160", + "linear": { + "base": 600, + "word": 120 } } }, - "0000000000000000000000000000000000000004":{ - "balance":"1", - "precompiled":{ - "name":"identity", - "linear":{ - "base":15, - "word":3 + "0000000000000000000000000000000000000004": { + "balance": "0x1", + "precompiled": { + "name": "identity", + "linear": { + "base": 15, + "word": 3 } } }, - "0000000000000000000000000000000000000005":{ - "balance":"1", - "precompiled":{ - "name":"modexp", - "startingBlock":"0x7530" + "0000000000000000000000000000000000000005": { + "balance": "0x1", + "precompiled": { + "name": "modexp", + "startingBlock": "0x7530" } }, - "0000000000000000000000000000000000000006":{ - "balance":"1", - "precompiled":{ - "name":"alt_bn128_G1_add", - "startingBlock":"0x7530", - "linear":{ - "base":500, - "word":0 - } + "0000000000000000000000000000000000000006": { + "balance": "0x1", + "precompiled": { + "name": "alt_bn128_G1_add", + "startingBlock": "0x7530" } }, - "0000000000000000000000000000000000000007":{ - "balance":"1", - "precompiled":{ - "name":"alt_bn128_G1_mul", - "startingBlock":"0x7530", - "linear":{ - "base":40000, - "word":0 - } + "0000000000000000000000000000000000000007": { + "balance": "0x1", + "precompiled": { + "name": "alt_bn128_G1_mul", + "startingBlock": "0x7530" } }, - "0000000000000000000000000000000000000008":{ - "balance":"1", - "precompiled":{ - "name":"alt_bn128_pairing_product", - "startingBlock":"0x7530" + "0000000000000000000000000000000000000008": { + "balance": "0x1", + "precompiled": { + "name": "alt_bn128_pairing_product", + "startingBlock": "0x7530" + } + }, + "0000000000000000000000000000000000000009": { + "balance": "0x1", + "precompiled": { + "name": "blake2_compression", + "startingBlock": "0xc350" } } } -} +} \ No newline at end of file diff --git a/cmd/puppeth/testdata/stureby_geth.json b/cmd/puppeth/testdata/stureby_geth.json index c8c3b3c95..79f03469a 100644 --- a/cmd/puppeth/testdata/stureby_geth.json +++ b/cmd/puppeth/testdata/stureby_geth.json @@ -1,6 +1,5 @@ { "config": { - "ethash":{}, "chainId": 314158, "homesteadBlock": 10000, "eip150Block": 15000, @@ -8,11 +7,13 @@ "eip155Block": 23000, "eip158Block": 23000, "byzantiumBlock": 30000, - "constantinopleBlock": 40000 + "constantinopleBlock": 40000, + "petersburgBlock": 40000, + "istanbulBlock": 50000, + "ethash": {} }, "nonce": "0x0", "timestamp": "0x59a4e76d", - "parentHash": "0x0000000000000000000000000000000000000000000000000000000000000000", "extraData": "0x0000000000000000000000000000000000000000000000000000000b4dc0ffee", "gasLimit": "0x47b760", "difficulty": "0x20000", @@ -20,28 +21,34 @@ "coinbase": "0x0000000000000000000000000000000000000000", "alloc": { "0000000000000000000000000000000000000001": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000002": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000003": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000004": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000005": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000006": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000007": { - "balance": "0x01" + "balance": "0x1" }, "0000000000000000000000000000000000000008": { - "balance": "0x01" + "balance": "0x1" + }, + "0000000000000000000000000000000000000009": { + "balance": "0x1" } - } -} + }, + "number": "0x0", + "gasUsed": "0x0", + "parentHash": "0x0000000000000000000000000000000000000000000000000000000000000000" +} \ No newline at end of file diff --git a/cmd/puppeth/testdata/stureby_parity.json b/cmd/puppeth/testdata/stureby_parity.json index f3fa8386a..fb84b39e2 100644 --- a/cmd/puppeth/testdata/stureby_parity.json +++ b/cmd/puppeth/testdata/stureby_parity.json @@ -1,181 +1,186 @@ { - "name":"Stureby", - "dataDir":"stureby", - "engine":{ - "Ethash":{ - "params":{ - "minimumDifficulty":"0x20000", - "difficultyBoundDivisor":"0x800", - "durationLimit":"0xd", - "blockReward":{ - "0x0":"0x4563918244f40000", - "0x7530":"0x29a2241af62c0000", - "0x9c40":"0x1bc16d674ec80000" + "name": "stureby", + "dataDir": "stureby", + "engine": { + "Ethash": { + "params": { + "minimumDifficulty": "0x20000", + "difficultyBoundDivisor": "0x800", + "durationLimit": "0xd", + "blockReward": { + "0x0": "0x4563918244f40000", + "0x7530": "0x29a2241af62c0000", + "0x9c40": "0x1bc16d674ec80000" }, - "homesteadTransition":"0x2710", - "eip100bTransition":"0x7530", - "difficultyBombDelays":{ - "0x7530":"0x2dc6c0", - "0x9c40":"0x1e8480" - } + "difficultyBombDelays": { + "0x7530": "0x2dc6c0", + "0x9c40": "0x1e8480" + }, + "homesteadTransition": "0x2710", + "eip100bTransition": "0x7530" } } }, - "params":{ - "accountStartNonce":"0x0", - "maximumExtraDataSize":"0x20", - "gasLimitBoundDivisor":"0x400", - "minGasLimit":"0x1388", - "networkID":"0x4cb2e", - "chainID":"0x4cb2e", - "maxCodeSize":"0x6000", - "maxCodeSizeTransition":"0x0", + "params": { + "accountStartNonce": "0x0", + "maximumExtraDataSize": "0x20", + "minGasLimit": "0x1388", + "gasLimitBoundDivisor": "0x400", + "networkID": "0x4cb2e", + "chainID": "0x4cb2e", + "maxCodeSize": "0x6000", + "maxCodeSizeTransition": "0x0", "eip98Transition": "0x7fffffffffffffff", - "eip150Transition":"0x3a98", - "eip160Transition":"0x59d8", - "eip161abcTransition":"0x59d8", - "eip161dTransition":"0x59d8", - "eip155Transition":"0x59d8", - "eip140Transition":"0x7530", - "eip211Transition":"0x7530", - "eip214Transition":"0x7530", - "eip658Transition":"0x7530", - "eip145Transition":"0x9c40", - "eip1014Transition":"0x9c40", - "eip1052Transition":"0x9c40", - "eip1283Transition":"0x9c40" + "eip150Transition": "0x3a98", + "eip160Transition": "0x59d8", + "eip161abcTransition": "0x59d8", + "eip161dTransition": "0x59d8", + "eip155Transition": "0x59d8", + "eip140Transition": "0x7530", + "eip211Transition": "0x7530", + "eip214Transition": "0x7530", + "eip658Transition": "0x7530", + "eip145Transition": "0x9c40", + "eip1014Transition": "0x9c40", + "eip1052Transition": "0x9c40", + "eip1283Transition": "0x9c40", + "eip1283DisableTransition": "0x9c40", + "eip1283ReenableTransition": "0xc350", + "eip1344Transition": "0xc350", + "eip1884Transition": "0xc350", + "eip2028Transition": "0xc350" }, - "genesis":{ - "seal":{ - "ethereum":{ - "nonce":"0x0000000000000000", - "mixHash":"0x0000000000000000000000000000000000000000000000000000000000000000" + "genesis": { + "seal": { + "ethereum": { + "nonce": "0x0000000000000000", + "mixHash": "0x0000000000000000000000000000000000000000000000000000000000000000" } }, - "difficulty":"0x20000", - "author":"0x0000000000000000000000000000000000000000", - "timestamp":"0x59a4e76d", - "parentHash":"0x0000000000000000000000000000000000000000000000000000000000000000", - "extraData":"0x0000000000000000000000000000000000000000000000000000000b4dc0ffee", - "gasLimit":"0x47b760" + "difficulty": "0x20000", + "author": "0x0000000000000000000000000000000000000000", + "timestamp": "0x59a4e76d", + "parentHash": "0x0000000000000000000000000000000000000000000000000000000000000000", + "extraData": "0x0000000000000000000000000000000000000000000000000000000b4dc0ffee", + "gasLimit": "0x47b760" }, - "nodes":[ - "enode://dfa7aca3f5b635fbfe7d0b20575f25e40d9e27b4bfbb3cf74364a42023ad9f25c1a4383bcc8cced86ee511a7d03415345a4df05be37f1dff040e4c780699f1c0@168.61.153.255:31303", - "enode://ef441b20dd70aeabf0eac35c3b8a2854e5ce04db0e30be9152ea9fd129359dcbb3f803993303ff5781c755dfd7223f3fe43505f583cccb740949407677412ba9@40.74.91.252:31303", - "enode://953b5ea1c8987cf46008232a0160324fd00d41320ecf00e23af86ec8f5396b19eb57ddab37c78141be56f62e9077de4f4dfa0747fa768ed8c8531bbfb1046237@40.70.214.166:31303", - "enode://276e613dd4b277a66591e565711e6c8bb107f0905248a9f8f8228c1a87992e156e5114bb9937c02824a9d9d25f76340442cf86e2028bf5293cae19904fb2b98e@35.178.251.52:30303", - "enode://064c820d41e52ed7d426ac64b60506c2998235bedc7e67cb497c6faf7bb4fc54fe56fc82d0add3180b747c0c4f40a1108a6f84d7d0629ed606d504528e61cc57@3.8.5.3:30303", - "enode://90069fdabcc5e684fa5d59430bebbb12755d9362dfe5006a1485b13d71a78a3812d36e74dd7d88e50b51add01e097ea80f16263aeaa4f0230db6c79e2a97e7ca@217.29.191.142:30303", - "enode://0aac74b7fd28726275e466acb5e03bc88a95927e9951eb66b5efb239b2f798ada0690853b2f2823fe4efa408f0f3d4dd258430bc952a5ff70677b8625b3e3b14@40.115.33.57:40404", - "enode://0b96415a10f835106d83e090a0528eed5e7887e5c802a6d084e9f1993a9d0fc713781e6e4101f6365e9b91259712f291acc0a9e6e667e22023050d602c36fbe2@40.115.33.57:40414" - ], - "accounts":{ - "0000000000000000000000000000000000000001":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"ecrecover", - "pricing":{ - "linear":{ - "base":3000, - "word":0 + "nodes": [], + "accounts": { + "0000000000000000000000000000000000000001": { + "balance": "0x1", + "builtin": { + "name": "ecrecover", + "pricing": { + "linear": { + "base": 3000, + "word": 0 } } } }, - "0000000000000000000000000000000000000002":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"sha256", - "pricing":{ - "linear":{ - "base":60, - "word":12 + "0000000000000000000000000000000000000002": { + "balance": "0x1", + "builtin": { + "name": "sha256", + "pricing": { + "linear": { + "base": 60, + "word": 12 } } } }, - "0000000000000000000000000000000000000003":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"ripemd160", - "pricing":{ - "linear":{ - "base":600, - "word":120 + "0000000000000000000000000000000000000003": { + "balance": "0x1", + "builtin": { + "name": "ripemd160", + "pricing": { + "linear": { + "base": 600, + "word": 120 } } } }, - "0000000000000000000000000000000000000004":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"identity", - "pricing":{ - "linear":{ - "base":15, - "word":3 + "0000000000000000000000000000000000000004": { + "balance": "0x1", + "builtin": { + "name": "identity", + "pricing": { + "linear": { + "base": 15, + "word": 3 } } } }, - "0000000000000000000000000000000000000005":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"modexp", - "activate_at":"0x7530", - "pricing":{ - "modexp":{ - "divisor":20 + "0000000000000000000000000000000000000005": { + "balance": "0x1", + "builtin": { + "name": "modexp", + "pricing": { + "modexp": { + "divisor": 20 } - } + }, + "activate_at": "0x7530" } }, - "0000000000000000000000000000000000000006":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"alt_bn128_add", - "activate_at":"0x7530", - "pricing":{ - "linear":{ - "base":500, - "word":0 + "0000000000000000000000000000000000000006": { + "balance": "0x1", + "builtin": { + "name": "alt_bn128_add", + "pricing": { + "alt_bn128_const_operations": { + "price": 500, + "eip1108_transition_price": 150 } - } + }, + "activate_at": "0x7530", + "eip1108_transition": "0xc350" } }, - "0000000000000000000000000000000000000007":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"alt_bn128_mul", - "activate_at":"0x7530", - "pricing":{ - "linear":{ - "base":40000, - "word":0 + "0000000000000000000000000000000000000007": { + "balance": "0x1", + "builtin": { + "name": "alt_bn128_mul", + "pricing": { + "alt_bn128_const_operations": { + "price": 40000, + "eip1108_transition_price": 6000 } - } + }, + "activate_at": "0x7530", + "eip1108_transition": "0xc350" } }, - "0000000000000000000000000000000000000008":{ - "balance":"1", - "nonce":"0", - "builtin":{ - "name":"alt_bn128_pairing", - "activate_at":"0x7530", - "pricing":{ - "alt_bn128_pairing":{ - "base":100000, - "pair":80000 + "0000000000000000000000000000000000000008": { + "balance": "0x1", + "builtin": { + "name": "alt_bn128_pairing", + "pricing": { + "alt_bn128_pairing": { + "base": 100000, + "pair": 80000, + "eip1108_transition_base": 45000, + "eip1108_transition_pair": 34000 } - } + }, + "activate_at": "0x7530", + "eip1108_transition": "0xc350" + } + }, + "0000000000000000000000000000000000000009": { + "balance": "0x1", + "builtin": { + "name": "blake2_f", + "pricing": { + "blake2_f": { + "gas_per_round": 1 + } + }, + "activate_at": "0xc350" } } } -} +} \ No newline at end of file diff --git a/cmd/puppeth/wizard_genesis.go b/cmd/puppeth/wizard_genesis.go index 499f320f6..ab3e2247b 100644 --- a/cmd/puppeth/wizard_genesis.go +++ b/cmd/puppeth/wizard_genesis.go @@ -51,6 +51,7 @@ func (w *wizard) makeGenesis() { ByzantiumBlock: big.NewInt(0), ConstantinopleBlock: big.NewInt(0), PetersburgBlock: big.NewInt(0), + IstanbulBlock: big.NewInt(0), }, } // Figure out which consensus engine to choose @@ -230,6 +231,10 @@ func (w *wizard) manageGenesis() { fmt.Printf("Which block should Petersburg come into effect? (default = %v)\n", w.conf.Genesis.Config.PetersburgBlock) w.conf.Genesis.Config.PetersburgBlock = w.readDefaultBigInt(w.conf.Genesis.Config.PetersburgBlock) + fmt.Println() + fmt.Printf("Which block should Istanbul come into effect? (default = %v)\n", w.conf.Genesis.Config.IstanbulBlock) + w.conf.Genesis.Config.IstanbulBlock = w.readDefaultBigInt(w.conf.Genesis.Config.IstanbulBlock) + out, _ := json.MarshalIndent(w.conf.Genesis.Config, "", " ") fmt.Printf("Chain configuration updated:\n\n%s\n", out) @@ -268,7 +273,7 @@ func (w *wizard) manageGenesis() { } else { saveGenesis(folder, w.network, "parity", spec) } - // Export the genesis spec used by Harmony (formerly EthereumJ + // Export the genesis spec used by Harmony (formerly EthereumJ) saveGenesis(folder, w.network, "harmony", w.conf.Genesis) case "3": @@ -291,7 +296,7 @@ func (w *wizard) manageGenesis() { func saveGenesis(folder, network, client string, spec interface{}) { path := filepath.Join(folder, fmt.Sprintf("%s-%s.json", network, client)) - out, _ := json.Marshal(spec) + out, _ := json.MarshalIndent(spec, "", " ") if err := ioutil.WriteFile(path, out, 0644); err != nil { log.Error("Failed to save genesis file", "client", client, "err", err) return diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index ec02856c4..d3075190f 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -1480,9 +1480,12 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" { Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name) } - cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive" - cfg.NoPrefetch = ctx.GlobalBool(CacheNoPrefetchFlag.Name) - + if ctx.GlobalIsSet(GCModeFlag.Name) { + cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive" + } + if ctx.GlobalIsSet(CacheNoPrefetchFlag.Name) { + cfg.NoPrefetch = ctx.GlobalBool(CacheNoPrefetchFlag.Name) + } if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheTrieFlag.Name) { cfg.TrieCleanCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheTrieFlag.Name) / 100 } diff --git a/common/bytes.go b/common/bytes.go index 910c97d3c..fa457b92c 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -134,3 +134,14 @@ func LeftPadBytes(slice []byte, l int) []byte { return padded } + +// TrimLeftZeroes returns a subslice of s without leading zeroes +func TrimLeftZeroes(s []byte) []byte { + idx := 0 + for ; idx < len(s); idx++ { + if s[idx] != 0 { + break + } + } + return s[idx:] +} diff --git a/common/mclock/mclock.go b/common/mclock/mclock.go index 0c941082f..d0e0cd78b 100644 --- a/common/mclock/mclock.go +++ b/common/mclock/mclock.go @@ -36,47 +36,39 @@ func (t AbsTime) Add(d time.Duration) AbsTime { return t + AbsTime(d) } -// Clock interface makes it possible to replace the monotonic system clock with +// The Clock interface makes it possible to replace the monotonic system clock with // a simulated clock. type Clock interface { Now() AbsTime Sleep(time.Duration) After(time.Duration) <-chan time.Time - AfterFunc(d time.Duration, f func()) Event + AfterFunc(d time.Duration, f func()) Timer } -// Event represents a cancellable event returned by AfterFunc -type Event interface { - Cancel() bool +// Timer represents a cancellable event returned by AfterFunc +type Timer interface { + Stop() bool } // System implements Clock using the system clock. type System struct{} -// Now implements Clock. +// Now returns the current monotonic time. func (System) Now() AbsTime { return AbsTime(monotime.Now()) } -// Sleep implements Clock. +// Sleep blocks for the given duration. func (System) Sleep(d time.Duration) { time.Sleep(d) } -// After implements Clock. +// After returns a channel which receives the current time after d has elapsed. func (System) After(d time.Duration) <-chan time.Time { return time.After(d) } -// AfterFunc implements Clock. -func (System) AfterFunc(d time.Duration, f func()) Event { - return (*SystemEvent)(time.AfterFunc(d, f)) -} - -// SystemEvent implements Event using time.Timer. -type SystemEvent time.Timer - -// Cancel implements Event. -func (e *SystemEvent) Cancel() bool { - return (*time.Timer)(e).Stop() +// AfterFunc runs f on a new goroutine after the duration has elapsed. +func (System) AfterFunc(d time.Duration, f func()) Timer { + return time.AfterFunc(d, f) } diff --git a/common/mclock/simclock.go b/common/mclock/simclock.go index af0f71c43..4d351252f 100644 --- a/common/mclock/simclock.go +++ b/common/mclock/simclock.go @@ -32,22 +32,17 @@ import ( // the timeout using a channel or semaphore. type Simulated struct { now AbsTime - scheduled []event + scheduled []*simTimer mu sync.RWMutex cond *sync.Cond lastId uint64 } -type event struct { +// simTimer implements Timer on the virtual clock. +type simTimer struct { do func() at AbsTime id uint64 -} - -// SimulatedEvent implements Event for a virtual clock. -type SimulatedEvent struct { - at AbsTime - id uint64 s *Simulated } @@ -75,6 +70,7 @@ func (s *Simulated) Run(d time.Duration) { } } +// ActiveTimers returns the number of timers that haven't fired. func (s *Simulated) ActiveTimers() int { s.mu.RLock() defer s.mu.RUnlock() @@ -82,6 +78,7 @@ func (s *Simulated) ActiveTimers() int { return len(s.scheduled) } +// WaitForTimers waits until the clock has at least n scheduled timers. func (s *Simulated) WaitForTimers(n int) { s.mu.Lock() defer s.mu.Unlock() @@ -92,7 +89,7 @@ func (s *Simulated) WaitForTimers(n int) { } } -// Now implements Clock. +// Now returns the current virtual time. func (s *Simulated) Now() AbsTime { s.mu.RLock() defer s.mu.RUnlock() @@ -100,12 +97,13 @@ func (s *Simulated) Now() AbsTime { return s.now } -// Sleep implements Clock. +// Sleep blocks until the clock has advanced by d. func (s *Simulated) Sleep(d time.Duration) { <-s.After(d) } -// After implements Clock. +// After returns a channel which receives the current time after the clock +// has advanced by d. func (s *Simulated) After(d time.Duration) <-chan time.Time { after := make(chan time.Time, 1) s.AfterFunc(d, func() { @@ -114,8 +112,9 @@ func (s *Simulated) After(d time.Duration) <-chan time.Time { return after } -// AfterFunc implements Clock. -func (s *Simulated) AfterFunc(d time.Duration, do func()) Event { +// AfterFunc runs fn after the clock has advanced by d. Unlike with the system +// clock, fn runs on the goroutine that calls Run. +func (s *Simulated) AfterFunc(d time.Duration, fn func()) Timer { s.mu.Lock() defer s.mu.Unlock() s.init() @@ -133,12 +132,27 @@ func (s *Simulated) AfterFunc(d time.Duration, do func()) Event { l = m + 1 } } - s.scheduled = append(s.scheduled, event{}) + ev := &simTimer{do: fn, at: at, s: s} + s.scheduled = append(s.scheduled, nil) copy(s.scheduled[l+1:], s.scheduled[l:ll]) - e := event{do: do, at: at, id: id} - s.scheduled[l] = e + s.scheduled[l] = ev s.cond.Broadcast() - return &SimulatedEvent{at: at, id: id, s: s} + return ev +} + +func (ev *simTimer) Stop() bool { + s := ev.s + s.mu.Lock() + defer s.mu.Unlock() + + for i := 0; i < len(s.scheduled); i++ { + if s.scheduled[i] == ev { + s.scheduled = append(s.scheduled[:i], s.scheduled[i+1:]...) + s.cond.Broadcast() + return true + } + } + return false } func (s *Simulated) init() { @@ -146,31 +160,3 @@ func (s *Simulated) init() { s.cond = sync.NewCond(&s.mu) } } - -// Cancel implements Event. -func (e *SimulatedEvent) Cancel() bool { - s := e.s - s.mu.Lock() - defer s.mu.Unlock() - - l, h := 0, len(s.scheduled) - ll := h - for l != h { - m := (l + h) / 2 - if e.id == s.scheduled[m].id { - l = m - break - } - if (e.at < s.scheduled[m].at) || ((e.at == s.scheduled[m].at) && (e.id < s.scheduled[m].id)) { - h = m - } else { - l = m + 1 - } - } - if l >= ll || s.scheduled[l].id != e.id { - return false - } - copy(s.scheduled[l:ll-1], s.scheduled[l+1:]) - s.scheduled = s.scheduled[:ll-1] - return true -} diff --git a/common/mclock/simclock_test.go b/common/mclock/simclock_test.go new file mode 100644 index 000000000..09e4391c1 --- /dev/null +++ b/common/mclock/simclock_test.go @@ -0,0 +1,115 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package mclock + +import ( + "testing" + "time" +) + +var _ Clock = System{} +var _ Clock = new(Simulated) + +func TestSimulatedAfter(t *testing.T) { + const timeout = 30 * time.Minute + const adv = time.Minute + + var ( + c Simulated + end = c.Now().Add(timeout) + ch = c.After(timeout) + ) + for c.Now() < end.Add(-adv) { + c.Run(adv) + select { + case <-ch: + t.Fatal("Timer fired early") + default: + } + } + + c.Run(adv) + select { + case stamp := <-ch: + want := time.Time{}.Add(timeout) + if !stamp.Equal(want) { + t.Errorf("Wrong time sent on timer channel: got %v, want %v", stamp, want) + } + default: + t.Fatal("Timer didn't fire") + } +} + +func TestSimulatedAfterFunc(t *testing.T) { + var c Simulated + + called1 := false + timer1 := c.AfterFunc(100*time.Millisecond, func() { called1 = true }) + if c.ActiveTimers() != 1 { + t.Fatalf("%d active timers, want one", c.ActiveTimers()) + } + if fired := timer1.Stop(); !fired { + t.Fatal("Stop returned false even though timer didn't fire") + } + if c.ActiveTimers() != 0 { + t.Fatalf("%d active timers, want zero", c.ActiveTimers()) + } + if called1 { + t.Fatal("timer 1 called") + } + if fired := timer1.Stop(); fired { + t.Fatal("Stop returned true after timer was already stopped") + } + + called2 := false + timer2 := c.AfterFunc(100*time.Millisecond, func() { called2 = true }) + c.Run(50 * time.Millisecond) + if called2 { + t.Fatal("timer 2 called") + } + c.Run(51 * time.Millisecond) + if !called2 { + t.Fatal("timer 2 not called") + } + if fired := timer2.Stop(); fired { + t.Fatal("Stop returned true after timer has fired") + } +} + +func TestSimulatedSleep(t *testing.T) { + var ( + c Simulated + timeout = 1 * time.Hour + done = make(chan AbsTime) + ) + go func() { + c.Sleep(timeout) + done <- c.Now() + }() + + c.WaitForTimers(1) + c.Run(2 * timeout) + select { + case stamp := <-done: + want := AbsTime(2 * timeout) + if stamp != want { + t.Errorf("Wrong time after sleep: got %v, want %v", stamp, want) + } + case <-time.After(5 * time.Second): + t.Fatal("Sleep didn't return in time") + } +} diff --git a/common/types.go b/common/types.go index 5cba4e9f3..8ca51a05f 100644 --- a/common/types.go +++ b/common/types.go @@ -149,7 +149,7 @@ func (h *Hash) UnmarshalGraphQL(input interface{}) error { var err error switch input := input.(type) { case string: - *h = HexToHash(input) + err = h.UnmarshalText([]byte(input)) default: err = fmt.Errorf("Unexpected type for Bytes32: %v", input) } @@ -288,7 +288,7 @@ func (a *Address) UnmarshalGraphQL(input interface{}) error { var err error switch input := input.(type) { case string: - *a = HexToAddress(input) + err = a.UnmarshalText([]byte(input)) default: err = fmt.Errorf("Unexpected type for Address: %v", input) } diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index 02b6da35b..100c20529 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -311,7 +311,7 @@ func (c *Clique) verifyCascadingFields(chain consensus.ChainReader, header *type if number == 0 { return nil } - // Ensure that the block's timestamp isn't too close to it's parent + // Ensure that the block's timestamp isn't too close to its parent var parent *types.Header if len(parents) > 0 { parent = parents[len(parents)-1] @@ -522,7 +522,7 @@ func (c *Clique) Prepare(chain consensus.ChainReader, header *types.Header) erro // Set the correct difficulty header.Difficulty = CalcDifficulty(snap, c.signer) - // Ensure the extra data has all it's components + // Ensure the extra data has all its components if len(header.Extra) < extraVanity { header.Extra = append(header.Extra, bytes.Repeat([]byte{0x00}, extraVanity-len(header.Extra))...) } diff --git a/consensus/errors.go b/consensus/errors.go index a005c5f63..ac5242fb5 100644 --- a/consensus/errors.go +++ b/consensus/errors.go @@ -31,7 +31,7 @@ var ( // to the current node. ErrFutureBlock = errors.New("block in the future") - // ErrInvalidNumber is returned if a block's number doesn't equal it's parent's + // ErrInvalidNumber is returned if a block's number doesn't equal its parent's // plus one. ErrInvalidNumber = errors.New("invalid block number") ) diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index d271518f4..3cff2d9fe 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -86,7 +86,7 @@ func (ethash *Ethash) VerifyHeader(chain consensus.ChainReader, header *types.He if ethash.config.PowMode == ModeFullFake { return nil } - // Short circuit if the header is known, or it's parent not + // Short circuit if the header is known, or its parent not number := header.Number.Uint64() if chain.GetHeader(header.Hash(), number) != nil { return nil @@ -252,7 +252,7 @@ func (ethash *Ethash) verifyHeader(chain consensus.ChainReader, header, parent * if header.Time <= parent.Time { return errZeroBlockTime } - // Verify the block's difficulty based in it's timestamp and parent's difficulty + // Verify the block's difficulty based in its timestamp and parent's difficulty expected := ethash.CalcDifficulty(chain, header.Time, parent) if expected.Cmp(header.Difficulty) != 0 { diff --git a/core/asm/compiler.go b/core/asm/compiler.go index c7a544070..799709929 100644 --- a/core/asm/compiler.go +++ b/core/asm/compiler.go @@ -57,6 +57,7 @@ func NewCompiler(debug bool) *Compiler { // second stage to push labels and determine the right // position. func (c *Compiler) Feed(ch <-chan token) { + var prev token for i := range ch { switch i.typ { case number: @@ -73,10 +74,14 @@ func (c *Compiler) Feed(ch <-chan token) { c.labels[i.text] = c.pc c.pc++ case label: - c.pc += 5 + c.pc += 4 + if prev.typ == element && isJump(prev.text) { + c.pc++ + } } c.tokens = append(c.tokens, i) + prev = i } if c.debug { fmt.Fprintln(os.Stderr, "found", len(c.labels), "labels") @@ -181,6 +186,8 @@ func (c *Compiler) compileElement(element token) error { pos := big.NewInt(int64(c.labels[rvalue.text])).Bytes() pos = append(make([]byte, 4-len(pos)), pos...) c.pushBin(pos) + case lineEnd: + c.pos-- default: return compileErr(rvalue, rvalue.text, "number, string or label") } @@ -201,8 +208,8 @@ func (c *Compiler) compileElement(element token) error { case stringValue: value = []byte(rvalue.text[1 : len(rvalue.text)-1]) case label: - value = make([]byte, 4) - copy(value, big.NewInt(int64(c.labels[rvalue.text])).Bytes()) + value = big.NewInt(int64(c.labels[rvalue.text])).Bytes() + value = append(make([]byte, 4-len(value)), value...) default: return compileErr(rvalue, rvalue.text, "number, string or label") } diff --git a/core/asm/compiler_test.go b/core/asm/compiler_test.go new file mode 100644 index 000000000..ce9df436b --- /dev/null +++ b/core/asm/compiler_test.go @@ -0,0 +1,71 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package asm + +import ( + "testing" +) + +func TestCompiler(t *testing.T) { + tests := []struct { + input, output string + }{ + { + input: ` + GAS + label: + PUSH @label +`, + output: "5a5b6300000001", + }, + { + input: ` + PUSH @label + label: +`, + output: "63000000055b", + }, + { + input: ` + PUSH @label + JUMP + label: +`, + output: "6300000006565b", + }, + { + input: ` + JUMP @label + label: +`, + output: "6300000006565b", + }, + } + for _, test := range tests { + ch := Lex([]byte(test.input), false) + c := NewCompiler(false) + c.Feed(ch) + output, err := c.Compile() + if len(err) != 0 { + t.Errorf("compile error: %v\ninput: %s", err, test.input) + continue + } + if output != test.output { + t.Errorf("incorrect output\ninput: %sgot: %s\nwant: %s\n", test.input, output, test.output) + } + } +} diff --git a/core/blockchain.go b/core/blockchain.go index d88110f5d..1fd1dec3c 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -232,10 +232,16 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par if bc.genesisBlock == nil { return nil, ErrNoGenesis } + + var nilBlock *types.Block + bc.currentBlock.Store(nilBlock) + bc.currentFastBlock.Store(nilBlock) + // Initialize the chain with ancient data if it isn't empty. if bc.empty() { rawdb.InitDatabaseFromFreezer(bc.db) } + if err := bc.loadLastState(); err != nil { return nil, err } @@ -1580,6 +1586,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, [] // Some other error occurred, abort case err != nil: + bc.futureBlocks.Remove(block.Hash()) stats.ignored += len(it.chain) bc.reportBlock(block, nil, err) return it.index, events, coalescedLogs, err @@ -2170,6 +2177,11 @@ func (bc *BlockChain) HasHeader(hash common.Hash, number uint64) bool { return bc.hc.HasHeader(hash, number) } +// GetCanonicalHash returns the canonical hash for a given block number +func (bc *BlockChain) GetCanonicalHash(number uint64) common.Hash { + return bc.hc.GetCanonicalHash(number) +} + // GetBlockHashesFromHash retrieves a number of block hashes starting at a given // hash, fetching towards the genesis block. func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash { @@ -2182,9 +2194,6 @@ func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com // // Note: ancestor == 0 returns the same block, 1 returns its parent and so on. func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) { - bc.chainmu.RLock() - defer bc.chainmu.RUnlock() - return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical) } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index be105d2db..97ea04f56 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -1318,7 +1318,7 @@ func TestEIP155Transition(t *testing.T) { funds = big.NewInt(1000000000) deleteAddr = common.Address{1} gspec = &Genesis{ - Config: ¶ms.ChainConfig{ChainID: big.NewInt(1), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, + Config: ¶ms.ChainConfig{ChainID: big.NewInt(1), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, Alloc: GenesisAlloc{address: {Balance: funds}, deleteAddr: {Balance: new(big.Int)}}, } genesis = gspec.MustCommit(db) @@ -1389,7 +1389,7 @@ func TestEIP155Transition(t *testing.T) { } // generate an invalid chain id transaction - config := ¶ms.ChainConfig{ChainID: big.NewInt(2), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} + config := ¶ms.ChainConfig{ChainID: big.NewInt(2), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} blocks, _ = GenerateChain(config, blocks[len(blocks)-1], ethash.NewFaker(), db, 4, func(i int, block *BlockGen) { var ( tx *types.Transaction @@ -1425,6 +1425,7 @@ func TestEIP161AccountRemoval(t *testing.T) { ChainID: big.NewInt(1), HomesteadBlock: new(big.Int), EIP155Block: new(big.Int), + EIP150Block: new(big.Int), EIP158Block: big.NewInt(2), }, Alloc: GenesisAlloc{address: {Balance: funds}}, @@ -2288,6 +2289,81 @@ func TestSideImportPrunedBlocks(t *testing.T) { } } +// TestDeleteCreateRevert tests a weird state transition corner case that we hit +// while changing the internals of statedb. The workflow is that a contract is +// self destructed, then in a followup transaction (but same block) it's created +// again and the transaction reverted. +// +// The original statedb implementation flushed dirty objects to the tries after +// each transaction, so this works ok. The rework accumulated writes in memory +// first, but the journal wiped the entire state object on create-revert. +func TestDeleteCreateRevert(t *testing.T) { + var ( + aa = common.HexToAddress("0x000000000000000000000000000000000000aaaa") + bb = common.HexToAddress("0x000000000000000000000000000000000000bbbb") + // Generate a canonical chain to act as the main dataset + engine = ethash.NewFaker() + db = rawdb.NewMemoryDatabase() + + // A sender who makes transactions, has some funds + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + address = crypto.PubkeyToAddress(key.PublicKey) + funds = big.NewInt(1000000000) + gspec = &Genesis{ + Config: params.TestChainConfig, + Alloc: GenesisAlloc{ + address: {Balance: funds}, + // The address 0xAAAAA selfdestructs if called + aa: { + // Code needs to just selfdestruct + Code: []byte{byte(vm.PC), 0xFF}, + Nonce: 1, + Balance: big.NewInt(0), + }, + // The address 0xBBBB send 1 wei to 0xAAAA, then reverts + bb: { + Code: []byte{ + byte(vm.PC), // [0] + byte(vm.DUP1), // [0,0] + byte(vm.DUP1), // [0,0,0] + byte(vm.DUP1), // [0,0,0,0] + byte(vm.PUSH1), 0x01, // [0,0,0,0,1] (value) + byte(vm.PUSH2), 0xaa, 0xaa, // [0,0,0,0,1, 0xaaaa] + byte(vm.GAS), + byte(vm.CALL), + byte(vm.REVERT), + }, + Balance: big.NewInt(1), + }, + }, + } + genesis = gspec.MustCommit(db) + ) + + blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 1, func(i int, b *BlockGen) { + b.SetCoinbase(common.Address{1}) + // One transaction to AAAA + tx, _ := types.SignTx(types.NewTransaction(0, aa, + big.NewInt(0), 50000, big.NewInt(1), nil), types.HomesteadSigner{}, key) + b.AddTx(tx) + // One transaction to BBBB + tx, _ = types.SignTx(types.NewTransaction(1, bb, + big.NewInt(0), 100000, big.NewInt(1), nil), types.HomesteadSigner{}, key) + b.AddTx(tx) + }) + // Import the canonical chain + diskdb := rawdb.NewMemoryDatabase() + gspec.MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}, nil) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + if n, err := chain.InsertChain(blocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } +} + func TestProcessingStateDiffs(t *testing.T) { defaultTrieCleanCache := 256 defaultTrieDirtyCache := 256 diff --git a/core/chain_makers.go b/core/chain_makers.go index 17f404211..0b0fcdb4a 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -103,7 +103,7 @@ func (b *BlockGen) AddTxWithChain(bc *BlockChain, tx *types.Transaction) { b.SetCoinbase(common.Address{}) } b.statedb.Prepare(tx.Hash(), common.Hash{}, len(b.txs)) - receipt, _, err := ApplyTransaction(b.config, bc, &b.header.Coinbase, b.gasPool, b.statedb, b.header, tx, &b.header.GasUsed, vm.Config{}) + receipt, err := ApplyTransaction(b.config, bc, &b.header.Coinbase, b.gasPool, b.statedb, b.header, tx, &b.header.GasUsed, vm.Config{}) if err != nil { panic(err) } diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index 8c1700879..1e2d7a744 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -50,6 +50,9 @@ type ID struct { Next uint64 // Block number of the next upcoming fork, or 0 if no forks are known } +// Filter is a fork id filter to validate a remotely advertised ID. +type Filter func(id ID) error + // NewID calculates the Ethereum fork ID from the chain config and head. func NewID(chain *core.BlockChain) ID { return newID( @@ -80,9 +83,9 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { return ID{Hash: checksumToBytes(hash), Next: next} } -// NewFilter creates an filter that returns if a fork ID should be rejected or not +// NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. -func NewFilter(chain *core.BlockChain) func(id ID) error { +func NewFilter(chain *core.BlockChain) Filter { return newFilter( chain.Config(), chain.Genesis().Hash(), @@ -92,10 +95,16 @@ func NewFilter(chain *core.BlockChain) func(id ID) error { ) } +// NewStaticFilter creates a filter at block zero. +func NewStaticFilter(config *params.ChainConfig, genesis common.Hash) Filter { + head := func() uint64 { return 0 } + return newFilter(config, genesis, head) +} + // newFilter is the internal version of NewFilter, taking closures as its arguments // instead of a chain. The reason is to allow testing it without having to simulate // an entire blockchain. -func newFilter(config *params.ChainConfig, genesis common.Hash, headfn func() uint64) func(id ID) error { +func newFilter(config *params.ChainConfig, genesis common.Hash, headfn func() uint64) Filter { // Calculate the all the valid fork hash and fork next combos var ( forks = gatherForks(config) @@ -114,10 +123,13 @@ func newFilter(config *params.ChainConfig, genesis common.Hash, headfn func() ui // Create a validator that will filter out incompatible chains return func(id ID) error { // Run the fork checksum validation ruleset: - // 1. If local and remote FORK_CSUM matches, connect. + // 1. If local and remote FORK_CSUM matches, compare local head to FORK_NEXT. // The two nodes are in the same fork state currently. They might know // of differing future forks, but that's not relevant until the fork // triggers (might be postponed, nodes might be updated to match). + // 1a. A remotely announced but remotely not passed block is already passed + // locally, disconnect, since the chains are incompatible. + // 1b. No remotely announced fork; or not yet passed locally, connect. // 2. If the remote FORK_CSUM is a subset of the local past forks and the // remote FORK_NEXT matches with the locally following fork block number, // connect. @@ -139,7 +151,12 @@ func newFilter(config *params.ChainConfig, genesis common.Hash, headfn func() ui // Found the first unpassed fork block, check if our current state matches // the remote checksum (rule #1). if sums[i] == id.Hash { - // Yay, fork checksum matched, ignore any upcoming fork + // Fork checksum matched, check if a remote future fork block already passed + // locally without the local node being aware of it (rule #1a). + if id.Next > 0 && head >= id.Next { + return ErrLocalIncompatibleOrStale + } + // Haven't passed locally a remote-only fork, accept the connection (rule #1b). return nil } // The local and remote nodes are in different forks currently, check if the diff --git a/core/forkid/forkid_test.go b/core/forkid/forkid_test.go index b33f85bec..ee201ae9a 100644 --- a/core/forkid/forkid_test.go +++ b/core/forkid/forkid_test.go @@ -55,8 +55,10 @@ func TestCreation(t *testing.T) { {4369999, ID{Hash: checksumToBytes(0x3edd5b10), Next: 4370000}}, // Last Spurious block {4370000, ID{Hash: checksumToBytes(0xa00bc324), Next: 7280000}}, // First Byzantium block {7279999, ID{Hash: checksumToBytes(0xa00bc324), Next: 7280000}}, // Last Byzantium block - {7280000, ID{Hash: checksumToBytes(0x668db0af), Next: 0}}, // First and last Constantinople, first Petersburg block - {7987396, ID{Hash: checksumToBytes(0x668db0af), Next: 0}}, // Today Petersburg block + {7280000, ID{Hash: checksumToBytes(0x668db0af), Next: 9069000}}, // First and last Constantinople, first Petersburg block + {9068999, ID{Hash: checksumToBytes(0x668db0af), Next: 9069000}}, // Last Petersburg block + {9069000, ID{Hash: checksumToBytes(0x879d6e30), Next: 0}}, // Today Istanbul block + {10000000, ID{Hash: checksumToBytes(0x879d6e30), Next: 0}}, // Future Istanbul block }, }, // Ropsten test cases @@ -72,8 +74,10 @@ func TestCreation(t *testing.T) { {4229999, ID{Hash: checksumToBytes(0x3ea159c7), Next: 4230000}}, // Last Byzantium block {4230000, ID{Hash: checksumToBytes(0x97b544f3), Next: 4939394}}, // First Constantinople block {4939393, ID{Hash: checksumToBytes(0x97b544f3), Next: 4939394}}, // Last Constantinople block - {4939394, ID{Hash: checksumToBytes(0xd6e2149b), Next: 0}}, // First Petersburg block - {5822692, ID{Hash: checksumToBytes(0xd6e2149b), Next: 0}}, // Today Petersburg block + {4939394, ID{Hash: checksumToBytes(0xd6e2149b), Next: 6485846}}, // First Petersburg block + {6485845, ID{Hash: checksumToBytes(0xd6e2149b), Next: 6485846}}, // Last Petersburg block + {6485846, ID{Hash: checksumToBytes(0x4bc66396), Next: 0}}, // First Istanbul block + {7500000, ID{Hash: checksumToBytes(0x4bc66396), Next: 0}}, // Future Istanbul block }, }, // Rinkeby test cases @@ -90,8 +94,10 @@ func TestCreation(t *testing.T) { {3660662, ID{Hash: checksumToBytes(0x8d748b57), Next: 3660663}}, // Last Byzantium block {3660663, ID{Hash: checksumToBytes(0xe49cab14), Next: 4321234}}, // First Constantinople block {4321233, ID{Hash: checksumToBytes(0xe49cab14), Next: 4321234}}, // Last Constantinople block - {4321234, ID{Hash: checksumToBytes(0xafec6b27), Next: 0}}, // First Petersburg block - {4586649, ID{Hash: checksumToBytes(0xafec6b27), Next: 0}}, // Today Petersburg block + {4321234, ID{Hash: checksumToBytes(0xafec6b27), Next: 5435345}}, // First Petersburg block + {5435344, ID{Hash: checksumToBytes(0xafec6b27), Next: 5435345}}, // Last Petersburg block + {5435345, ID{Hash: checksumToBytes(0xcbdb8838), Next: 0}}, // First Istanbul block + {6000000, ID{Hash: checksumToBytes(0xcbdb8838), Next: 0}}, // Future Istanbul block }, }, // Goerli test cases @@ -99,8 +105,10 @@ func TestCreation(t *testing.T) { params.GoerliChainConfig, params.GoerliGenesisHash, []testcase{ - {0, ID{Hash: checksumToBytes(0xa3f5ab08), Next: 0}}, // Unsynced, last Frontier, Homestead, Tangerine, Spurious, Byzantium, Constantinople and first Petersburg block - {795329, ID{Hash: checksumToBytes(0xa3f5ab08), Next: 0}}, // Today Petersburg block + {0, ID{Hash: checksumToBytes(0xa3f5ab08), Next: 1561651}}, // Unsynced, last Frontier, Homestead, Tangerine, Spurious, Byzantium, Constantinople and first Petersburg block + {1561650, ID{Hash: checksumToBytes(0xa3f5ab08), Next: 1561651}}, // Last Petersburg block + {1561651, ID{Hash: checksumToBytes(0xc25efa5c), Next: 0}}, // First Istanbul block + {2000000, ID{Hash: checksumToBytes(0xc25efa5c), Next: 0}}, // Future Istanbul block }, }, } @@ -145,7 +153,7 @@ func TestValidation(t *testing.T) { // Local is mainnet Petersburg, remote announces Byzantium + knowledge about Petersburg. Remote // is simply out of sync, accept. - {7987396, ID{Hash: checksumToBytes(0x668db0af), Next: 7280000}, nil}, + {7987396, ID{Hash: checksumToBytes(0xa00bc324), Next: 7280000}, nil}, // Local is mainnet Petersburg, remote announces Spurious + knowledge about Byzantium. Remote // is definitely out of sync. It may or may not need the Petersburg update, we don't know yet. @@ -172,6 +180,16 @@ func TestValidation(t *testing.T) { // Local is mainnet Petersburg, remote is Rinkeby Petersburg. {7987396, ID{Hash: checksumToBytes(0xafec6b27), Next: 0}, ErrLocalIncompatibleOrStale}, + + // Local is mainnet Istanbul, far in the future. Remote announces Gopherium (non existing fork) + // at some future block 88888888, for itself, but past block for local. Local is incompatible. + // + // This case detects non-upgraded nodes with majority hash power (typical Ropsten mess). + {88888888, ID{Hash: checksumToBytes(0x879d6e30), Next: 88888888}, ErrLocalIncompatibleOrStale}, + + // Local is mainnet Byzantium. Remote is also in Byzantium, but announces Gopherium (non existing + // fork) at block 7279999, before Petersburg. Local is incompatible. + {7279999, ID{Hash: checksumToBytes(0xa00bc324), Next: 7279999}, ErrLocalIncompatibleOrStale}, } for i, tt := range tests { filter := newFilter(params.MainnetChainConfig, params.MainnetGenesisHash, func() uint64 { return tt.head }) diff --git a/core/genesis.go b/core/genesis.go index 8261c18cc..df0c96798 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -207,6 +207,9 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override if overrideIstanbul != nil { newcfg.IstanbulBlock = overrideIstanbul } + if err := newcfg.CheckConfigForkOrder(); err != nil { + return newcfg, common.Hash{}, err + } storedcfg := rawdb.ReadChainConfig(db, stored) if storedcfg == nil { log.Warn("Found genesis block without chain config") @@ -295,6 +298,13 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { if block.Number().Sign() != 0 { return nil, fmt.Errorf("can't commit genesis block with number > 0") } + config := g.Config + if config == nil { + config = params.AllEthashProtocolChanges + } + if err := config.CheckConfigForkOrder(); err != nil { + return nil, err + } rawdb.WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty) rawdb.WriteBlock(db, block) rawdb.WriteReceipts(db, block.Hash(), block.NumberU64(), nil) @@ -302,11 +312,6 @@ func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { rawdb.WriteHeadBlockHash(db, block.Hash()) rawdb.WriteHeadFastBlockHash(db, block.Hash()) rawdb.WriteHeadHeaderHash(db, block.Hash()) - - config := g.Config - if config == nil { - config = params.AllEthashProtocolChanges - } rawdb.WriteChainConfig(db, block.Hash(), config) return block, nil } diff --git a/core/headerchain.go b/core/headerchain.go index 034858f65..4682069cf 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -349,8 +349,11 @@ func (hc *HeaderChain) GetAncestor(hash common.Hash, number, ancestor uint64, ma } for ancestor != 0 { if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash { - number -= ancestor - return rawdb.ReadCanonicalHash(hc.chainDb, number), number + ancestorHash := rawdb.ReadCanonicalHash(hc.chainDb, number-ancestor) + if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash { + number -= ancestor + return ancestorHash, number + } } if *maxNonCanonical == 0 { return common.Hash{}, 0 @@ -445,6 +448,10 @@ func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header { return hc.GetHeader(hash, number) } +func (hc *HeaderChain) GetCanonicalHash(number uint64) common.Hash { + return rawdb.ReadCanonicalHash(hc.chainDb, number) +} + // CurrentHeader retrieves the current head header of the canonical chain. The // header is retrieved from the HeaderChain's internal cache. func (hc *HeaderChain) CurrentHeader() *types.Header { diff --git a/core/rawdb/freezer.go b/core/rawdb/freezer.go index 41677fbba..5497c59d4 100644 --- a/core/rawdb/freezer.go +++ b/core/rawdb/freezer.go @@ -80,9 +80,9 @@ type freezer struct { func newFreezer(datadir string, namespace string) (*freezer, error) { // Create the initial freezer object var ( - readMeter = metrics.NewRegisteredMeter(namespace+"ancient/read", nil) - writeMeter = metrics.NewRegisteredMeter(namespace+"ancient/write", nil) - sizeCounter = metrics.NewRegisteredCounter(namespace+"ancient/size", nil) + readMeter = metrics.NewRegisteredMeter(namespace+"ancient/read", nil) + writeMeter = metrics.NewRegisteredMeter(namespace+"ancient/write", nil) + sizeGauge = metrics.NewRegisteredGauge(namespace+"ancient/size", nil) ) // Ensure the datadir is not a symbolic link if it exists. if info, err := os.Lstat(datadir); !os.IsNotExist(err) { @@ -103,7 +103,7 @@ func newFreezer(datadir string, namespace string) (*freezer, error) { instanceLock: lock, } for name, disableSnappy := range freezerNoSnappy { - table, err := newTable(datadir, name, readMeter, writeMeter, sizeCounter, disableSnappy) + table, err := newTable(datadir, name, readMeter, writeMeter, sizeGauge, disableSnappy) if err != nil { for _, table := range freezer.tables { table.Close() diff --git a/core/rawdb/freezer_table.go b/core/rawdb/freezer_table.go index 61804f1f2..9fb341f02 100644 --- a/core/rawdb/freezer_table.go +++ b/core/rawdb/freezer_table.go @@ -94,18 +94,18 @@ type freezerTable struct { // to count how many historic items have gone missing. itemOffset uint32 // Offset (number of discarded items) - headBytes uint32 // Number of bytes written to the head file - readMeter metrics.Meter // Meter for measuring the effective amount of data read - writeMeter metrics.Meter // Meter for measuring the effective amount of data written - sizeCounter metrics.Counter // Counter for tracking the combined size of all freezer tables + headBytes uint32 // Number of bytes written to the head file + readMeter metrics.Meter // Meter for measuring the effective amount of data read + writeMeter metrics.Meter // Meter for measuring the effective amount of data written + sizeGauge metrics.Gauge // Gauge for tracking the combined size of all freezer tables logger log.Logger // Logger with database path and table name ambedded lock sync.RWMutex // Mutex protecting the data file descriptors } // newTable opens a freezer table with default settings - 2G files -func newTable(path string, name string, readMeter metrics.Meter, writeMeter metrics.Meter, sizeCounter metrics.Counter, disableSnappy bool) (*freezerTable, error) { - return newCustomTable(path, name, readMeter, writeMeter, sizeCounter, 2*1000*1000*1000, disableSnappy) +func newTable(path string, name string, readMeter metrics.Meter, writeMeter metrics.Meter, sizeGauge metrics.Gauge, disableSnappy bool) (*freezerTable, error) { + return newCustomTable(path, name, readMeter, writeMeter, sizeGauge, 2*1000*1000*1000, disableSnappy) } // openFreezerFileForAppend opens a freezer table file and seeks to the end @@ -149,7 +149,7 @@ func truncateFreezerFile(file *os.File, size int64) error { // newCustomTable opens a freezer table, creating the data and index files if they are // non existent. Both files are truncated to the shortest common length to ensure // they don't go out of sync. -func newCustomTable(path string, name string, readMeter metrics.Meter, writeMeter metrics.Meter, sizeCounter metrics.Counter, maxFilesize uint32, noCompression bool) (*freezerTable, error) { +func newCustomTable(path string, name string, readMeter metrics.Meter, writeMeter metrics.Meter, sizeGauge metrics.Gauge, maxFilesize uint32, noCompression bool) (*freezerTable, error) { // Ensure the containing directory exists and open the indexEntry file if err := os.MkdirAll(path, 0755); err != nil { return nil, err @@ -172,7 +172,7 @@ func newCustomTable(path string, name string, readMeter metrics.Meter, writeMete files: make(map[uint32]*os.File), readMeter: readMeter, writeMeter: writeMeter, - sizeCounter: sizeCounter, + sizeGauge: sizeGauge, name: name, path: path, logger: log.New("database", path, "table", name), @@ -189,7 +189,7 @@ func newCustomTable(path string, name string, readMeter metrics.Meter, writeMete tab.Close() return nil, err } - tab.sizeCounter.Inc(int64(size)) + tab.sizeGauge.Inc(int64(size)) return tab, nil } @@ -378,7 +378,7 @@ func (t *freezerTable) truncate(items uint64) error { if err != nil { return err } - t.sizeCounter.Dec(int64(oldSize - newSize)) + t.sizeGauge.Dec(int64(oldSize - newSize)) return nil } @@ -510,7 +510,7 @@ func (t *freezerTable) Append(item uint64, blob []byte) error { t.index.Write(idx.marshallBinary()) t.writeMeter.Mark(int64(bLen + indexEntrySize)) - t.sizeCounter.Inc(int64(bLen + indexEntrySize)) + t.sizeGauge.Inc(int64(bLen + indexEntrySize)) atomic.AddUint64(&t.items, 1) return nil diff --git a/core/rawdb/freezer_table_test.go b/core/rawdb/freezer_table_test.go index 61ba7a17e..7de108151 100644 --- a/core/rawdb/freezer_table_test.go +++ b/core/rawdb/freezer_table_test.go @@ -56,7 +56,7 @@ func TestFreezerBasics(t *testing.T) { // set cutoff at 50 bytes f, err := newCustomTable(os.TempDir(), fmt.Sprintf("unittest-%d", rand.Uint64()), - metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter(), 50, true) + metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge(), 50, true) if err != nil { t.Fatal(err) } @@ -99,11 +99,11 @@ func TestFreezerBasicsClosing(t *testing.T) { // set cutoff at 50 bytes var ( fname = fmt.Sprintf("basics-close-%d", rand.Uint64()) - rm, wm, sc = metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg = metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() f *freezerTable err error ) - f, err = newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err = newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -112,7 +112,7 @@ func TestFreezerBasicsClosing(t *testing.T) { data := getChunk(15, x) f.Append(uint64(x), data) f.Close() - f, err = newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err = newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -129,7 +129,7 @@ func TestFreezerBasicsClosing(t *testing.T) { t.Fatalf("test %d, got \n%x != \n%x", y, got, exp) } f.Close() - f, err = newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err = newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -139,11 +139,11 @@ func TestFreezerBasicsClosing(t *testing.T) { // TestFreezerRepairDanglingHead tests that we can recover if index entries are removed func TestFreezerRepairDanglingHead(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("dangling_headtest-%d", rand.Uint64()) { // Fill table - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -172,7 +172,7 @@ func TestFreezerRepairDanglingHead(t *testing.T) { idxFile.Close() // Now open it again { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -190,11 +190,11 @@ func TestFreezerRepairDanglingHead(t *testing.T) { // TestFreezerRepairDanglingHeadLarge tests that we can recover if very many index entries are removed func TestFreezerRepairDanglingHeadLarge(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("dangling_headtest-%d", rand.Uint64()) { // Fill a table and close it - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func TestFreezerRepairDanglingHeadLarge(t *testing.T) { idxFile.Close() // Now open it again { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -243,7 +243,7 @@ func TestFreezerRepairDanglingHeadLarge(t *testing.T) { } // And if we open it, we should now be able to read all of them (new values) { - f, _ := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, _ := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) for y := 1; y < 255; y++ { exp := getChunk(15, ^y) got, err := f.Retrieve(uint64(y)) @@ -260,11 +260,11 @@ func TestFreezerRepairDanglingHeadLarge(t *testing.T) { // TestSnappyDetection tests that we fail to open a snappy database and vice versa func TestSnappyDetection(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("snappytest-%d", rand.Uint64()) // Open with snappy { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -277,7 +277,7 @@ func TestSnappyDetection(t *testing.T) { } // Open without snappy { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, false) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, false) if err != nil { t.Fatal(err) } @@ -289,7 +289,7 @@ func TestSnappyDetection(t *testing.T) { // Open with snappy { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -317,11 +317,11 @@ func assertFileSize(f string, size int64) error { // the index is repaired func TestFreezerRepairDanglingIndex(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("dangling_indextest-%d", rand.Uint64()) { // Fill a table and close it - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestFreezerRepairDanglingIndex(t *testing.T) { // 45, 45, 15 // with 3+3+1 items { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -374,11 +374,11 @@ func TestFreezerRepairDanglingIndex(t *testing.T) { func TestFreezerTruncate(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("truncation-%d", rand.Uint64()) { // Fill table - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -395,7 +395,7 @@ func TestFreezerTruncate(t *testing.T) { } // Reopen, truncate { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -417,10 +417,10 @@ func TestFreezerTruncate(t *testing.T) { // That will rewind the index, and _should_ truncate the head file func TestFreezerRepairFirstFile(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("truncationfirst-%d", rand.Uint64()) { // Fill table - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -448,7 +448,7 @@ func TestFreezerRepairFirstFile(t *testing.T) { } // Reopen { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -473,10 +473,10 @@ func TestFreezerRepairFirstFile(t *testing.T) { // - check that we did not keep the rdonly file descriptors func TestFreezerReadAndTruncate(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("read_truncate-%d", rand.Uint64()) { // Fill table - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -493,7 +493,7 @@ func TestFreezerReadAndTruncate(t *testing.T) { } // Reopen and read all files { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 50, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 50, true) if err != nil { t.Fatal(err) } @@ -519,10 +519,10 @@ func TestFreezerReadAndTruncate(t *testing.T) { func TestOffset(t *testing.T) { t.Parallel() - rm, wm, sc := metrics.NewMeter(), metrics.NewMeter(), metrics.NewCounter() + rm, wm, sg := metrics.NewMeter(), metrics.NewMeter(), metrics.NewGauge() fname := fmt.Sprintf("offset-%d", rand.Uint64()) { // Fill table - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 40, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 40, true) if err != nil { t.Fatal(err) } @@ -578,7 +578,7 @@ func TestOffset(t *testing.T) { } // Now open again { - f, err := newCustomTable(os.TempDir(), fname, rm, wm, sc, 40, true) + f, err := newCustomTable(os.TempDir(), fname, rm, wm, sg, 40, true) if err != nil { t.Fatal(err) } diff --git a/core/state/state_object.go b/core/state/state_object.go index 45ae95a2a..8680de021 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -79,9 +79,10 @@ type stateObject struct { trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - originStorage Storage // Storage cache of original entries to dedup rewrites - dirtyStorage Storage // Storage entries that need to be flushed to disk - fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. // Cache flags. // When an object is marked suicided it will be delete from the trie @@ -113,13 +114,17 @@ func newObject(db *StateDB, address common.Address, data Account) *stateObject { if data.CodeHash == nil { data.CodeHash = emptyCodeHash } + if data.Root == (common.Hash{}) { + data.Root = emptyRoot + } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - originStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } @@ -183,9 +188,11 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has if s.fakeStorage != nil { return s.fakeStorage[key] } - // If we have the original value cached, return that - value, cached := s.originStorage[key] - if cached { + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { + return value + } + if value, cached := s.originStorage[key]; cached { return value } // Track the amount of time wasted on reading the storage trie @@ -198,6 +205,7 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has s.setError(err) return common.Hash{} } + var value common.Hash if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { @@ -252,17 +260,29 @@ func (s *stateObject) setState(key, value common.Hash) { s.dirtyStorage[key] = value } +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise() { + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) + } +} + // updateTrie writes cached storage modifications into the object's storage trie. func (s *stateObject) updateTrie(db Database) Trie { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise() + // Track the amount of time wasted on updating the storge trie if metrics.EnabledExpensive { defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) } - // Update all the dirty slots in the trie + // Insert all the pending updates into the trie tr := s.getTrie(db) - for key, value := range s.dirtyStorage { - delete(s.dirtyStorage, key) - + for key, value := range s.pendingStorage { // Skip noop changes, persist actual changes if value == s.originStorage[key] { continue @@ -274,9 +294,12 @@ func (s *stateObject) updateTrie(db Database) Trie { continue } // Encoding []byte cannot fail, ok to ignore the error. - v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) + v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) s.setError(tr.TryUpdate(key[:], v)) } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) + } return tr } @@ -356,6 +379,7 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject { stateObject.code = s.code stateObject.dirtyStorage = s.dirtyStorage.Copy() stateObject.originStorage = s.originStorage.Copy() + stateObject.pendingStorage = s.pendingStorage.Copy() stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted diff --git a/core/state/state_object_test.go b/core/state/state_object_test.go new file mode 100644 index 000000000..e86d3b994 --- /dev/null +++ b/core/state/state_object_test.go @@ -0,0 +1,70 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" +) + +func BenchmarkCutOriginal(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + bytes.TrimLeft(value[:], "\x00") + } +} + +func BenchmarkCutsetterFn(b *testing.B) { + value := common.HexToHash("0x01") + cutSetFn := func(r rune) bool { + return int32(r) == int32(0) + } + for i := 0; i < b.N; i++ { + bytes.TrimLeftFunc(value[:], cutSetFn) + } +} + +func BenchmarkCutCustomTrim(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + common.TrimLeftZeroes(value[:]) + } +} + +func xTestFuzzCutter(t *testing.T) { + rand.Seed(time.Now().Unix()) + for { + v := make([]byte, 20) + zeroes := rand.Intn(21) + rand.Read(v[zeroes:]) + exp := bytes.TrimLeft(v[:], "\x00") + got := common.TrimLeftZeroes(v) + if !bytes.Equal(exp, got) { + + fmt.Printf("Input %x\n", v) + fmt.Printf("Exp %x\n", exp) + fmt.Printf("Got %x\n", got) + t.Fatalf("Error") + } + //break + } +} diff --git a/core/state/statedb.go b/core/state/statedb.go index b07f08fd2..4b4f374c9 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -67,8 +67,9 @@ type StateDB struct { trie Trie // This map holds 'live' objects, which will get modified while processing a state transition. - stateObjects map[common.Address]*stateObject - stateObjectsDirty map[common.Address]struct{} + stateObjects map[common.Address]*stateObject + stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie + stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution // DB error. // State objects are used by the consensus core and VM which are @@ -111,13 +112,14 @@ func New(root common.Hash, db Database) (*StateDB, error) { return nil, err } return &StateDB{ - db: db, - trie: tr, - stateObjects: make(map[common.Address]*stateObject), - stateObjectsDirty: make(map[common.Address]struct{}), - logs: make(map[common.Hash][]*types.Log), - preimages: make(map[common.Hash][]byte), - journal: newJournal(), + db: db, + trie: tr, + stateObjects: make(map[common.Address]*stateObject), + stateObjectsPending: make(map[common.Address]struct{}), + stateObjectsDirty: make(map[common.Address]struct{}), + logs: make(map[common.Hash][]*types.Log), + preimages: make(map[common.Hash][]byte), + journal: newJournal(), }, nil } @@ -141,6 +143,7 @@ func (self *StateDB) Reset(root common.Hash) error { } self.trie = tr self.stateObjects = make(map[common.Address]*stateObject) + self.stateObjectsPending = make(map[common.Address]struct{}) self.stateObjectsDirty = make(map[common.Address]struct{}) self.thash = common.Hash{} self.bhash = common.Hash{} @@ -421,15 +424,15 @@ func (self *StateDB) Suicide(addr common.Address) bool { // // updateStateObject writes the given object to the trie. -func (s *StateDB) updateStateObject(stateObject *stateObject) { +func (s *StateDB) updateStateObject(obj *stateObject) { // Track the amount of time wasted on updating the account from the trie if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) } // Encode the account and update the account trie - addr := stateObject.Address() + addr := obj.Address() - data, err := rlp.EncodeToBytes(stateObject) + data, err := rlp.EncodeToBytes(obj) if err != nil { panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) } @@ -437,25 +440,33 @@ func (s *StateDB) updateStateObject(stateObject *stateObject) { } // deleteStateObject removes the given object from the state trie. -func (s *StateDB) deleteStateObject(stateObject *stateObject) { +func (s *StateDB) deleteStateObject(obj *stateObject) { // Track the amount of time wasted on deleting the account from the trie if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) } // Delete the account from the trie - stateObject.deleted = true - - addr := stateObject.Address() + addr := obj.Address() s.setError(s.trie.TryDelete(addr[:])) } -// Retrieve a state object given by the address. Returns nil if not found. -func (s *StateDB) getStateObject(addr common.Address) (stateObject *stateObject) { - // Prefer live objects +// getStateObject retrieves a state object given by the address, returning nil if +// the object is not found or was deleted in this execution context. If you need +// to differentiate between non-existent/just-deleted, use getDeletedStateObject. +func (s *StateDB) getStateObject(addr common.Address) *stateObject { + if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted { + return obj + } + return nil +} + +// getDeletedStateObject is similar to getStateObject, but instead of returning +// nil for a deleted state object, it returns the actual object with the deleted +// flag set. This is needed by the state journal to revert to the correct self- +// destructed object instead of wiping all knowledge about the state object. +func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { + // Prefer live objects if any is available if obj := s.stateObjects[addr]; obj != nil { - if obj.deleted { - return nil - } return obj } // Track the amount of time wasted on loading the object from the database @@ -486,7 +497,7 @@ func (self *StateDB) setStateObject(object *stateObject) { // Retrieve a state object or create a new state object if nil. func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { stateObject := self.getStateObject(addr) - if stateObject == nil || stateObject.deleted { + if stateObject == nil { stateObject, _ = self.createObject(addr) } return stateObject @@ -495,7 +506,8 @@ func (self *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { // createObject creates a new state object. If there is an existing account with // the given address, it is overwritten and returned as the second return value. func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { - prev = self.getStateObject(addr) + prev = self.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that! + newobj = newObject(self, addr, Account{}) newobj.setNonce(0) // sets the object to dirty if prev == nil { @@ -558,15 +570,16 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common func (self *StateDB) Copy() *StateDB { // Copy all the basic fields, initialize the memory ones state := &StateDB{ - db: self.db, - trie: self.db.CopyTrie(self.trie), - stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)), - stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)), - refund: self.refund, - logs: make(map[common.Hash][]*types.Log, len(self.logs)), - logSize: self.logSize, - preimages: make(map[common.Hash][]byte, len(self.preimages)), - journal: newJournal(), + db: self.db, + trie: self.db.CopyTrie(self.trie), + stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)), + stateObjectsPending: make(map[common.Address]struct{}, len(self.stateObjectsPending)), + stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)), + refund: self.refund, + logs: make(map[common.Hash][]*types.Log, len(self.logs)), + logSize: self.logSize, + preimages: make(map[common.Hash][]byte, len(self.preimages)), + journal: newJournal(), } // Copy the dirty states, logs, and preimages for addr := range self.journal.dirties { @@ -575,18 +588,29 @@ func (self *StateDB) Copy() *StateDB { // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for // nil if object, exist := self.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that anyside effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. state.stateObjects[addr] = object.deepCopy(state) - state.stateObjectsDirty[addr] = struct{}{} + + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits } } // Above, we don't copy the actual journal. This means that if the copy is copied, the // loop above will be a no-op, since the copy's journal is empty. // Thus, here we iterate over stateObjects, to enable copies of copies + for addr := range self.stateObjectsPending { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state) + } + state.stateObjectsPending[addr] = struct{}{} + } for addr := range self.stateObjectsDirty { if _, exist := state.stateObjects[addr]; !exist { state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state) - state.stateObjectsDirty[addr] = struct{}{} } + state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { cpy := make([]*types.Log, len(logs)) @@ -631,11 +655,12 @@ func (self *StateDB) GetRefund() uint64 { return self.refund } -// Finalise finalises the state by removing the self destructed objects -// and clears the journal as well as the refunds. +// Finalise finalises the state by removing the self destructed objects and clears +// the journal as well as the refunds. Finalise, however, will not push any updates +// into the tries just yet. Only IntermediateRoot or Commit will do that. func (s *StateDB) Finalise(deleteEmptyObjects bool) { for addr := range s.journal.dirties { - stateObject, exist := s.stateObjects[addr] + obj, exist := s.stateObjects[addr] if !exist { // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 // That tx goes out of gas, and although the notion of 'touched' does not exist there, the @@ -645,13 +670,12 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { // Thus, we can safely ignore it here continue } - - if stateObject.suicided || (deleteEmptyObjects && stateObject.empty()) { - s.deleteStateObject(stateObject) + if obj.suicided || (deleteEmptyObjects && obj.empty()) { + obj.deleted = true } else { - stateObject.updateRoot(s.db) - s.updateStateObject(stateObject) + obj.finalise() } + s.stateObjectsPending[addr] = struct{}{} s.stateObjectsDirty[addr] = struct{}{} } // Invalidate journal because reverting across transactions is not allowed. @@ -662,8 +686,21 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { + // Finalise all the dirty storage states and write them into the tries s.Finalise(deleteEmptyObjects) + for addr := range s.stateObjectsPending { + obj := s.stateObjects[addr] + if obj.deleted { + s.deleteStateObject(obj) + } else { + obj.updateRoot(s.db) + s.updateStateObject(obj) + } + } + if len(s.stateObjectsPending) > 0 { + s.stateObjectsPending = make(map[common.Address]struct{}) + } // Track the amount of time wasted on hashing the account trie if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) @@ -680,46 +717,40 @@ func (self *StateDB) Prepare(thash, bhash common.Hash, ti int) { } func (s *StateDB) clearJournalAndRefund() { - s.journal = newJournal() - s.validRevisions = s.validRevisions[:0] - s.refund = 0 + if len(s.journal.entries) > 0 { + s.journal = newJournal() + s.refund = 0 + } + s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entires } // Commit writes the state to the underlying in-memory trie database. -func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { - defer s.clearJournalAndRefund() +func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { + // Finalize any pending changes and merge everything into the tries + s.IntermediateRoot(deleteEmptyObjects) - for addr := range s.journal.dirties { - s.stateObjectsDirty[addr] = struct{}{} - } // Commit objects to the trie, measuring the elapsed time - for addr, stateObject := range s.stateObjects { - _, isDirty := s.stateObjectsDirty[addr] - switch { - case stateObject.suicided || (isDirty && deleteEmptyObjects && stateObject.empty()): - // If the object has been removed, don't bother syncing it - // and just mark it for deletion in the trie. - s.deleteStateObject(stateObject) - case isDirty: + for addr := range s.stateObjectsDirty { + if obj := s.stateObjects[addr]; !obj.deleted { // Write any contract code associated with the state object - if stateObject.code != nil && stateObject.dirtyCode { - s.db.TrieDB().InsertBlob(common.BytesToHash(stateObject.CodeHash()), stateObject.code) - stateObject.dirtyCode = false + if obj.code != nil && obj.dirtyCode { + s.db.TrieDB().InsertBlob(common.BytesToHash(obj.CodeHash()), obj.code) + obj.dirtyCode = false } - // Write any storage changes in the state object to its storage trie. - if err := stateObject.CommitTrie(s.db); err != nil { + // Write any storage changes in the state object to its storage trie + if err := obj.CommitTrie(s.db); err != nil { return common.Hash{}, err } - // Update the object in the main account trie. - s.updateStateObject(stateObject) } - delete(s.stateObjectsDirty, addr) + } + if len(s.stateObjectsDirty) > 0 { + s.stateObjectsDirty = make(map[common.Address]struct{}) } // Write the account trie changes, measuing the amount of wasted time if metrics.EnabledExpensive { defer func(start time.Time) { s.AccountCommits += time.Since(start) }(time.Now()) } - root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { + return s.trie.Commit(func(leaf []byte, parent common.Hash) error { var account Account if err := rlp.DecodeBytes(leaf, &account); err != nil { return nil @@ -733,5 +764,4 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) } return nil }) - return root, err } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index bf073bc94..bfb081191 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -25,10 +25,11 @@ import ( "math/rand" "reflect" "strings" + "sync" "testing" "testing/quick" - check "gopkg.in/check.v1" + "gopkg.in/check.v1" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -53,8 +54,13 @@ func TestUpdateLeaks(t *testing.T) { if i%3 == 0 { state.SetCode(addr, []byte{i, i, i, i, i}) } - state.IntermediateRoot(false) } + + root := state.IntermediateRoot(false) + if err := state.Database().TrieDB().Commit(root, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", root.Hex()) + } + // Ensure that no data was leaked into the database it := db.NewIterator() for it.Next() { @@ -98,27 +104,45 @@ func TestIntermediateLeaks(t *testing.T) { } // Commit and cross check the databases. - if _, err := transState.Commit(false); err != nil { + transRoot, err := transState.Commit(false) + if err != nil { t.Fatalf("failed to commit transition state: %v", err) } - if _, err := finalState.Commit(false); err != nil { + if err = transState.Database().TrieDB().Commit(transRoot, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", transRoot.Hex()) + } + + finalRoot, err := finalState.Commit(false) + if err != nil { t.Fatalf("failed to commit final state: %v", err) } + if err = finalState.Database().TrieDB().Commit(finalRoot, false); err != nil { + t.Errorf("can not commit trie %v to persistent database", finalRoot.Hex()) + } + it := finalDb.NewIterator() for it.Next() { - key := it.Key() - if _, err := transDb.Get(key); err != nil { - t.Errorf("entry missing from the transition database: %x -> %x", key, it.Value()) + key, fvalue := it.Key(), it.Value() + tvalue, err := transDb.Get(key) + if err != nil { + t.Errorf("entry missing from the transition database: %x -> %x", key, fvalue) + } + if !bytes.Equal(fvalue, tvalue) { + t.Errorf("the value associate key %x is mismatch,: %x in transition database ,%x in final database", key, tvalue, fvalue) } } it.Release() it = transDb.NewIterator() for it.Next() { - key := it.Key() - if _, err := finalDb.Get(key); err != nil { + key, tvalue := it.Key(), it.Value() + fvalue, err := finalDb.Get(key) + if err != nil { t.Errorf("extra entry in the transition database: %x -> %x", key, it.Value()) } + if !bytes.Equal(fvalue, tvalue) { + t.Errorf("the value associate key %x is mismatch,: %x in transition database ,%x in final database", key, tvalue, fvalue) + } } } @@ -136,32 +160,45 @@ func TestCopy(t *testing.T) { } orig.Finalise(false) - // Copy the state, modify both in-memory + // Copy the state copy := orig.Copy() + // Copy the copy state + ccopy := copy.Copy() + + // modify all in memory for i := byte(0); i < 255; i++ { origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) origObj.AddBalance(big.NewInt(2 * int64(i))) copyObj.AddBalance(big.NewInt(3 * int64(i))) + ccopyObj.AddBalance(big.NewInt(4 * int64(i))) orig.updateStateObject(origObj) copy.updateStateObject(copyObj) + ccopy.updateStateObject(copyObj) } - // Finalise the changes on both concurrently - done := make(chan struct{}) - go func() { - orig.Finalise(true) - close(done) - }() - copy.Finalise(true) - <-done - // Verify that the two states have been updated independently + // Finalise the changes on all concurrently + finalise := func(wg *sync.WaitGroup, db *StateDB) { + defer wg.Done() + db.Finalise(true) + } + + var wg sync.WaitGroup + wg.Add(3) + go finalise(&wg, orig) + go finalise(&wg, copy) + go finalise(&wg, ccopy) + wg.Wait() + + // Verify that the three states have been updated independently for i := byte(0); i < 255; i++ { origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) @@ -169,6 +206,9 @@ func TestCopy(t *testing.T) { if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) } + if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want) + } } } @@ -438,14 +478,206 @@ func (s *StateSuite) TestTouchDelete(c *check.C) { // TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy. // See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512 func TestCopyOfCopy(t *testing.T) { - sdb, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) addr := common.HexToAddress("aaaa") - sdb.SetBalance(addr, big.NewInt(42)) + state.SetBalance(addr, big.NewInt(42)) - if got := sdb.Copy().GetBalance(addr).Uint64(); got != 42 { + if got := state.Copy().GetBalance(addr).Uint64(); got != 42 { t.Fatalf("1st copy fail, expected 42, got %v", got) } - if got := sdb.Copy().Copy().GetBalance(addr).Uint64(); got != 42 { + if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 { t.Fatalf("2nd copy fail, expected 42, got %v", got) } } + +// Tests a regression where committing a copy lost some internal meta information, +// leading to corrupted subsequent copies. +// +// See https://github.com/ethereum/go-ethereum/issues/20106. +func TestCopyCommitCopy(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) + + // Create an account and check if the retrieved balance is correct + addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") + skey := common.HexToHash("aaa") + sval := common.HexToHash("bbb") + + state.SetBalance(addr, big.NewInt(42)) // Change the account trie + state.SetCode(addr, []byte("hello")) // Change an external metadata + state.SetState(addr, skey, sval) // Change the storage trie + + if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42) + } + if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := state.GetState(addr, skey); val != sval { + t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + // Copy the non-committed state database and check pre/post commit balance + copyOne := state.Copy() + if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("first copy pre-commit balance mismatch: have %v, want %v", balance, 42) + } + if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("first copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyOne.GetState(addr, skey); val != sval { + t.Fatalf("first copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("first copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + + copyOne.Commit(false) + if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("first copy post-commit balance mismatch: have %v, want %v", balance, 42) + } + if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyOne.GetState(addr, skey); val != sval { + t.Fatalf("first copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyOne.GetCommittedState(addr, skey); val != sval { + t.Fatalf("first copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) + } + // Copy the copy and check the balance once more + copyTwo := copyOne.Copy() + if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("second copy balance mismatch: have %v, want %v", balance, 42) + } + if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("second copy code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyTwo.GetState(addr, skey); val != sval { + t.Fatalf("second copy non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyTwo.GetCommittedState(addr, skey); val != sval { + t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) + } +} + +// Tests a regression where committing a copy lost some internal meta information, +// leading to corrupted subsequent copies. +// +// See https://github.com/ethereum/go-ethereum/issues/20106. +func TestCopyCopyCommitCopy(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) + + // Create an account and check if the retrieved balance is correct + addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") + skey := common.HexToHash("aaa") + sval := common.HexToHash("bbb") + + state.SetBalance(addr, big.NewInt(42)) // Change the account trie + state.SetCode(addr, []byte("hello")) // Change an external metadata + state.SetState(addr, skey, sval) // Change the storage trie + + if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42) + } + if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := state.GetState(addr, skey); val != sval { + t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + // Copy the non-committed state database and check pre/post commit balance + copyOne := state.Copy() + if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("first copy balance mismatch: have %v, want %v", balance, 42) + } + if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("first copy code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyOne.GetState(addr, skey); val != sval { + t.Fatalf("first copy non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("first copy committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + // Copy the copy and check the balance once more + copyTwo := copyOne.Copy() + if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("second copy pre-commit balance mismatch: have %v, want %v", balance, 42) + } + if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("second copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyTwo.GetState(addr, skey); val != sval { + t.Fatalf("second copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { + t.Fatalf("second copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) + } + copyTwo.Commit(false) + if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("second copy post-commit balance mismatch: have %v, want %v", balance, 42) + } + if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyTwo.GetState(addr, skey); val != sval { + t.Fatalf("second copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyTwo.GetCommittedState(addr, skey); val != sval { + t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) + } + // Copy the copy-copy and check the balance once more + copyThree := copyTwo.Copy() + if balance := copyThree.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { + t.Fatalf("third copy balance mismatch: have %v, want %v", balance, 42) + } + if code := copyThree.GetCode(addr); !bytes.Equal(code, []byte("hello")) { + t.Fatalf("third copy code mismatch: have %x, want %x", code, []byte("hello")) + } + if val := copyThree.GetState(addr, skey); val != sval { + t.Fatalf("third copy non-committed storage slot mismatch: have %x, want %x", val, sval) + } + if val := copyThree.GetCommittedState(addr, skey); val != sval { + t.Fatalf("third copy committed storage slot mismatch: have %x, want %x", val, sval) + } +} + +// TestDeleteCreateRevert tests a weird state transition corner case that we hit +// while changing the internals of statedb. The workflow is that a contract is +// self destructed, then in a followup transaction (but same block) it's created +// again and the transaction reverted. +// +// The original statedb implementation flushed dirty objects to the tries after +// each transaction, so this works ok. The rework accumulated writes in memory +// first, but the journal wiped the entire state object on create-revert. +func TestDeleteCreateRevert(t *testing.T) { + // Create an initial state with a single contract + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase())) + + addr := toAddr([]byte("so")) + state.SetBalance(addr, big.NewInt(1)) + + root, _ := state.Commit(false) + state.Reset(root) + + // Simulate self-destructing in one transaction, then create-reverting in another + state.Suicide(addr) + state.Finalise(true) + + id := state.Snapshot() + state.SetBalance(addr, big.NewInt(2)) + state.RevertToSnapshot(id) + + // Commit the entire state and make sure we don't crash and have the correct state + root, _ = state.Commit(true) + state.Reset(root) + + if state.getStateObject(addr) != nil { + t.Fatalf("self-destructed contract came alive") + } +} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index de098dce0..f4a221bd9 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -136,7 +136,7 @@ func TestEmptyStateSync(t *testing.T) { func TestIterativeStateSyncIndividual(t *testing.T) { testIterativeStateSync(t, 1) } func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t, 100) } -func testIterativeStateSync(t *testing.T, batch int) { +func testIterativeStateSync(t *testing.T, count int) { // Create a random state to copy srcDb, srcRoot, srcAccounts := makeTestState() @@ -144,7 +144,7 @@ func testIterativeStateSync(t *testing.T, batch int) { dstDb := rawdb.NewMemoryDatabase() sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) - queue := append([]common.Hash{}, sched.Missing(batch)...) + queue := append([]common.Hash{}, sched.Missing(count)...) for len(queue) > 0 { results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { @@ -157,10 +157,12 @@ func testIterativeStateSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := dstDb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } - queue = append(queue[:0], sched.Missing(batch)...) + batch.Write() + queue = append(queue[:0], sched.Missing(count)...) } // Cross check that the two states are in sync checkStateAccounts(t, dstDb, srcRoot, srcAccounts) @@ -190,9 +192,11 @@ func TestIterativeDelayedStateSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := dstDb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() queue = append(queue[len(results):], sched.Missing(0)...) } // Cross check that the two states are in sync @@ -205,7 +209,7 @@ func TestIterativeDelayedStateSync(t *testing.T) { func TestIterativeRandomStateSyncIndividual(t *testing.T) { testIterativeRandomStateSync(t, 1) } func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomStateSync(t, 100) } -func testIterativeRandomStateSync(t *testing.T, batch int) { +func testIterativeRandomStateSync(t *testing.T, count int) { // Create a random state to copy srcDb, srcRoot, srcAccounts := makeTestState() @@ -214,7 +218,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { sched := NewStateSync(srcRoot, dstDb, trie.NewSyncBloom(1, dstDb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(batch) { + for _, hash := range sched.Missing(count) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -231,11 +235,13 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := dstDb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() queue = make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(batch) { + for _, hash := range sched.Missing(count) { queue[hash] = struct{}{} } } @@ -277,9 +283,11 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := dstDb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() for _, hash := range sched.Missing(0) { queue[hash] = struct{}{} } @@ -316,9 +324,11 @@ func TestIncompleteStateSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := dstDb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() for _, result := range results { added = append(added, result.Hash) } diff --git a/core/state_processor.go b/core/state_processor.go index bed6a0730..cfe17d587 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -68,7 +68,7 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg // Iterate over and process the individual transactions for i, tx := range block.Transactions() { statedb.Prepare(tx.Hash(), block.Hash(), i) - receipt, _, err := ApplyTransaction(p.config, p.bc, nil, gp, statedb, header, tx, usedGas, cfg) + receipt, err := ApplyTransaction(p.config, p.bc, nil, gp, statedb, header, tx, usedGas, cfg) if err != nil { return nil, nil, 0, err } @@ -85,10 +85,10 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg // and uses the input parameters for its environment. It returns the receipt // for the transaction, gas used and an error if the transaction failed, // indicating the block was invalid. -func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, uint64, error) { +func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, error) { msg, err := tx.AsMessage(types.MakeSigner(config, header.Number)) if err != nil { - return nil, 0, err + return nil, err } // Create a new context to be used in the EVM environment context := NewEVMContext(msg, header, bc, author) @@ -98,7 +98,7 @@ func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *commo // Apply the transaction to the current state (included in the env) _, gas, failed, err := ApplyMessage(vmenv, msg, gp) if err != nil { - return nil, 0, err + return nil, err } // Update the state with pending changes var root []byte @@ -125,5 +125,5 @@ func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *commo receipt.BlockNumber = header.Number receipt.TransactionIndex = uint(statedb.TxIndex()) - return receipt, gas, err + return receipt, err } diff --git a/core/tx_pool.go b/core/tx_pool.go index a49b42261..f7032dbd1 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -97,13 +97,14 @@ var ( queuedNofundsMeter = metrics.NewRegisteredMeter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds // General tx metrics - validMeter = metrics.NewRegisteredMeter("txpool/valid", nil) + knownTxMeter = metrics.NewRegisteredMeter("txpool/known", nil) + validTxMeter = metrics.NewRegisteredMeter("txpool/valid", nil) invalidTxMeter = metrics.NewRegisteredMeter("txpool/invalid", nil) underpricedTxMeter = metrics.NewRegisteredMeter("txpool/underpriced", nil) - pendingCounter = metrics.NewRegisteredCounter("txpool/pending", nil) - queuedCounter = metrics.NewRegisteredCounter("txpool/queued", nil) - localCounter = metrics.NewRegisteredCounter("txpool/local", nil) + pendingGauge = metrics.NewRegisteredGauge("txpool/pending", nil) + queuedGauge = metrics.NewRegisteredGauge("txpool/queued", nil) + localGauge = metrics.NewRegisteredGauge("txpool/local", nil) ) // TxStatus is the current status of a transaction as seen by the pool. @@ -564,16 +565,15 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e hash := tx.Hash() if pool.all.Get(hash) != nil { log.Trace("Discarding already known transaction", "hash", hash) + knownTxMeter.Mark(1) return false, fmt.Errorf("known transaction: %x", hash) } - // If the transaction fails basic validation, discard it if err := pool.validateTx(tx, local); err != nil { log.Trace("Discarding invalid transaction", "hash", hash, "err", err) invalidTxMeter.Mark(1) return false, err } - // If the transaction pool is full, discard underpriced transactions if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { // If the new transaction is underpriced, don't accept it @@ -590,7 +590,6 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e pool.removeTx(tx.Hash(), false) } } - // Try to replace an existing transaction in the pending pool from, _ := types.Sender(pool.signer, tx) // already validated if list := pool.pending[from]; list != nil && list.Overlaps(tx) { @@ -613,13 +612,11 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e log.Trace("Pooled new executable transaction", "hash", hash, "from", from, "to", tx.To()) return old != nil, nil } - // New transaction isn't replacing a pending one, push into queue replaced, err = pool.enqueueTx(hash, tx) if err != nil { return false, err } - // Mark local addresses and journal local transactions if local { if !pool.locals.contains(from) { @@ -628,7 +625,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e } } if local || pool.locals.contains(from) { - localCounter.Inc(1) + localGauge.Inc(1) } pool.journalTx(from, tx) @@ -658,7 +655,7 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er queuedReplaceMeter.Mark(1) } else { // Nothing was replaced, bump the queued counter - queuedCounter.Inc(1) + queuedGauge.Inc(1) } if pool.all.Get(hash) == nil { pool.all.Add(tx) @@ -707,7 +704,7 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T pendingReplaceMeter.Mark(1) } else { // Nothing was replaced, bump the pending counter - pendingCounter.Inc(1) + pendingGauge.Inc(1) } // Failsafe to work around direct pending inserts (tests) if pool.all.Get(hash) == nil { @@ -768,15 +765,41 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error { // addTxs attempts to queue a batch of transactions if they are valid. func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error { + // Filter out known ones without obtaining the pool lock or recovering signatures + var ( + errs = make([]error, len(txs)) + news = make([]*types.Transaction, 0, len(txs)) + ) + for i, tx := range txs { + // If the transaction is known, pre-set the error slot + if pool.all.Get(tx.Hash()) != nil { + errs[i] = fmt.Errorf("known transaction: %x", tx.Hash()) + knownTxMeter.Mark(1) + continue + } + // Accumulate all unknown transactions for deeper processing + news = append(news, tx) + } + if len(news) == 0 { + return errs + } // Cache senders in transactions before obtaining lock (pool.signer is immutable) - for _, tx := range txs { + for _, tx := range news { types.Sender(pool.signer, tx) } - + // Process all the new transaction and merge any errors into the original slice pool.mu.Lock() - errs, dirtyAddrs := pool.addTxsLocked(txs, local) + newErrs, dirtyAddrs := pool.addTxsLocked(news, local) pool.mu.Unlock() + var nilSlot = 0 + for _, err := range newErrs { + for errs[nilSlot] != nil { + nilSlot++ + } + errs[nilSlot] = err + } + // Reorg the pool internals if needed and return done := pool.requestPromoteExecutables(dirtyAddrs) if sync { <-done @@ -796,26 +819,29 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) ([]error, dirty.addTx(tx) } } - validMeter.Mark(int64(len(dirty.accounts))) + validTxMeter.Mark(int64(len(dirty.accounts))) return errs, dirty } // Status returns the status (unknown/pending/queued) of a batch of transactions // identified by their hashes. func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { - pool.mu.RLock() - defer pool.mu.RUnlock() - status := make([]TxStatus, len(hashes)) for i, hash := range hashes { - if tx := pool.all.Get(hash); tx != nil { - from, _ := types.Sender(pool.signer, tx) // already validated - if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { - status[i] = TxStatusPending - } else { - status[i] = TxStatusQueued - } + tx := pool.Get(hash) + if tx == nil { + continue } + from, _ := types.Sender(pool.signer, tx) // already validated + pool.mu.RLock() + if txList := pool.pending[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusPending + } else if txList := pool.queue[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusQueued + } + // implicit else: the tx may have been included into a block between + // checking pool.Get and obtaining the lock. In that case, TxStatusUnknown is correct + pool.mu.RUnlock() } return status } @@ -841,7 +867,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { pool.priced.Removed(1) } if pool.locals.contains(addr) { - localCounter.Dec(1) + localGauge.Dec(1) } // Remove the transaction from the pending lists and reset the account nonce if pending := pool.pending[addr]; pending != nil { @@ -858,7 +884,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // Update the account nonce if needed pool.pendingNonces.setIfLower(addr, tx.Nonce()) // Reduce the pending counter - pendingCounter.Dec(int64(1 + len(invalids))) + pendingGauge.Dec(int64(1 + len(invalids))) return } } @@ -866,7 +892,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { if future := pool.queue[addr]; future != nil { if removed, _ := future.Remove(tx); removed { // Reduce the queued counter - queuedCounter.Dec(1) + queuedGauge.Dec(1) } if future.Empty() { delete(pool.queue, addr) @@ -1164,7 +1190,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans promoted = append(promoted, tx) } } - queuedCounter.Dec(int64(len(readies))) + queuedGauge.Dec(int64(len(readies))) // Drop all transactions over the allowed limit var caps types.Transactions @@ -1179,9 +1205,9 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans } // Mark all the items dropped as removed pool.priced.Removed(len(forwards) + len(drops) + len(caps)) - queuedCounter.Dec(int64(len(forwards) + len(drops) + len(caps))) + queuedGauge.Dec(int64(len(forwards) + len(drops) + len(caps))) if pool.locals.contains(addr) { - localCounter.Dec(int64(len(forwards) + len(drops) + len(caps))) + localGauge.Dec(int64(len(forwards) + len(drops) + len(caps))) } // Delete the entire queue entry if it became empty. if list.Empty() { @@ -1240,9 +1266,9 @@ func (pool *TxPool) truncatePending() { log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) } pool.priced.Removed(len(caps)) - pendingCounter.Dec(int64(len(caps))) + pendingGauge.Dec(int64(len(caps))) if pool.locals.contains(offenders[i]) { - localCounter.Dec(int64(len(caps))) + localGauge.Dec(int64(len(caps))) } pending-- } @@ -1267,9 +1293,9 @@ func (pool *TxPool) truncatePending() { log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) } pool.priced.Removed(len(caps)) - pendingCounter.Dec(int64(len(caps))) + pendingGauge.Dec(int64(len(caps))) if pool.locals.contains(addr) { - localCounter.Dec(int64(len(caps))) + localGauge.Dec(int64(len(caps))) } pending-- } @@ -1353,9 +1379,9 @@ func (pool *TxPool) demoteUnexecutables() { log.Trace("Demoting pending transaction", "hash", hash) pool.enqueueTx(hash, tx) } - pendingCounter.Dec(int64(len(olds) + len(drops) + len(invalids))) + pendingGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) if pool.locals.contains(addr) { - localCounter.Dec(int64(len(olds) + len(drops) + len(invalids))) + localGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) } // If there's a gap in front, alert (should never happen) and postpone all transactions if list.Len() > 0 && list.txs.Get(nonce) == nil { @@ -1365,7 +1391,7 @@ func (pool *TxPool) demoteUnexecutables() { log.Error("Demoting invalidated transaction", "hash", hash) pool.enqueueTx(hash, tx) } - pendingCounter.Dec(int64(len(gapped))) + pendingGauge.Dec(int64(len(gapped))) } // Delete the entire queue entry if it became empty. if list.Empty() { diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 388668ed8..0f1e7ac8f 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -1438,6 +1438,71 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) { } } +// Tests that the pool rejects duplicate transactions. +func TestTransactionDeduplication(t *testing.T) { + t.Parallel() + + // Create the pool to test the pricing enforcement with + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) + defer pool.Stop() + + // Create a test account to add transactions with + key, _ := crypto.GenerateKey() + pool.currentState.AddBalance(crypto.PubkeyToAddress(key.PublicKey), big.NewInt(1000000000)) + + // Create a batch of transactions and add a few of them + txs := make([]*types.Transaction, 16) + for i := 0; i < len(txs); i++ { + txs[i] = pricedTransaction(uint64(i), 100000, big.NewInt(1), key) + } + var firsts []*types.Transaction + for i := 0; i < len(txs); i += 2 { + firsts = append(firsts, txs[i]) + } + errs := pool.AddRemotesSync(firsts) + if len(errs) != len(firsts) { + t.Fatalf("first add mismatching result count: have %d, want %d", len(errs), len(firsts)) + } + for i, err := range errs { + if err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued := pool.Stats() + if pending != 1 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1) + } + if queued != len(txs)/2-1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, len(txs)/2-1) + } + // Try to add all of them now and ensure previous ones error out as knowns + errs = pool.AddRemotesSync(txs) + if len(errs) != len(txs) { + t.Fatalf("all add mismatching result count: have %d, want %d", len(errs), len(txs)) + } + for i, err := range errs { + if i%2 == 0 && err == nil { + t.Errorf("add %d succeeded, should have failed as known", i) + } + if i%2 == 1 && err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued = pool.Stats() + if pending != len(txs) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, len(txs)) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that the pool rejects replacement transactions that don't meet the minimum // price bump required. func TestTransactionReplacement(t *testing.T) { diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 875054f89..9b0ba09ed 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -106,8 +106,13 @@ func (c *ecrecover) Run(input []byte) ([]byte, error) { if !allZero(input[32:63]) || !crypto.ValidateSignatureValues(v, r, s, false) { return nil, nil } + // We must make sure not to modify the 'input', so placing the 'v' along with + // the signature needs to be done on a new allocation + sig := make([]byte, 65) + copy(sig, input[64:128]) + sig[64] = v // v needs to be at the end for libsecp256k1 - pubKey, err := crypto.Ecrecover(input[:32], append(input[64:128], v)) + pubKey, err := crypto.Ecrecover(input[:32], sig) // make sure the public key is a valid one if err != nil { return nil, nil diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index ae95b4462..b4a0c07dc 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -17,6 +17,7 @@ package vm import ( + "bytes" "fmt" "math/big" "reflect" @@ -409,6 +410,11 @@ func testPrecompiled(addr string, test precompiledTest, t *testing.T) { } else if common.Bytes2Hex(res) != test.expected { t.Errorf("Expected %v, got %v", test.expected, common.Bytes2Hex(res)) } + // Verify that the precompile did not touch the input buffer + exp := common.Hex2Bytes(test.input) + if !bytes.Equal(in, exp) { + t.Errorf("Precompiled %v modified input data", addr) + } }) } @@ -423,6 +429,11 @@ func testPrecompiledFailure(addr string, test precompiledFailureTest, t *testing if !reflect.DeepEqual(err, test.expectedError) { t.Errorf("Expected error [%v], got [%v]", test.expectedError, err) } + // Verify that the precompile did not touch the input buffer + exp := common.Hex2Bytes(test.input) + if !bytes.Equal(in, exp) { + t.Errorf("Precompiled %v modified input data", addr) + } }) } @@ -574,3 +585,55 @@ func TestPrecompileBlake2FMalformedInput(t *testing.T) { testPrecompiledFailure("09", test, t) } } + +// EcRecover test vectors +var ecRecoverTests = []precompiledTest{ + { + input: "a8b53bdf3306a35a7103ab5504a0c9b492295564b6202b1942a84ef300107281" + + "000000000000000000000000000000000000000000000000000000000000001b" + + "3078356531653033663533636531386237373263636230303933666637316633" + + "6635336635633735623734646362333161383561613862383839326234653862" + + "1122334455667788991011121314151617181920212223242526272829303132", + expected: "", + name: "CallEcrecoverUnrecoverableKey", + }, + { + input: "18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c" + + "000000000000000000000000000000000000000000000000000000000000001c" + + "73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75f" + + "eeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549", + expected: "000000000000000000000000a94f5374fce5edbc8e2a8697c15331677e6ebf0b", + name: "ValidKey", + }, + { + input: "18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c" + + "100000000000000000000000000000000000000000000000000000000000001c" + + "73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75f" + + "eeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549", + expected: "", + name: "InvalidHighV-bits-1", + }, + { + input: "18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c" + + "000000000000000000000000000000000000001000000000000000000000001c" + + "73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75f" + + "eeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549", + expected: "", + name: "InvalidHighV-bits-2", + }, + { + input: "18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c" + + "000000000000000000000000000000000000001000000000000000000000011c" + + "73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75f" + + "eeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549", + expected: "", + name: "InvalidHighV-bits-3", + }, +} + +func TestPrecompiledEcrecover(t *testing.T) { + for _, test := range ecRecoverTests { + testPrecompiled("01", test, t) + } + +} diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 7b6909c92..d65664b67 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -384,7 +384,7 @@ func opSAR(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory * func opSha3(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { offset, size := stack.pop(), stack.pop() - data := memory.Get(offset.Int64(), size.Int64()) + data := memory.GetPtr(offset.Int64(), size.Int64()) if interpreter.hasher == nil { interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState) @@ -602,11 +602,9 @@ func opPop(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory * } func opMload(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) { - offset := stack.pop() - val := interpreter.intPool.get().SetBytes(memory.Get(offset.Int64(), 32)) - stack.push(val) - - interpreter.intPool.put(offset) + v := stack.peek() + offset := v.Int64() + v.SetBytes(memory.GetPtr(offset, 32)) return nil, nil } @@ -691,7 +689,7 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memor var ( value = stack.pop() offset, size = stack.pop(), stack.pop() - input = memory.Get(offset.Int64(), size.Int64()) + input = memory.GetCopy(offset.Int64(), size.Int64()) gas = contract.Gas ) if interpreter.evm.chainRules.IsEIP150 { @@ -725,7 +723,7 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memo endowment = stack.pop() offset, size = stack.pop(), stack.pop() salt = stack.pop() - input = memory.Get(offset.Int64(), size.Int64()) + input = memory.GetCopy(offset.Int64(), size.Int64()) gas = contract.Gas ) @@ -757,7 +755,7 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory toAddr := common.BigToAddress(addr) value = math.U256(value) // Get the arguments from the memory. - args := memory.Get(inOffset.Int64(), inSize.Int64()) + args := memory.GetPtr(inOffset.Int64(), inSize.Int64()) if value.Sign() != 0 { gas += params.CallStipend @@ -786,7 +784,7 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, contract *Contract, mem toAddr := common.BigToAddress(addr) value = math.U256(value) // Get arguments from the memory. - args := memory.Get(inOffset.Int64(), inSize.Int64()) + args := memory.GetPtr(inOffset.Int64(), inSize.Int64()) if value.Sign() != 0 { gas += params.CallStipend @@ -814,7 +812,7 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, contract *Contract, addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.BigToAddress(addr) // Get arguments from the memory. - args := memory.Get(inOffset.Int64(), inSize.Int64()) + args := memory.GetPtr(inOffset.Int64(), inSize.Int64()) ret, returnGas, err := interpreter.evm.DelegateCall(contract, toAddr, args, gas) if err != nil { @@ -839,7 +837,7 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, contract *Contract, m addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.BigToAddress(addr) // Get arguments from the memory. - args := memory.Get(inOffset.Int64(), inSize.Int64()) + args := memory.GetPtr(inOffset.Int64(), inSize.Int64()) ret, returnGas, err := interpreter.evm.StaticCall(contract, toAddr, args, gas) if err != nil { @@ -895,7 +893,7 @@ func makeLog(size int) executionFunc { topics[i] = common.BigToHash(stack.pop()) } - d := memory.Get(mStart.Int64(), mSize.Int64()) + d := memory.GetCopy(mStart.Int64(), mSize.Int64()) interpreter.evm.StateDB.AddLog(&types.Log{ Address: contract.Address(), Topics: topics, diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 50d0a9dda..b12df3905 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -509,12 +509,12 @@ func TestOpMstore(t *testing.T) { v := "abcdef00000000000000abba000000000deaf000000c0de00100000000133700" stack.pushN(new(big.Int).SetBytes(common.Hex2Bytes(v)), big.NewInt(0)) opMstore(&pc, evmInterpreter, nil, mem, stack) - if got := common.Bytes2Hex(mem.Get(0, 32)); got != v { + if got := common.Bytes2Hex(mem.GetCopy(0, 32)); got != v { t.Fatalf("Mstore fail, got %v, expected %v", got, v) } stack.pushN(big.NewInt(0x1), big.NewInt(0)) opMstore(&pc, evmInterpreter, nil, mem, stack) - if common.Bytes2Hex(mem.Get(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" { + if common.Bytes2Hex(mem.GetCopy(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" { t.Fatalf("Mstore failed to overwrite previous value") } poolOfIntPools.put(evmInterpreter.intPool) diff --git a/core/vm/memory.go b/core/vm/memory.go index 7e6f0eb94..496a4024b 100644 --- a/core/vm/memory.go +++ b/core/vm/memory.go @@ -70,7 +70,7 @@ func (m *Memory) Resize(size uint64) { } // Get returns offset + size as a new slice -func (m *Memory) Get(offset, size int64) (cpy []byte) { +func (m *Memory) GetCopy(offset, size int64) (cpy []byte) { if size == 0 { return nil } diff --git a/crypto/ecies/ecies_test.go b/crypto/ecies/ecies_test.go index 2836b8126..2def505d0 100644 --- a/crypto/ecies/ecies_test.go +++ b/crypto/ecies/ecies_test.go @@ -35,7 +35,6 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" - "flag" "fmt" "math/big" "testing" @@ -43,14 +42,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) -var dumpEnc bool - -func init() { - flDump := flag.Bool("dump", false, "write encrypted test message to file") - flag.Parse() - dumpEnc = *flDump -} - // Ensure the KDF generates appropriately sized keys. func TestKDF(t *testing.T) { msg := []byte("Hello, world") diff --git a/dashboard/README.md b/dashboard/README.md index 641c5f44b..67b65bda3 100644 --- a/dashboard/README.md +++ b/dashboard/README.md @@ -48,8 +48,8 @@ For more IDE support install the `linter-eslint` package too, which finds the `. [ESLint]: https://eslint.org/ [Airbnb]: https://github.com/airbnb/javascript/tree/master/react [Webpack]: https://webpack.github.io/ -[WA]: http://webpack.github.io/analyse/ -[WV]: http://chrisbateman.github.io/webpack-visualizer/ +[WA]: https://webpack.github.io/analyse/ +[WV]: https://chrisbateman.github.io/webpack-visualizer/ [Node.js]: https://nodejs.org/en/ [Flow]: https://flow.org/ [Atom]: https://atom.io/ diff --git a/dashboard/dashboard.go b/dashboard/dashboard.go index d69a750f1..b576293bc 100644 --- a/dashboard/dashboard.go +++ b/dashboard/dashboard.go @@ -125,7 +125,7 @@ func (db *Dashboard) APIs() []rpc.API { return nil } // Start starts the data collection thread and the listening server of the dashboard. // Implements the node.Service interface. func (db *Dashboard) Start(server *p2p.Server) error { - log.Info("Starting dashboard") + log.Info("Starting dashboard", "url", fmt.Sprintf("http://%s:%d", db.config.Host, db.config.Port)) db.wg.Add(3) go db.collectSystemData() diff --git a/eth/api_backend.go b/eth/api_backend.go index 69904a70f..4b74ccff5 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -72,6 +72,23 @@ func (b *EthAPIBackend) HeaderByNumber(ctx context.Context, number rpc.BlockNumb return b.eth.blockchain.GetHeaderByNumber(uint64(number)), nil } +func (b *EthAPIBackend) HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.HeaderByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + header := b.eth.blockchain.GetHeaderByHash(hash) + if header == nil { + return nil, errors.New("header for hash not found") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { + return nil, errors.New("hash is not currently canonical") + } + return header, nil + } + return nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *EthAPIBackend) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) { return b.eth.blockchain.GetHeaderByHash(hash), nil } @@ -93,6 +110,27 @@ func (b *EthAPIBackend) BlockByHash(ctx context.Context, hash common.Hash) (*typ return b.eth.blockchain.GetBlockByHash(hash), nil } +func (b *EthAPIBackend) BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.BlockByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + header := b.eth.blockchain.GetHeaderByHash(hash) + if header == nil { + return nil, errors.New("header for hash not found") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { + return nil, errors.New("hash is not currently canonical") + } + block := b.eth.blockchain.GetBlock(hash, header.Number.Uint64()) + if block == nil { + return nil, errors.New("header found, but block body is missing") + } + return block, nil + } + return nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) { // Pending state is only known by the miner if number == rpc.PendingBlockNumber { @@ -111,6 +149,27 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B return stateDb, header, err } +func (b *EthAPIBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.StateAndHeaderByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + header, err := b.HeaderByHash(ctx, hash) + if err != nil { + return nil, nil, err + } + if header == nil { + return nil, nil, errors.New("header for hash not found") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { + return nil, nil, errors.New("hash is not currently canonical") + } + stateDb, err := b.eth.BlockChain().StateAt(header.Root) + return stateDb, header, err + } + return nil, nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *EthAPIBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { return b.eth.blockchain.GetReceiptsByHash(hash), nil } diff --git a/eth/backend.go b/eth/backend.go index 2711e8642..aa1db76ef 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -69,8 +69,6 @@ type Ethereum struct { // Channel for shutting down the service shutdownChan chan bool - server *p2p.Server - // Handlers txPool *core.TxPool blockchain *core.BlockChain diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index edd0eb4d9..f8982f696 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -1574,13 +1574,14 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { func (d *Downloader) processFastSyncContent(latest *types.Header) error { // Start syncing state of the reported head block. This should get us most of // the state of the pivot block. - stateSync := d.syncState(latest.Root) - defer stateSync.Cancel() - go func() { - if err := stateSync.Wait(); err != nil && err != errCancelStateFetch && err != errCanceled { + sync := d.syncState(latest.Root) + defer sync.Cancel() + closeOnErr := func(s *stateSync) { + if err := s.Wait(); err != nil && err != errCancelStateFetch && err != errCanceled { d.queue.Close() // wake up Results } - }() + } + go closeOnErr(sync) // Figure out the ideal pivot block. Note, that this goalpost may move if the // sync takes long enough for the chain head to move significantly. pivot := uint64(0) @@ -1600,12 +1601,12 @@ func (d *Downloader) processFastSyncContent(latest *types.Header) error { if len(results) == 0 { // If pivot sync is done, stop if oldPivot == nil { - return stateSync.Cancel() + return sync.Cancel() } // If sync failed, stop select { case <-d.cancelCh: - stateSync.Cancel() + sync.Cancel() return errCanceled default: } @@ -1625,28 +1626,24 @@ func (d *Downloader) processFastSyncContent(latest *types.Header) error { } } P, beforeP, afterP := splitAroundPivot(pivot, results) - if err := d.commitFastSyncData(beforeP, stateSync); err != nil { + if err := d.commitFastSyncData(beforeP, sync); err != nil { return err } if P != nil { // If new pivot block found, cancel old state retrieval and restart if oldPivot != P { - stateSync.Cancel() + sync.Cancel() - stateSync = d.syncState(P.Header.Root) - defer stateSync.Cancel() - go func() { - if err := stateSync.Wait(); err != nil && err != errCancelStateFetch && err != errCanceled { - d.queue.Close() // wake up Results - } - }() + sync = d.syncState(P.Header.Root) + defer sync.Cancel() + go closeOnErr(sync) oldPivot = P } // Wait for completion, occasionally checking for pivot staleness select { - case <-stateSync.done: - if stateSync.err != nil { - return stateSync.err + case <-sync.done: + if sync.err != nil { + return sync.err } if err := d.commitPivotBlock(P); err != nil { return err diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index b422557d5..f875b3a84 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -347,7 +347,7 @@ func (s *stateSync) commit(force bool) error { } start := time.Now() b := s.d.stateDB.NewBatch() - if written, err := s.sched.Commit(b); written == 0 || err != nil { + if err := s.sched.Commit(b); err != nil { return err } if err := b.Write(); err != nil { diff --git a/eth/handler.go b/eth/handler.go index 4ce2d1c82..d2355a876 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/fetcher" @@ -63,7 +64,8 @@ func errResp(code errCode, format string, v ...interface{}) error { } type ProtocolManager struct { - networkID uint64 + networkID uint64 + forkFilter forkid.Filter // Fork ID filter, constant across the lifetime of the node fastSync uint32 // Flag whether fast sync is enabled (gets disabled if we already have blocks) acceptTxs uint32 // Flag whether we're considered synchronised (enables transaction processing) @@ -103,6 +105,7 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh // Create the protocol manager with the base fields manager := &ProtocolManager{ networkID: networkID, + forkFilter: forkid.NewFilter(blockchain), eventMux: mux, txpool: txpool, blockchain: blockchain, @@ -304,7 +307,7 @@ func (pm *ProtocolManager) handle(p *peer) error { number = head.Number.Uint64() td = pm.blockchain.GetTd(hash, number) ) - if err := p.Handshake(pm.networkID, td, hash, genesis.Hash()); err != nil { + if err := p.Handshake(pm.networkID, td, hash, genesis.Hash(), forkid.NewID(pm.blockchain), pm.forkFilter); err != nil { p.Log().Debug("Ethereum handshake failed", "err", err) return err } diff --git a/eth/handler_test.go b/eth/handler_test.go index 0f1672fd4..256883d1f 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -39,8 +39,8 @@ import ( ) // Tests that block headers can be retrieved from a remote chain based on user queries. -func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) } func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } +func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) } func testGetBlockHeaders(t *testing.T, protocol int) { pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) @@ -198,8 +198,8 @@ func testGetBlockHeaders(t *testing.T, protocol int) { } // Tests that block contents can be retrieved from a remote chain based on their hashes. -func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) } func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } +func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) } func testGetBlockBodies(t *testing.T, protocol int) { pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil) @@ -271,6 +271,7 @@ func testGetBlockBodies(t *testing.T, protocol int) { // Tests that the node state database can be retrieved based on hashes. func TestGetNodeData63(t *testing.T) { testGetNodeData(t, 63) } +func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) } func testGetNodeData(t *testing.T, protocol int) { // Define three accounts to simulate transactions with @@ -367,6 +368,7 @@ func testGetNodeData(t *testing.T, protocol int) { // Tests that the transaction receipts can be retrieved based on hashes. func TestGetReceipt63(t *testing.T) { testGetReceipt(t, 63) } +func TestGetReceipt64(t *testing.T) { testGetReceipt(t, 64) } func testGetReceipt(t *testing.T, protocol int) { // Define three accounts to simulate transactions with diff --git a/eth/helper_test.go b/eth/helper_test.go index 1482e99c4..e66910334 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -22,6 +22,7 @@ package eth import ( "crypto/ecdsa" "crypto/rand" + "fmt" "math/big" "sort" "sync" @@ -30,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -171,20 +173,35 @@ func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*te head = pm.blockchain.CurrentHeader() td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) ) - tp.handshake(nil, td, head.Hash(), genesis.Hash()) + tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkid.NewID(pm.blockchain), forkid.NewFilter(pm.blockchain)) } return tp, errc } // handshake simulates a trivial handshake that expects the same state from the // remote side as we are simulating locally. -func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash) { - msg := &statusData{ - ProtocolVersion: uint32(p.version), - NetworkId: DefaultConfig.NetworkId, - TD: td, - CurrentBlock: head, - GenesisBlock: genesis, +func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) { + var msg interface{} + switch { + case p.version == eth63: + msg = &statusData63{ + ProtocolVersion: uint32(p.version), + NetworkId: DefaultConfig.NetworkId, + TD: td, + CurrentBlock: head, + GenesisBlock: genesis, + } + case p.version == eth64: + msg = &statusData{ + ProtocolVersion: uint32(p.version), + NetworkID: DefaultConfig.NetworkId, + TD: td, + Head: head, + Genesis: genesis, + ForkID: forkID, + } + default: + panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) } if err := p2p.ExpectMsg(p.app, StatusMsg, msg); err != nil { t.Fatalf("status recv: %v", err) diff --git a/eth/peer.go b/eth/peer.go index 814c787b8..0beec1d84 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -25,6 +25,7 @@ import ( mapset "github.com/deckarep/golang-set" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/rlp" @@ -353,22 +354,46 @@ func (p *peer) RequestReceipts(hashes []common.Hash) error { // Handshake executes the eth protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. -func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash) error { +func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error { // Send out own handshake in a new thread errc := make(chan error, 2) - var status statusData // safe to read after two values have been received from errc + var ( + status63 statusData63 // safe to read after two values have been received from errc + status statusData // safe to read after two values have been received from errc + ) go func() { - errc <- p2p.Send(p.rw, StatusMsg, &statusData{ - ProtocolVersion: uint32(p.version), - NetworkId: network, - TD: td, - CurrentBlock: head, - GenesisBlock: genesis, - }) + switch { + case p.version == eth63: + errc <- p2p.Send(p.rw, StatusMsg, &statusData63{ + ProtocolVersion: uint32(p.version), + NetworkId: network, + TD: td, + CurrentBlock: head, + GenesisBlock: genesis, + }) + case p.version == eth64: + errc <- p2p.Send(p.rw, StatusMsg, &statusData{ + ProtocolVersion: uint32(p.version), + NetworkID: network, + TD: td, + Head: head, + Genesis: genesis, + ForkID: forkID, + }) + default: + panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) + } }() go func() { - errc <- p.readStatus(network, &status, genesis) + switch { + case p.version == eth63: + errc <- p.readStatusLegacy(network, &status63, genesis) + case p.version == eth64: + errc <- p.readStatus(network, &status, genesis, forkFilter) + default: + panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) + } }() timeout := time.NewTimer(handshakeTimeout) defer timeout.Stop() @@ -382,11 +407,18 @@ func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis return p2p.DiscReadTimeout } } - p.td, p.head = status.TD, status.CurrentBlock + switch { + case p.version == eth63: + p.td, p.head = status63.TD, status63.CurrentBlock + case p.version == eth64: + p.td, p.head = status.TD, status.Head + default: + panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) + } return nil } -func (p *peer) readStatus(network uint64, status *statusData, genesis common.Hash) (err error) { +func (p *peer) readStatusLegacy(network uint64, status *statusData63, genesis common.Hash) error { msg, err := p.rw.ReadMsg() if err != nil { return err @@ -402,10 +434,10 @@ func (p *peer) readStatus(network uint64, status *statusData, genesis common.Has return errResp(ErrDecode, "msg %v: %v", msg, err) } if status.GenesisBlock != genesis { - return errResp(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock[:8], genesis[:8]) + return errResp(ErrGenesisMismatch, "%x (!= %x)", status.GenesisBlock[:8], genesis[:8]) } if status.NetworkId != network { - return errResp(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, network) + return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkId, network) } if int(status.ProtocolVersion) != p.version { return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version) @@ -413,6 +445,36 @@ func (p *peer) readStatus(network uint64, status *statusData, genesis common.Has return nil } +func (p *peer) readStatus(network uint64, status *statusData, genesis common.Hash, forkFilter forkid.Filter) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != StatusMsg { + return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) + } + if msg.Size > protocolMaxMsgSize { + return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize) + } + // Decode the handshake and make sure everything matches + if err := msg.Decode(&status); err != nil { + return errResp(ErrDecode, "msg %v: %v", msg, err) + } + if status.NetworkID != network { + return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkID, network) + } + if int(status.ProtocolVersion) != p.version { + return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version) + } + if status.Genesis != genesis { + return errResp(ErrGenesisMismatch, "%x (!= %x)", status.Genesis, genesis) + } + if err := forkFilter(status.ForkID); err != nil { + return errResp(ErrForkIDRejected, "%v", err) + } + return nil +} + // String implements fmt.Stringer. func (p *peer) String() string { return fmt.Sprintf("Peer %s [%s]", p.id, diff --git a/eth/protocol.go b/eth/protocol.go index de0c979d8..62e4d13d1 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rlp" @@ -30,24 +31,23 @@ import ( // Constants to match up protocol versions and messages const ( - eth62 = 62 eth63 = 63 + eth64 = 64 ) // protocolName is the official short name of the protocol used during capability negotiation. const protocolName = "eth" // ProtocolVersions are the supported versions of the eth protocol (first is primary). -var ProtocolVersions = []uint{eth63} +var ProtocolVersions = []uint{eth64, eth63} // protocolLengths are the number of implemented message corresponding to different protocol versions. -var protocolLengths = map[uint]uint64{eth63: 17, eth62: 8} +var protocolLengths = map[uint]uint64{eth64: 17, eth63: 17} const protocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message // eth protocol message codes const ( - // Protocol messages belonging to eth/62 StatusMsg = 0x00 NewBlockHashesMsg = 0x01 TxMsg = 0x02 @@ -56,12 +56,10 @@ const ( GetBlockBodiesMsg = 0x05 BlockBodiesMsg = 0x06 NewBlockMsg = 0x07 - - // Protocol messages belonging to eth/63 - GetNodeDataMsg = 0x0d - NodeDataMsg = 0x0e - GetReceiptsMsg = 0x0f - ReceiptsMsg = 0x10 + GetNodeDataMsg = 0x0d + NodeDataMsg = 0x0e + GetReceiptsMsg = 0x0f + ReceiptsMsg = 0x10 ) type errCode int @@ -71,11 +69,11 @@ const ( ErrDecode ErrInvalidMsgCode ErrProtocolVersionMismatch - ErrNetworkIdMismatch - ErrGenesisBlockMismatch + ErrNetworkIDMismatch + ErrGenesisMismatch + ErrForkIDRejected ErrNoStatusMsg ErrExtraStatusMsg - ErrSuspendedPeer ) func (e errCode) String() string { @@ -88,11 +86,11 @@ var errorToString = map[int]string{ ErrDecode: "Invalid message", ErrInvalidMsgCode: "Invalid message code", ErrProtocolVersionMismatch: "Protocol version mismatch", - ErrNetworkIdMismatch: "NetworkId mismatch", - ErrGenesisBlockMismatch: "Genesis block mismatch", + ErrNetworkIDMismatch: "Network ID mismatch", + ErrGenesisMismatch: "Genesis mismatch", + ErrForkIDRejected: "Fork ID rejected", ErrNoStatusMsg: "No status message", ErrExtraStatusMsg: "Extra status message", - ErrSuspendedPeer: "Suspended peer", } type txPool interface { @@ -108,8 +106,8 @@ type txPool interface { SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription } -// statusData is the network packet for the status message. -type statusData struct { +// statusData63 is the network packet for the status message for eth/63. +type statusData63 struct { ProtocolVersion uint32 NetworkId uint64 TD *big.Int @@ -117,6 +115,16 @@ type statusData struct { GenesisBlock common.Hash } +// statusData is the network packet for the status message for eth/64 and later. +type statusData struct { + ProtocolVersion uint32 + NetworkID uint64 + TD *big.Int + Head common.Hash + Genesis common.Hash + ForkID forkid.ID +} + // newBlockHashesData is the network packet for the block announcements. type newBlockHashesData []struct { Hash common.Hash // Hash of one particular block being announced diff --git a/eth/protocol_test.go b/eth/protocol_test.go index e817d673a..ca418942b 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -18,15 +18,24 @@ package eth import ( "fmt" + "math/big" "sync" "testing" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" ) @@ -37,10 +46,7 @@ func init() { var testAccount, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") // Tests that handshake failures are detected and reported correctly. -func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) } -func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) } - -func testStatusMsgErrors(t *testing.T, protocol int) { +func TestStatusMsgErrors63(t *testing.T) { pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) var ( genesis = pm.blockchain.Genesis() @@ -59,21 +65,20 @@ func testStatusMsgErrors(t *testing.T, protocol int) { wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"), }, { - code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()}, - wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", protocol), + code: StatusMsg, data: statusData63{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()}, + wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 63), }, { - code: StatusMsg, data: statusData{uint32(protocol), 999, td, head.Hash(), genesis.Hash()}, - wantError: errResp(ErrNetworkIdMismatch, "999 (!= %d)", DefaultConfig.NetworkId), + code: StatusMsg, data: statusData63{63, 999, td, head.Hash(), genesis.Hash()}, + wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId), }, { - code: StatusMsg, data: statusData{uint32(protocol), DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}}, - wantError: errResp(ErrGenesisBlockMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]), + code: StatusMsg, data: statusData63{63, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}}, + wantError: errResp(ErrGenesisMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]), }, } - for i, test := range tests { - p, errc := newTestPeer("peer", protocol, pm, false) + p, errc := newTestPeer("peer", 63, pm, false) // The send call might hang until reset because // the protocol might not read the payload. go p2p.Send(p.app, test.code, test.data) @@ -92,9 +97,155 @@ func testStatusMsgErrors(t *testing.T, protocol int) { } } +func TestStatusMsgErrors64(t *testing.T) { + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) + var ( + genesis = pm.blockchain.Genesis() + head = pm.blockchain.CurrentHeader() + td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) + forkID = forkid.NewID(pm.blockchain) + ) + defer pm.Stop() + + tests := []struct { + code uint64 + data interface{} + wantError error + }{ + { + code: TxMsg, data: []interface{}{}, + wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"), + }, + { + code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkID}, + wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 64), + }, + { + code: StatusMsg, data: statusData{64, 999, td, head.Hash(), genesis.Hash(), forkID}, + wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId), + }, + { + code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}, forkID}, + wantError: errResp(ErrGenesisMismatch, "0300000000000000000000000000000000000000000000000000000000000000 (!= %x)", genesis.Hash()), + }, + { + code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}}, + wantError: errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()), + }, + } + for i, test := range tests { + p, errc := newTestPeer("peer", 64, pm, false) + // The send call might hang until reset because + // the protocol might not read the payload. + go p2p.Send(p.app, test.code, test.data) + + select { + case err := <-errc: + if err == nil { + t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError) + } else if err.Error() != test.wantError.Error() { + t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError) + } + case <-time.After(2 * time.Second): + t.Errorf("protocol did not shut down within 2 seconds") + } + p.close() + } +} + +func TestForkIDSplit(t *testing.T) { + var ( + engine = ethash.NewFaker() + + configNoFork = ¶ms.ChainConfig{HomesteadBlock: big.NewInt(1)} + configProFork = ¶ms.ChainConfig{ + HomesteadBlock: big.NewInt(1), + EIP150Block: big.NewInt(2), + EIP155Block: big.NewInt(2), + EIP158Block: big.NewInt(2), + ByzantiumBlock: big.NewInt(3), + } + dbNoFork = rawdb.NewMemoryDatabase() + dbProFork = rawdb.NewMemoryDatabase() + + gspecNoFork = &core.Genesis{Config: configNoFork} + gspecProFork = &core.Genesis{Config: configProFork} + + genesisNoFork = gspecNoFork.MustCommit(dbNoFork) + genesisProFork = gspecProFork.MustCommit(dbProFork) + + chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil) + chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil) + + blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil) + blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil) + + ethNoFork, _ = NewProtocolManager(configNoFork, nil, downloader.FullSync, 1, new(event.TypeMux), new(testTxPool), engine, chainNoFork, dbNoFork, 1, nil) + ethProFork, _ = NewProtocolManager(configProFork, nil, downloader.FullSync, 1, new(event.TypeMux), new(testTxPool), engine, chainProFork, dbProFork, 1, nil) + ) + ethNoFork.Start(1000) + ethProFork.Start(1000) + + // Both nodes should allow the other to connect (same genesis, next fork is the same) + p2pNoFork, p2pProFork := p2p.MsgPipe() + peerNoFork := newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork) + peerProFork := newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork) + + errc := make(chan error, 2) + go func() { errc <- ethNoFork.handle(peerProFork) }() + go func() { errc <- ethProFork.handle(peerNoFork) }() + + select { + case err := <-errc: + t.Fatalf("frontier nofork <-> profork failed: %v", err) + case <-time.After(250 * time.Millisecond): + p2pNoFork.Close() + p2pProFork.Close() + } + // Progress into Homestead. Fork's match, so we don't care what the future holds + chainNoFork.InsertChain(blocksNoFork[:1]) + chainProFork.InsertChain(blocksProFork[:1]) + + p2pNoFork, p2pProFork = p2p.MsgPipe() + peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork) + peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork) + + errc = make(chan error, 2) + go func() { errc <- ethNoFork.handle(peerProFork) }() + go func() { errc <- ethProFork.handle(peerNoFork) }() + + select { + case err := <-errc: + t.Fatalf("homestead nofork <-> profork failed: %v", err) + case <-time.After(250 * time.Millisecond): + p2pNoFork.Close() + p2pProFork.Close() + } + // Progress into Spurious. Forks mismatch, signalling differing chains, reject + chainNoFork.InsertChain(blocksNoFork[1:2]) + chainProFork.InsertChain(blocksProFork[1:2]) + + p2pNoFork, p2pProFork = p2p.MsgPipe() + peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork) + peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork) + + errc = make(chan error, 2) + go func() { errc <- ethNoFork.handle(peerProFork) }() + go func() { errc <- ethProFork.handle(peerNoFork) }() + + select { + case err := <-errc: + if want := errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()); err.Error() != want.Error() { + t.Fatalf("fork ID rejection error mismatch: have %v, want %v", err, want) + } + case <-time.After(250 * time.Millisecond): + t.Fatalf("split peers not rejected") + } +} + // This test checks that received transactions are added to the local pool. -func TestRecvTransactions62(t *testing.T) { testRecvTransactions(t, 62) } func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) } +func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) } func testRecvTransactions(t *testing.T, protocol int) { txAdded := make(chan []*types.Transaction) @@ -121,8 +272,8 @@ func testRecvTransactions(t *testing.T, protocol int) { } // This test checks that pending transactions are sent. -func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) } func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) } +func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) } func testSendTransactions(t *testing.T, protocol int) { pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) diff --git a/eth/tracers/tracer.go b/eth/tracers/tracer.go index c0729fb1d..724c5443a 100644 --- a/eth/tracers/tracer.go +++ b/eth/tracers/tracer.go @@ -99,7 +99,7 @@ func (mw *memoryWrapper) slice(begin, end int64) []byte { log.Warn("Tracer accessed out of bound memory", "available", mw.memory.Len(), "offset", begin, "size", end-begin) return nil } - return mw.memory.Get(begin, end-begin) + return mw.memory.GetCopy(begin, end-begin) } // getUint returns the 32 bytes at the specified address interpreted as a uint. diff --git a/ethdb/leveldb/leveldb.go b/ethdb/leveldb/leveldb.go index aba6593c7..378d4c3cd 100644 --- a/ethdb/leveldb/leveldb.go +++ b/ethdb/leveldb/leveldb.go @@ -62,14 +62,18 @@ type Database struct { fn string // filename for reporting db *leveldb.DB // LevelDB instance - compTimeMeter metrics.Meter // Meter for measuring the total time spent in database compaction - compReadMeter metrics.Meter // Meter for measuring the data read during compaction - compWriteMeter metrics.Meter // Meter for measuring the data written during compaction - writeDelayNMeter metrics.Meter // Meter for measuring the write delay number due to database compaction - writeDelayMeter metrics.Meter // Meter for measuring the write delay duration due to database compaction - diskSizeGauge metrics.Gauge // Gauge for tracking the size of all the levels in the database - diskReadMeter metrics.Meter // Meter for measuring the effective amount of data read - diskWriteMeter metrics.Meter // Meter for measuring the effective amount of data written + compTimeMeter metrics.Meter // Meter for measuring the total time spent in database compaction + compReadMeter metrics.Meter // Meter for measuring the data read during compaction + compWriteMeter metrics.Meter // Meter for measuring the data written during compaction + writeDelayNMeter metrics.Meter // Meter for measuring the write delay number due to database compaction + writeDelayMeter metrics.Meter // Meter for measuring the write delay duration due to database compaction + diskSizeGauge metrics.Gauge // Gauge for tracking the size of all the levels in the database + diskReadMeter metrics.Meter // Meter for measuring the effective amount of data read + diskWriteMeter metrics.Meter // Meter for measuring the effective amount of data written + memCompGauge metrics.Gauge // Gauge for tracking the number of memory compaction + level0CompGauge metrics.Gauge // Gauge for tracking the number of table compaction in level0 + nonlevel0CompGauge metrics.Gauge // Gauge for tracking the number of table compaction in non0 level + seekCompGauge metrics.Gauge // Gauge for tracking the number of table compaction caused by read opt quitLock sync.Mutex // Mutex protecting the quit channel access quitChan chan chan error // Quit channel to stop the metrics collection before closing the database @@ -96,6 +100,7 @@ func New(file string, cache int, handles int, namespace string) (*Database, erro BlockCacheCapacity: cache / 2 * opt.MiB, WriteBuffer: cache / 4 * opt.MiB, // Two of these are used internally Filter: filter.NewBloomFilter(10), + DisableSeeksCompaction: true, }) if _, corrupted := err.(*errors.ErrCorrupted); corrupted { db, err = leveldb.RecoverFile(file, nil) @@ -118,6 +123,10 @@ func New(file string, cache int, handles int, namespace string) (*Database, erro ldb.diskWriteMeter = metrics.NewRegisteredMeter(namespace+"disk/write", nil) ldb.writeDelayMeter = metrics.NewRegisteredMeter(namespace+"compact/writedelay/duration", nil) ldb.writeDelayNMeter = metrics.NewRegisteredMeter(namespace+"compact/writedelay/counter", nil) + ldb.memCompGauge = metrics.NewRegisteredGauge(namespace+"compact/memory", nil) + ldb.level0CompGauge = metrics.NewRegisteredGauge(namespace+"compact/level0", nil) + ldb.nonlevel0CompGauge = metrics.NewRegisteredGauge(namespace+"compact/nonlevel0", nil) + ldb.seekCompGauge = metrics.NewRegisteredGauge(namespace+"compact/seek", nil) // Start up the metrics gathering and return go ldb.meter(metricsGatheringInterval) @@ -375,6 +384,29 @@ func (db *Database) meter(refresh time.Duration) { } iostats[0], iostats[1] = nRead, nWrite + compCount, err := db.db.GetProperty("leveldb.compcount") + if err != nil { + db.log.Error("Failed to read database iostats", "err", err) + merr = err + continue + } + + var ( + memComp uint32 + level0Comp uint32 + nonLevel0Comp uint32 + seekComp uint32 + ) + if n, err := fmt.Sscanf(compCount, "MemComp:%d Level0Comp:%d NonLevel0Comp:%d SeekComp:%d", &memComp, &level0Comp, &nonLevel0Comp, &seekComp); n != 4 || err != nil { + db.log.Error("Compaction count statistic not found") + merr = err + continue + } + db.memCompGauge.Update(int64(memComp)) + db.level0CompGauge.Update(int64(level0Comp)) + db.nonlevel0CompGauge.Update(int64(nonLevel0Comp)) + db.seekCompGauge.Update(int64(seekComp)) + // Sleep a bit, then repeat the stats collection select { case errc = <-db.quitChan: diff --git a/graphql/graphiql.go b/graphql/graphiql.go index 483d4cea3..864ebf57d 100644 --- a/graphql/graphiql.go +++ b/graphql/graphiql.go @@ -52,7 +52,7 @@ func (h GraphiQL) ServeHTTP(w http.ResponseWriter, r *http.Request) { respond(w, errorJSON("only GET requests are supported"), http.StatusMethodNotAllowed) return } - + w.Header().Set("Content-Type", "text/html") w.Write(graphiql) } diff --git a/graphql/graphql.go b/graphql/graphql.go index df279f42b..ddd928dff 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -36,20 +36,19 @@ import ( ) var ( - errOnlyOnMainChain = errors.New("this operation is only available for blocks on the canonical chain") - errBlockInvariant = errors.New("block objects must be instantiated with at least one of num or hash") + errBlockInvariant = errors.New("block objects must be instantiated with at least one of num or hash") ) // Account represents an Ethereum account at a particular block. type Account struct { - backend ethapi.Backend - address common.Address - blockNumber rpc.BlockNumber + backend ethapi.Backend + address common.Address + blockNrOrHash rpc.BlockNumberOrHash } // getState fetches the StateDB object for an account. func (a *Account) getState(ctx context.Context) (*state.StateDB, error) { - state, _, err := a.backend.StateAndHeaderByNumber(ctx, a.blockNumber) + state, _, err := a.backend.StateAndHeaderByNumberOrHash(ctx, a.blockNrOrHash) return state, err } @@ -102,9 +101,9 @@ func (l *Log) Transaction(ctx context.Context) *Transaction { func (l *Log) Account(ctx context.Context, args BlockNumberArgs) *Account { return &Account{ - backend: l.backend, - address: l.log.Address, - blockNumber: args.Number(), + backend: l.backend, + address: l.log.Address, + blockNrOrHash: args.NumberOrLatest(), } } @@ -136,10 +135,10 @@ func (t *Transaction) resolve(ctx context.Context) (*types.Transaction, error) { tx, blockHash, _, index := rawdb.ReadTransaction(t.backend.ChainDb(), t.hash) if tx != nil { t.tx = tx + blockNrOrHash := rpc.BlockNumberOrHashWithHash(blockHash, false) t.block = &Block{ - backend: t.backend, - hash: blockHash, - canonical: unknown, + backend: t.backend, + numberOrHash: &blockNrOrHash, } t.index = index } else { @@ -203,9 +202,9 @@ func (t *Transaction) To(ctx context.Context, args BlockNumberArgs) (*Account, e return nil, nil } return &Account{ - backend: t.backend, - address: *to, - blockNumber: args.Number(), + backend: t.backend, + address: *to, + blockNrOrHash: args.NumberOrLatest(), }, nil } @@ -221,9 +220,9 @@ func (t *Transaction) From(ctx context.Context, args BlockNumberArgs) (*Account, from, _ := types.Sender(signer, tx) return &Account{ - backend: t.backend, - address: from, - blockNumber: args.Number(), + backend: t.backend, + address: from, + blockNrOrHash: args.NumberOrLatest(), }, nil } @@ -293,9 +292,9 @@ func (t *Transaction) CreatedContract(ctx context.Context, args BlockNumberArgs) return nil, err } return &Account{ - backend: t.backend, - address: receipt.ContractAddress, - blockNumber: args.Number(), + backend: t.backend, + address: receipt.ContractAddress, + blockNrOrHash: args.NumberOrLatest(), }, nil } @@ -317,45 +316,16 @@ func (t *Transaction) Logs(ctx context.Context) (*[]*Log, error) { type BlockType int -const ( - unknown BlockType = iota - isCanonical - notCanonical -) - // Block represents an Ethereum block. -// backend, and either num or hash are mandatory. All other fields are lazily fetched +// backend, and numberOrHash are mandatory. All other fields are lazily fetched // when required. type Block struct { - backend ethapi.Backend - num *rpc.BlockNumber - hash common.Hash - header *types.Header - block *types.Block - receipts []*types.Receipt - canonical BlockType // Indicates if this block is on the main chain or not. -} - -func (b *Block) onMainChain(ctx context.Context) error { - if b.canonical == unknown { - header, err := b.resolveHeader(ctx) - if err != nil { - return err - } - canonHeader, err := b.backend.HeaderByNumber(ctx, rpc.BlockNumber(header.Number.Uint64())) - if err != nil { - return err - } - if header.Hash() == canonHeader.Hash() { - b.canonical = isCanonical - } else { - b.canonical = notCanonical - } - } - if b.canonical != isCanonical { - return errOnlyOnMainChain - } - return nil + backend ethapi.Backend + numberOrHash *rpc.BlockNumberOrHash + hash common.Hash + header *types.Header + block *types.Block + receipts []*types.Receipt } // resolve returns the internal Block object representing this block, fetching @@ -364,14 +334,17 @@ func (b *Block) resolve(ctx context.Context) (*types.Block, error) { if b.block != nil { return b.block, nil } - var err error - if b.hash != (common.Hash{}) { - b.block, err = b.backend.BlockByHash(ctx, b.hash) - } else { - b.block, err = b.backend.BlockByNumber(ctx, *b.num) + if b.numberOrHash == nil { + latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) + b.numberOrHash = &latest } + var err error + b.block, err = b.backend.BlockByNumberOrHash(ctx, *b.numberOrHash) if b.block != nil && b.header == nil { b.header = b.block.Header() + if hash, ok := b.numberOrHash.Hash(); ok { + b.hash = hash + } } return b.block, err } @@ -380,7 +353,7 @@ func (b *Block) resolve(ctx context.Context) (*types.Block, error) { // if necessary. Call this function instead of `resolve` unless you need the // additional data (transactions and uncles). func (b *Block) resolveHeader(ctx context.Context) (*types.Header, error) { - if b.num == nil && b.hash == (common.Hash{}) { + if b.numberOrHash == nil && b.hash == (common.Hash{}) { return nil, errBlockInvariant } var err error @@ -388,7 +361,7 @@ func (b *Block) resolveHeader(ctx context.Context) (*types.Header, error) { if b.hash != (common.Hash{}) { b.header, err = b.backend.HeaderByHash(ctx, b.hash) } else { - b.header, err = b.backend.HeaderByNumber(ctx, *b.num) + b.header, err = b.backend.HeaderByNumberOrHash(ctx, *b.numberOrHash) } } return b.header, err @@ -416,15 +389,12 @@ func (b *Block) resolveReceipts(ctx context.Context) ([]*types.Receipt, error) { } func (b *Block) Number(ctx context.Context) (hexutil.Uint64, error) { - if b.num == nil || *b.num == rpc.LatestBlockNumber { - header, err := b.resolveHeader(ctx) - if err != nil { - return 0, err - } - num := rpc.BlockNumber(header.Number.Uint64()) - b.num = &num + header, err := b.resolveHeader(ctx) + if err != nil { + return 0, err } - return hexutil.Uint64(*b.num), nil + + return hexutil.Uint64(header.Number.Uint64()), nil } func (b *Block) Hash(ctx context.Context) (common.Hash, error) { @@ -456,26 +426,17 @@ func (b *Block) GasUsed(ctx context.Context) (hexutil.Uint64, error) { func (b *Block) Parent(ctx context.Context) (*Block, error) { // If the block header hasn't been fetched, and we'll need it, fetch it. - if b.num == nil && b.hash != (common.Hash{}) && b.header == nil { + if b.numberOrHash == nil && b.header == nil { if _, err := b.resolveHeader(ctx); err != nil { return nil, err } } if b.header != nil && b.header.Number.Uint64() > 0 { - num := rpc.BlockNumber(b.header.Number.Uint64() - 1) + num := rpc.BlockNumberOrHashWithNumber(rpc.BlockNumber(b.header.Number.Uint64() - 1)) return &Block{ - backend: b.backend, - num: &num, - hash: b.header.ParentHash, - canonical: unknown, - }, nil - } - if b.num != nil && *b.num != 0 { - num := *b.num - 1 - return &Block{ - backend: b.backend, - num: &num, - canonical: isCanonical, + backend: b.backend, + numberOrHash: &num, + hash: b.header.ParentHash, }, nil } return nil, nil @@ -561,13 +522,11 @@ func (b *Block) Ommers(ctx context.Context) (*[]*Block, error) { } ret := make([]*Block, 0, len(block.Uncles())) for _, uncle := range block.Uncles() { - blockNumber := rpc.BlockNumber(uncle.Number.Uint64()) + blockNumberOrHash := rpc.BlockNumberOrHashWithHash(uncle.Hash(), false) ret = append(ret, &Block{ - backend: b.backend, - num: &blockNumber, - hash: uncle.Hash(), - header: uncle, - canonical: notCanonical, + backend: b.backend, + numberOrHash: &blockNumberOrHash, + header: uncle, }) } return &ret, nil @@ -603,16 +562,26 @@ func (b *Block) TotalDifficulty(ctx context.Context) (hexutil.Big, error) { // BlockNumberArgs encapsulates arguments to accessors that specify a block number. type BlockNumberArgs struct { + // TODO: Ideally we could use input unions to allow the query to specify the + // block parameter by hash, block number, or tag but input unions aren't part of the + // standard GraphQL schema SDL yet, see: https://github.com/graphql/graphql-spec/issues/488 Block *hexutil.Uint64 } -// Number returns the provided block number, or rpc.LatestBlockNumber if none +// NumberOr returns the provided block number argument, or the "current" block number or hash if none // was provided. -func (a BlockNumberArgs) Number() rpc.BlockNumber { +func (a BlockNumberArgs) NumberOr(current rpc.BlockNumberOrHash) rpc.BlockNumberOrHash { if a.Block != nil { - return rpc.BlockNumber(*a.Block) + blockNr := rpc.BlockNumber(*a.Block) + return rpc.BlockNumberOrHashWithNumber(blockNr) } - return rpc.LatestBlockNumber + return current +} + +// NumberOrLatest returns the provided block number argument, or the "latest" block number if none +// was provided. +func (a BlockNumberArgs) NumberOrLatest() rpc.BlockNumberOrHash { + return a.NumberOr(rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber)) } func (b *Block) Miner(ctx context.Context, args BlockNumberArgs) (*Account, error) { @@ -621,9 +590,9 @@ func (b *Block) Miner(ctx context.Context, args BlockNumberArgs) (*Account, erro return nil, err } return &Account{ - backend: b.backend, - address: header.Coinbase, - blockNumber: args.Number(), + backend: b.backend, + address: header.Coinbase, + blockNrOrHash: args.NumberOrLatest(), }, nil } @@ -683,13 +652,11 @@ func (b *Block) OmmerAt(ctx context.Context, args struct{ Index int32 }) (*Block return nil, nil } uncle := uncles[args.Index] - blockNumber := rpc.BlockNumber(uncle.Number.Uint64()) + blockNumberOrHash := rpc.BlockNumberOrHashWithHash(uncle.Hash(), false) return &Block{ - backend: b.backend, - num: &blockNumber, - hash: uncle.Hash(), - header: uncle, - canonical: notCanonical, + backend: b.backend, + numberOrHash: &blockNumberOrHash, + header: uncle, }, nil } @@ -757,20 +724,16 @@ func (b *Block) Logs(ctx context.Context, args struct{ Filter BlockFilterCriteri func (b *Block) Account(ctx context.Context, args struct { Address common.Address }) (*Account, error) { - err := b.onMainChain(ctx) - if err != nil { - return nil, err - } - if b.num == nil { + if b.numberOrHash == nil { _, err := b.resolveHeader(ctx) if err != nil { return nil, err } } return &Account{ - backend: b.backend, - address: args.Address, - blockNumber: *b.num, + backend: b.backend, + address: args.Address, + blockNrOrHash: *b.numberOrHash, }, nil } @@ -807,17 +770,13 @@ func (c *CallResult) Status() hexutil.Uint64 { func (b *Block) Call(ctx context.Context, args struct { Data ethapi.CallArgs }) (*CallResult, error) { - err := b.onMainChain(ctx) - if err != nil { - return nil, err - } - if b.num == nil { - _, err := b.resolveHeader(ctx) + if b.numberOrHash == nil { + _, err := b.resolve(ctx) if err != nil { return nil, err } } - result, gas, failed, err := ethapi.DoCall(ctx, b.backend, args.Data, *b.num, nil, vm.Config{}, 5*time.Second, b.backend.RPCGasCap()) + result, gas, failed, err := ethapi.DoCall(ctx, b.backend, args.Data, *b.numberOrHash, nil, vm.Config{}, 5*time.Second, b.backend.RPCGasCap()) status := hexutil.Uint64(1) if failed { status = 0 @@ -832,17 +791,13 @@ func (b *Block) Call(ctx context.Context, args struct { func (b *Block) EstimateGas(ctx context.Context, args struct { Data ethapi.CallArgs }) (hexutil.Uint64, error) { - err := b.onMainChain(ctx) - if err != nil { - return hexutil.Uint64(0), err - } - if b.num == nil { + if b.numberOrHash == nil { _, err := b.resolveHeader(ctx) if err != nil { return hexutil.Uint64(0), err } } - gas, err := ethapi.DoEstimateGas(ctx, b.backend, args.Data, *b.num, b.backend.RPCGasCap()) + gas, err := ethapi.DoEstimateGas(ctx, b.backend, args.Data, *b.numberOrHash, b.backend.RPCGasCap()) return gas, err } @@ -875,17 +830,19 @@ func (p *Pending) Transactions(ctx context.Context) (*[]*Transaction, error) { func (p *Pending) Account(ctx context.Context, args struct { Address common.Address }) *Account { + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) return &Account{ - backend: p.backend, - address: args.Address, - blockNumber: rpc.PendingBlockNumber, + backend: p.backend, + address: args.Address, + blockNrOrHash: pendingBlockNr, } } func (p *Pending) Call(ctx context.Context, args struct { Data ethapi.CallArgs }) (*CallResult, error) { - result, gas, failed, err := ethapi.DoCall(ctx, p.backend, args.Data, rpc.PendingBlockNumber, nil, vm.Config{}, 5*time.Second, p.backend.RPCGasCap()) + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + result, gas, failed, err := ethapi.DoCall(ctx, p.backend, args.Data, pendingBlockNr, nil, vm.Config{}, 5*time.Second, p.backend.RPCGasCap()) status := hexutil.Uint64(1) if failed { status = 0 @@ -900,7 +857,8 @@ func (p *Pending) Call(ctx context.Context, args struct { func (p *Pending) EstimateGas(ctx context.Context, args struct { Data ethapi.CallArgs }) (hexutil.Uint64, error) { - return ethapi.DoEstimateGas(ctx, p.backend, args.Data, rpc.PendingBlockNumber, p.backend.RPCGasCap()) + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + return ethapi.DoEstimateGas(ctx, p.backend, args.Data, pendingBlockNr, p.backend.RPCGasCap()) } // Resolver is the top-level object in the GraphQL hierarchy. @@ -914,24 +872,23 @@ func (r *Resolver) Block(ctx context.Context, args struct { }) (*Block, error) { var block *Block if args.Number != nil { - num := rpc.BlockNumber(uint64(*args.Number)) + number := rpc.BlockNumber(uint64(*args.Number)) + numberOrHash := rpc.BlockNumberOrHashWithNumber(number) block = &Block{ - backend: r.backend, - num: &num, - canonical: isCanonical, + backend: r.backend, + numberOrHash: &numberOrHash, } } else if args.Hash != nil { + numberOrHash := rpc.BlockNumberOrHashWithHash(*args.Hash, false) block = &Block{ - backend: r.backend, - hash: *args.Hash, - canonical: unknown, + backend: r.backend, + numberOrHash: &numberOrHash, } } else { - num := rpc.LatestBlockNumber + numberOrHash := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) block = &Block{ - backend: r.backend, - num: &num, - canonical: isCanonical, + backend: r.backend, + numberOrHash: &numberOrHash, } } // Resolve the header, return nil if it doesn't exist. @@ -963,11 +920,10 @@ func (r *Resolver) Blocks(ctx context.Context, args struct { } ret := make([]*Block, 0, to-from+1) for i := from; i <= to; i++ { - num := i + numberOrHash := rpc.BlockNumberOrHashWithNumber(i) ret = append(ret, &Block{ - backend: r.backend, - num: &num, - canonical: isCanonical, + backend: r.backend, + numberOrHash: &numberOrHash, }) } return ret, nil diff --git a/internal/build/archive.go b/internal/build/archive.go index ac680ba63..8571edd5a 100644 --- a/internal/build/archive.go +++ b/internal/build/archive.go @@ -183,3 +183,49 @@ func (a *TarballArchive) Close() error { } return a.file.Close() } + +func ExtractTarballArchive(archive string, dest string) error { + // We're only interested in gzipped archives, wrap the reader now + ar, err := os.Open(archive) + if err != nil { + return err + } + defer ar.Close() + + gzr, err := gzip.NewReader(ar) + if err != nil { + return err + } + defer gzr.Close() + + // Iterate over all the files in the tarball + tr := tar.NewReader(gzr) + for { + // Fetch the next tarball header and abort if needed + header, err := tr.Next() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + // Figure out the target and create it + target := filepath.Join(dest, header.Name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, 0755); err != nil { + return err + } + case tar.TypeReg: + file, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(file, tr); err != nil { + return err + } + file.Close() + } + } +} diff --git a/internal/build/gosrc.go b/internal/build/gosrc.go new file mode 100644 index 000000000..c85e46968 --- /dev/null +++ b/internal/build/gosrc.go @@ -0,0 +1,81 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package build + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" +) + +// EnsureGoSources ensures that path contains a file with the given SHA256 hash, +// and if not, it downloads a fresh Go source package from upstream and replaces +// path with it (if the hash matches). +func EnsureGoSources(version string, hash []byte, path string) error { + // Sanity check the destination path to ensure we don't do weird things + if !strings.HasSuffix(path, ".tar.gz") { + return fmt.Errorf("destination path (%s) must end with .tar.gz", path) + } + // If the file exists, validate it's hash + if archive, err := ioutil.ReadFile(path); err == nil { // Go sources are ~20MB, it's fine to read all + hasher := sha256.New() + hasher.Write(archive) + have := hasher.Sum(nil) + + if bytes.Equal(have, hash) { + fmt.Printf("Go %s [%x] available at %s\n", version, hash, path) + return nil + } + fmt.Printf("Go %s hash mismatch (have %x, want %x) at %s, deleting old archive\n", version, have, hash, path) + if err := os.Remove(path); err != nil { + return err + } + } + // Archive missing or bad hash, download a new one + fmt.Printf("Downloading Go %s [want %x] into %s\n", version, hash, path) + + res, err := http.Get(fmt.Sprintf("https://dl.google.com/go/go%s.src.tar.gz", version)) + if err != nil || res.StatusCode != http.StatusOK { + return fmt.Errorf("failed to access Go sources: code %d, err %v", res.StatusCode, err) + } + defer res.Body.Close() + + archive, err := ioutil.ReadAll(res.Body) + if err != nil { + return err + } + // Sanity check the downloaded archive, save if checks out + hasher := sha256.New() + hasher.Write(archive) + + if have := hasher.Sum(nil); !bytes.Equal(have, hash) { + return fmt.Errorf("downloaded Go %s hash mismatch (have %x, want %x)", version, have, hash) + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + if err := ioutil.WriteFile(path, archive, 0644); err != nil { + return err + } + fmt.Printf("Downloaded Go %s [%x] into %s\n", version, hash, path) + return nil +} diff --git a/internal/build/util.go b/internal/build/util.go index 971d948c4..a1f456777 100644 --- a/internal/build/util.go +++ b/internal/build/util.go @@ -153,32 +153,6 @@ func GoTool(tool string, args ...string) *exec.Cmd { return exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), args...) } -// ExpandPackagesNoVendor expands a cmd/go import path pattern, skipping -// vendored packages. -func ExpandPackagesNoVendor(patterns []string) []string { - expand := false - for _, pkg := range patterns { - if strings.Contains(pkg, "...") { - expand = true - } - } - if expand { - cmd := GoTool("list", patterns...) - out, err := cmd.CombinedOutput() - if err != nil { - log.Fatalf("package listing failed: %v\n%s", err, string(out)) - } - var packages []string - for _, line := range strings.Split(string(out), "\n") { - if !strings.Contains(line, "/vendor/") { - packages = append(packages, strings.TrimSpace(line)) - } - } - return packages - } - return patterns -} - // UploadSFTP uploads files to a remote host using the sftp command line tool. // The destination host may be specified either as [user@]host[: or as a URI in // the form sftp://[user@]host[:port]. diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 354614d0a..ea7bb7fc8 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -530,8 +530,8 @@ func (s *PublicBlockChainAPI) BlockNumber() hexutil.Uint64 { // GetBalance returns the amount of wei for the given address in the state of the // given block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta // block numbers are also allowed. -func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (*hexutil.Big, error) { - state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) +func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (*hexutil.Big, error) { + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } @@ -555,8 +555,8 @@ type StorageResult struct { } // GetProof returns the Merkle-proof for a given account and optionally some storage keys. -func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNr rpc.BlockNumber) (*AccountResult, error) { - state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) +func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNrOrHash rpc.BlockNumberOrHash) (*AccountResult, error) { + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } @@ -712,8 +712,8 @@ func (s *PublicBlockChainAPI) GetUncleCountByBlockHash(ctx context.Context, bloc } // GetCode returns the code stored at the given address in the state for the given block number. -func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) +func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } @@ -724,8 +724,8 @@ func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Addres // GetStorageAt returns the storage from the state at the given address, key and // block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block // numbers are also allowed. -func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (hexutil.Bytes, error) { - state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) +func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } @@ -757,10 +757,10 @@ type account struct { StateDiff *map[common.Hash]common.Hash `json:"stateDiff"` } -func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumber, overrides map[common.Address]account, vmCfg vm.Config, timeout time.Duration, globalGasCap *big.Int) ([]byte, uint64, bool, error) { +func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides map[common.Address]account, vmCfg vm.Config, timeout time.Duration, globalGasCap *big.Int) ([]byte, uint64, bool, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) - state, header, err := b.StateAndHeaderByNumber(ctx, blockNr) + state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, 0, false, err } @@ -874,16 +874,16 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumb // // Note, this function doesn't make and changes in the state/blockchain and is // useful to execute and retrieve values. -func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber, overrides *map[common.Address]account) (hexutil.Bytes, error) { +func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *map[common.Address]account) (hexutil.Bytes, error) { var accounts map[common.Address]account if overrides != nil { accounts = *overrides } - result, _, _, err := DoCall(ctx, s.b, args, blockNr, accounts, vm.Config{}, 5*time.Second, s.b.RPCGasCap()) + result, _, _, err := DoCall(ctx, s.b, args, blockNrOrHash, accounts, vm.Config{}, 5*time.Second, s.b.RPCGasCap()) return (hexutil.Bytes)(result), err } -func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNr rpc.BlockNumber, gasCap *big.Int) (hexutil.Uint64, error) { +func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, gasCap *big.Int) (hexutil.Uint64, error) { // Binary search the gas requirement, as it may be higher than the amount used var ( lo uint64 = params.TxGas - 1 @@ -894,7 +894,7 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNr rpc.Bl hi = uint64(*args.Gas) } else { // Retrieve the block to act as the gas ceiling - block, err := b.BlockByNumber(ctx, blockNr) + block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) if err != nil { return 0, err } @@ -910,7 +910,7 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNr rpc.Bl executable := func(gas uint64) bool { args.Gas = (*hexutil.Uint64)(&gas) - _, _, failed, err := DoCall(ctx, b, args, rpc.PendingBlockNumber, nil, vm.Config{}, 0, gasCap) + _, _, failed, err := DoCall(ctx, b, args, blockNrOrHash, nil, vm.Config{}, 0, gasCap) if err != nil || failed { return false } @@ -937,7 +937,8 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNr rpc.Bl // EstimateGas returns an estimate of the amount of gas needed to execute the // given transaction against the current pending block. func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (hexutil.Uint64, error) { - return DoEstimateGas(ctx, s.b, args, rpc.PendingBlockNumber, s.b.RPCGasCap()) + blockNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + return DoEstimateGas(ctx, s.b, args, blockNrOrHash, s.b.RPCGasCap()) } // ExecutionResult groups all structured logs emitted by the EVM @@ -1224,9 +1225,9 @@ func (s *PublicTransactionPoolAPI) GetRawTransactionByBlockHashAndIndex(ctx cont } // GetTransactionCount returns the number of transactions the given address has sent for the given block number -func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (*hexutil.Uint64, error) { +func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (*hexutil.Uint64, error) { // Ask transaction pool for the nonce which includes pending transactions - if blockNr == rpc.PendingBlockNumber { + if blockNr, ok := blockNrOrHash.Number(); ok && blockNr == rpc.PendingBlockNumber { nonce, err := s.b.GetPoolNonce(ctx, address) if err != nil { return nil, err @@ -1234,7 +1235,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr return (*hexutil.Uint64)(&nonce), nil } // Resolve block number and use its state to ask for the nonce - state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) + state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } @@ -1405,7 +1406,8 @@ func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { Value: args.Value, Data: input, } - estimated, err := DoEstimateGas(ctx, b, callArgs, rpc.PendingBlockNumber, b.RPCGasCap()) + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + estimated, err := DoEstimateGas(ctx, b, callArgs, pendingBlockNr, b.RPCGasCap()) if err != nil { return err } diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 06c6db33b..73b6c89ce 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -52,9 +52,12 @@ type Backend interface { SetHead(number uint64) HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) + HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error) BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) + BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) + StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) GetTd(hash common.Hash) *big.Int GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header) (*vm.EVM, func() error, error) diff --git a/les/api_backend.go b/les/api_backend.go index 5cd432dcf..e01e1be98 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -65,6 +65,26 @@ func (b *LesApiBackend) HeaderByNumber(ctx context.Context, number rpc.BlockNumb return b.eth.blockchain.GetHeaderByNumberOdr(ctx, uint64(number)) } +func (b *LesApiBackend) HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.HeaderByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + header, err := b.HeaderByHash(ctx, hash) + if err != nil { + return nil, err + } + if header == nil { + return nil, errors.New("header for hash not found") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { + return nil, errors.New("hash is not currently canonical") + } + return header, nil + } + return nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *LesApiBackend) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) { return b.eth.blockchain.GetHeaderByHash(hash), nil } @@ -81,6 +101,26 @@ func (b *LesApiBackend) BlockByHash(ctx context.Context, hash common.Hash) (*typ return b.eth.blockchain.GetBlockByHash(ctx, hash) } +func (b *LesApiBackend) BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.BlockByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + block, err := b.BlockByHash(ctx, hash) + if err != nil { + return nil, err + } + if block == nil { + return nil, errors.New("header found, but block body is missing") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(block.NumberU64()) != hash { + return nil, errors.New("hash is not currently canonical") + } + return block, nil + } + return nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) { header, err := b.HeaderByNumber(ctx, number) if err != nil { @@ -92,6 +132,23 @@ func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B return light.NewState(ctx, header, b.eth.odr), header, nil } +func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { + if blockNr, ok := blockNrOrHash.Number(); ok { + return b.StateAndHeaderByNumber(ctx, blockNr) + } + if hash, ok := blockNrOrHash.Hash(); ok { + header := b.eth.blockchain.GetHeaderByHash(hash) + if header == nil { + return nil, nil, errors.New("header for hash not found") + } + if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { + return nil, nil, errors.New("hash is not currently canonical") + } + return light.NewState(ctx, header, b.eth.odr), header, nil + } + return nil, nil, errors.New("invalid arguments; neither block nor hash specified") +} + func (b *LesApiBackend) GetReceipts(ctx context.Context, hash common.Hash) (types.Receipts, error) { if number := rawdb.ReadHeaderNumber(b.eth.chainDb, hash); number != nil { return light.GetBlockReceipts(ctx, b.eth.odr, hash, *number) diff --git a/les/api_test.go b/les/api_test.go index 7d3b4ce5d..660af8eee 100644 --- a/les/api_test.go +++ b/les/api_test.go @@ -43,11 +43,25 @@ import ( "github.com/mattn/go-colorable" ) -/* -This test is not meant to be a part of the automatic testing process because it -runs for a long time and also requires a large database in order to do a meaningful -request performance test. When testServerDataDir is empty, the test is skipped. -*/ +// Additional command line flags for the test binary. +var ( + loglevel = flag.Int("loglevel", 0, "verbosity of logs") + simAdapter = flag.String("adapter", "exec", "type of simulation: sim|socket|exec|docker") +) + +func TestMain(m *testing.M) { + flag.Parse() + log.PrintOrigins(true) + log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) + // register the Delivery service which will run as a devp2p + // protocol when using the exec adapter + adapters.RegisterServices(services) + os.Exit(m.Run()) +} + +// This test is not meant to be a part of the automatic testing process because it +// runs for a long time and also requires a large database in order to do a meaningful +// request performance test. When testServerDataDir is empty, the test is skipped. const ( testServerDataDir = "" // should always be empty on the master branch @@ -377,29 +391,13 @@ func getCapacityInfo(ctx context.Context, t *testing.T, server *rpc.Client) (min return } -func init() { - flag.Parse() - // register the Delivery service which will run as a devp2p - // protocol when using the exec adapter - adapters.RegisterServices(services) - - log.PrintOrigins(true) - log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) -} - -var ( - adapter = flag.String("adapter", "exec", "type of simulation: sim|socket|exec|docker") - loglevel = flag.Int("loglevel", 0, "verbosity of logs") - nodes = flag.Int("nodes", 0, "number of nodes") -) - var services = adapters.Services{ "lesclient": newLesClientService, "lesserver": newLesServerService, } func NewNetwork() (*simulations.Network, func(), error) { - adapter, adapterTeardown, err := NewAdapter(*adapter, services) + adapter, adapterTeardown, err := NewAdapter(*simAdapter, services) if err != nil { return nil, adapterTeardown, err } diff --git a/les/balance.go b/les/balance.go index 4f08a304e..2813db01c 100644 --- a/les/balance.go +++ b/les/balance.go @@ -42,7 +42,7 @@ type balanceTracker struct { negTimeFactor, negRequestFactor float64 sumReqCost uint64 lastUpdate, nextUpdate, initTime mclock.AbsTime - updateEvent mclock.Event + updateEvent mclock.Timer // since only a limited and fixed number of callbacks are needed, they are // stored in a fixed size array ordered by priority threshold. callbacks [balanceCallbackCount]balanceCallback @@ -67,7 +67,7 @@ type balanceCallback struct { // init initializes balanceTracker func (bt *balanceTracker) init(clock mclock.Clock, capacity uint64) { bt.clock = clock - bt.initTime = clock.Now() + bt.initTime, bt.lastUpdate = clock.Now(), clock.Now() // Init timestamps for i := range bt.callbackIndex { bt.callbackIndex[i] = -1 } @@ -86,7 +86,7 @@ func (bt *balanceTracker) stop(now mclock.AbsTime) { bt.timeFactor = 0 bt.requestFactor = 0 if bt.updateEvent != nil { - bt.updateEvent.Cancel() + bt.updateEvent.Stop() bt.updateEvent = nil } } @@ -235,7 +235,7 @@ func (bt *balanceTracker) checkCallbacks(now mclock.AbsTime) { // updateAfter schedules a balance update and callback check in the future func (bt *balanceTracker) updateAfter(dt time.Duration) { - if bt.updateEvent == nil || bt.updateEvent.Cancel() { + if bt.updateEvent == nil || bt.updateEvent.Stop() { if dt == 0 { bt.updateEvent = nil } else { diff --git a/les/balance_test.go b/les/balance_test.go new file mode 100644 index 000000000..b571c2cc5 --- /dev/null +++ b/les/balance_test.go @@ -0,0 +1,260 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" +) + +func TestSetBalance(t *testing.T) { + var clock = &mclock.Simulated{} + var inputs = []struct { + pos uint64 + neg uint64 + }{ + {1000, 0}, + {0, 1000}, + {1000, 1000}, + } + + tracker := balanceTracker{} + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + + for _, i := range inputs { + tracker.setBalance(i.pos, i.neg) + pos, neg := tracker.getBalance(clock.Now()) + if pos != i.pos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.pos, pos) + } + if neg != i.neg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.neg, neg) + } + } +} + +func TestBalanceTimeCost(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) // 1 minute time allowance + + var inputs = []struct { + runTime time.Duration + expPos uint64 + expNeg uint64 + }{ + {time.Second, uint64(time.Second * 59), 0}, + {0, uint64(time.Second * 59), 0}, + {time.Second * 59, 0, 0}, + {time.Second, 0, uint64(time.Second)}, + } + for _, i := range inputs { + clock.Run(i.runTime) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } + + tracker.setBalance(uint64(time.Minute), 0) // Refill 1 minute time allowance + for _, i := range inputs { + clock.Run(i.runTime) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } +} + +func TestBalanceReqCost(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) // 1 minute time serving time allowance + var inputs = []struct { + reqCost uint64 + expPos uint64 + expNeg uint64 + }{ + {uint64(time.Second), uint64(time.Second * 59), 0}, + {0, uint64(time.Second * 59), 0}, + {uint64(time.Second * 59), 0, 0}, + {uint64(time.Second), 0, uint64(time.Second)}, + } + for _, i := range inputs { + tracker.requestCost(i.reqCost) + if pos, _ := tracker.getBalance(clock.Now()); pos != i.expPos { + t.Fatalf("Positive balance mismatch, want %v, got %v", i.expPos, pos) + } + if _, neg := tracker.getBalance(clock.Now()); neg != i.expNeg { + t.Fatalf("Negative balance mismatch, want %v, got %v", i.expNeg, neg) + } + } +} + +func TestBalanceToPriority(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) // cap = 1000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + var inputs = []struct { + pos uint64 + neg uint64 + priority int64 + }{ + {1000, 0, ^int64(1)}, + {2000, 0, ^int64(2)}, // Higher balance, lower priority value + {0, 0, 0}, + {0, 1000, 1000}, + } + for _, i := range inputs { + tracker.setBalance(i.pos, i.neg) + priority := tracker.getPriority(clock.Now()) + if priority != i.priority { + t.Fatalf("Priority mismatch, want %v, got %v", i.priority, priority) + } + } +} + +func TestEstimatedPriority(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000000000) // cap = 1000,000,000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + tracker.setBalance(uint64(time.Minute), 0) + var inputs = []struct { + runTime time.Duration // time cost + futureTime time.Duration // diff of future time + reqCost uint64 // single request cost + priority int64 // expected estimated priority + }{ + {time.Second, time.Second, 0, ^int64(58)}, + {0, time.Second, 0, ^int64(58)}, + + // 2 seconds time cost, 1 second estimated time cost, 10^9 request cost, + // 10^9 estimated request cost per second. + {time.Second, time.Second, 1000000000, ^int64(55)}, + + // 3 seconds time cost, 3 second estimated time cost, 10^9*2 request cost, + // 4*10^9 estimated request cost. + {time.Second, 3 * time.Second, 1000000000, ^int64(48)}, + + // All positive balance is used up + {time.Second * 55, 0, 0, 0}, + + // 1 minute estimated time cost, 4/58 * 10^9 estimated request cost per sec. + {0, time.Minute, 0, int64(time.Minute) + int64(time.Second)*120/29}, + } + for _, i := range inputs { + clock.Run(i.runTime) + tracker.requestCost(i.reqCost) + priority := tracker.estimatedPriority(clock.Now()+mclock.AbsTime(i.futureTime), true) + if priority != i.priority { + t.Fatalf("Estimated priority mismatch, want %v, got %v", i.priority, priority) + } + } +} + +func TestCallbackChecking(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000000) // cap = 1000,000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + var inputs = []struct { + priority int64 + expDiff time.Duration + }{ + {^int64(500), time.Millisecond * 500}, + {0, time.Second}, + {int64(time.Second), 2 * time.Second}, + } + tracker.setBalance(uint64(time.Second), 0) + for _, i := range inputs { + diff, _ := tracker.timeUntil(i.priority) + if diff != i.expDiff { + t.Fatalf("Time difference mismatch, want %v, got %v", i.expDiff, diff) + } + } +} + +func TestCallback(t *testing.T) { + var ( + clock = &mclock.Simulated{} + tracker = balanceTracker{} + ) + tracker.init(clock, 1000) // cap = 1000 + defer tracker.stop(clock.Now()) + tracker.setFactors(false, 1, 1) + tracker.setFactors(true, 1, 1) + + callCh := make(chan struct{}, 1) + tracker.setBalance(uint64(time.Minute), 0) + tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) + + clock.Run(time.Minute) + select { + case <-callCh: + case <-time.NewTimer(time.Second).C: + t.Fatalf("Callback hasn't been called yet") + } + + tracker.setBalance(uint64(time.Minute), 0) + tracker.addCallback(balanceCallbackZero, 0, func() { callCh <- struct{}{} }) + tracker.removeCallback(balanceCallbackZero) + + clock.Run(time.Minute) + select { + case <-callCh: + t.Fatalf("Callback shouldn't be called") + case <-time.NewTimer(time.Millisecond * 100).C: + } +} diff --git a/les/checkpointoracle.go b/les/checkpointoracle.go index 4695fbc16..5494e3d6d 100644 --- a/les/checkpointoracle.go +++ b/les/checkpointoracle.go @@ -35,11 +35,8 @@ type checkpointOracle struct { config *params.CheckpointOracleConfig contract *checkpointoracle.CheckpointOracle - // Whether the contract backend is set. - running int32 - - getLocal func(uint64) params.TrustedCheckpoint // Function used to retrieve local checkpoint - syncDoneHook func() // Function used to notify that light syncing has completed. + running int32 // Flag whether the contract backend is set or not + getLocal func(uint64) params.TrustedCheckpoint // Function used to retrieve local checkpoint } // newCheckpointOracle returns a checkpoint registrar handler. diff --git a/les/client_handler.go b/les/client_handler.go index aff05ddbc..7fdb16571 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -40,14 +40,16 @@ type clientHandler struct { downloader *downloader.Downloader backend *LightEthereum - closeCh chan struct{} - wg sync.WaitGroup // WaitGroup used to track all connected peers. + closeCh chan struct{} + wg sync.WaitGroup // WaitGroup used to track all connected peers. + syncDone func() // Test hooks when syncing is done. } func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ - backend: backend, - closeCh: make(chan struct{}), + checkpoint: checkpoint, + backend: backend, + closeCh: make(chan struct{}), } if ulcServers != nil { ulc, err := newULC(ulcServers, ulcFraction) diff --git a/les/clientpool.go b/les/clientpool.go index cff5f41ed..0b4d1b961 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -17,67 +17,81 @@ package les import ( + "encoding/binary" "io" "math" "sync" "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" + "github.com/hashicorp/golang-lru" ) const ( - negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance - fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format - connectedBias = time.Minute // this bias is applied in favor of already connected clients in order to avoid kicking them out very soon - lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue -) + negBalanceExpTC = time.Hour // time constant for exponentially reducing negative balance + fixedPointMultiplier = 0x1000000 // constant to convert logarithms to fixed point format + lazyQueueRefresh = time.Second * 10 // refresh period of the connected queue + persistCumulativeTimeRefresh = time.Minute * 5 // refresh period of the cumulative running time persistence + posBalanceCacheLimit = 8192 // the maximum number of cached items in positive balance queue + negBalanceCacheLimit = 8192 // the maximum number of cached items in negative balance queue -var ( - clientPoolDbKey = []byte("clientPool") - clientBalanceDbKey = []byte("clientPool-balance") + // connectedBias is applied to already connected clients So that + // already connected client won't be kicked out very soon and we + // can ensure all connected clients can have enough time to request + // or sync some data. + // + // todo(rjl493456442) make it configurable. It can be the option of + // free trial time! + connectedBias = time.Minute * 3 ) // clientPool implements a client database that assigns a priority to each client // based on a positive and negative balance. Positive balance is externally assigned // to prioritized clients and is decreased with connection time and processed // requests (unless the price factors are zero). If the positive balance is zero -// then negative balance is accumulated. Balance tracking and priority calculation -// for connected clients is done by balanceTracker. connectedQueue ensures that -// clients with the lowest positive or highest negative balance get evicted when -// the total capacity allowance is full and new clients with a better balance want -// to connect. Already connected nodes receive a small bias in their favor in order -// to avoid accepting and instantly kicking out clients. -// Balances of disconnected clients are stored in posBalanceQueue and negBalanceQueue -// and are also saved in the database. Negative balance is transformed into a -// logarithmic form with a constantly shifting linear offset in order to implement -// an exponential decrease. negBalanceQueue has a limited size and drops the smallest -// values when necessary. Positive balances are stored in the database as long as -// they exist, posBalanceQueue only acts as a cache for recently accessed entries. +// then negative balance is accumulated. +// +// Balance tracking and priority calculation for connected clients is done by +// balanceTracker. connectedQueue ensures that clients with the lowest positive or +// highest negative balance get evicted when the total capacity allowance is full +// and new clients with a better balance want to connect. +// +// Already connected nodes receive a small bias in their favor in order to avoid +// accepting and instantly kicking out clients. In theory, we try to ensure that +// each client can have several minutes of connection time. +// +// Balances of disconnected clients are stored in nodeDB including positive balance +// and negative banalce. Negative balance is transformed into a logarithmic form +// with a constantly shifting linear offset in order to implement an exponential +// decrease. Besides nodeDB will have a background thread to check the negative +// balance of disconnected client. If the balance is low enough, then the record +// will be dropped. type clientPool struct { - db ethdb.Database + ndb *nodeDB lock sync.Mutex clock mclock.Clock - stopCh chan chan struct{} + stopCh chan struct{} closed bool removePeer func(enode.ID) - queueLimit, countLimit int - freeClientCap, capacityLimit, connectedCapacity uint64 + connectedMap map[enode.ID]*clientInfo + connectedQueue *prque.LazyQueue - connectedMap map[enode.ID]*clientInfo - posBalanceMap map[enode.ID]*posBalance - negBalanceMap map[string]*negBalance - connectedQueue *prque.LazyQueue - posBalanceQueue, negBalanceQueue *prque.Prque - posFactors, negFactors priceFactors - posBalanceAccessCounter int64 - startupTime mclock.AbsTime - logOffsetAtStartup int64 + posFactors, negFactors priceFactors + + connLimit int // The maximum number of connections that clientpool can support + capLimit uint64 // The maximum cumulative capacity that clientpool can support + connectedCap uint64 // The sum of the capacity of the current clientpool connected + freeClientCap uint64 // The capacity value of each free client + startTime mclock.AbsTime // The timestamp at which the clientpool started running + cumulativeTime int64 // The cumulative running time of clientpool at the start point. + disableBias bool // Disable connection bias(used in testing) } // clientPeer represents a client in the pool. @@ -138,22 +152,25 @@ type priceFactors struct { } // newClientPool creates a new client pool -func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { +func newClientPool(db ethdb.Database, freeClientCap uint64, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { + ndb := newNodeDB(db, clock) pool := &clientPool{ - db: db, - clock: clock, - connectedMap: make(map[enode.ID]*clientInfo), - posBalanceMap: make(map[enode.ID]*posBalance), - negBalanceMap: make(map[string]*negBalance), - connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), - negBalanceQueue: prque.New(negSetIndex), - posBalanceQueue: prque.New(posSetIndex), - freeClientCap: freeClientCap, - queueLimit: queueLimit, - removePeer: removePeer, - stopCh: make(chan chan struct{}), + ndb: ndb, + clock: clock, + connectedMap: make(map[enode.ID]*clientInfo), + connectedQueue: prque.NewLazyQueue(connSetIndex, connPriority, connMaxPriority, clock, lazyQueueRefresh), + freeClientCap: freeClientCap, + removePeer: removePeer, + startTime: clock.Now(), + cumulativeTime: ndb.getCumulativeTime(), + stopCh: make(chan struct{}), + } + // If the negative balance of free client is even lower than 1, + // delete this entry. + ndb.nbEvictCallBack = func(now mclock.AbsTime, b negBalance) bool { + balance := math.Exp(float64(b.logValue-pool.logOffset(now)) / fixedPointMultiplier) + return balance <= 1 } - pool.loadFromDb() go func() { for { select { @@ -161,8 +178,9 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, cloc pool.lock.Lock() pool.connectedQueue.Refresh() pool.lock.Unlock() - case stop := <-pool.stopCh: - close(stop) + case <-clock.After(persistCumulativeTimeRefresh): + pool.ndb.setCumulativeTime(pool.logOffset(clock.Now())) + case <-pool.stopCh: return } } @@ -172,13 +190,12 @@ func newClientPool(db ethdb.Database, freeClientCap uint64, queueLimit int, cloc // stop shuts the client pool down func (f *clientPool) stop() { - stop := make(chan struct{}) - f.stopCh <- stop - <-stop + close(f.stopCh) f.lock.Lock() f.closed = true - f.saveToDb() f.lock.Unlock() + f.ndb.setCumulativeTime(f.logOffset(f.clock.Now())) + f.ndb.close() } // connect should be called after a successful handshake. If the connection was @@ -187,7 +204,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { f.lock.Lock() defer f.lock.Unlock() - // Short circuit is clientPool is already closed. + // Short circuit if clientPool is already closed. if f.closed { return false } @@ -199,14 +216,19 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { return false } // Create a clientInfo but do not add it yet - now := f.clock.Now() - posBalance := f.getPosBalance(id).value + var ( + posBalance uint64 + negBalance uint64 + now = f.clock.Now() + ) + pb := f.ndb.getOrNewPB(id) + posBalance = pb.value e := &clientInfo{pool: f, peer: peer, address: freeID, queueIndex: -1, id: id, priority: posBalance != 0} - var negBalance uint64 - nb := f.negBalanceMap[freeID] - if nb != nil { + nb := f.ndb.getOrNewNB(freeID) + if nb.logValue != 0 { negBalance = uint64(math.Exp(float64(nb.logValue-f.logOffset(now)) / fixedPointMultiplier)) + negBalance *= uint64(time.Second) } // If the client is a free client, assign with a low free capacity, // Otherwise assign with the given value(priority client) @@ -219,6 +241,7 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { } e.capacity = capacity + // Starts a balance tracker e.balanceTracker.init(f.clock, capacity) e.balanceTracker.setBalance(posBalance, negBalance) f.setClientPriceFactors(e) @@ -228,9 +251,9 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { // // If the priority of the newly added client is lower than the priority of // all connected clients, the client is rejected. - newCapacity := f.connectedCapacity + capacity + newCapacity := f.connectedCap + capacity newCount := f.connectedQueue.Size() + 1 - if newCapacity > f.capacityLimit || newCount > f.countLimit { + if newCapacity > f.capLimit || newCount > f.connLimit { var ( kickList []*clientInfo kickPriority int64 @@ -241,10 +264,13 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { kickPriority = priority newCapacity -= c.capacity newCount-- - return newCapacity > f.capacityLimit || newCount > f.countLimit + return newCapacity > f.capLimit || newCount > f.connLimit }) - if newCapacity > f.capacityLimit || newCount > f.countLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(connectedBias), false)-kickPriority) > 0 { - // reject client + bias := connectedBias + if f.disableBias { + bias = 0 + } + if newCapacity > f.capLimit || newCount > f.connLimit || (e.balanceTracker.estimatedPriority(now+mclock.AbsTime(bias), false)-kickPriority) > 0 { for _, c := range kickList { f.connectedQueue.Push(c) } @@ -257,21 +283,22 @@ func (f *clientPool) connect(peer clientPeer, capacity uint64) bool { f.dropClient(c, now, true) } } - // client accepted, finish setting it up - if nb != nil { - delete(f.negBalanceMap, freeID) - f.negBalanceQueue.Remove(nb.queueIndex) - } + // Register new client to connection queue. + f.connectedMap[id] = e + f.connectedQueue.Push(e) + f.connectedCap += e.capacity + + // If the current client is a paid client, monitor the status of client, + // downgrade it to normal client if positive balance is used up. if e.priority { e.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) } - f.connectedMap[id] = e - f.connectedQueue.Push(e) - f.connectedCapacity += e.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + // If the capacity of client is not the default value(free capacity), notify + // it to update capacity. if e.capacity != f.freeClientCap { e.peer.updateCapacity(e.capacity) } + totalConnectedGauge.Update(int64(f.connectedCap)) clientConnectedMeter.Mark(1) log.Debug("Client accepted", "address", freeID) return true @@ -284,15 +311,14 @@ func (f *clientPool) disconnect(p clientPeer) { f.lock.Lock() defer f.lock.Unlock() + // Short circuit if client pool is already closed. if f.closed { return } - address := p.freeClientId() - id := p.ID() // Short circuit if the peer hasn't been registered. - e := f.connectedMap[id] + e := f.connectedMap[p.ID()] if e == nil { - log.Debug("Client not connected", "address", address, "id", peerIdToString(id)) + log.Debug("Client not connected", "address", p.freeClientId(), "id", peerIdToString(p.ID())) return } f.dropClient(e, f.clock.Now(), false) @@ -307,8 +333,8 @@ func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { f.finalizeBalance(e, now) f.connectedQueue.Remove(e.queueIndex) delete(f.connectedMap, e.id) - f.connectedCapacity -= e.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + f.connectedCap -= e.capacity + totalConnectedGauge.Update(int64(f.connectedCap)) if kick { clientKickedMeter.Mark(1) log.Debug("Client kicked out", "address", e.address) @@ -324,18 +350,17 @@ func (f *clientPool) dropClient(e *clientInfo, now mclock.AbsTime, kick bool) { func (f *clientPool) finalizeBalance(c *clientInfo, now mclock.AbsTime) { c.balanceTracker.stop(now) pos, neg := c.balanceTracker.getBalance(now) - pb := f.getPosBalance(c.id) + + pb, nb := f.ndb.getOrNewPB(c.id), f.ndb.getOrNewNB(c.address) pb.value = pos - f.storePosBalance(pb) - if neg < 1 { - neg = 1 - } - nb := &negBalance{address: c.address, queueIndex: -1, logValue: int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now)} - f.negBalanceMap[c.address] = nb - f.negBalanceQueue.Push(nb, -nb.logValue) - if f.negBalanceQueue.Size() > f.queueLimit { - nn := f.negBalanceQueue.PopItem().(*negBalance) - delete(f.negBalanceMap, nn.address) + f.ndb.setPB(c.id, pb) + + neg /= uint64(time.Second) // Convert the expanse to second level. + if neg > 1 { + nb.logValue = int64(math.Log(float64(neg))*fixedPointMultiplier) + f.logOffset(now) + f.ndb.setNB(c.address, nb) + } else { + f.ndb.delNB(c.address) // Negative balance is small enough, drop it directly. } } @@ -351,27 +376,28 @@ func (f *clientPool) balanceExhausted(id enode.ID) { } c.priority = false if c.capacity != f.freeClientCap { - f.connectedCapacity += f.freeClientCap - c.capacity - totalConnectedGauge.Update(int64(f.connectedCapacity)) + f.connectedCap += f.freeClientCap - c.capacity + totalConnectedGauge.Update(int64(f.connectedCap)) c.capacity = f.freeClientCap c.peer.updateCapacity(c.capacity) } + f.ndb.delPB(id) } // setConnLimit sets the maximum number and total capacity of connected clients, // dropping some of them if necessary. -func (f *clientPool) setLimits(count int, totalCap uint64) { +func (f *clientPool) setLimits(totalConn int, totalCap uint64) { f.lock.Lock() defer f.lock.Unlock() - f.countLimit = count - f.capacityLimit = totalCap - now := mclock.Now() - f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { - c := data.(*clientInfo) - f.dropClient(c, now, true) - return f.connectedCapacity > f.capacityLimit || f.connectedQueue.Size() > f.countLimit - }) + f.connLimit = totalConn + f.capLimit = totalCap + if f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit { + f.connectedQueue.MultiPop(func(data interface{}, priority int64) bool { + f.dropClient(data.(*clientInfo), mclock.Now(), true) + return f.connectedCap > f.capLimit || f.connectedQueue.Size() > f.connLimit + }) + } } // requestCost feeds request cost after serving a request from the given peer. @@ -388,11 +414,14 @@ func (f *clientPool) requestCost(p *peer, cost uint64) { // logOffset calculates the time-dependent offset for the logarithmic // representation of negative balance +// +// From another point of view, the result returned by the function represents +// the total time that the clientpool is cumulatively running(total_hours/multiplier). func (f *clientPool) logOffset(now mclock.AbsTime) int64 { // Note: fixedPointMultiplier acts as a multiplier here; the reason for dividing the divisor // is to avoid int64 overflow. We assume that int64(negBalanceExpTC) >> fixedPointMultiplier. - logDecay := int64((time.Duration(now - f.startupTime)) / (negBalanceExpTC / fixedPointMultiplier)) - return f.logOffsetAtStartup + logDecay + cumulativeTime := int64((time.Duration(now - f.startTime)) / (negBalanceExpTC / fixedPointMultiplier)) + return f.cumulativeTime + cumulativeTime } // setPriceFactors changes pricing factors for both positive and negative balances. @@ -413,100 +442,6 @@ func (f *clientPool) setClientPriceFactors(c *clientInfo) { c.balanceTracker.setFactors(false, f.posFactors.timeFactor+float64(c.capacity)*f.posFactors.capacityFactor/1000000, f.posFactors.requestFactor) } -// clientPoolStorage is the RLP representation of the pool's database storage -type clientPoolStorage struct { - LogOffset uint64 - List []*negBalance -} - -// loadFromDb restores pool status from the database storage -// (automatically called at initialization) -func (f *clientPool) loadFromDb() { - enc, err := f.db.Get(clientPoolDbKey) - if err != nil { - return - } - var storage clientPoolStorage - err = rlp.DecodeBytes(enc, &storage) - if err != nil { - log.Error("Failed to decode client list", "err", err) - return - } - f.logOffsetAtStartup = int64(storage.LogOffset) - f.startupTime = f.clock.Now() - for _, e := range storage.List { - log.Debug("Loaded free client record", "address", e.address, "logValue", e.logValue) - f.negBalanceMap[e.address] = e - f.negBalanceQueue.Push(e, -e.logValue) - } -} - -// saveToDb saves pool status to the database storage -// (automatically called during shutdown) -func (f *clientPool) saveToDb() { - now := f.clock.Now() - storage := clientPoolStorage{ - LogOffset: uint64(f.logOffset(now)), - } - for _, c := range f.connectedMap { - f.finalizeBalance(c, now) - } - i := 0 - storage.List = make([]*negBalance, len(f.negBalanceMap)) - for _, e := range f.negBalanceMap { - storage.List[i] = e - i++ - } - enc, err := rlp.EncodeToBytes(storage) - if err != nil { - log.Error("Failed to encode negative balance list", "err", err) - } else { - f.db.Put(clientPoolDbKey, enc) - } -} - -// storePosBalance stores a single positive balance entry in the database -func (f *clientPool) storePosBalance(b *posBalance) { - if b.value == b.lastStored { - return - } - enc, err := rlp.EncodeToBytes(b) - if err != nil { - log.Error("Failed to encode client balance", "err", err) - } else { - f.db.Put(append(clientBalanceDbKey, b.id[:]...), enc) - b.lastStored = b.value - } -} - -// getPosBalance retrieves a single positive balance entry from cache or the database -func (f *clientPool) getPosBalance(id enode.ID) *posBalance { - if b, ok := f.posBalanceMap[id]; ok { - f.posBalanceQueue.Remove(b.queueIndex) - f.posBalanceAccessCounter-- - f.posBalanceQueue.Push(b, f.posBalanceAccessCounter) - return b - } - balance := &posBalance{} - if enc, err := f.db.Get(append(clientBalanceDbKey, id[:]...)); err == nil { - if err := rlp.DecodeBytes(enc, balance); err != nil { - log.Error("Failed to decode client balance", "err", err) - balance = &posBalance{} - } - } - balance.id = id - balance.queueIndex = -1 - if f.posBalanceQueue.Size() >= f.queueLimit { - b := f.posBalanceQueue.PopItem().(*posBalance) - f.storePosBalance(b) - delete(f.posBalanceMap, b.id) - } - f.posBalanceAccessCounter-- - f.posBalanceQueue.Push(balance, f.posBalanceAccessCounter) - f.posBalanceMap[id] = balance - return balance -} - // addBalance updates the positive balance of a client. // If setTotal is false then the given amount is added to the balance. // If setTotal is true then amount represents the total amount ever added to the @@ -516,11 +451,21 @@ func (f *clientPool) addBalance(id enode.ID, amount uint64, setTotal bool) { f.lock.Lock() defer f.lock.Unlock() - pb := f.getPosBalance(id) + pb := f.ndb.getOrNewPB(id) c := f.connectedMap[id] - var negBalance uint64 if c != nil { - pb.value, negBalance = c.balanceTracker.getBalance(f.clock.Now()) + posBalance, negBalance := c.balanceTracker.getBalance(f.clock.Now()) + pb.value = posBalance + defer func() { + c.balanceTracker.setBalance(pb.value, negBalance) + if !c.priority && pb.value > 0 { + // The capacity should be adjusted based on the requirement, + // but we have no idea about the new capacity, need a second + // call to udpate it. + c.priority = true + c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) + } + }() } if setTotal { if pb.value+amount > pb.lastTotal { @@ -533,21 +478,12 @@ func (f *clientPool) addBalance(id enode.ID, amount uint64, setTotal bool) { pb.value += amount pb.lastTotal += amount } - f.storePosBalance(pb) - if c != nil { - c.balanceTracker.setBalance(pb.value, negBalance) - if !c.priority && pb.value > 0 { - c.priority = true - c.balanceTracker.addCallback(balanceCallbackZero, 0, func() { f.balanceExhausted(id) }) - } - } + f.ndb.setPB(id, pb) } // posBalance represents a recently accessed positive balance entry type posBalance struct { - id enode.ID - value, lastStored, lastTotal uint64 - queueIndex int // position in posBalanceQueue + value, lastTotal uint64 } // EncodeRLP implements rlp.Encoder @@ -564,44 +500,207 @@ func (e *posBalance) DecodeRLP(s *rlp.Stream) error { return err } e.value = entry.Value - e.lastStored = entry.Value e.lastTotal = entry.LastTotal return nil } -// posSetIndex callback updates posBalance item index in posBalanceQueue -func posSetIndex(a interface{}, index int) { - a.(*posBalance).queueIndex = index -} - // negBalance represents a negative balance entry of a disconnected client -type negBalance struct { - address string - logValue int64 - queueIndex int // position in negBalanceQueue -} +type negBalance struct{ logValue int64 } // EncodeRLP implements rlp.Encoder func (e *negBalance) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, []interface{}{e.address, uint64(e.logValue)}) + return rlp.Encode(w, []interface{}{uint64(e.logValue)}) } // DecodeRLP implements rlp.Decoder func (e *negBalance) DecodeRLP(s *rlp.Stream) error { var entry struct { - Address string LogValue uint64 } if err := s.Decode(&entry); err != nil { return err } - e.address = entry.Address e.logValue = int64(entry.LogValue) - e.queueIndex = -1 return nil } -// negSetIndex callback updates negBalance item index in negBalanceQueue -func negSetIndex(a interface{}, index int) { - a.(*negBalance).queueIndex = index +const ( + // nodeDBVersion is the version identifier of the node data in db + nodeDBVersion = 0 + + // dbCleanupCycle is the cycle of db for useless data cleanup + dbCleanupCycle = time.Hour +) + +var ( + positiveBalancePrefix = []byte("pb:") // dbVersion(uint16 big endian) + positiveBalancePrefix + id -> balance + negativeBalancePrefix = []byte("nb:") // dbVersion(uint16 big endian) + negativeBalancePrefix + ip -> balance + cumulativeRunningTimeKey = []byte("cumulativeTime:") // dbVersion(uint16 big endian) + cumulativeRunningTimeKey -> cumulativeTime +) + +type nodeDB struct { + db ethdb.Database + pcache *lru.Cache + ncache *lru.Cache + auxbuf []byte // 37-byte auxiliary buffer for key encoding + verbuf [2]byte // 2-byte auxiliary buffer for db version + nbEvictCallBack func(mclock.AbsTime, negBalance) bool // Callback to determine whether the negative balance can be evicted. + clock mclock.Clock + closeCh chan struct{} + cleanupHook func() // Test hook used for testing +} + +func newNodeDB(db ethdb.Database, clock mclock.Clock) *nodeDB { + pcache, _ := lru.New(posBalanceCacheLimit) + ncache, _ := lru.New(negBalanceCacheLimit) + ndb := &nodeDB{ + db: db, + pcache: pcache, + ncache: ncache, + auxbuf: make([]byte, 37), + clock: clock, + closeCh: make(chan struct{}), + } + binary.BigEndian.PutUint16(ndb.verbuf[:], uint16(nodeDBVersion)) + go ndb.expirer() + return ndb +} + +func (db *nodeDB) close() { + close(db.closeCh) +} + +func (db *nodeDB) key(id []byte, neg bool) []byte { + prefix := positiveBalancePrefix + if neg { + prefix = negativeBalancePrefix + } + if len(prefix)+len(db.verbuf)+len(id) > len(db.auxbuf) { + db.auxbuf = append(db.auxbuf, make([]byte, len(prefix)+len(db.verbuf)+len(id)-len(db.auxbuf))...) + } + copy(db.auxbuf[:len(db.verbuf)], db.verbuf[:]) + copy(db.auxbuf[len(db.verbuf):len(db.verbuf)+len(prefix)], prefix) + copy(db.auxbuf[len(prefix)+len(db.verbuf):len(prefix)+len(db.verbuf)+len(id)], id) + return db.auxbuf[:len(prefix)+len(db.verbuf)+len(id)] +} + +func (db *nodeDB) getCumulativeTime() int64 { + blob, err := db.db.Get(append(cumulativeRunningTimeKey, db.verbuf[:]...)) + if err != nil || len(blob) == 0 { + return 0 + } + return int64(binary.BigEndian.Uint64(blob)) +} + +func (db *nodeDB) setCumulativeTime(v int64) { + binary.BigEndian.PutUint64(db.auxbuf[:8], uint64(v)) + db.db.Put(append(cumulativeRunningTimeKey, db.verbuf[:]...), db.auxbuf[:8]) +} + +func (db *nodeDB) getOrNewPB(id enode.ID) posBalance { + key := db.key(id.Bytes(), false) + item, exist := db.pcache.Get(string(key)) + if exist { + return item.(posBalance) + } + var balance posBalance + if enc, err := db.db.Get(key); err == nil { + if err := rlp.DecodeBytes(enc, &balance); err != nil { + log.Error("Failed to decode positive balance", "err", err) + } + } + db.pcache.Add(string(key), balance) + return balance +} + +func (db *nodeDB) setPB(id enode.ID, b posBalance) { + key := db.key(id.Bytes(), false) + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Error("Failed to encode positive balance", "err", err) + return + } + db.db.Put(key, enc) + db.pcache.Add(string(key), b) +} + +func (db *nodeDB) delPB(id enode.ID) { + key := db.key(id.Bytes(), false) + db.db.Delete(key) + db.pcache.Remove(string(key)) +} + +func (db *nodeDB) getOrNewNB(id string) negBalance { + key := db.key([]byte(id), true) + item, exist := db.ncache.Get(string(key)) + if exist { + return item.(negBalance) + } + var balance negBalance + if enc, err := db.db.Get(key); err == nil { + if err := rlp.DecodeBytes(enc, &balance); err != nil { + log.Error("Failed to decode negative balance", "err", err) + } + } + db.ncache.Add(string(key), balance) + return balance +} + +func (db *nodeDB) setNB(id string, b negBalance) { + key := db.key([]byte(id), true) + enc, err := rlp.EncodeToBytes(&(b)) + if err != nil { + log.Error("Failed to encode negative balance", "err", err) + return + } + db.db.Put(key, enc) + db.ncache.Add(string(key), b) +} + +func (db *nodeDB) delNB(id string) { + key := db.key([]byte(id), true) + db.db.Delete(key) + db.ncache.Remove(string(key)) +} + +func (db *nodeDB) expirer() { + for { + select { + case <-db.clock.After(dbCleanupCycle): + db.expireNodes() + case <-db.closeCh: + return + } + } +} + +// expireNodes iterates the whole node db and checks whether the negative balance +// entry can deleted. +// +// The rationale behind this is: server doesn't need to keep the negative balance +// records if they are low enough. +func (db *nodeDB) expireNodes() { + var ( + visited int + deleted int + start = time.Now() + ) + iter := db.db.NewIteratorWithPrefix(append(db.verbuf[:], negativeBalancePrefix...)) + for iter.Next() { + visited += 1 + var balance negBalance + if err := rlp.DecodeBytes(iter.Value(), &balance); err != nil { + log.Error("Failed to decode negative balance", "err", err) + continue + } + if db.nbEvictCallBack != nil && db.nbEvictCallBack(db.clock.Now(), balance) { + deleted += 1 + db.db.Delete(iter.Key()) + } + } + // Invoke testing hook if it's not nil. + if db.cleanupHook != nil { + db.cleanupHook() + } + log.Debug("Expire nodes", "visited", visited, "deleted", deleted, "elapsed", common.PrettyDuration(time.Since(start))) } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index 225f828ec..53973696c 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -17,8 +17,11 @@ package les import ( + "bytes" "fmt" + "math" "math/rand" + "reflect" "testing" "time" @@ -51,7 +54,7 @@ func TestClientPoolL100C300P20(t *testing.T) { testClientPool(t, 100, 300, 20, false) } -const testClientPoolTicks = 500000 +const testClientPoolTicks = 100000 type poolTestPeer int @@ -65,6 +68,14 @@ func (i poolTestPeer) freeClientId() string { func (i poolTestPeer) updateCapacity(uint64) {} +type poolTestPeerWithCap struct { + poolTestPeer + + cap uint64 +} + +func (i *poolTestPeerWithCap) updateCapacity(cap uint64) { i.cap = cap } + func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomDisconnect bool) { rand.Seed(time.Now().UnixNano()) var ( @@ -76,8 +87,9 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, 10000, &clock, disconnFn) + pool = newClientPool(db, 1, &clock, disconnFn) ) + pool.disableBias = true pool.setLimits(connLimit, uint64(connLimit)) pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) @@ -89,16 +101,9 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD t.Fatalf("Test peer #%d rejected", i) } } - // since all accepted peers are new and should not be kicked out, the next one should be rejected - if pool.connect(poolTestPeer(connLimit), 0) { - connected[connLimit] = true - t.Fatalf("Peer accepted over connected limit") - } - // randomly connect and disconnect peers, expect to have a similar total connection time at the end for tickCounter := 0; tickCounter < testClientPoolTicks; tickCounter++ { clock.Run(1 * time.Second) - //time.Sleep(time.Microsecond * 100) if tickCounter == testClientPoolTicks/4 { // give a positive balance to some of the peers @@ -137,11 +142,11 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD } expTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2*(connLimit-paidCount)/(clientCount-paidCount) - expMin := expTicks - expTicks/10 - expMax := expTicks + expTicks/10 + expMin := expTicks - expTicks/5 + expMax := expTicks + expTicks/5 paidTicks := testClientPoolTicks/2*connLimit/clientCount + testClientPoolTicks/2 - paidMin := paidTicks - paidTicks/10 - paidMax := paidTicks + paidTicks/10 + paidMin := paidTicks - paidTicks/5 + paidMax := paidTicks + paidTicks/5 // check if the total connected time of peers are all in the expected range for i, c := range connected { @@ -157,24 +162,380 @@ func testClientPool(t *testing.T, connLimit, clientCount, paidCount int, randomD t.Errorf("Total connected time of test node #%d (%d) outside expected range (%d to %d)", i, connTicks[i], min, max) } } - - // a previously unknown peer should be accepted now - if !pool.connect(poolTestPeer(54321), 0) { - t.Fatalf("Previously unknown peer rejected") - } - - // close and restart pool - pool.stop() - pool = newClientPool(db, 1, 10000, &clock, func(id enode.ID) {}) - pool.setLimits(connLimit, uint64(connLimit)) - - // try connecting all known peers (connLimit should be filled up) - for i := 0; i < clientCount; i++ { - pool.connect(poolTestPeer(i), 0) - } - // expect pool to remember known nodes and kick out one of them to accept a new one - if !pool.connect(poolTestPeer(54322), 0) { - t.Errorf("Previously unknown peer rejected after restarting pool") - } pool.stop() } + +func TestConnectPaidClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + // Add balance for an external client and mark it as paid client + pool.addBalance(poolTestPeer(0).ID(), 1000, false) + + if !pool.connect(poolTestPeer(0), 10) { + t.Fatalf("Failed to connect paid client") + } +} + +func TestConnectPaidClientToSmallPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + // Add balance for an external client and mark it as paid client + pool.addBalance(poolTestPeer(0).ID(), 1000, false) + + // Connect a fat paid client to pool, should reject it. + if pool.connect(poolTestPeer(0), 100) { + t.Fatalf("Connected fat paid client, should reject it") + } +} + +func TestConnectPaidClientToFullPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + removeFn := func(enode.ID) {} // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.addBalance(poolTestPeer(i).ID(), 1000000000, false) + pool.connect(poolTestPeer(i), 1) + } + pool.addBalance(poolTestPeer(11).ID(), 1000, false) // Add low balance to new paid client + if pool.connect(poolTestPeer(11), 1) { + t.Fatalf("Low balance paid client should be rejected") + } + clock.Run(time.Second) + pool.addBalance(poolTestPeer(12).ID(), 1000000000*60*3, false) // Add high balance to new paid client + if !pool.connect(poolTestPeer(12), 1) { + t.Fatalf("High balance paid client should be accpected") + } +} + +func TestPaidClientKickedOut(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kickedCh = make(chan int, 1) + ) + removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.addBalance(poolTestPeer(i).ID(), 1000000000, false) // 1 second allowance + pool.connect(poolTestPeer(i), 1) + clock.Run(time.Millisecond) + } + clock.Run(time.Second) + clock.Run(connectedBias) + if !pool.connect(poolTestPeer(11), 0) { + t.Fatalf("Free client should be accectped") + } + select { + case id := <-kickedCh: + if id != 0 { + t.Fatalf("Kicked client mismatch, want %v, got %v", 0, id) + } + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } +} + +func TestConnectFreeClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + pool := newClientPool(db, 1, &clock, nil) + defer pool.stop() + pool.setLimits(10, uint64(10)) + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + if !pool.connect(poolTestPeer(0), 10) { + t.Fatalf("Failed to connect free client") + } +} + +func TestConnectFreeClientToFullPool(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + ) + removeFn := func(enode.ID) {} // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + if pool.connect(poolTestPeer(11), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(time.Minute) + if pool.connect(poolTestPeer(12), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(time.Millisecond) + clock.Run(4 * time.Minute) + if !pool.connect(poolTestPeer(13), 1) { + t.Fatalf("Old client connects more than 5min should be kicked") + } +} + +func TestFreeClientKickedOut(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + clock.Run(time.Millisecond) + } + if pool.connect(poolTestPeer(10), 1) { + t.Fatalf("New free client should be rejected") + } + clock.Run(5 * time.Minute) + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i+10), 1) + } + for i := 0; i < 10; i++ { + select { + case id := <-kicked: + if id >= 10 { + t.Fatalf("Old client should be kicked, now got: %d", id) + } + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + } +} + +func TestPositiveBalanceCalculation(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + pool.addBalance(poolTestPeer(0).ID(), uint64(time.Minute*3), false) + pool.connect(poolTestPeer(0), 10) + clock.Run(time.Minute) + + pool.disconnect(poolTestPeer(0)) + pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != uint64(time.Minute*2) { + t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute*2), pb.value) + } +} + +func TestDowngradePriorityClient(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + p := &poolTestPeerWithCap{ + poolTestPeer: poolTestPeer(0), + } + pool.addBalance(p.ID(), uint64(time.Minute), false) + pool.connect(p, 10) + if p.cap != 10 { + t.Fatalf("The capcacity of priority peer hasn't been updated, got: %d", p.cap) + } + + clock.Run(time.Minute) // All positive balance should be used up. + time.Sleep(300 * time.Millisecond) // Ensure the callback is called + if p.cap != 1 { + t.Fatalf("The capcacity of peer should be downgraded, got: %d", p.cap) + } + pb := pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != 0 { + t.Fatalf("Positive balance mismatch, want %v, got %v", 0, pb.value) + } + + pool.addBalance(poolTestPeer(0).ID(), uint64(time.Minute), false) + pb = pool.ndb.getOrNewPB(poolTestPeer(0).ID()) + if pb.value != uint64(time.Minute) { + t.Fatalf("Positive balance mismatch, want %v, got %v", uint64(time.Minute), pb.value) + } +} + +func TestNegativeBalanceCalculation(t *testing.T) { + var ( + clock mclock.Simulated + db = rawdb.NewMemoryDatabase() + kicked = make(chan int, 10) + ) + removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop + pool := newClientPool(db, 1, &clock, removeFn) + defer pool.stop() + pool.setLimits(10, uint64(10)) // Total capacity limit is 10 + pool.setPriceFactors(priceFactors{1, 0, 1}, priceFactors{1, 0, 1}) + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + clock.Run(time.Second) + + for i := 0; i < 10; i++ { + pool.disconnect(poolTestPeer(i)) + nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + if nb.logValue != 0 { + t.Fatalf("Short connection shouldn't be recorded") + } + } + + for i := 0; i < 10; i++ { + pool.connect(poolTestPeer(i), 1) + } + clock.Run(time.Minute) + for i := 0; i < 10; i++ { + pool.disconnect(poolTestPeer(i)) + nb := pool.ndb.getOrNewNB(poolTestPeer(i).freeClientId()) + nb.logValue -= pool.logOffset(clock.Now()) + nb.logValue /= fixedPointMultiplier + if nb.logValue != int64(math.Log(float64(time.Minute/time.Second))) { + t.Fatalf("Negative balance mismatch, want %v, got %v", int64(math.Log(float64(time.Minute/time.Second))), nb.logValue) + } + } +} + +func TestNodeDB(t *testing.T) { + ndb := newNodeDB(rawdb.NewMemoryDatabase(), mclock.System{}) + defer ndb.close() + + if !bytes.Equal(ndb.verbuf[:], []byte{0x00, 0x00}) { + t.Fatalf("version buffer mismatch, want %v, got %v", []byte{0x00, 0x00}, ndb.verbuf) + } + var cases = []struct { + id enode.ID + ip string + balance interface{} + positive bool + }{ + {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 100, lastTotal: 200}, true}, + {enode.ID{0x00, 0x01, 0x02}, "", posBalance{value: 200, lastTotal: 300}, true}, + {enode.ID{}, "127.0.0.1", negBalance{logValue: 10}, false}, + {enode.ID{}, "127.0.0.1", negBalance{logValue: 20}, false}, + } + for _, c := range cases { + if c.positive { + ndb.setPB(c.id, c.balance.(posBalance)) + if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, c.balance.(posBalance)) { + t.Fatalf("Positive balance mismatch, want %v, got %v", c.balance.(posBalance), pb) + } + } else { + ndb.setNB(c.ip, c.balance.(negBalance)) + if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, c.balance.(negBalance)) { + t.Fatalf("Negative balance mismatch, want %v, got %v", c.balance.(negBalance), nb) + } + } + } + for _, c := range cases { + if c.positive { + ndb.delPB(c.id) + if pb := ndb.getOrNewPB(c.id); !reflect.DeepEqual(pb, posBalance{}) { + t.Fatalf("Positive balance mismatch, want %v, got %v", posBalance{}, pb) + } + } else { + ndb.delNB(c.ip) + if nb := ndb.getOrNewNB(c.ip); !reflect.DeepEqual(nb, negBalance{}) { + t.Fatalf("Negative balance mismatch, want %v, got %v", negBalance{}, nb) + } + } + } + ndb.setCumulativeTime(100) + if ndb.getCumulativeTime() != 100 { + t.Fatalf("Cumulative time mismatch, want %v, got %v", 100, ndb.getCumulativeTime()) + } +} + +func TestNodeDBExpiration(t *testing.T) { + var ( + iterated int + done = make(chan struct{}, 1) + ) + callback := func(now mclock.AbsTime, b negBalance) bool { + iterated += 1 + return true + } + clock := &mclock.Simulated{} + ndb := newNodeDB(rawdb.NewMemoryDatabase(), clock) + defer ndb.close() + ndb.nbEvictCallBack = callback + ndb.cleanupHook = func() { done <- struct{}{} } + + var cases = []struct { + ip string + balance negBalance + }{ + {"127.0.0.1", negBalance{logValue: 1}}, + {"127.0.0.2", negBalance{logValue: 1}}, + {"127.0.0.3", negBalance{logValue: 1}}, + {"127.0.0.4", negBalance{logValue: 1}}, + } + for _, c := range cases { + ndb.setNB(c.ip, c.balance) + } + time.Sleep(100 * time.Millisecond) // Ensure the db expirer is registered. + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 4 { + t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) + } + + for _, c := range cases { + ndb.setNB(c.ip, c.balance) + } + clock.Run(time.Hour + time.Minute) + select { + case <-done: + case <-time.NewTimer(time.Second).C: + t.Fatalf("timeout") + } + if iterated != 8 { + t.Fatalf("Failed to evict useless negative balances, want %v, got %d", 4, iterated) + } +} diff --git a/les/costtracker.go b/les/costtracker.go index d1f5b54ca..81da04566 100644 --- a/les/costtracker.go +++ b/les/costtracker.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/les/flowcontrol" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" ) const makeCostStats = false // make request cost statistics during operation @@ -87,7 +88,7 @@ const ( gfUsageTC = time.Second gfRaiseTC = time.Second * 200 gfDropTC = time.Second * 50 - gfDbKey = "_globalCostFactorV3" + gfDbKey = "_globalCostFactorV6" ) // costTracker is responsible for calculating costs and cost estimates on the @@ -226,6 +227,9 @@ type reqInfo struct { // servingTime is the CPU time corresponding to the actual processing of // the request. servingTime float64 + + // msgCode indicates the type of request. + msgCode uint64 } // gfLoop starts an event loop which updates the global cost factor which is @@ -269,11 +273,43 @@ func (ct *costTracker) gfLoop() { for { select { case r := <-ct.reqInfoCh: + relCost := int64(factor * r.servingTime * 100 / r.avgTimeCost) // Convert the value to a percentage form + + // Record more metrics if we are debugging + if metrics.EnabledExpensive { + switch r.msgCode { + case GetBlockHeadersMsg: + relativeCostHeaderHistogram.Update(relCost) + case GetBlockBodiesMsg: + relativeCostBodyHistogram.Update(relCost) + case GetReceiptsMsg: + relativeCostReceiptHistogram.Update(relCost) + case GetCodeMsg: + relativeCostCodeHistogram.Update(relCost) + case GetProofsV2Msg: + relativeCostProofHistogram.Update(relCost) + case GetHelperTrieProofsMsg: + relativeCostHelperProofHistogram.Update(relCost) + case SendTxV2Msg: + relativeCostSendTxHistogram.Update(relCost) + case GetTxStatusMsg: + relativeCostTxStatusHistogram.Update(relCost) + } + } + // SendTxV2 and GetTxStatus requests are two special cases. + // All other requests will only put pressure on the database, and + // the corresponding delay is relatively stable. While these two + // requests involve txpool query, which is usually unstable. + // + // TODO(rjl493456442) fixes this. + if r.msgCode == SendTxV2Msg || r.msgCode == GetTxStatusMsg { + continue + } requestServedMeter.Mark(int64(r.servingTime)) requestServedTimer.Update(time.Duration(r.servingTime)) requestEstimatedMeter.Mark(int64(r.avgTimeCost / factor)) requestEstimatedTimer.Update(time.Duration(r.avgTimeCost / factor)) - relativeCostHistogram.Update(int64(r.avgTimeCost / factor / r.servingTime)) + relativeCostHistogram.Update(relCost) now := mclock.Now() dt := float64(now - expUpdate) @@ -324,6 +360,7 @@ func (ct *costTracker) gfLoop() { default: } } + globalFactorGauge.Update(int64(1000 * factor)) log.Debug("global cost factor updated", "factor", factor) } } @@ -375,7 +412,7 @@ func (ct *costTracker) updateStats(code, amount, servingTime, realCost uint64) { avg := reqAvgTimeCost[code] avgTimeCost := avg.baseCost + amount*avg.reqCost select { - case ct.reqInfoCh <- reqInfo{float64(avgTimeCost), float64(servingTime)}: + case ct.reqInfoCh <- reqInfo{float64(avgTimeCost), float64(servingTime), code}: default: } if makeCostStats { diff --git a/les/distributor.go b/les/distributor.go index 62abef47d..6d8114972 100644 --- a/les/distributor.go +++ b/les/distributor.go @@ -110,13 +110,15 @@ func (d *requestDistributor) registerTestPeer(p distPeer) { d.peerLock.Unlock() } -// distMaxWait is the maximum waiting time after which further necessary waiting -// times are recalculated based on new feedback from the servers -const distMaxWait = time.Millisecond * 50 +var ( + // distMaxWait is the maximum waiting time after which further necessary waiting + // times are recalculated based on new feedback from the servers + distMaxWait = time.Millisecond * 50 -// waitForPeers is the time window in which a request does not fail even if it -// has no suitable peers to send to at the moment -const waitForPeers = time.Second * 3 + // waitForPeers is the time window in which a request does not fail even if it + // has no suitable peers to send to at the moment + waitForPeers = time.Second * 3 +) // main event loop func (d *requestDistributor) loop() { diff --git a/les/distributor_test.go b/les/distributor_test.go index 00d43e1d6..539c1aa7d 100644 --- a/les/distributor_test.go +++ b/les/distributor_test.go @@ -86,8 +86,8 @@ func (p *testDistPeer) worker(t *testing.T, checkOrder bool, stop chan struct{}) const ( testDistBufLimit = 10000000 testDistMaxCost = 1000000 - testDistPeerCount = 5 - testDistReqCount = 5000 + testDistPeerCount = 2 + testDistReqCount = 10 testDistMaxResendCount = 3 ) @@ -128,6 +128,9 @@ func testRequestDistributor(t *testing.T, resend bool) { go peers[i].worker(t, !resend, stop) dist.registerTestPeer(peers[i]) } + // Disable the mechanism that we will wait a few time for request + // even there is no suitable peer to send right now. + waitForPeers = 0 var wg sync.WaitGroup diff --git a/les/enr_entry.go b/les/enr_entry.go new file mode 100644 index 000000000..c2a92dd99 --- /dev/null +++ b/les/enr_entry.go @@ -0,0 +1,32 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "github.com/ethereum/go-ethereum/rlp" +) + +// lesEntry is the "les" ENR entry. This is set for LES servers only. +type lesEntry struct { + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` +} + +// ENRKey implements enr.Entry. +func (e lesEntry) ENRKey() string { + return "les" +} diff --git a/les/metrics.go b/les/metrics.go index 797631b8e..9ef8c3651 100644 --- a/les/metrics.go +++ b/les/metrics.go @@ -60,6 +60,15 @@ var ( miscOutTxStatusPacketsMeter = metrics.NewRegisteredMeter("les/misc/out/packets/txStatus", nil) miscOutTxStatusTrafficMeter = metrics.NewRegisteredMeter("les/misc/out/traffic/txStatus", nil) + miscServingTimeHeaderTimer = metrics.NewRegisteredTimer("les/misc/serve/header", nil) + miscServingTimeBodyTimer = metrics.NewRegisteredTimer("les/misc/serve/body", nil) + miscServingTimeCodeTimer = metrics.NewRegisteredTimer("les/misc/serve/code", nil) + miscServingTimeReceiptTimer = metrics.NewRegisteredTimer("les/misc/serve/receipt", nil) + miscServingTimeTrieProofTimer = metrics.NewRegisteredTimer("les/misc/serve/proof", nil) + miscServingTimeHelperTrieTimer = metrics.NewRegisteredTimer("les/misc/serve/helperTrie", nil) + miscServingTimeTxTimer = metrics.NewRegisteredTimer("les/misc/serve/txs", nil) + miscServingTimeTxStatusTimer = metrics.NewRegisteredTimer("les/misc/serve/txStatus", nil) + connectionTimer = metrics.NewRegisteredTimer("les/connection/duration", nil) serverConnectionGauge = metrics.NewRegisteredGauge("les/connection/server", nil) clientConnectionGauge = metrics.NewRegisteredGauge("les/connection/client", nil) @@ -69,12 +78,21 @@ var ( totalConnectedGauge = metrics.NewRegisteredGauge("les/server/totalConnected", nil) blockProcessingTimer = metrics.NewRegisteredTimer("les/server/blockProcessingTime", nil) - requestServedMeter = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil) - requestServedTimer = metrics.NewRegisteredTimer("les/server/req/servedTime", nil) - requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil) - requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil) - relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015)) + requestServedMeter = metrics.NewRegisteredMeter("les/server/req/avgServedTime", nil) + requestServedTimer = metrics.NewRegisteredTimer("les/server/req/servedTime", nil) + requestEstimatedMeter = metrics.NewRegisteredMeter("les/server/req/avgEstimatedTime", nil) + requestEstimatedTimer = metrics.NewRegisteredTimer("les/server/req/estimatedTime", nil) + relativeCostHistogram = metrics.NewRegisteredHistogram("les/server/req/relative", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostHeaderHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/header", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostBodyHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/body", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostReceiptHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/receipt", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostCodeHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/code", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostProofHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/proof", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostHelperProofHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/helperTrie", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostSendTxHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/txs", nil, metrics.NewExpDecaySample(1028, 0.015)) + relativeCostTxStatusHistogram = metrics.NewRegisteredHistogram("les/server/req/relative/txStatus", nil, metrics.NewExpDecaySample(1028, 0.015)) + globalFactorGauge = metrics.NewRegisteredGauge("les/server/globalFactor", nil) recentServedGauge = metrics.NewRegisteredGauge("les/server/recentRequestServed", nil) recentEstimatedGauge = metrics.NewRegisteredGauge("les/server/recentRequestEstimated", nil) sqServedGauge = metrics.NewRegisteredGauge("les/server/servingQueue/served", nil) diff --git a/les/odr_test.go b/les/odr_test.go index 97217e948..7d1087822 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -188,6 +188,15 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od client.handler.synchronise(client.peer.peer) + // Ensure the client has synced all necessary data. + clientHead := client.handler.backend.blockchain.CurrentHeader() + if clientHead.Number.Uint64() != 4 { + t.Fatalf("Failed to sync the chain with server, head: %v", clientHead.Number.Uint64()) + } + // Disable the mechanism that we will wait a few time for request + // even there is no suitable peer to send right now. + waitForPeers = 0 + test := func(expFail uint64) { // Mark this as a helper to put the failures at the correct lines t.Helper() @@ -196,7 +205,9 @@ func testOdr(t *testing.T, protocol int, expFail uint64, checkCached bool, fn od bhash := rawdb.ReadCanonicalHash(server.db, i) b1 := fn(light.NoOdr, server.db, server.handler.server.chainConfig, server.handler.blockchain, nil, bhash) - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + // Set the timeout as 1 second here, ensure there is enough time + // for travis to make the action. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) b2 := fn(ctx, client.db, client.handler.backend.chainConfig, nil, client.handler.backend.blockchain, bhash) cancel() diff --git a/les/request_test.go b/les/request_test.go index 69b57ca31..8d09703c5 100644 --- a/les/request_test.go +++ b/les/request_test.go @@ -81,8 +81,15 @@ func testAccess(t *testing.T, protocol int, fn accessTestFn) { // Assemble the test environment server, client, tearDown := newClientServerEnv(t, 4, protocol, nil, nil, 0, false, true) defer tearDown() + client.handler.synchronise(client.peer.peer) + // Ensure the client has synced all necessary data. + clientHead := client.handler.backend.blockchain.CurrentHeader() + if clientHead.Number.Uint64() != 4 { + t.Fatalf("Failed to sync the chain with server, head: %v", clientHead.Number.Uint64()) + } + test := func(expFail uint64) { for i := uint64(0); i <= server.handler.blockchain.CurrentHeader().Number.Uint64(); i++ { bhash := rawdb.ReadCanonicalHash(server.db, i) diff --git a/les/server.go b/les/server.go index 8e790323f..997a24191 100644 --- a/les/server.go +++ b/les/server.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) @@ -112,7 +113,8 @@ func NewLesServer(e *eth.Ethereum, config *eth.Config) (*LesServer, error) { maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.freeCapacity, maxCapacity, srv.freeCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, 10000, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) + srv.clientPool = newClientPool(srv.chainDb, srv.freeCapacity, mclock.System{}, func(id enode.ID) { go srv.peers.Unregister(peerIdToString(id)) }) + srv.clientPool.setPriceFactors(priceFactors{0, 1, 1}, priceFactors{0, 1, 1}) checkpoint := srv.latestLocalCheckpoint() if !checkpoint.Empty() { @@ -135,12 +137,17 @@ func (s *LesServer) APIs() []rpc.API { } func (s *LesServer) Protocols() []p2p.Protocol { - return s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { + ps := s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { if p := s.peers.Peer(peerIdToString(id)); p != nil { return p.Info() } return nil }) + // Add "les" ENR entries. + for i := range ps { + ps[i].Attributes = []enr.Entry{&lesEntry{}} + } + return ps } // Start starts the LES server @@ -176,9 +183,9 @@ func (s *LesServer) Stop() { s.peers.Close() s.fcManager.Stop() - s.clientPool.stop() s.costTracker.stop() s.handler.stop() + s.clientPool.stop() // client pool should be closed after handler. s.servingQueue.stop() // Note, bloom trie indexer is closed by parent bloombits indexer. diff --git a/les/server_handler.go b/les/server_handler.go index 79c0a08a9..16249ef1b 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -268,6 +268,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInHeaderPacketsMeter.Mark(1) miscInHeaderTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeHeaderTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -380,6 +381,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInBodyPacketsMeter.Mark(1) miscInBodyTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeBodyTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -428,6 +430,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInCodePacketsMeter.Mark(1) miscInCodeTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeCodeTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -499,6 +502,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInReceiptPacketsMeter.Mark(1) miscInReceiptTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeReceiptTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -555,6 +559,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInTrieProofPacketsMeter.Mark(1) miscInTrieProofTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeTrieProofTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -657,6 +662,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInHelperTriePacketsMeter.Mark(1) miscInHelperTrieTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeHelperTrieTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -731,6 +737,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInTxsPacketsMeter.Mark(1) miscInTxsTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeTxTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 @@ -779,6 +786,7 @@ func (h *serverHandler) handleMsg(p *peer, wg *sync.WaitGroup) error { if metrics.EnabledExpensive { miscInTxStatusPacketsMeter.Mark(1) miscInTxStatusTrafficMeter.Mark(int64(msg.Size)) + defer func(start time.Time) { miscServingTimeTxStatusTimer.UpdateSince(start) }(time.Now()) } var req struct { ReqID uint64 diff --git a/les/sync.go b/les/sync.go index 693394464..1214fefca 100644 --- a/les/sync.go +++ b/les/sync.go @@ -135,21 +135,24 @@ func (h *clientHandler) synchronise(peer *peer) { mode = legacyCheckpointSync log.Debug("Disable checkpoint syncing", "reason", "checkpoint is hardcoded") case h.backend.oracle == nil || !h.backend.oracle.isRunning(): - mode = legacyCheckpointSync + if h.checkpoint == nil { + mode = lightSync // Downgrade to light sync unfortunately. + } else { + checkpoint = h.checkpoint + mode = legacyCheckpointSync + } log.Debug("Disable checkpoint syncing", "reason", "checkpoint syncing is not activated") } // Notify testing framework if syncing has completed(for testing purpose). defer func() { - if h.backend.oracle != nil && h.backend.oracle.syncDoneHook != nil { - h.backend.oracle.syncDoneHook() + if h.syncDone != nil { + h.syncDone() } }() start := time.Now() if mode == checkpointSync || mode == legacyCheckpointSync { // Validate the advertised checkpoint - if mode == legacyCheckpointSync { - checkpoint = h.checkpoint - } else if mode == checkpointSync { + if mode == checkpointSync { if err := h.validateCheckpoint(peer); err != nil { log.Debug("Failed to validate checkpoint", "reason", err) h.removePeer(peer.id) diff --git a/les/sync_test.go b/les/sync_test.go index 63833c1ab..8df6223b8 100644 --- a/les/sync_test.go +++ b/les/sync_test.go @@ -30,17 +30,14 @@ import ( ) // Test light syncing which will download all headers from genesis. -func TestLightSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 0) } func TestLightSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 0) } // Test legacy checkpoint syncing which will download tail headers // based on a hardcoded checkpoint. -func TestLegacyCheckpointSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 1) } func TestLegacyCheckpointSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 1) } // Test checkpoint syncing which will download tail headers based // on a verified checkpoint. -func TestCheckpointSyncingLes2(t *testing.T) { testCheckpointSyncing(t, 2, 2) } func TestCheckpointSyncingLes3(t *testing.T) { testCheckpointSyncing(t, 3, 2) } func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { @@ -92,7 +89,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { for { _, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil) if err != nil || hash == [32]byte{} { - time.Sleep(100 * time.Millisecond) + time.Sleep(10 * time.Millisecond) continue } break @@ -102,7 +99,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { } done := make(chan error) - client.handler.backend.oracle.syncDoneHook = func() { + client.handler.syncDone = func() { header := client.handler.backend.blockchain.CurrentHeader() if header.Number.Uint64() == expected { done <- nil @@ -131,3 +128,102 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { t.Error("checkpoint syncing timeout") } } + +func TestMissOracleBackend(t *testing.T) { testMissOracleBackend(t, true) } +func TestMissOracleBackendNoCheckpoint(t *testing.T) { testMissOracleBackend(t, false) } + +func testMissOracleBackend(t *testing.T, hasCheckpoint bool) { + config := light.TestServerIndexerConfig + + waitIndexers := func(cIndexer, bIndexer, btIndexer *core.ChainIndexer) { + for { + cs, _, _ := cIndexer.Sections() + bts, _, _ := btIndexer.Sections() + if cs >= 1 && bts >= 1 { + break + } + time.Sleep(10 * time.Millisecond) + } + } + // Generate 512+4 blocks (totally 1 CHT sections) + server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), 3, waitIndexers, nil, 0, false, false) + defer tearDown() + + expected := config.ChtSize + config.ChtConfirms + + s, _, head := server.chtIndexer.Sections() + cp := ¶ms.TrustedCheckpoint{ + SectionIndex: 0, + SectionHead: head, + CHTRoot: light.GetChtRoot(server.db, s-1, head), + BloomRoot: light.GetBloomTrieRoot(server.db, s-1, head), + } + // Register the assembled checkpoint into oracle. + header := server.backend.Blockchain().CurrentHeader() + + data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) + sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) + sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper + if _, err := server.handler.server.oracle.contract.RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { + t.Error("register checkpoint failed", err) + } + server.backend.Commit() + + // Wait for the checkpoint registration + for { + _, hash, _, err := server.handler.server.oracle.contract.Contract().GetLatestCheckpoint(nil) + if err != nil || hash == [32]byte{} { + time.Sleep(100 * time.Millisecond) + continue + } + break + } + expected += 1 + + // Explicitly set the oracle as nil. In normal use case it can happen + // that user wants to unlock something which blocks the oracle backend + // initialisation. But at the same time syncing starts. + // + // See https://github.com/ethereum/go-ethereum/issues/20097 for more detail. + // + // In this case, client should run light sync or legacy checkpoint sync + // if hardcoded checkpoint is configured. + client.handler.backend.oracle = nil + + // For some private networks it can happen checkpoint syncing is enabled + // but there is no hardcoded checkpoint configured. + if hasCheckpoint { + client.handler.checkpoint = cp + client.handler.backend.blockchain.AddTrustedCheckpoint(cp) + } + + done := make(chan error) + client.handler.syncDone = func() { + header := client.handler.backend.blockchain.CurrentHeader() + if header.Number.Uint64() == expected { + done <- nil + } else { + done <- fmt.Errorf("blockchain length mismatch, want %d, got %d", expected, header.Number) + } + } + + // Create connected peer pair. + _, err1, _, err2 := newTestPeerPair("peer", 2, server.handler, client.handler) + select { + case <-time.After(time.Millisecond * 100): + case err := <-err1: + t.Fatalf("peer 1 handshake error: %v", err) + case err := <-err2: + t.Fatalf("peer 2 handshake error: %v", err) + } + + select { + case err := <-done: + if err != nil { + t.Error("sync failed", err) + } + return + case <-time.NewTimer(10 * time.Second).C: + t.Error("checkpoint syncing timeout") + } +} diff --git a/les/test_helper.go b/les/test_helper.go index 79cf323d6..ee3d7a32e 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -71,10 +71,10 @@ var ( var ( // The block frequency for creating checkpoint(only used in test) - sectionSize = big.NewInt(512) + sectionSize = big.NewInt(128) // The number of confirmations needed to generate a checkpoint(only used in test). - processConfirms = big.NewInt(4) + processConfirms = big.NewInt(1) // The token bucket buffer limit for testing purpose. testBufLimit = uint64(1000000) @@ -280,7 +280,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.freeCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, 1, 10000, clock, nil) + server.clientPool = newClientPool(db, 1, clock, nil) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { @@ -517,7 +517,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer if connect { cpeer, err1, speer, err2 = newTestPeerPair("peer", protocol, server, client) select { - case <-time.After(time.Millisecond * 100): + case <-time.After(time.Millisecond * 300): case err := <-err1: t.Fatalf("peer 1 handshake error: %v", err) case err := <-err2: diff --git a/light/lightchain.go b/light/lightchain.go index 7f64d1c28..02b90138a 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -426,6 +426,11 @@ func (lc *LightChain) HasHeader(hash common.Hash, number uint64) bool { return lc.hc.HasHeader(hash, number) } +// GetCanonicalHash returns the canonical hash for a given block number +func (bc *LightChain) GetCanonicalHash(number uint64) common.Hash { + return bc.hc.GetCanonicalHash(number) +} + // GetBlockHashesFromHash retrieves a number of block hashes starting at a given // hash, fetching towards the genesis block. func (lc *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash { @@ -438,9 +443,6 @@ func (lc *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com // // Note: ancestor == 0 returns the same block, 1 returns its parent and so on. func (lc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) { - lc.chainmu.RLock() - defer lc.chainmu.RUnlock() - return lc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical) } diff --git a/light/postprocess.go b/light/postprocess.go index 083dcfceb..af3b25792 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -79,21 +79,21 @@ var ( } // TestServerIndexerConfig wraps a set of configs as a test indexer config for server side. TestServerIndexerConfig = &IndexerConfig{ - ChtSize: 512, - ChtConfirms: 4, - BloomSize: 64, - BloomConfirms: 4, - BloomTrieSize: 512, - BloomTrieConfirms: 4, + ChtSize: 128, + ChtConfirms: 1, + BloomSize: 16, + BloomConfirms: 1, + BloomTrieSize: 128, + BloomTrieConfirms: 1, } // TestClientIndexerConfig wraps a set of configs as a test indexer config for client side. TestClientIndexerConfig = &IndexerConfig{ - ChtSize: 512, - ChtConfirms: 32, - BloomSize: 512, - BloomConfirms: 32, - BloomTrieSize: 512, - BloomTrieConfirms: 32, + ChtSize: 128, + ChtConfirms: 8, + BloomSize: 128, + BloomConfirms: 8, + BloomTrieSize: 128, + BloomTrieConfirms: 8, } ) diff --git a/log/README.md b/log/README.md index b4476577b..47426806d 100644 --- a/log/README.md +++ b/log/README.md @@ -1,8 +1,8 @@ -![obligatory xkcd](http://imgs.xkcd.com/comics/standards.png) +![obligatory xkcd](https://imgs.xkcd.com/comics/standards.png) # log15 [![godoc reference](https://godoc.org/github.com/inconshreveable/log15?status.png)](https://godoc.org/github.com/inconshreveable/log15) [![Build Status](https://travis-ci.org/inconshreveable/log15.svg?branch=master)](https://travis-ci.org/inconshreveable/log15) -Package log15 provides an opinionated, simple toolkit for best-practice logging in Go (golang) that is both human and machine readable. It is modeled after the Go standard library's [`io`](http://golang.org/pkg/io/) and [`net/http`](http://golang.org/pkg/net/http/) packages and is an alternative to the standard library's [`log`](http://golang.org/pkg/log/) package. +Package log15 provides an opinionated, simple toolkit for best-practice logging in Go (golang) that is both human and machine readable. It is modeled after the Go standard library's [`io`](https://golang.org/pkg/io/) and [`net/http`](https://golang.org/pkg/net/http/) packages and is an alternative to the standard library's [`log`](https://golang.org/pkg/log/) package. ## Features - A simple, easy-to-understand API diff --git a/metrics/README.md b/metrics/README.md index bc2a45a83..e2d794500 100644 --- a/metrics/README.md +++ b/metrics/README.md @@ -5,7 +5,7 @@ go-metrics Go port of Coda Hale's Metrics library: . -Documentation: . +Documentation: . Usage ----- @@ -128,7 +128,7 @@ go stathat.Stathat(metrics.DefaultRegistry, 10e9, "example@example.com") Maintain all metrics along with expvars at `/debug/metrics`: -This uses the same mechanism as [the official expvar](http://golang.org/pkg/expvar/) +This uses the same mechanism as [the official expvar](https://golang.org/pkg/expvar/) but exposed under `/debug/metrics`, which shows a json representation of all your usual expvars as well as all your go-metrics. diff --git a/metrics/gauge.go b/metrics/gauge.go index 0fbfdb860..b6b2758b0 100644 --- a/metrics/gauge.go +++ b/metrics/gauge.go @@ -6,6 +6,8 @@ import "sync/atomic" type Gauge interface { Snapshot() Gauge Update(int64) + Dec(int64) + Inc(int64) Value() int64 } @@ -65,6 +67,16 @@ func (GaugeSnapshot) Update(int64) { panic("Update called on a GaugeSnapshot") } +// Dec panics. +func (GaugeSnapshot) Dec(int64) { + panic("Dec called on a GaugeSnapshot") +} + +// Inc panics. +func (GaugeSnapshot) Inc(int64) { + panic("Inc called on a GaugeSnapshot") +} + // Value returns the value at the time the snapshot was taken. func (g GaugeSnapshot) Value() int64 { return int64(g) } @@ -77,6 +89,12 @@ func (NilGauge) Snapshot() Gauge { return NilGauge{} } // Update is a no-op. func (NilGauge) Update(v int64) {} +// Dec is a no-op. +func (NilGauge) Dec(i int64) {} + +// Inc is a no-op. +func (NilGauge) Inc(i int64) {} + // Value is a no-op. func (NilGauge) Value() int64 { return 0 } @@ -101,6 +119,16 @@ func (g *StandardGauge) Value() int64 { return atomic.LoadInt64(&g.value) } +// Dec decrements the gauge's current value by the given amount. +func (g *StandardGauge) Dec(i int64) { + atomic.AddInt64(&g.value, -i) +} + +// Inc increments the gauge's current value by the given amount. +func (g *StandardGauge) Inc(i int64) { + atomic.AddInt64(&g.value, i) +} + // FunctionalGauge returns value from given function type FunctionalGauge struct { value func() int64 @@ -118,3 +146,13 @@ func (g FunctionalGauge) Snapshot() Gauge { return GaugeSnapshot(g.Value()) } func (FunctionalGauge) Update(int64) { panic("Update called on a FunctionalGauge") } + +// Dec panics. +func (FunctionalGauge) Dec(int64) { + panic("Dec called on a FunctionalGauge") +} + +// Inc panics. +func (FunctionalGauge) Inc(int64) { + panic("Inc called on a FunctionalGauge") +} diff --git a/miner/worker.go b/miner/worker.go index 4a9528c39..183499ec3 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -704,7 +704,7 @@ func (w *worker) updateSnapshot() { func (w *worker) commitTransaction(tx *types.Transaction, coinbase common.Address) ([]*types.Log, error) { snap := w.current.state.Snapshot() - receipt, _, err := core.ApplyTransaction(w.chainConfig, w.chain, &coinbase, w.current.gasPool, w.current.state, w.current.header, tx, &w.current.header.GasUsed, *w.chain.GetVMConfig()) + receipt, err := core.ApplyTransaction(w.chainConfig, w.chain, &coinbase, w.current.gasPool, w.current.state, w.current.header, tx, &w.current.header.GasUsed, *w.chain.GetVMConfig()) if err != nil { w.current.state.RevertToSnapshot(snap) return nil, err diff --git a/miner/worker_test.go b/miner/worker_test.go index 1604e988d..faebee1a7 100644 --- a/miner/worker_test.go +++ b/miner/worker_test.go @@ -18,9 +18,11 @@ package miner import ( "math/big" + "math/rand" "testing" "time" + "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/clique" @@ -35,6 +37,15 @@ import ( "github.com/ethereum/go-ethereum/params" ) +const ( + // testCode is the testing contract binary code which will initialises some + // variables in constructor + testCode = "0x60806040527fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0060005534801561003457600080fd5b5060fc806100436000396000f3fe6080604052348015600f57600080fd5b506004361060325760003560e01c80630c4dae8814603757806398a213cf146053575b600080fd5b603d607e565b6040518082815260200191505060405180910390f35b607c60048036036020811015606757600080fd5b81019080803590602001909291905050506084565b005b60005481565b806000819055507fe9e44f9f7da8c559de847a3232b57364adc0354f15a2cd8dc636d54396f9587a6000546040518082815260200191505060405180910390a15056fea265627a7a723058208ae31d9424f2d0bc2a3da1a5dd659db2d71ec322a17db8f87e19e209e3a1ff4a64736f6c634300050a0032" + + // testGas is the gas required for contract deployment. + testGas = 144109 +) + var ( // Test chain configurations testTxPoolConfig core.TxPoolConfig @@ -73,6 +84,7 @@ func init() { pendingTxs = append(pendingTxs, tx1) tx2, _ := types.SignTx(types.NewTransaction(1, testUserAddress, big.NewInt(1000), params.TxGas, nil, nil), types.HomesteadSigner{}, testBankKey) newTxs = append(newTxs, tx2) + rand.Seed(time.Now().UnixNano()) } // testWorkerBackend implements worker.Backend interfaces and wraps all information needed during the testing. @@ -81,29 +93,30 @@ type testWorkerBackend struct { txPool *core.TxPool chain *core.BlockChain testTxFeed event.Feed + genesis *core.Genesis uncleBlock *types.Block } -func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine, n int) *testWorkerBackend { - var ( - db = rawdb.NewMemoryDatabase() - gspec = core.Genesis{ - Config: chainConfig, - Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, - } - ) +func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine, db ethdb.Database, n int) *testWorkerBackend { + var gspec = core.Genesis{ + Config: chainConfig, + Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}, + } - switch engine.(type) { + switch e := engine.(type) { case *clique.Clique: gspec.ExtraData = make([]byte, 32+common.AddressLength+crypto.SignatureLength) - copy(gspec.ExtraData[32:], testBankAddress[:]) + copy(gspec.ExtraData[32:32+common.AddressLength], testBankAddress.Bytes()) + e.Authorize(testBankAddress, func(account accounts.Account, s string, data []byte) ([]byte, error) { + return crypto.Sign(crypto.Keccak256(data), testBankKey) + }) case *ethash.Ethash: default: t.Fatalf("unexpected consensus engine type: %T", engine) } genesis := gspec.MustCommit(db) - chain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil) + chain, _ := core.NewBlockChain(db, &core.CacheConfig{TrieDirtyDisabled: true}, gspec.Config, engine, vm.Config{}, nil) txpool := core.NewTxPool(testTxPoolConfig, chainConfig, chain) // Generate a small n-block chain and an uncle block for it @@ -127,6 +140,7 @@ func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine db: db, chain: chain, txPool: txpool, + genesis: &gspec, uncleBlock: blocks[0], } } @@ -137,14 +151,124 @@ func (b *testWorkerBackend) PostChainEvents(events []interface{}) { b.chain.PostChainEvents(events, nil) } -func newTestWorker(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine, blocks int) (*worker, *testWorkerBackend) { - backend := newTestWorkerBackend(t, chainConfig, engine, blocks) +func (b *testWorkerBackend) newRandomUncle() *types.Block { + var parent *types.Block + cur := b.chain.CurrentBlock() + if cur.NumberU64() == 0 { + parent = b.chain.Genesis() + } else { + parent = b.chain.GetBlockByHash(b.chain.CurrentBlock().ParentHash()) + } + blocks, _ := core.GenerateChain(b.chain.Config(), parent, b.chain.Engine(), b.db, 1, func(i int, gen *core.BlockGen) { + var addr = make([]byte, common.AddressLength) + rand.Read(addr) + gen.SetCoinbase(common.BytesToAddress(addr)) + }) + return blocks[0] +} + +func (b *testWorkerBackend) newRandomTx(creation bool) *types.Transaction { + var tx *types.Transaction + if creation { + tx, _ = types.SignTx(types.NewContractCreation(b.txPool.Nonce(testBankAddress), big.NewInt(0), testGas, nil, common.FromHex(testCode)), types.HomesteadSigner{}, testBankKey) + } else { + tx, _ = types.SignTx(types.NewTransaction(b.txPool.Nonce(testBankAddress), testUserAddress, big.NewInt(1000), params.TxGas, nil, nil), types.HomesteadSigner{}, testBankKey) + } + return tx +} + +func newTestWorker(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine, db ethdb.Database, blocks int) (*worker, *testWorkerBackend) { + backend := newTestWorkerBackend(t, chainConfig, engine, db, blocks) backend.txPool.AddLocals(pendingTxs) w := newWorker(testConfig, chainConfig, engine, backend, new(event.TypeMux), nil) w.setEtherbase(testBankAddress) return w, backend } +func TestGenerateBlockAndImportEthash(t *testing.T) { + testGenerateBlockAndImport(t, false) +} + +func TestGenerateBlockAndImportClique(t *testing.T) { + testGenerateBlockAndImport(t, true) +} + +func testGenerateBlockAndImport(t *testing.T, isClique bool) { + var ( + engine consensus.Engine + chainConfig *params.ChainConfig + db = rawdb.NewMemoryDatabase() + ) + if isClique { + chainConfig = params.AllCliqueProtocolChanges + chainConfig.Clique = ¶ms.CliqueConfig{Period: 1, Epoch: 30000} + engine = clique.New(chainConfig.Clique, db) + } else { + chainConfig = params.AllEthashProtocolChanges + engine = ethash.NewFaker() + } + + w, b := newTestWorker(t, chainConfig, engine, db, 0) + defer w.close() + + db2 := rawdb.NewMemoryDatabase() + b.genesis.MustCommit(db2) + chain, _ := core.NewBlockChain(db2, nil, b.chain.Config(), engine, vm.Config{}, nil) + defer chain.Stop() + + newBlock := make(chan struct{}) + listenNewBlock := func() { + sub := w.mux.Subscribe(core.NewMinedBlockEvent{}) + defer sub.Unsubscribe() + + for item := range sub.Chan() { + block := item.Data.(core.NewMinedBlockEvent).Block + _, err := chain.InsertChain([]*types.Block{block}) + if err != nil { + t.Fatalf("Failed to insert new mined block:%d, error:%v", block.NumberU64(), err) + } + newBlock <- struct{}{} + } + } + + // Ensure worker has finished initialization + for { + b := w.pendingBlock() + if b != nil && b.NumberU64() == 1 { + break + } + } + w.start() // Start mining! + + // Ignore first 2 commits caused by start operation + ignored := make(chan struct{}, 2) + w.skipSealHook = func(task *task) bool { + ignored <- struct{}{} + return true + } + for i := 0; i < 2; i++ { + <-ignored + } + + go listenNewBlock() + + // Ignore empty commit here for less noise + w.skipSealHook = func(task *task) bool { + return len(task.receipts) == 0 + } + for i := 0; i < 5; i++ { + b.txPool.AddLocal(b.newRandomTx(true)) + b.txPool.AddLocal(b.newRandomTx(false)) + b.PostChainEvents([]interface{}{core.ChainSideEvent{Block: b.newRandomUncle()}}) + b.PostChainEvents([]interface{}{core.ChainSideEvent{Block: b.newRandomUncle()}}) + select { + case <-newBlock: + case <-time.NewTimer(3 * time.Second).C: // Worker needs 1s to include new changes. + t.Fatalf("timeout") + } + } +} + func TestPendingStateAndBlockEthash(t *testing.T) { testPendingStateAndBlock(t, ethashChainConfig, ethash.NewFaker()) } @@ -155,7 +279,7 @@ func TestPendingStateAndBlockClique(t *testing.T) { func testPendingStateAndBlock(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) { defer engine.Close() - w, b := newTestWorker(t, chainConfig, engine, 0) + w, b := newTestWorker(t, chainConfig, engine, rawdb.NewMemoryDatabase(), 0) defer w.close() // Ensure snapshot has been updated. @@ -187,7 +311,7 @@ func TestEmptyWorkClique(t *testing.T) { func testEmptyWork(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) { defer engine.Close() - w, _ := newTestWorker(t, chainConfig, engine, 0) + w, _ := newTestWorker(t, chainConfig, engine, rawdb.NewMemoryDatabase(), 0) defer w.close() var ( @@ -241,7 +365,7 @@ func TestStreamUncleBlock(t *testing.T) { ethash := ethash.NewFaker() defer ethash.Close() - w, b := newTestWorker(t, ethashChainConfig, ethash, 1) + w, b := newTestWorker(t, ethashChainConfig, ethash, rawdb.NewMemoryDatabase(), 1) defer w.close() var taskCh = make(chan struct{}) @@ -304,7 +428,7 @@ func TestRegenerateMiningBlockClique(t *testing.T) { func testRegenerateMiningBlock(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) { defer engine.Close() - w, b := newTestWorker(t, chainConfig, engine, 0) + w, b := newTestWorker(t, chainConfig, engine, rawdb.NewMemoryDatabase(), 0) defer w.close() var taskCh = make(chan struct{}) @@ -369,7 +493,7 @@ func TestAdjustIntervalClique(t *testing.T) { func testAdjustInterval(t *testing.T, chainConfig *params.ChainConfig, engine consensus.Engine) { defer engine.Close() - w, _ := newTestWorker(t, chainConfig, engine, 0) + w, _ := newTestWorker(t, chainConfig, engine, rawdb.NewMemoryDatabase(), 0) defer w.close() w.skipSealHook = func(task *task) bool { diff --git a/p2p/dial.go b/p2p/dial.go index 8dee5063f..68e06cce5 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -33,12 +33,7 @@ const ( // private networks. dialHistoryExpiration = inboundThrottleTime + 5*time.Second - // Discovery lookups are throttled and can only run - // once every few seconds. - lookupInterval = 4 * time.Second - - // If no peers are found for this amount of time, the initial bootnodes are - // attempted to be connected. + // If no peers are found for this amount of time, the initial bootnodes are dialed. fallbackInterval = 20 * time.Second // Endpoint resolution is throttled with bounded backoff. @@ -52,6 +47,10 @@ type NodeDialer interface { Dial(*enode.Node) (net.Conn, error) } +type nodeResolver interface { + Resolve(*enode.Node) *enode.Node +} + // TCPDialer implements the NodeDialer interface by using a net.Dialer to // create TCP connections to nodes in the network type TCPDialer struct { @@ -69,7 +68,6 @@ func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) { // of the main loop in Server.run. type dialstate struct { maxDynDials int - ntab discoverTable netrestrict *netutil.Netlist self enode.ID bootnodes []*enode.Node // default dials when there are no peers @@ -79,55 +77,23 @@ type dialstate struct { lookupRunning bool dialing map[enode.ID]connFlag lookupBuf []*enode.Node // current discovery lookup results - randomNodes []*enode.Node // filled from Table static map[enode.ID]*dialTask hist expHeap } -type discoverTable interface { - Close() - Resolve(*enode.Node) *enode.Node - LookupRandom() []*enode.Node - ReadRandomNodes([]*enode.Node) int -} - type task interface { Do(*Server) } -// A dialTask is generated for each node that is dialed. Its -// fields cannot be accessed while the task is running. -type dialTask struct { - flags connFlag - dest *enode.Node - lastResolved time.Time - resolveDelay time.Duration -} - -// discoverTask runs discovery table operations. -// Only one discoverTask is active at any time. -// discoverTask.Do performs a random lookup. -type discoverTask struct { - results []*enode.Node -} - -// A waitExpireTask is generated if there are no other tasks -// to keep the loop in Server.run ticking. -type waitExpireTask struct { - time.Duration -} - -func newDialState(self enode.ID, ntab discoverTable, maxdyn int, cfg *Config) *dialstate { +func newDialState(self enode.ID, maxdyn int, cfg *Config) *dialstate { s := &dialstate{ maxDynDials: maxdyn, - ntab: ntab, self: self, netrestrict: cfg.NetRestrict, log: cfg.Logger, static: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]connFlag), bootnodes: make([]*enode.Node, len(cfg.BootstrapNodes)), - randomNodes: make([]*enode.Node, maxdyn/2), } copy(s.bootnodes, cfg.BootstrapNodes) if s.log == nil { @@ -151,10 +117,6 @@ func (s *dialstate) removeStatic(n *enode.Node) { } func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task { - if s.start.IsZero() { - s.start = now - } - var newtasks []task addDial := func(flag connFlag, n *enode.Node) bool { if err := s.checkDial(n, peers); err != nil { @@ -166,20 +128,9 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti return true } - // Compute number of dynamic dials necessary at this point. - needDynDials := s.maxDynDials - for _, p := range peers { - if p.rw.is(dynDialedConn) { - needDynDials-- - } + if s.start.IsZero() { + s.start = now } - for _, flag := range s.dialing { - if flag&dynDialedConn != 0 { - needDynDials-- - } - } - - // Expire the dial history on every invocation. s.hist.expire(now) // Create dials for static nodes if they are not connected. @@ -194,6 +145,20 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti newtasks = append(newtasks, t) } } + + // Compute number of dynamic dials needed. + needDynDials := s.maxDynDials + for _, p := range peers { + if p.rw.is(dynDialedConn) { + needDynDials-- + } + } + for _, flag := range s.dialing { + if flag&dynDialedConn != 0 { + needDynDials-- + } + } + // If we don't have any peers whatsoever, try to dial a random bootnode. This // scenario is useful for the testnet (and private networks) where the discovery // table might be full of mostly bad peers, making it hard to find good ones. @@ -201,24 +166,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti bootnode := s.bootnodes[0] s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) s.bootnodes = append(s.bootnodes, bootnode) - if addDial(dynDialedConn, bootnode) { needDynDials-- } } - // Use random nodes from the table for half of the necessary - // dynamic dials. - randomCandidates := needDynDials / 2 - if randomCandidates > 0 { - n := s.ntab.ReadRandomNodes(s.randomNodes) - for i := 0; i < randomCandidates && i < n; i++ { - if addDial(dynDialedConn, s.randomNodes[i]) { - needDynDials-- - } - } - } - // Create dynamic dials from random lookup results, removing tried - // items from the result buffer. + + // Create dynamic dials from discovery results. i := 0 for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { if addDial(dynDialedConn, s.lookupBuf[i]) { @@ -226,10 +179,11 @@ func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Ti } } s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] + // Launch a discovery lookup if more candidates are needed. if len(s.lookupBuf) < needDynDials && !s.lookupRunning { s.lookupRunning = true - newtasks = append(newtasks, &discoverTask{}) + newtasks = append(newtasks, &discoverTask{want: needDynDials - len(s.lookupBuf)}) } // Launch a timer to wait for the next node to expire if all @@ -279,6 +233,15 @@ func (s *dialstate) taskDone(t task, now time.Time) { } } +// A dialTask is generated for each node that is dialed. Its +// fields cannot be accessed while the task is running. +type dialTask struct { + flags connFlag + dest *enode.Node + lastResolved time.Time + resolveDelay time.Duration +} + func (t *dialTask) Do(srv *Server) { if t.dest.Incomplete() { if !t.resolve(srv) { @@ -304,8 +267,8 @@ func (t *dialTask) Do(srv *Server) { // discovery network with useless queries for nodes that don't exist. // The backoff delay resets when the node is found. func (t *dialTask) resolve(srv *Server) bool { - if srv.ntab == nil { - srv.log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") + if srv.staticNodeResolver == nil { + srv.log.Debug("Can't resolve node", "id", t.dest.ID(), "err", "discovery is disabled") return false } if t.resolveDelay == 0 { @@ -314,20 +277,20 @@ func (t *dialTask) resolve(srv *Server) bool { if time.Since(t.lastResolved) < t.resolveDelay { return false } - resolved := srv.ntab.Resolve(t.dest) + resolved := srv.staticNodeResolver.Resolve(t.dest) t.lastResolved = time.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - srv.log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) + srv.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay t.dest = resolved - srv.log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + srv.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) return true } @@ -350,26 +313,34 @@ func (t *dialTask) String() string { return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) } +// discoverTask runs discovery table operations. +// Only one discoverTask is active at any time. +// discoverTask.Do performs a random lookup. +type discoverTask struct { + want int + results []*enode.Node +} + func (t *discoverTask) Do(srv *Server) { - // newTasks generates a lookup task whenever dynamic dials are - // necessary. Lookups need to take some time, otherwise the - // event loop spins too fast. - next := srv.lastLookup.Add(lookupInterval) - if now := time.Now(); now.Before(next) { - time.Sleep(next.Sub(now)) - } - srv.lastLookup = time.Now() - t.results = srv.ntab.LookupRandom() + t.results = enode.ReadNodes(srv.discmix, t.want) } func (t *discoverTask) String() string { - s := "discovery lookup" + s := "discovery query" if len(t.results) > 0 { s += fmt.Sprintf(" (%d results)", len(t.results)) + } else { + s += fmt.Sprintf(" (want %d)", t.want) } return s } +// A waitExpireTask is generated if there are no other tasks +// to keep the loop in Server.run ticking. +type waitExpireTask struct { + time.Duration +} + func (t waitExpireTask) Do(*Server) { time.Sleep(t.Duration) } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index de8fc4a6e..6189ec4d0 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -73,7 +73,7 @@ func runDialTest(t *testing.T, test dialtest) { t.Errorf("ERROR round %d: got %v\nwant %v\nstate: %v\nrunning: %v", i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running)) } - t.Logf("round %d new tasks: %s", i, strings.TrimSpace(spew.Sdump(new))) + t.Logf("round %d (running %d) new tasks: %s", i, running, strings.TrimSpace(spew.Sdump(new))) // Time advances by 16 seconds on every round. vtime = vtime.Add(16 * time.Second) @@ -81,19 +81,11 @@ func runDialTest(t *testing.T, test dialtest) { } } -type fakeTable []*enode.Node - -func (t fakeTable) Self() *enode.Node { return new(enode.Node) } -func (t fakeTable) Close() {} -func (t fakeTable) LookupRandom() []*enode.Node { return nil } -func (t fakeTable) Resolve(*enode.Node) *enode.Node { return nil } -func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t) } - // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { config := &Config{Logger: testlog.Logger(t, log.LvlTrace)} runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ // A discovery query is launched. { @@ -102,7 +94,9 @@ func TestDialStateDynDial(t *testing.T) { {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, }, - new: []task{&discoverTask{}}, + new: []task{ + &discoverTask{want: 3}, + }, }, // Dynamic dials are launched when it completes. { @@ -188,7 +182,7 @@ func TestDialStateDynDial(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(7), nil)}, - &discoverTask{}, + &discoverTask{want: 2}, }, }, // Peer 7 is connected, but there still aren't enough dynamic peers @@ -218,7 +212,7 @@ func TestDialStateDynDial(t *testing.T) { &discoverTask{}, }, new: []task{ - &discoverTask{}, + &discoverTask{want: 2}, }, }, }, @@ -235,35 +229,37 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, Logger: testlog.Logger(t, log.LvlTrace), } - table := fakeTable{ - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 5, config), + init: newDialState(enode.ID{}, 5, config), rounds: []round{ - // 2 dynamic dials attempted, bootnodes pending fallback interval { + new: []task{ + &discoverTask{want: 5}, + }, + }, + { + done: []task{ + &discoverTask{ + results: []*enode.Node{ + newNode(uintID(4), nil), + newNode(uintID(5), nil), + }, + }, + }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, + &discoverTask{want: 3}, }, }, // No dials succeed, bootnodes still pending fallback interval + {}, + // 1 bootnode attempted as fallback interval was reached { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, - }, - // No dials succeed, bootnodes still pending fallback interval - {}, - // No dials succeed, 2 dynamic dials attempted and 1 bootnode too as fallback interval was reached - { new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, }, @@ -275,15 +271,12 @@ func TestDialStateDynDialBootnode(t *testing.T) { }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, }, // No dials succeed, 3rd bootnode is attempted { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, }, new: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, @@ -293,115 +286,19 @@ func TestDialStateDynDialBootnode(t *testing.T) { { done: []task{ &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - }, - new: []task{}, - }, - // Random dial succeeds, no more bootnodes are attempted - { - new: []task{ - &waitExpireTask{3 * time.Second}, - }, - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(4), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - }, - }, - }, - }) -} - -func TestDialStateDynDialFromTable(t *testing.T) { - // This table always returns the same random nodes - // in the order given below. - table := fakeTable{ - newNode(uintID(1), nil), - newNode(uintID(2), nil), - newNode(uintID(3), nil), - newNode(uintID(4), nil), - newNode(uintID(5), nil), - newNode(uintID(6), nil), - newNode(uintID(7), nil), - newNode(uintID(8), nil), - } - - runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{Logger: testlog.Logger(t, log.LvlTrace)}), - rounds: []round{ - // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. - { - new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &discoverTask{}, - }, - }, - // Dialing nodes 1,2 succeeds. Dials from the lookup are launched. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(1), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(2), nil)}, &discoverTask{results: []*enode.Node{ - newNode(uintID(10), nil), - newNode(uintID(11), nil), - newNode(uintID(12), nil), + newNode(uintID(6), nil), }}, }, new: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, - &discoverTask{}, + &dialTask{flags: dynDialedConn, dest: newNode(uintID(6), nil)}, + &discoverTask{want: 4}, }, }, - // Dialing nodes 3,4,5 fails. The dials from the lookup succeed. + // Random dial succeeds, no more bootnodes are attempted { peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - done: []task{ - &dialTask{flags: dynDialedConn, dest: newNode(uintID(3), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(4), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(5), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(10), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(11), nil)}, - &dialTask{flags: dynDialedConn, dest: newNode(uintID(12), nil)}, - }, - }, - // Waiting for expiry. No waitExpireTask is launched because the - // discovery query is still running. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, - }, - }, - // Nodes 3,4 are not tried again because only the first two - // returned random nodes (nodes 1,2) are tried and they're - // already connected. - { - peers: []*Peer{ - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(1), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(2), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(10), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(11), nil)}}, - {rw: &conn{flags: dynDialedConn, node: newNode(uintID(12), nil)}}, + {rw: &conn{flags: dynDialedConn, node: newNode(uintID(6), nil)}}, }, }, }, @@ -416,11 +313,11 @@ func newNode(id enode.ID, ip net.IP) *enode.Node { return enode.SignNull(&r, id) } -// This test checks that candidates that do not match the netrestrict list are not dialed. +// // This test checks that candidates that do not match the netrestrict list are not dialed. func TestDialStateNetRestrict(t *testing.T) { // This table always returns the same random nodes // in the order given below. - table := fakeTable{ + nodes := []*enode.Node{ newNode(uintID(1), net.ParseIP("127.0.0.1")), newNode(uintID(2), net.ParseIP("127.0.0.2")), newNode(uintID(3), net.ParseIP("127.0.0.3")), @@ -434,12 +331,23 @@ func TestDialStateNetRestrict(t *testing.T) { restrict.Add("127.0.2.0/24") runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, table, 10, &Config{NetRestrict: restrict}), + init: newDialState(enode.ID{}, 10, &Config{NetRestrict: restrict}), rounds: []round{ { new: []task{ - &dialTask{flags: dynDialedConn, dest: table[4]}, - &discoverTask{}, + &discoverTask{want: 10}, + }, + }, + { + done: []task{ + &discoverTask{results: nodes}, + }, + new: []task{ + &dialTask{flags: dynDialedConn, dest: nodes[4]}, + &dialTask{flags: dynDialedConn, dest: nodes[5]}, + &dialTask{flags: dynDialedConn, dest: nodes[6]}, + &dialTask{flags: dynDialedConn, dest: nodes[7]}, + &discoverTask{want: 6}, }, }, }, @@ -459,7 +367,7 @@ func TestDialStateStaticDial(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -544,7 +452,7 @@ func TestDialStateCache(t *testing.T) { Logger: testlog.Logger(t, log.LvlTrace), } runDialTest(t, dialtest{ - init: newDialState(enode.ID{}, fakeTable{}, 0, config), + init: newDialState(enode.ID{}, 0, config), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -618,8 +526,8 @@ func TestDialResolve(t *testing.T) { Dialer: TCPDialer{&net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}}, } resolved := newNode(uintID(1), net.IP{127, 0, 55, 234}) - table := &resolveMock{answer: resolved} - state := newDialState(enode.ID{}, table, 0, config) + resolver := &resolveMock{answer: resolved} + state := newDialState(enode.ID{}, 0, config) // Check that the task is generated with an incomplete ID. dest := newNode(uintID(1), nil) @@ -630,10 +538,14 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - srv := &Server{ntab: table, log: config.Logger, Config: *config} + srv := &Server{ + Config: *config, + log: config.Logger, + staticNodeResolver: resolver, + } tasks[0].Do(srv) - if !reflect.DeepEqual(table.resolveCalls, []*enode.Node{dest}) { - t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) + if !reflect.DeepEqual(resolver.calls, []*enode.Node{dest}) { + t.Fatalf("wrong resolve calls, got %v", resolver.calls) } // Report it as done to the dialer, which should update the static node record. @@ -666,18 +578,13 @@ func uintID(i uint32) enode.ID { return id } -// implements discoverTable for TestDialResolve +// for TestDialResolve type resolveMock struct { - resolveCalls []*enode.Node - answer *enode.Node + calls []*enode.Node + answer *enode.Node } func (t *resolveMock) Resolve(n *enode.Node) *enode.Node { - t.resolveCalls = append(t.resolveCalls, n) + t.calls = append(t.calls, n) return t.answer } - -func (t *resolveMock) Self() *enode.Node { return new(enode.Node) } -func (t *resolveMock) Close() {} -func (t *resolveMock) LookupRandom() []*enode.Node { return nil } -func (t *resolveMock) ReadRandomNodes(buf []*enode.Node) int { return 0 } diff --git a/p2p/discover/common.go b/p2p/discover/common.go index 3c080359f..cef6a9fc4 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -25,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/netutil" ) +// UDPConn is a network connection on which discovery can operate. type UDPConn interface { ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) @@ -32,7 +33,7 @@ type UDPConn interface { LocalAddr() net.Addr } -// Config holds Table-related settings. +// Config holds settings for the discovery listener. type Config struct { // These settings are required and configure the UDP listener: PrivateKey *ecdsa.PrivateKey @@ -50,7 +51,7 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { } // ReadPacket is a packet that couldn't be handled. Those packets are sent to the unhandled -// channel if configured. +// channel if configured. This is exported for internal use, do not use this type. type ReadPacket struct { Data []byte Addr *net.UDPAddr diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go new file mode 100644 index 000000000..f988e0683 --- /dev/null +++ b/p2p/discover/lookup.go @@ -0,0 +1,209 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "context" + + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// lookup performs a network search for nodes close to the given target. It approaches the +// target by querying nodes that are closer to it on each iteration. The given target does +// not need to be an actual node identifier. +type lookup struct { + tab *Table + queryfunc func(*node) ([]*node, error) + replyCh chan []*node + cancelCh <-chan struct{} + asked, seen map[enode.ID]bool + result nodesByDistance + replyBuffer []*node + queries int +} + +type queryFunc func(*node) ([]*node, error) + +func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { + it := &lookup{ + tab: tab, + queryfunc: q, + asked: make(map[enode.ID]bool), + seen: make(map[enode.ID]bool), + result: nodesByDistance{target: target}, + replyCh: make(chan []*node, alpha), + cancelCh: ctx.Done(), + queries: -1, + } + // Don't query further if we hit ourself. + // Unlikely to happen often in practice. + it.asked[tab.self().ID()] = true + return it +} + +// run runs the lookup to completion and returns the closest nodes found. +func (it *lookup) run() []*enode.Node { + for it.advance() { + } + return unwrapNodes(it.result.entries) +} + +// advance advances the lookup until any new nodes have been found. +// It returns false when the lookup has ended. +func (it *lookup) advance() bool { + for it.startQueries() { + select { + case nodes := <-it.replyCh: + it.replyBuffer = it.replyBuffer[:0] + for _, n := range nodes { + if n != nil && !it.seen[n.ID()] { + it.seen[n.ID()] = true + it.result.push(n, bucketSize) + it.replyBuffer = append(it.replyBuffer, n) + } + } + it.queries-- + if len(it.replyBuffer) > 0 { + return true + } + case <-it.cancelCh: + it.shutdown() + } + } + return false +} + +func (it *lookup) shutdown() { + for it.queries > 0 { + <-it.replyCh + it.queries-- + } + it.queryfunc = nil + it.replyBuffer = nil +} + +func (it *lookup) startQueries() bool { + if it.queryfunc == nil { + return false + } + + // The first query returns nodes from the local table. + if it.queries == -1 { + it.tab.mutex.Lock() + closest := it.tab.closest(it.result.target, bucketSize, false) + it.tab.mutex.Unlock() + it.queries = 1 + it.replyCh <- closest.entries + return true + } + + // Ask the closest nodes that we haven't asked yet. + for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ { + n := it.result.entries[i] + if !it.asked[n.ID()] { + it.asked[n.ID()] = true + it.queries++ + go it.query(n, it.replyCh) + } + } + // The lookup ends when no more nodes can be asked. + return it.queries > 0 +} + +func (it *lookup) query(n *node, reply chan<- []*node) { + fails := it.tab.db.FindFails(n.ID(), n.IP()) + r, err := it.queryfunc(n) + if err == errClosed { + // Avoid recording failures on shutdown. + reply <- nil + return + } else if len(r) == 0 { + fails++ + it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) + it.tab.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) + if fails >= maxFindnodeFailures { + it.tab.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) + it.tab.delete(n) + } + } else if fails > 0 { + // Reset failure counter because it counts _consecutive_ failures. + it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0) + } + + // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll + // just remove those again during revalidation. + for _, n := range r { + it.tab.addSeenNode(n) + } + reply <- r +} + +// lookupIterator performs lookup operations and iterates over all seen nodes. +// When a lookup finishes, a new one is created through nextLookup. +type lookupIterator struct { + buffer []*node + nextLookup lookupFunc + ctx context.Context + cancel func() + lookup *lookup +} + +type lookupFunc func(ctx context.Context) *lookup + +func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator { + ctx, cancel := context.WithCancel(ctx) + return &lookupIterator{ctx: ctx, cancel: cancel, nextLookup: next} +} + +// Node returns the current node. +func (it *lookupIterator) Node() *enode.Node { + if len(it.buffer) == 0 { + return nil + } + return unwrapNode(it.buffer[0]) +} + +// Next moves to the next node. +func (it *lookupIterator) Next() bool { + // Consume next node in buffer. + if len(it.buffer) > 0 { + it.buffer = it.buffer[1:] + } + // Advance the lookup to refill the buffer. + for len(it.buffer) == 0 { + if it.ctx.Err() != nil { + it.lookup = nil + it.buffer = nil + return false + } + if it.lookup == nil { + it.lookup = it.nextLookup(it.ctx) + continue + } + if !it.lookup.advance() { + it.lookup = nil + continue + } + it.buffer = it.lookup.replyBuffer + } + return true +} + +// Close ends the iterator. +func (it *lookupIterator) Close() { + it.cancel() +} diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 2292055e1..e35e48c5e 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -17,11 +17,14 @@ package discover import ( + "bytes" "crypto/ecdsa" "encoding/hex" + "errors" "fmt" "math/rand" "net" + "reflect" "sort" "sync" @@ -169,6 +172,28 @@ func hasDuplicates(slice []*node) bool { return false } +func checkNodesEqual(got, want []*enode.Node) error { + if reflect.DeepEqual(got, want) { + return nil + } + output := new(bytes.Buffer) + fmt.Fprintf(output, "got %d nodes:\n", len(got)) + for _, n := range got { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + fmt.Fprintf(output, "want %d:\n", len(want)) + for _, n := range want { + fmt.Fprintf(output, " %v %v\n", n.ID(), n) + } + return errors.New(output.String()) +} + +func sortByID(nodes []*enode.Node) { + sort.Slice(nodes, func(i, j int) bool { + return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes()) + }) +} + func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { return sort.SliceIsSorted(slice, func(i, j int) bool { return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0 diff --git a/p2p/discover/v4_udp_lookup_test.go b/p2p/discover/v4_lookup_test.go similarity index 75% rename from p2p/discover/v4_udp_lookup_test.go rename to p2p/discover/v4_lookup_test.go index bc1cdfb08..9b4042c5a 100644 --- a/p2p/discover/v4_udp_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -20,7 +20,6 @@ import ( "crypto/ecdsa" "fmt" "net" - "reflect" "sort" "testing" @@ -49,19 +48,7 @@ func TestUDPv4_Lookup(t *testing.T) { }() // Answer lookup packets. - for done := false; !done; { - done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) { - n, key := lookupTestnet.nodeByAddr(to) - switch p.(type) { - case *pingV4: - test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash}) - case *findnodeV4: - dist := enode.LogDist(n.ID(), lookupTestnet.target.id()) - nodes := lookupTestnet.nodesAtDistance(dist - 1) - test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes}) - } - }) - } + serveTestnet(test, lookupTestnet) // Verify result nodes. results := <-resultC @@ -78,8 +65,94 @@ func TestUDPv4_Lookup(t *testing.T) { if !sortedByDistanceTo(lookupTestnet.target.id(), wrapNodes(results)) { t.Errorf("result set not sorted by distance to target") } - if !reflect.DeepEqual(results, lookupTestnet.closest(bucketSize)) { - t.Errorf("results aren't the closest %d nodes", bucketSize) + if err := checkNodesEqual(results, lookupTestnet.closest(bucketSize)); err != nil { + t.Errorf("results aren't the closest %d nodes\n%v", bucketSize, err) + } +} + +func TestUDPv4_LookupIterator(t *testing.T) { + t.Parallel() + test := newUDPTest(t) + defer test.close() + + // Seed table with initial nodes. + bootnodes := make([]*node, len(lookupTestnet.dists[256])) + for i := range lookupTestnet.dists[256] { + bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + } + fillTable(test.table, bootnodes) + go serveTestnet(test, lookupTestnet) + + // Create the iterator and collect the nodes it yields. + iter := test.udp.RandomNodes() + seen := make(map[enode.ID]*enode.Node) + for limit := lookupTestnet.len(); iter.Next() && len(seen) < limit; { + seen[iter.Node().ID()] = iter.Node() + } + iter.Close() + + // Check that all nodes in lookupTestnet were seen by the iterator. + results := make([]*enode.Node, 0, len(seen)) + for _, n := range seen { + results = append(results, n) + } + sortByID(results) + want := lookupTestnet.nodes() + if err := checkNodesEqual(results, want); err != nil { + t.Fatal(err) + } +} + +// TestUDPv4_LookupIteratorClose checks that lookupIterator ends when its Close +// method is called. +func TestUDPv4_LookupIteratorClose(t *testing.T) { + t.Parallel() + test := newUDPTest(t) + defer test.close() + + // Seed table with initial nodes. + bootnodes := make([]*node, len(lookupTestnet.dists[256])) + for i := range lookupTestnet.dists[256] { + bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + } + fillTable(test.table, bootnodes) + go serveTestnet(test, lookupTestnet) + + it := test.udp.RandomNodes() + if ok := it.Next(); !ok || it.Node() == nil { + t.Fatalf("iterator didn't return any node") + } + + it.Close() + + ncalls := 0 + for ; ncalls < 100 && it.Next(); ncalls++ { + if it.Node() == nil { + t.Error("iterator returned Node() == nil node after Next() == true") + } + } + t.Logf("iterator returned %d nodes after close", ncalls) + if it.Next() { + t.Errorf("Next() == true after close and %d more calls", ncalls) + } + if n := it.Node(); n != nil { + t.Errorf("iterator returned non-nil node after close and %d more calls", ncalls) + } +} + +func serveTestnet(test *udpTest, testnet *preminedTestnet) { + for done := false; !done; { + done = test.waitPacketOut(func(p packetV4, to *net.UDPAddr, hash []byte) { + n, key := testnet.nodeByAddr(to) + switch p.(type) { + case *pingV4: + test.packetInFrom(nil, key, to, &pongV4{Expiration: futureExp, ReplyTok: hash}) + case *findnodeV4: + dist := enode.LogDist(n.ID(), testnet.target.id()) + nodes := testnet.nodesAtDistance(dist - 1) + test.packetInFrom(nil, key, to, &neighborsV4{Expiration: futureExp, Nodes: nodes}) + } + }) } } @@ -148,6 +221,25 @@ type preminedTestnet struct { dists [hashBits + 1][]*ecdsa.PrivateKey } +func (tn *preminedTestnet) len() int { + n := 0 + for _, keys := range tn.dists { + n += len(keys) + } + return n +} + +func (tn *preminedTestnet) nodes() []*enode.Node { + result := make([]*enode.Node, 0, tn.len()) + for dist, keys := range tn.dists { + for index := range keys { + result = append(result, tn.node(dist, index)) + } + } + sortByID(result) + return result +} + func (tn *preminedTestnet) node(dist, index int) *enode.Node { key := tn.dists[dist][index] ip := net.IP{127, byte(dist >> 8), byte(dist), byte(index)} diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index a8f7101b0..bfb66fcb1 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -19,6 +19,7 @@ package discover import ( "bytes" "container/list" + "context" "crypto/ecdsa" crand "crypto/rand" "errors" @@ -207,7 +208,8 @@ type UDPv4 struct { addReplyMatcher chan *replyMatcher gotreply chan reply - closing chan struct{} + closeCtx context.Context + cancelCloseCtx func() } // replyMatcher represents a pending reply. @@ -256,20 +258,23 @@ type reply struct { } func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { + closeCtx, cancel := context.WithCancel(context.Background()) t := &UDPv4{ conn: c, priv: cfg.PrivateKey, netrestrict: cfg.NetRestrict, localNode: ln, db: ln.Database(), - closing: make(chan struct{}), gotreply: make(chan reply), addReplyMatcher: make(chan *replyMatcher), + closeCtx: closeCtx, + cancelCloseCtx: cancel, log: cfg.Log, } if t.log == nil { t.log = log.Root() } + tab, err := newTable(t, ln.Database(), cfg.Bootnodes, t.log) if err != nil { return nil, err @@ -291,126 +296,13 @@ func (t *UDPv4) Self() *enode.Node { // Close shuts down the socket and aborts any running queries. func (t *UDPv4) Close() { t.closeOnce.Do(func() { - close(t.closing) + t.cancelCloseCtx() t.conn.Close() t.wg.Wait() t.tab.close() }) } -// ReadRandomNodes reads random nodes from the local table. -func (t *UDPv4) ReadRandomNodes(buf []*enode.Node) int { - return t.tab.ReadRandomNodes(buf) -} - -// LookupRandom finds random nodes in the network. -func (t *UDPv4) LookupRandom() []*enode.Node { - if t.tab.len() == 0 { - // All nodes were dropped, refresh. The very first query will hit this - // case and run the bootstrapping logic. - <-t.tab.refresh() - } - return t.lookupRandom() -} - -func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { - if t.tab.len() == 0 { - // All nodes were dropped, refresh. The very first query will hit this - // case and run the bootstrapping logic. - <-t.tab.refresh() - } - return unwrapNodes(t.lookup(encodePubkey(key))) -} - -func (t *UDPv4) lookupRandom() []*enode.Node { - var target encPubkey - crand.Read(target[:]) - return unwrapNodes(t.lookup(target)) -} - -func (t *UDPv4) lookupSelf() []*enode.Node { - return unwrapNodes(t.lookup(encodePubkey(&t.priv.PublicKey))) -} - -// lookup performs a network search for nodes close to the given target. It approaches the -// target by querying nodes that are closer to it on each iteration. The given target does -// not need to be an actual node identifier. -func (t *UDPv4) lookup(targetKey encPubkey) []*node { - var ( - target = enode.ID(crypto.Keccak256Hash(targetKey[:])) - asked = make(map[enode.ID]bool) - seen = make(map[enode.ID]bool) - reply = make(chan []*node, alpha) - pendingQueries = 0 - result *nodesByDistance - ) - // Don't query further if we hit ourself. - // Unlikely to happen often in practice. - asked[t.Self().ID()] = true - - // Generate the initial result set. - t.tab.mutex.Lock() - result = t.tab.closest(target, bucketSize, false) - t.tab.mutex.Unlock() - - for { - // ask the alpha closest nodes that we haven't asked yet - for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { - n := result.entries[i] - if !asked[n.ID()] { - asked[n.ID()] = true - pendingQueries++ - go t.lookupWorker(n, targetKey, reply) - } - } - if pendingQueries == 0 { - // we have asked all closest nodes, stop the search - break - } - select { - case nodes := <-reply: - for _, n := range nodes { - if n != nil && !seen[n.ID()] { - seen[n.ID()] = true - result.push(n, bucketSize) - } - } - case <-t.tab.closeReq: - return nil // shutdown, no need to continue. - } - pendingQueries-- - } - return result.entries -} - -func (t *UDPv4) lookupWorker(n *node, targetKey encPubkey, reply chan<- []*node) { - fails := t.db.FindFails(n.ID(), n.IP()) - r, err := t.findnode(n.ID(), n.addr(), targetKey) - if err == errClosed { - // Avoid recording failures on shutdown. - reply <- nil - return - } else if len(r) == 0 { - fails++ - t.db.UpdateFindFails(n.ID(), n.IP(), fails) - t.log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err) - if fails >= maxFindnodeFailures { - t.log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails) - t.tab.delete(n) - } - } else if fails > 0 { - // Reset failure counter because it counts _consecutive_ failures. - t.db.UpdateFindFails(n.ID(), n.IP(), 0) - } - - // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll - // just remove those again during revalidation. - for _, n := range r { - t.tab.addSeenNode(n) - } - reply <- r -} - // Resolve searches for a specific node with the given ID and tries to get the most recent // version of the node record for it. It returns n if the node could not be resolved. func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { @@ -498,6 +390,45 @@ func (t *UDPv4) makePing(toaddr *net.UDPAddr) *pingV4 { } } +// LookupPubkey finds the closest nodes to the given public key. +func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { + if t.tab.len() == 0 { + // All nodes were dropped, refresh. The very first query will hit this + // case and run the bootstrapping logic. + <-t.tab.refresh() + } + return t.newLookup(t.closeCtx, encodePubkey(key)).run() +} + +// RandomNodes is an iterator yielding nodes from a random walk of the DHT. +func (t *UDPv4) RandomNodes() enode.Iterator { + return newLookupIterator(t.closeCtx, t.newRandomLookup) +} + +// lookupRandom implements transport. +func (t *UDPv4) lookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).run() +} + +// lookupSelf implements transport. +func (t *UDPv4) lookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, encodePubkey(&t.priv.PublicKey)).run() +} + +func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { + var target encPubkey + crand.Read(target[:]) + return t.newLookup(ctx, target) +} + +func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { + target := enode.ID(crypto.Keccak256Hash(targetKey[:])) + it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return t.findnode(n.ID(), n.addr(), targetKey) + }) + return it +} + // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) { @@ -575,7 +506,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF select { case t.addReplyMatcher <- p: // loop will handle it - case <-t.closing: + case <-t.closeCtx.Done(): ch <- errClosed } return p @@ -589,7 +520,7 @@ func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req packetV4) bool { case t.gotreply <- reply{from, fromIP, req, matched}: // loop will handle it return <-matched - case <-t.closing: + case <-t.closeCtx.Done(): return false } } @@ -635,7 +566,7 @@ func (t *UDPv4) loop() { resetTimeout() select { - case <-t.closing: + case <-t.closeCtx.Done(): for el := plist.Front(); el != nil; el = el.Next() { el.Value.(*replyMatcher).errc <- errClosed } diff --git a/p2p/dnsdisc/client.go b/p2p/dnsdisc/client.go new file mode 100644 index 000000000..677c0aa92 --- /dev/null +++ b/p2p/dnsdisc/client.go @@ -0,0 +1,260 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "net" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + lru "github.com/hashicorp/golang-lru" +) + +// Client discovers nodes by querying DNS servers. +type Client struct { + cfg Config + clock mclock.Clock + linkCache linkCache + trees map[string]*clientTree + + entries *lru.Cache +} + +// Config holds configuration options for the client. +type Config struct { + Timeout time.Duration // timeout used for DNS lookups (default 5s) + RecheckInterval time.Duration // time between tree root update checks (default 30min) + CacheLimit int // maximum number of cached records (default 1000) + ValidSchemes enr.IdentityScheme // acceptable ENR identity schemes (default enode.ValidSchemes) + Resolver Resolver // the DNS resolver to use (defaults to system DNS) + Logger log.Logger // destination of client log messages (defaults to root logger) +} + +// Resolver is a DNS resolver that can query TXT records. +type Resolver interface { + LookupTXT(ctx context.Context, domain string) ([]string, error) +} + +func (cfg Config) withDefaults() Config { + const ( + defaultTimeout = 5 * time.Second + defaultRecheck = 30 * time.Minute + defaultCache = 1000 + ) + if cfg.Timeout == 0 { + cfg.Timeout = defaultTimeout + } + if cfg.RecheckInterval == 0 { + cfg.RecheckInterval = defaultRecheck + } + if cfg.CacheLimit == 0 { + cfg.CacheLimit = defaultCache + } + if cfg.ValidSchemes == nil { + cfg.ValidSchemes = enode.ValidSchemes + } + if cfg.Resolver == nil { + cfg.Resolver = new(net.Resolver) + } + if cfg.Logger == nil { + cfg.Logger = log.Root() + } + return cfg +} + +// NewClient creates a client. +func NewClient(cfg Config, urls ...string) (*Client, error) { + c := &Client{ + cfg: cfg.withDefaults(), + clock: mclock.System{}, + trees: make(map[string]*clientTree), + } + var err error + if c.entries, err = lru.New(c.cfg.CacheLimit); err != nil { + return nil, err + } + for _, url := range urls { + if err := c.AddTree(url); err != nil { + return nil, err + } + } + return c, nil +} + +// SyncTree downloads the entire node tree at the given URL. This doesn't add the tree for +// later use, but any previously-synced entries are reused. +func (c *Client) SyncTree(url string) (*Tree, error) { + le, err := parseLink(url) + if err != nil { + return nil, fmt.Errorf("invalid enrtree URL: %v", err) + } + ct := newClientTree(c, le) + t := &Tree{entries: make(map[string]entry)} + if err := ct.syncAll(t.entries); err != nil { + return nil, err + } + t.root = ct.root + return t, nil +} + +// AddTree adds a enrtree:// URL to crawl. +func (c *Client) AddTree(url string) error { + le, err := parseLink(url) + if err != nil { + return fmt.Errorf("invalid enrtree URL: %v", err) + } + ct, err := c.ensureTree(le) + if err != nil { + return err + } + c.linkCache.add(ct) + return nil +} + +func (c *Client) ensureTree(le *linkEntry) (*clientTree, error) { + if tree, ok := c.trees[le.domain]; ok { + if !tree.matchPubkey(le.pubkey) { + return nil, fmt.Errorf("conflicting public keys for domain %q", le.domain) + } + return tree, nil + } + ct := newClientTree(c, le) + c.trees[le.domain] = ct + return ct, nil +} + +// RandomNode retrieves the next random node. +func (c *Client) RandomNode(ctx context.Context) *enode.Node { + for { + ct := c.randomTree() + if ct == nil { + return nil + } + n, err := ct.syncRandom(ctx) + if err != nil { + if err == ctx.Err() { + return nil // context canceled. + } + c.cfg.Logger.Debug("Error in DNS random node sync", "tree", ct.loc.domain, "err", err) + continue + } + if n != nil { + return n + } + } +} + +// randomTree returns a random tree. +func (c *Client) randomTree() *clientTree { + if !c.linkCache.valid() { + c.gcTrees() + } + limit := rand.Intn(len(c.trees)) + for _, ct := range c.trees { + if limit == 0 { + return ct + } + limit-- + } + return nil +} + +// gcTrees rebuilds the 'trees' map. +func (c *Client) gcTrees() { + trees := make(map[string]*clientTree) + for t := range c.linkCache.all() { + trees[t.loc.domain] = t + } + c.trees = trees +} + +// resolveRoot retrieves a root entry via DNS. +func (c *Client) resolveRoot(ctx context.Context, loc *linkEntry) (rootEntry, error) { + txts, err := c.cfg.Resolver.LookupTXT(ctx, loc.domain) + c.cfg.Logger.Trace("Updating DNS discovery root", "tree", loc.domain, "err", err) + if err != nil { + return rootEntry{}, err + } + for _, txt := range txts { + if strings.HasPrefix(txt, rootPrefix) { + return parseAndVerifyRoot(txt, loc) + } + } + return rootEntry{}, nameError{loc.domain, errNoRoot} +} + +func parseAndVerifyRoot(txt string, loc *linkEntry) (rootEntry, error) { + e, err := parseRoot(txt) + if err != nil { + return e, err + } + if !e.verifySignature(loc.pubkey) { + return e, entryError{typ: "root", err: errInvalidSig} + } + return e, nil +} + +// resolveEntry retrieves an entry from the cache or fetches it from the network +// if it isn't cached. +func (c *Client) resolveEntry(ctx context.Context, domain, hash string) (entry, error) { + cacheKey := truncateHash(hash) + if e, ok := c.entries.Get(cacheKey); ok { + return e.(entry), nil + } + e, err := c.doResolveEntry(ctx, domain, hash) + if err != nil { + return nil, err + } + c.entries.Add(cacheKey, e) + return e, nil +} + +// doResolveEntry fetches an entry via DNS. +func (c *Client) doResolveEntry(ctx context.Context, domain, hash string) (entry, error) { + wantHash, err := b32format.DecodeString(hash) + if err != nil { + return nil, fmt.Errorf("invalid base32 hash") + } + name := hash + "." + domain + txts, err := c.cfg.Resolver.LookupTXT(ctx, hash+"."+domain) + c.cfg.Logger.Trace("DNS discovery lookup", "name", name, "err", err) + if err != nil { + return nil, err + } + for _, txt := range txts { + e, err := parseEntry(txt, c.cfg.ValidSchemes) + if err == errUnknownEntry { + continue + } + if !bytes.HasPrefix(crypto.Keccak256([]byte(txt)), wantHash) { + err = nameError{name, errHashMismatch} + } else if err != nil { + err = nameError{name, err} + } + return e, err + } + return nil, nameError{name, errNoEntry} +} diff --git a/p2p/dnsdisc/client_test.go b/p2p/dnsdisc/client_test.go new file mode 100644 index 000000000..d8e3ecee3 --- /dev/null +++ b/p2p/dnsdisc/client_test.go @@ -0,0 +1,316 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "context" + "crypto/ecdsa" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +const ( + signingKeySeed = 0x111111 + nodesSeed1 = 0x2945237 + nodesSeed2 = 0x4567299 +) + +func TestClientSyncTree(t *testing.T) { + r := mapResolver{ + "n": "enrtree-root:v1 e=JWXYDBPXYWG6FX3GMDIBFA6CJ4 l=C7HRFPF3BLGF3YR4DY5KX3SMBE seq=1 sig=o908WmNp7LibOfPsr4btQwatZJ5URBr2ZAuxvK4UWHlsB9sUOTJQaGAlLPVAhM__XJesCHxLISo94z5Z2a463gA", + "C7HRFPF3BLGF3YR4DY5KX3SMBE.n": "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org", + "JWXYDBPXYWG6FX3GMDIBFA6CJ4.n": "enrtree-branch:2XS2367YHAXJFGLZHVAWLQD4ZY,H4FHT4B454P6UXFD7JCYQ5PWDY,MHTDO6TMUBRIA2XWG5LUDACK24", + "2XS2367YHAXJFGLZHVAWLQD4ZY.n": "enr:-HW4QOFzoVLaFJnNhbgMoDXPnOvcdVuj7pDpqRvh6BRDO68aVi5ZcjB3vzQRZH2IcLBGHzo8uUN3snqmgTiE56CH3AMBgmlkgnY0iXNlY3AyNTZrMaECC2_24YYkYHEgdzxlSNKQEnHhuNAbNlMlWJxrJxbAFvA", + "H4FHT4B454P6UXFD7JCYQ5PWDY.n": "enr:-HW4QAggRauloj2SDLtIHN1XBkvhFZ1vtf1raYQp9TBW2RD5EEawDzbtSmlXUfnaHcvwOizhVYLtr7e6vw7NAf6mTuoCgmlkgnY0iXNlY3AyNTZrMaECjrXI8TLNXU0f8cthpAMxEshUyQlK-AM0PW2wfrnacNI", + "MHTDO6TMUBRIA2XWG5LUDACK24.n": "enr:-HW4QLAYqmrwllBEnzWWs7I5Ev2IAs7x_dZlbYdRdMUx5EyKHDXp7AV5CkuPGUPdvbv1_Ms1CPfhcGCvSElSosZmyoqAgmlkgnY0iXNlY3AyNTZrMaECriawHKWdDRk2xeZkrOXBQ0dfMFLHY4eENZwdufn1S1o", + } + var ( + wantNodes = testNodes(0x29452, 3) + wantLinks = []string{"enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org"} + wantSeq = uint(1) + ) + + c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) + stree, err := c.SyncTree("enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@n") + if err != nil { + t.Fatal("sync error:", err) + } + if !reflect.DeepEqual(sortByID(stree.Nodes()), sortByID(wantNodes)) { + t.Errorf("wrong nodes in synced tree:\nhave %v\nwant %v", spew.Sdump(stree.Nodes()), spew.Sdump(wantNodes)) + } + if !reflect.DeepEqual(stree.Links(), wantLinks) { + t.Errorf("wrong links in synced tree: %v", stree.Links()) + } + if stree.Seq() != wantSeq { + t.Errorf("synced tree has wrong seq: %d", stree.Seq()) + } + if len(c.trees) > 0 { + t.Errorf("tree from SyncTree added to client") + } +} + +// In this test, syncing the tree fails because it contains an invalid ENR entry. +func TestClientSyncTreeBadNode(t *testing.T) { + // var b strings.Builder + // b.WriteString(enrPrefix) + // b.WriteString("-----") + // badHash := subdomain(&b) + // tree, _ := MakeTree(3, nil, []string{"enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org"}) + // tree.entries[badHash] = &b + // tree.root.eroot = badHash + // url, _ := tree.Sign(testKey(signingKeySeed), "n") + // fmt.Println(url) + // fmt.Printf("%#v\n", tree.ToTXT("n")) + + r := mapResolver{ + "n": "enrtree-root:v1 e=INDMVBZEEQ4ESVYAKGIYU74EAA l=C7HRFPF3BLGF3YR4DY5KX3SMBE seq=3 sig=Vl3AmunLur0JZ3sIyJPSH6A3Vvdp4F40jWQeCmkIhmcgwE4VC5U9wpK8C_uL_CMY29fd6FAhspRvq2z_VysTLAA", + "C7HRFPF3BLGF3YR4DY5KX3SMBE.n": "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org", + "INDMVBZEEQ4ESVYAKGIYU74EAA.n": "enr:-----", + } + c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) + _, err := c.SyncTree("enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@n") + wantErr := nameError{name: "INDMVBZEEQ4ESVYAKGIYU74EAA.n", err: entryError{typ: "enr", err: errInvalidENR}} + if err != wantErr { + t.Fatalf("expected sync error %q, got %q", wantErr, err) + } +} + +// This test checks that RandomNode hits all entries. +func TestClientRandomNode(t *testing.T) { + nodes := testNodes(nodesSeed1, 30) + tree, url := makeTestTree("n", nodes, nil) + r := mapResolver(tree.ToTXT("n")) + c, _ := NewClient(Config{Resolver: r, Logger: testlog.Logger(t, log.LvlTrace)}) + if err := c.AddTree(url); err != nil { + t.Fatal(err) + } + + checkRandomNode(t, c, nodes) +} + +// This test checks that RandomNode traverses linked trees as well as explicitly added trees. +func TestClientRandomNodeLinks(t *testing.T) { + nodes := testNodes(nodesSeed1, 40) + tree1, url1 := makeTestTree("t1", nodes[:10], nil) + tree2, url2 := makeTestTree("t2", nodes[10:], []string{url1}) + cfg := Config{ + Resolver: newMapResolver(tree1.ToTXT("t1"), tree2.ToTXT("t2")), + Logger: testlog.Logger(t, log.LvlTrace), + } + c, _ := NewClient(cfg) + if err := c.AddTree(url2); err != nil { + t.Fatal(err) + } + + checkRandomNode(t, c, nodes) +} + +// This test verifies that RandomNode re-checks the root of the tree to catch +// updates to nodes. +func TestClientRandomNodeUpdates(t *testing.T) { + var ( + clock = new(mclock.Simulated) + nodes = testNodes(nodesSeed1, 30) + resolver = newMapResolver() + cfg = Config{ + Resolver: resolver, + Logger: testlog.Logger(t, log.LvlTrace), + RecheckInterval: 20 * time.Minute, + } + c, _ = NewClient(cfg) + ) + c.clock = clock + tree1, url := makeTestTree("n", nodes[:25], nil) + + // Sync the original tree. + resolver.add(tree1.ToTXT("n")) + c.AddTree(url) + checkRandomNode(t, c, nodes[:25]) + + // Update some nodes and ensure RandomNode returns the new nodes as well. + keys := testKeys(nodesSeed1, len(nodes)) + for i, n := range nodes[:len(nodes)/2] { + r := n.Record() + r.Set(enr.IP{127, 0, 0, 1}) + r.SetSeq(55) + enode.SignV4(r, keys[i]) + n2, _ := enode.New(enode.ValidSchemes, r) + nodes[i] = n2 + } + tree2, _ := makeTestTree("n", nodes, nil) + clock.Run(cfg.RecheckInterval + 1*time.Second) + resolver.clear() + resolver.add(tree2.ToTXT("n")) + checkRandomNode(t, c, nodes) +} + +// This test verifies that RandomNode re-checks the root of the tree to catch +// updates to links. +func TestClientRandomNodeLinkUpdates(t *testing.T) { + var ( + clock = new(mclock.Simulated) + nodes = testNodes(nodesSeed1, 30) + resolver = newMapResolver() + cfg = Config{ + Resolver: resolver, + Logger: testlog.Logger(t, log.LvlTrace), + RecheckInterval: 20 * time.Minute, + } + c, _ = NewClient(cfg) + ) + c.clock = clock + tree3, url3 := makeTestTree("t3", nodes[20:30], nil) + tree2, url2 := makeTestTree("t2", nodes[10:20], nil) + tree1, url1 := makeTestTree("t1", nodes[0:10], []string{url2}) + resolver.add(tree1.ToTXT("t1")) + resolver.add(tree2.ToTXT("t2")) + resolver.add(tree3.ToTXT("t3")) + + // Sync tree1 using RandomNode. + c.AddTree(url1) + checkRandomNode(t, c, nodes[:20]) + + // Add link to tree3, remove link to tree2. + tree1, _ = makeTestTree("t1", nodes[:10], []string{url3}) + resolver.add(tree1.ToTXT("t1")) + clock.Run(cfg.RecheckInterval + 1*time.Second) + t.Log("tree1 updated") + + var wantNodes []*enode.Node + wantNodes = append(wantNodes, tree1.Nodes()...) + wantNodes = append(wantNodes, tree3.Nodes()...) + checkRandomNode(t, c, wantNodes) + + // Check that linked trees are GCed when they're no longer referenced. + if len(c.trees) != 2 { + t.Errorf("client knows %d trees, want 2", len(c.trees)) + } +} + +func checkRandomNode(t *testing.T, c *Client, wantNodes []*enode.Node) { + t.Helper() + + var ( + want = make(map[enode.ID]*enode.Node) + maxCalls = len(wantNodes) * 2 + calls = 0 + ctx = context.Background() + ) + for _, n := range wantNodes { + want[n.ID()] = n + } + for ; len(want) > 0 && calls < maxCalls; calls++ { + n := c.RandomNode(ctx) + if n == nil { + t.Fatalf("RandomNode returned nil (call %d)", calls) + } + delete(want, n.ID()) + } + t.Logf("checkRandomNode called RandomNode %d times to find %d nodes", calls, len(wantNodes)) + for _, n := range want { + t.Errorf("RandomNode didn't discover node %v", n.ID()) + } +} + +func makeTestTree(domain string, nodes []*enode.Node, links []string) (*Tree, string) { + tree, err := MakeTree(1, nodes, links) + if err != nil { + panic(err) + } + url, err := tree.Sign(testKey(signingKeySeed), domain) + if err != nil { + panic(err) + } + return tree, url +} + +// testKeys creates deterministic private keys for testing. +func testKeys(seed int64, n int) []*ecdsa.PrivateKey { + rand := rand.New(rand.NewSource(seed)) + keys := make([]*ecdsa.PrivateKey, n) + for i := 0; i < n; i++ { + key, err := ecdsa.GenerateKey(crypto.S256(), rand) + if err != nil { + panic("can't generate key: " + err.Error()) + } + keys[i] = key + } + return keys +} + +func testKey(seed int64) *ecdsa.PrivateKey { + return testKeys(seed, 1)[0] +} + +func testNodes(seed int64, n int) []*enode.Node { + keys := testKeys(seed, n) + nodes := make([]*enode.Node, n) + for i, key := range keys { + record := new(enr.Record) + record.SetSeq(uint64(i)) + enode.SignV4(record, key) + n, err := enode.New(enode.ValidSchemes, record) + if err != nil { + panic(err) + } + nodes[i] = n + } + return nodes +} + +func testNode(seed int64) *enode.Node { + return testNodes(seed, 1)[0] +} + +type mapResolver map[string]string + +func newMapResolver(maps ...map[string]string) mapResolver { + mr := make(mapResolver) + for _, m := range maps { + mr.add(m) + } + return mr +} + +func (mr mapResolver) clear() { + for k := range mr { + delete(mr, k) + } +} + +func (mr mapResolver) add(m map[string]string) { + for k, v := range m { + mr[k] = v + } +} + +func (mr mapResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + if record, ok := mr[name]; ok { + return []string{record}, nil + } + return nil, nil +} diff --git a/p2p/dnsdisc/doc.go b/p2p/dnsdisc/doc.go new file mode 100644 index 000000000..227467d08 --- /dev/null +++ b/p2p/dnsdisc/doc.go @@ -0,0 +1,18 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package dnsdisc implements node discovery via DNS (EIP-1459). +package dnsdisc diff --git a/p2p/dnsdisc/error.go b/p2p/dnsdisc/error.go new file mode 100644 index 000000000..e0998c735 --- /dev/null +++ b/p2p/dnsdisc/error.go @@ -0,0 +1,63 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "errors" + "fmt" +) + +// Entry parse errors. +var ( + errUnknownEntry = errors.New("unknown entry type") + errNoPubkey = errors.New("missing public key") + errBadPubkey = errors.New("invalid public key") + errInvalidENR = errors.New("invalid node record") + errInvalidChild = errors.New("invalid child hash") + errInvalidSig = errors.New("invalid base64 signature") + errSyntax = errors.New("invalid syntax") +) + +// Resolver/sync errors +var ( + errNoRoot = errors.New("no valid root found") + errNoEntry = errors.New("no valid tree entry found") + errHashMismatch = errors.New("hash mismatch") + errENRInLinkTree = errors.New("enr entry in link tree") + errLinkInENRTree = errors.New("link entry in ENR tree") +) + +type nameError struct { + name string + err error +} + +func (err nameError) Error() string { + if ee, ok := err.err.(entryError); ok { + return fmt.Sprintf("invalid %s entry at %s: %v", ee.typ, err.name, ee.err) + } + return err.name + ": " + err.err.Error() +} + +type entryError struct { + typ string + err error +} + +func (err entryError) Error() string { + return fmt.Sprintf("invalid %s entry: %v", err.typ, err.err) +} diff --git a/p2p/dnsdisc/sync.go b/p2p/dnsdisc/sync.go new file mode 100644 index 000000000..533dacc65 --- /dev/null +++ b/p2p/dnsdisc/sync.go @@ -0,0 +1,277 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "context" + "crypto/ecdsa" + "math/rand" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// clientTree is a full tree being synced. +type clientTree struct { + c *Client + loc *linkEntry + root *rootEntry + lastRootCheck mclock.AbsTime // last revalidation of root + enrs *subtreeSync + links *subtreeSync + linkCache linkCache +} + +func newClientTree(c *Client, loc *linkEntry) *clientTree { + ct := &clientTree{c: c, loc: loc} + ct.linkCache.self = ct + return ct +} + +func (ct *clientTree) matchPubkey(key *ecdsa.PublicKey) bool { + return keysEqual(ct.loc.pubkey, key) +} + +func keysEqual(k1, k2 *ecdsa.PublicKey) bool { + return k1.Curve == k2.Curve && k1.X.Cmp(k2.X) == 0 && k1.Y.Cmp(k2.Y) == 0 +} + +// syncAll retrieves all entries of the tree. +func (ct *clientTree) syncAll(dest map[string]entry) error { + if err := ct.updateRoot(); err != nil { + return err + } + if err := ct.links.resolveAll(dest); err != nil { + return err + } + if err := ct.enrs.resolveAll(dest); err != nil { + return err + } + return nil +} + +// syncRandom retrieves a single entry of the tree. The Node return value +// is non-nil if the entry was a node. +func (ct *clientTree) syncRandom(ctx context.Context) (*enode.Node, error) { + if ct.rootUpdateDue() { + if err := ct.updateRoot(); err != nil { + return nil, err + } + } + // Link tree sync has priority, run it to completion before syncing ENRs. + if !ct.links.done() { + err := ct.syncNextLink(ctx) + return nil, err + } + + // Sync next random entry in ENR tree. Once every node has been visited, we simply + // start over. This is fine because entries are cached. + if ct.enrs.done() { + ct.enrs = newSubtreeSync(ct.c, ct.loc, ct.root.eroot, false) + } + return ct.syncNextRandomENR(ctx) +} + +func (ct *clientTree) syncNextLink(ctx context.Context) error { + hash := ct.links.missing[0] + e, err := ct.links.resolveNext(ctx, hash) + if err != nil { + return err + } + ct.links.missing = ct.links.missing[1:] + + if le, ok := e.(*linkEntry); ok { + lt, err := ct.c.ensureTree(le) + if err != nil { + return err + } + ct.linkCache.add(lt) + } + return nil +} + +func (ct *clientTree) syncNextRandomENR(ctx context.Context) (*enode.Node, error) { + index := rand.Intn(len(ct.enrs.missing)) + hash := ct.enrs.missing[index] + e, err := ct.enrs.resolveNext(ctx, hash) + if err != nil { + return nil, err + } + ct.enrs.missing = removeHash(ct.enrs.missing, index) + if ee, ok := e.(*enrEntry); ok { + return ee.node, nil + } + return nil, nil +} + +func (ct *clientTree) String() string { + return ct.loc.String() +} + +// removeHash removes the element at index from h. +func removeHash(h []string, index int) []string { + if len(h) == 1 { + return nil + } + last := len(h) - 1 + if index < last { + h[index] = h[last] + h[last] = "" + } + return h[:last] +} + +// updateRoot ensures that the given tree has an up-to-date root. +func (ct *clientTree) updateRoot() error { + ct.lastRootCheck = ct.c.clock.Now() + ctx, cancel := context.WithTimeout(context.Background(), ct.c.cfg.Timeout) + defer cancel() + root, err := ct.c.resolveRoot(ctx, ct.loc) + if err != nil { + return err + } + ct.root = &root + + // Invalidate subtrees if changed. + if ct.links == nil || root.lroot != ct.links.root { + ct.links = newSubtreeSync(ct.c, ct.loc, root.lroot, true) + ct.linkCache.reset() + } + if ct.enrs == nil || root.eroot != ct.enrs.root { + ct.enrs = newSubtreeSync(ct.c, ct.loc, root.eroot, false) + } + return nil +} + +// rootUpdateDue returns true when a root update is needed. +func (ct *clientTree) rootUpdateDue() bool { + return ct.root == nil || time.Duration(ct.c.clock.Now()-ct.lastRootCheck) > ct.c.cfg.RecheckInterval +} + +// subtreeSync is the sync of an ENR or link subtree. +type subtreeSync struct { + c *Client + loc *linkEntry + root string + missing []string // missing tree node hashes + link bool // true if this sync is for the link tree +} + +func newSubtreeSync(c *Client, loc *linkEntry, root string, link bool) *subtreeSync { + return &subtreeSync{c, loc, root, []string{root}, link} +} + +func (ts *subtreeSync) done() bool { + return len(ts.missing) == 0 +} + +func (ts *subtreeSync) resolveAll(dest map[string]entry) error { + for !ts.done() { + hash := ts.missing[0] + ctx, cancel := context.WithTimeout(context.Background(), ts.c.cfg.Timeout) + e, err := ts.resolveNext(ctx, hash) + cancel() + if err != nil { + return err + } + dest[hash] = e + ts.missing = ts.missing[1:] + } + return nil +} + +func (ts *subtreeSync) resolveNext(ctx context.Context, hash string) (entry, error) { + e, err := ts.c.resolveEntry(ctx, ts.loc.domain, hash) + if err != nil { + return nil, err + } + switch e := e.(type) { + case *enrEntry: + if ts.link { + return nil, errENRInLinkTree + } + case *linkEntry: + if !ts.link { + return nil, errLinkInENRTree + } + case *branchEntry: + ts.missing = append(ts.missing, e.children...) + } + return e, nil +} + +// linkCache tracks the links of a tree. +type linkCache struct { + self *clientTree + directM map[*clientTree]struct{} // direct links + allM map[*clientTree]struct{} // direct & transitive links +} + +// reset clears the cache. +func (lc *linkCache) reset() { + lc.directM = nil + lc.allM = nil +} + +// add adds a direct link to the cache. +func (lc *linkCache) add(ct *clientTree) { + if lc.directM == nil { + lc.directM = make(map[*clientTree]struct{}) + } + if _, ok := lc.directM[ct]; !ok { + lc.invalidate() + } + lc.directM[ct] = struct{}{} +} + +// invalidate resets the cache of transitive links. +func (lc *linkCache) invalidate() { + lc.allM = nil +} + +// valid returns true when the cache of transitive links is up-to-date. +func (lc *linkCache) valid() bool { + // Re-check validity of child caches to catch updates. + for ct := range lc.allM { + if ct != lc.self && !ct.linkCache.valid() { + lc.allM = nil + break + } + } + return lc.allM != nil +} + +// all returns all trees reachable through the cache. +func (lc *linkCache) all() map[*clientTree]struct{} { + if lc.valid() { + return lc.allM + } + // Remake lc.allM it by taking the union of all() across children. + m := make(map[*clientTree]struct{}) + if lc.self != nil { + m[lc.self] = struct{}{} + } + for ct := range lc.directM { + m[ct] = struct{}{} + for lt := range ct.linkCache.all() { + m[lt] = struct{}{} + } + } + lc.allM = m + return m +} diff --git a/p2p/dnsdisc/tree.go b/p2p/dnsdisc/tree.go new file mode 100644 index 000000000..eba2ff9c0 --- /dev/null +++ b/p2p/dnsdisc/tree.go @@ -0,0 +1,385 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "bytes" + "crypto/ecdsa" + "encoding/base32" + "encoding/base64" + "fmt" + "io" + "sort" + "strings" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/crypto/sha3" +) + +// Tree is a merkle tree of node records. +type Tree struct { + root *rootEntry + entries map[string]entry +} + +// Sign signs the tree with the given private key and sets the sequence number. +func (t *Tree) Sign(key *ecdsa.PrivateKey, domain string) (url string, err error) { + root := *t.root + sig, err := crypto.Sign(root.sigHash(), key) + if err != nil { + return "", err + } + root.sig = sig + t.root = &root + link := &linkEntry{domain, &key.PublicKey} + return link.String(), nil +} + +// SetSignature verifies the given signature and assigns it as the tree's current +// signature if valid. +func (t *Tree) SetSignature(pubkey *ecdsa.PublicKey, signature string) error { + sig, err := b64format.DecodeString(signature) + if err != nil || len(sig) != crypto.SignatureLength { + return errInvalidSig + } + root := *t.root + root.sig = sig + if !root.verifySignature(pubkey) { + return errInvalidSig + } + t.root = &root + return nil +} + +// Seq returns the sequence number of the tree. +func (t *Tree) Seq() uint { + return t.root.seq +} + +// Signature returns the signature of the tree. +func (t *Tree) Signature() string { + return b64format.EncodeToString(t.root.sig) +} + +// ToTXT returns all DNS TXT records required for the tree. +func (t *Tree) ToTXT(domain string) map[string]string { + records := map[string]string{domain: t.root.String()} + for _, e := range t.entries { + sd := subdomain(e) + if domain != "" { + sd = sd + "." + domain + } + records[sd] = e.String() + } + return records +} + +// Links returns all links contained in the tree. +func (t *Tree) Links() []string { + var links []string + for _, e := range t.entries { + if le, ok := e.(*linkEntry); ok { + links = append(links, le.String()) + } + } + return links +} + +// Nodes returns all nodes contained in the tree. +func (t *Tree) Nodes() []*enode.Node { + var nodes []*enode.Node + for _, e := range t.entries { + if ee, ok := e.(*enrEntry); ok { + nodes = append(nodes, ee.node) + } + } + return nodes +} + +const ( + hashAbbrev = 16 + maxChildren = 300 / hashAbbrev * (13 / 8) + minHashLength = 12 +) + +// MakeTree creates a tree containing the given nodes and links. +func MakeTree(seq uint, nodes []*enode.Node, links []string) (*Tree, error) { + // Sort records by ID and ensure all nodes have a valid record. + records := make([]*enode.Node, len(nodes)) + + copy(records, nodes) + sortByID(records) + for _, n := range records { + if len(n.Record().Signature()) == 0 { + return nil, fmt.Errorf("can't add node %v: unsigned node record", n.ID()) + } + } + + // Create the leaf list. + enrEntries := make([]entry, len(records)) + for i, r := range records { + enrEntries[i] = &enrEntry{r} + } + linkEntries := make([]entry, len(links)) + for i, l := range links { + le, err := parseLink(l) + if err != nil { + return nil, err + } + linkEntries[i] = le + } + + // Create intermediate nodes. + t := &Tree{entries: make(map[string]entry)} + eroot := t.build(enrEntries) + t.entries[subdomain(eroot)] = eroot + lroot := t.build(linkEntries) + t.entries[subdomain(lroot)] = lroot + t.root = &rootEntry{seq: seq, eroot: subdomain(eroot), lroot: subdomain(lroot)} + return t, nil +} + +func (t *Tree) build(entries []entry) entry { + if len(entries) == 1 { + return entries[0] + } + if len(entries) <= maxChildren { + hashes := make([]string, len(entries)) + for i, e := range entries { + hashes[i] = subdomain(e) + t.entries[hashes[i]] = e + } + return &branchEntry{hashes} + } + var subtrees []entry + for len(entries) > 0 { + n := maxChildren + if len(entries) < n { + n = len(entries) + } + sub := t.build(entries[:n]) + entries = entries[n:] + subtrees = append(subtrees, sub) + t.entries[subdomain(sub)] = sub + } + return t.build(subtrees) +} + +func sortByID(nodes []*enode.Node) []*enode.Node { + sort.Slice(nodes, func(i, j int) bool { + return bytes.Compare(nodes[i].ID().Bytes(), nodes[j].ID().Bytes()) < 0 + }) + return nodes +} + +// Entry Types + +type entry interface { + fmt.Stringer +} + +type ( + rootEntry struct { + eroot string + lroot string + seq uint + sig []byte + } + branchEntry struct { + children []string + } + enrEntry struct { + node *enode.Node + } + linkEntry struct { + domain string + pubkey *ecdsa.PublicKey + } +) + +// Entry Encoding + +var ( + b32format = base32.StdEncoding.WithPadding(base32.NoPadding) + b64format = base64.RawURLEncoding +) + +const ( + rootPrefix = "enrtree-root:v1" + linkPrefix = "enrtree://" + branchPrefix = "enrtree-branch:" + enrPrefix = "enr:" +) + +func subdomain(e entry) string { + h := sha3.NewLegacyKeccak256() + io.WriteString(h, e.String()) + return b32format.EncodeToString(h.Sum(nil)[:16]) +} + +func (e *rootEntry) String() string { + return fmt.Sprintf(rootPrefix+" e=%s l=%s seq=%d sig=%s", e.eroot, e.lroot, e.seq, b64format.EncodeToString(e.sig)) +} + +func (e *rootEntry) sigHash() []byte { + h := sha3.NewLegacyKeccak256() + fmt.Fprintf(h, rootPrefix+" e=%s l=%s seq=%d", e.eroot, e.lroot, e.seq) + return h.Sum(nil) +} + +func (e *rootEntry) verifySignature(pubkey *ecdsa.PublicKey) bool { + sig := e.sig[:crypto.RecoveryIDOffset] // remove recovery id + return crypto.VerifySignature(crypto.FromECDSAPub(pubkey), e.sigHash(), sig) +} + +func (e *branchEntry) String() string { + return branchPrefix + strings.Join(e.children, ",") +} + +func (e *enrEntry) String() string { + return e.node.String() +} + +func (e *linkEntry) String() string { + pubkey := b32format.EncodeToString(crypto.CompressPubkey(e.pubkey)) + return fmt.Sprintf("%s%s@%s", linkPrefix, pubkey, e.domain) +} + +// Entry Parsing + +func parseEntry(e string, validSchemes enr.IdentityScheme) (entry, error) { + switch { + case strings.HasPrefix(e, linkPrefix): + return parseLinkEntry(e) + case strings.HasPrefix(e, branchPrefix): + return parseBranch(e) + case strings.HasPrefix(e, enrPrefix): + return parseENR(e, validSchemes) + default: + return nil, errUnknownEntry + } +} + +func parseRoot(e string) (rootEntry, error) { + var eroot, lroot, sig string + var seq uint + if _, err := fmt.Sscanf(e, rootPrefix+" e=%s l=%s seq=%d sig=%s", &eroot, &lroot, &seq, &sig); err != nil { + return rootEntry{}, entryError{"root", errSyntax} + } + if !isValidHash(eroot) || !isValidHash(lroot) { + return rootEntry{}, entryError{"root", errInvalidChild} + } + sigb, err := b64format.DecodeString(sig) + if err != nil || len(sigb) != crypto.SignatureLength { + return rootEntry{}, entryError{"root", errInvalidSig} + } + return rootEntry{eroot, lroot, seq, sigb}, nil +} + +func parseLinkEntry(e string) (entry, error) { + le, err := parseLink(e) + if err != nil { + return nil, err + } + return le, nil +} + +func parseLink(e string) (*linkEntry, error) { + if !strings.HasPrefix(e, linkPrefix) { + return nil, fmt.Errorf("wrong/missing scheme 'enrtree' in URL") + } + e = e[len(linkPrefix):] + pos := strings.IndexByte(e, '@') + if pos == -1 { + return nil, entryError{"link", errNoPubkey} + } + keystring, domain := e[:pos], e[pos+1:] + keybytes, err := b32format.DecodeString(keystring) + if err != nil { + return nil, entryError{"link", errBadPubkey} + } + key, err := crypto.DecompressPubkey(keybytes) + if err != nil { + return nil, entryError{"link", errBadPubkey} + } + return &linkEntry{domain, key}, nil +} + +func parseBranch(e string) (entry, error) { + e = e[len(branchPrefix):] + if e == "" { + return &branchEntry{}, nil // empty entry is OK + } + hashes := make([]string, 0, strings.Count(e, ",")) + for _, c := range strings.Split(e, ",") { + if !isValidHash(c) { + return nil, entryError{"branch", errInvalidChild} + } + hashes = append(hashes, c) + } + return &branchEntry{hashes}, nil +} + +func parseENR(e string, validSchemes enr.IdentityScheme) (entry, error) { + e = e[len(enrPrefix):] + enc, err := b64format.DecodeString(e) + if err != nil { + return nil, entryError{"enr", errInvalidENR} + } + var rec enr.Record + if err := rlp.DecodeBytes(enc, &rec); err != nil { + return nil, entryError{"enr", err} + } + n, err := enode.New(validSchemes, &rec) + if err != nil { + return nil, entryError{"enr", err} + } + return &enrEntry{n}, nil +} + +func isValidHash(s string) bool { + dlen := b32format.DecodedLen(len(s)) + if dlen < minHashLength || dlen > 32 || strings.ContainsAny(s, "\n\r") { + return false + } + buf := make([]byte, 32) + _, err := b32format.Decode(buf, []byte(s)) + return err == nil +} + +// truncateHash truncates the given base32 hash string to the minimum acceptable length. +func truncateHash(hash string) string { + maxLen := b32format.EncodedLen(minHashLength) + if len(hash) < maxLen { + panic(fmt.Errorf("dnsdisc: hash %q is too short", hash)) + } + return hash[:maxLen] +} + +// URL encoding + +// ParseURL parses an enrtree:// URL and returns its components. +func ParseURL(url string) (domain string, pubkey *ecdsa.PublicKey, err error) { + le, err := parseLink(url) + if err != nil { + return "", nil, err + } + return le.domain, le.pubkey, nil +} diff --git a/p2p/dnsdisc/tree_test.go b/p2p/dnsdisc/tree_test.go new file mode 100644 index 000000000..b6d0a8433 --- /dev/null +++ b/p2p/dnsdisc/tree_test.go @@ -0,0 +1,144 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package dnsdisc + +import ( + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +func TestParseRoot(t *testing.T) { + tests := []struct { + input string + e rootEntry + err error + }{ + { + input: "enrtree-root:v1 e=TO4Q75OQ2N7DX4EOOR7X66A6OM seq=3 sig=N-YY6UB9xD0hFx1Gmnt7v0RfSxch5tKyry2SRDoLx7B4GfPXagwLxQqyf7gAMvApFn_ORwZQekMWa_pXrcGCtw", + err: entryError{"root", errSyntax}, + }, + { + input: "enrtree-root:v1 e=TO4Q75OQ2N7DX4EOOR7X66A6OM l=TO4Q75OQ2N7DX4EOOR7X66A6OM seq=3 sig=N-YY6UB9xD0hFx1Gmnt7v0RfSxch5tKyry2SRDoLx7B4GfPXagwLxQqyf7gAMvApFn_ORwZQekMWa_pXrcGCtw", + err: entryError{"root", errInvalidSig}, + }, + { + input: "enrtree-root:v1 e=QFT4PBCRX4XQCV3VUYJ6BTCEPU l=JGUFMSAGI7KZYB3P7IZW4S5Y3A seq=3 sig=3FmXuVwpa8Y7OstZTx9PIb1mt8FrW7VpDOFv4AaGCsZ2EIHmhraWhe4NxYhQDlw5MjeFXYMbJjsPeKlHzmJREQE", + e: rootEntry{ + eroot: "QFT4PBCRX4XQCV3VUYJ6BTCEPU", + lroot: "JGUFMSAGI7KZYB3P7IZW4S5Y3A", + seq: 3, + sig: hexutil.MustDecode("0xdc5997b95c296bc63b3acb594f1f4f21bd66b7c16b5bb5690ce16fe006860ac6761081e686b69685ee0dc588500e5c393237855d831b263b0f78a947ce62511101"), + }, + }, + } + for i, test := range tests { + e, err := parseRoot(test.input) + if !reflect.DeepEqual(e, test.e) { + t.Errorf("test %d: wrong entry %s, want %s", i, spew.Sdump(e), spew.Sdump(test.e)) + } + if err != test.err { + t.Errorf("test %d: wrong error %q, want %q", i, err, test.err) + } + } +} + +func TestParseEntry(t *testing.T) { + testkey := testKey(signingKeySeed) + tests := []struct { + input string + e entry + err error + }{ + // Subtrees: + { + input: "enrtree-branch:1,2", + err: entryError{"branch", errInvalidChild}, + }, + { + input: "enrtree-branch:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + err: entryError{"branch", errInvalidChild}, + }, + { + input: "enrtree-branch:", + e: &branchEntry{}, + }, + { + input: "enrtree-branch:AAAAAAAAAAAAAAAAAAAA", + e: &branchEntry{[]string{"AAAAAAAAAAAAAAAAAAAA"}}, + }, + { + input: "enrtree-branch:AAAAAAAAAAAAAAAAAAAA,BBBBBBBBBBBBBBBBBBBB", + e: &branchEntry{[]string{"AAAAAAAAAAAAAAAAAAAA", "BBBBBBBBBBBBBBBBBBBB"}}, + }, + // Links + { + input: "enrtree://AKPYQIUQIL7PSIACI32J7FGZW56E5FKHEFCCOFHILBIMW3M6LWXS2@nodes.example.org", + e: &linkEntry{"nodes.example.org", &testkey.PublicKey}, + }, + { + input: "enrtree://nodes.example.org", + err: entryError{"link", errNoPubkey}, + }, + { + input: "enrtree://AP62DT7WOTEQZGQZOU474PP3KMEGVTTE7A7NPRXKX3DUD57@nodes.example.org", + err: entryError{"link", errBadPubkey}, + }, + { + input: "enrtree://AP62DT7WONEQZGQZOU474PP3KMEGVTTE7A7NPRXKX3DUD57TQHGIA@nodes.example.org", + err: entryError{"link", errBadPubkey}, + }, + // ENRs + { + input: "enr:-HW4QES8QIeXTYlDzbfr1WEzE-XKY4f8gJFJzjJL-9D7TC9lJb4Z3JPRRz1lP4pL_N_QpT6rGQjAU9Apnc-C1iMP36OAgmlkgnY0iXNlY3AyNTZrMaED5IdwfMxdmR8W37HqSFdQLjDkIwBd4Q_MjxgZifgKSdM", + e: &enrEntry{node: testNode(nodesSeed1)}, + }, + { + input: "enr:-HW4QLZHjM4vZXkbp-5xJoHsKSbE7W39FPC8283X-y8oHcHPTnDDlIlzL5ArvDUlHZVDPgmFASrh7cWgLOLxj4wprRkHgmlkgnY0iXNlY3AyNTZrMaEC3t2jLMhDpCDX5mbSEwDn4L3iUfyXzoO8G28XvjGRkrAg=", + err: entryError{"enr", errInvalidENR}, + }, + // Invalid: + {input: "", err: errUnknownEntry}, + {input: "foo", err: errUnknownEntry}, + {input: "enrtree", err: errUnknownEntry}, + {input: "enrtree-x=", err: errUnknownEntry}, + } + for i, test := range tests { + e, err := parseEntry(test.input, enode.ValidSchemes) + if !reflect.DeepEqual(e, test.e) { + t.Errorf("test %d: wrong entry %s, want %s", i, spew.Sdump(e), spew.Sdump(test.e)) + } + if err != test.err { + t.Errorf("test %d: wrong error %q, want %q", i, err, test.err) + } + } +} + +func TestMakeTree(t *testing.T) { + nodes := testNodes(nodesSeed2, 50) + tree, err := MakeTree(2, nodes, nil) + if err != nil { + t.Fatal(err) + } + txt := tree.ToTXT("") + if len(txt) < len(nodes)+1 { + t.Fatal("too few TXT records in output") + } +} diff --git a/p2p/enode/iter.go b/p2p/enode/iter.go new file mode 100644 index 000000000..112b76d06 --- /dev/null +++ b/p2p/enode/iter.go @@ -0,0 +1,286 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "sync" + "time" +) + +// Iterator represents a sequence of nodes. The Next method moves to the next node in the +// sequence. It returns false when the sequence has ended or the iterator is closed. Close +// may be called concurrently with Next and Node, and interrupts Next if it is blocked. +type Iterator interface { + Next() bool // moves to next node + Node() *Node // returns current node + Close() // ends the iterator +} + +// ReadNodes reads at most n nodes from the given iterator. The return value contains no +// duplicates and no nil values. To prevent looping indefinitely for small repeating node +// sequences, this function calls Next at most n times. +func ReadNodes(it Iterator, n int) []*Node { + seen := make(map[ID]*Node, n) + for i := 0; i < n && it.Next(); i++ { + // Remove duplicates, keeping the node with higher seq. + node := it.Node() + prevNode, ok := seen[node.ID()] + if ok && prevNode.Seq() > node.Seq() { + continue + } + seen[node.ID()] = node + } + result := make([]*Node, 0, len(seen)) + for _, node := range seen { + result = append(result, node) + } + return result +} + +// IterNodes makes an iterator which runs through the given nodes once. +func IterNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1} +} + +// CycleNodes makes an iterator which cycles through the given nodes indefinitely. +func CycleNodes(nodes []*Node) Iterator { + return &sliceIter{nodes: nodes, index: -1, cycle: true} +} + +type sliceIter struct { + mu sync.Mutex + nodes []*Node + index int + cycle bool +} + +func (it *sliceIter) Next() bool { + it.mu.Lock() + defer it.mu.Unlock() + + if len(it.nodes) == 0 { + return false + } + it.index++ + if it.index == len(it.nodes) { + if it.cycle { + it.index = 0 + } else { + it.nodes = nil + return false + } + } + return true +} + +func (it *sliceIter) Node() *Node { + if len(it.nodes) == 0 { + return nil + } + return it.nodes[it.index] +} + +func (it *sliceIter) Close() { + it.mu.Lock() + defer it.mu.Unlock() + + it.nodes = nil +} + +// Filter wraps an iterator such that Next only returns nodes for which +// the 'check' function returns true. +func Filter(it Iterator, check func(*Node) bool) Iterator { + return &filterIter{it, check} +} + +type filterIter struct { + Iterator + check func(*Node) bool +} + +func (f *filterIter) Next() bool { + for f.Iterator.Next() { + if f.check(f.Node()) { + return true + } + } + return false +} + +// FairMix aggregates multiple node iterators. The mixer itself is an iterator which ends +// only when Close is called. Source iterators added via AddSource are removed from the +// mix when they end. +// +// The distribution of nodes returned by Next is approximately fair, i.e. FairMix +// attempts to draw from all sources equally often. However, if a certain source is slow +// and doesn't return a node within the configured timeout, a node from any other source +// will be returned. +// +// It's safe to call AddSource and Close concurrently with Next. +type FairMix struct { + wg sync.WaitGroup + fromAny chan *Node + timeout time.Duration + cur *Node + + mu sync.Mutex + closed chan struct{} + sources []*mixSource + last int +} + +type mixSource struct { + it Iterator + next chan *Node + timeout time.Duration +} + +// NewFairMix creates a mixer. +// +// The timeout specifies how long the mixer will wait for the next fairly-chosen source +// before giving up and taking a node from any other source. A good way to set the timeout +// is deciding how long you'd want to wait for a node on average. Passing a negative +// timeout makes the mixer completely fair. +func NewFairMix(timeout time.Duration) *FairMix { + m := &FairMix{ + fromAny: make(chan *Node), + closed: make(chan struct{}), + timeout: timeout, + } + return m +} + +// AddSource adds a source of nodes. +func (m *FairMix) AddSource(it Iterator) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + m.wg.Add(1) + source := &mixSource{it, make(chan *Node), m.timeout} + m.sources = append(m.sources, source) + go m.runSource(m.closed, source) +} + +// Close shuts down the mixer and all current sources. +// Calling this is required to release resources associated with the mixer. +func (m *FairMix) Close() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed == nil { + return + } + for _, s := range m.sources { + s.it.Close() + } + close(m.closed) + m.wg.Wait() + close(m.fromAny) + m.sources = nil + m.closed = nil +} + +// Next returns a node from a random source. +func (m *FairMix) Next() bool { + m.cur = nil + + var timeout <-chan time.Time + if m.timeout >= 0 { + timer := time.NewTimer(m.timeout) + timeout = timer.C + defer timer.Stop() + } + for { + source := m.pickSource() + if source == nil { + return m.nextFromAny() + } + select { + case n, ok := <-source.next: + if ok { + m.cur = n + source.timeout = m.timeout + return true + } + // This source has ended. + m.deleteSource(source) + case <-timeout: + source.timeout /= 2 + return m.nextFromAny() + } + } +} + +// Node returns the current node. +func (m *FairMix) Node() *Node { + return m.cur +} + +// nextFromAny is used when there are no sources or when the 'fair' choice +// doesn't turn up a node quickly enough. +func (m *FairMix) nextFromAny() bool { + n, ok := <-m.fromAny + if ok { + m.cur = n + } + return ok +} + +// pickSource chooses the next source to read from, cycling through them in order. +func (m *FairMix) pickSource() *mixSource { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.sources) == 0 { + return nil + } + m.last = (m.last + 1) % len(m.sources) + return m.sources[m.last] +} + +// deleteSource deletes a source. +func (m *FairMix) deleteSource(s *mixSource) { + m.mu.Lock() + defer m.mu.Unlock() + + for i := range m.sources { + if m.sources[i] == s { + copy(m.sources[i:], m.sources[i+1:]) + m.sources[len(m.sources)-1] = nil + m.sources = m.sources[:len(m.sources)-1] + break + } + } +} + +// runSource reads a single source in a loop. +func (m *FairMix) runSource(closed chan struct{}, s *mixSource) { + defer m.wg.Done() + defer close(s.next) + for s.it.Next() { + n := s.it.Node() + select { + case s.next <- n: + case m.fromAny <- n: + case <-closed: + return + } + } +} diff --git a/p2p/enode/iter_test.go b/p2p/enode/iter_test.go new file mode 100644 index 000000000..6009661f3 --- /dev/null +++ b/p2p/enode/iter_test.go @@ -0,0 +1,291 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package enode + +import ( + "encoding/binary" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/p2p/enr" +) + +func TestReadNodes(t *testing.T) { + nodes := ReadNodes(new(genIter), 10) + checkNodes(t, nodes, 10) +} + +// This test checks that ReadNodes terminates when reading N nodes from an iterator +// which returns less than N nodes in an endless cycle. +func TestReadNodesCycle(t *testing.T) { + iter := &callCountIter{ + Iterator: CycleNodes([]*Node{ + testNode(0, 0), + testNode(1, 0), + testNode(2, 0), + }), + } + nodes := ReadNodes(iter, 10) + checkNodes(t, nodes, 3) + if iter.count != 10 { + t.Fatalf("%d calls to Next, want %d", iter.count, 100) + } +} + +func TestFilterNodes(t *testing.T) { + nodes := make([]*Node, 100) + for i := range nodes { + nodes[i] = testNode(uint64(i), uint64(i)) + } + + it := Filter(IterNodes(nodes), func(n *Node) bool { + return n.Seq() >= 50 + }) + for i := 50; i < len(nodes); i++ { + if !it.Next() { + t.Fatal("Next returned false") + } + if it.Node() != nodes[i] { + t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i]) + } + } + if it.Next() { + t.Fatal("Next returned true after underlying iterator has ended") + } +} + +func checkNodes(t *testing.T, nodes []*Node, wantLen int) { + if len(nodes) != wantLen { + t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen) + return + } + seen := make(map[ID]bool) + for i, e := range nodes { + if e == nil { + t.Errorf("nil node at index %d", i) + return + } + if seen[e.ID()] { + t.Errorf("slice has duplicate node %v", e.ID()) + return + } + seen[e.ID()] = true + } +} + +// This test checks fairness of FairMix in the happy case where all sources return nodes +// within the context's deadline. +func TestFairMix(t *testing.T) { + for i := 0; i < 500; i++ { + testMixerFairness(t) + } +} + +func testMixerFairness(t *testing.T) { + mix := NewFairMix(1 * time.Second) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(&genIter{index: 2}) + mix.AddSource(&genIter{index: 3}) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + // Verify that the nodes slice contains an approximately equal number of nodes + // from each source. + d := idPrefixDistribution(nodes) + for _, count := range d { + if approxEqual(count, len(nodes)/3, 30) { + t.Fatalf("ID distribution is unfair: %v", d) + } + } +} + +// This test checks that FairMix falls back to an alternative source when +// the 'fair' choice doesn't return a node within the timeout. +func TestFairMixNextFromAll(t *testing.T) { + mix := NewFairMix(1 * time.Millisecond) + mix.AddSource(&genIter{index: 1}) + mix.AddSource(CycleNodes(nil)) + defer mix.Close() + + nodes := ReadNodes(mix, 500) + checkNodes(t, nodes, 500) + + d := idPrefixDistribution(nodes) + if len(d) > 1 || d[1] != len(nodes) { + t.Fatalf("wrong ID distribution: %v", d) + } +} + +// This test ensures FairMix works for Next with no sources. +func TestFairMixEmpty(t *testing.T) { + var ( + mix = NewFairMix(1 * time.Second) + testN = testNode(1, 1) + ch = make(chan *Node) + ) + defer mix.Close() + + go func() { + mix.Next() + ch <- mix.Node() + }() + + mix.AddSource(CycleNodes([]*Node{testN})) + if n := <-ch; n != testN { + t.Errorf("got wrong node: %v", n) + } +} + +// This test checks closing a source while Next runs. +func TestFairMixRemoveSource(t *testing.T) { + mix := NewFairMix(1 * time.Second) + source := make(blockingIter) + mix.AddSource(source) + + sig := make(chan *Node) + go func() { + <-sig + mix.Next() + sig <- mix.Node() + }() + + sig <- nil + runtime.Gosched() + source.Close() + + wantNode := testNode(0, 0) + mix.AddSource(CycleNodes([]*Node{wantNode})) + n := <-sig + + if len(mix.sources) != 1 { + t.Fatalf("have %d sources, want one", len(mix.sources)) + } + if n != wantNode { + t.Fatalf("mixer returned wrong node") + } +} + +type blockingIter chan struct{} + +func (it blockingIter) Next() bool { + <-it + return false +} + +func (it blockingIter) Node() *Node { + return nil +} + +func (it blockingIter) Close() { + close(it) +} + +func TestFairMixClose(t *testing.T) { + for i := 0; i < 20 && !t.Failed(); i++ { + testMixerClose(t) + } +} + +func testMixerClose(t *testing.T) { + mix := NewFairMix(-1) + mix.AddSource(CycleNodes(nil)) + mix.AddSource(CycleNodes(nil)) + + done := make(chan struct{}) + go func() { + defer close(done) + if mix.Next() { + t.Error("Next returned true") + } + }() + // This call is supposed to make it more likely that NextNode is + // actually executing by the time we call Close. + runtime.Gosched() + + mix.Close() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("Next didn't unblock on Close") + } + + mix.Close() // shouldn't crash +} + +func idPrefixDistribution(nodes []*Node) map[uint32]int { + d := make(map[uint32]int) + for _, node := range nodes { + id := node.ID() + d[binary.BigEndian.Uint32(id[:4])]++ + } + return d +} + +func approxEqual(x, y, ε int) bool { + if y > x { + x, y = y, x + } + return x-y > ε +} + +// genIter creates fake nodes with numbered IDs based on 'index' and 'gen' +type genIter struct { + node *Node + index, gen uint32 +} + +func (s *genIter) Next() bool { + index := atomic.LoadUint32(&s.index) + if index == ^uint32(0) { + s.node = nil + return false + } + s.node = testNode(uint64(index)<<32|uint64(s.gen), 0) + s.gen++ + return true +} + +func (s *genIter) Node() *Node { + return s.node +} + +func (s *genIter) Close() { + s.index = ^uint32(0) +} + +func testNode(id, seq uint64) *Node { + var nodeID ID + binary.BigEndian.PutUint64(nodeID[:], id) + r := new(enr.Record) + r.SetSeq(seq) + return SignNull(r, nodeID) +} + +// callCountIter counts calls to NextNode. +type callCountIter struct { + Iterator + count int +} + +func (it *callCountIter) Next() bool { + it.count++ + return it.Iterator.Next() +} diff --git a/p2p/message.go b/p2p/message.go index b98773222..10b55a939 100644 --- a/p2p/message.go +++ b/p2p/message.go @@ -39,9 +39,13 @@ import ( // separate Msg with a bytes.Reader as Payload for each send. type Msg struct { Code uint64 - Size uint32 // size of the paylod + Size uint32 // Size of the raw payload Payload io.Reader ReceivedAt time.Time + + meterCap Cap // Protocol name and version for egress metering + meterCode uint64 // Message within protocol for egress metering + meterSize uint32 // Compressed message size for ingress metering } // Decode parses the RLP content of a message into diff --git a/p2p/metrics.go b/p2p/metrics.go index c04e5ab4c..8b29efdcd 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -45,7 +45,7 @@ var ( ingressTrafficMeter = metrics.NewRegisteredMeter(MetricsInboundTraffic, nil) // Meter metering the cumulative ingress traffic egressConnectMeter = metrics.NewRegisteredMeter(MetricsOutboundConnects, nil) // Meter counting the egress connections egressTrafficMeter = metrics.NewRegisteredMeter(MetricsOutboundTraffic, nil) // Meter metering the cumulative egress traffic - activePeerCounter = metrics.NewRegisteredCounter("p2p/peers", nil) // Gauge tracking the current peer count + activePeerGauge = metrics.NewRegisteredGauge("p2p/peers", nil) // Gauge tracking the current peer count PeerIngressRegistry = metrics.NewPrefixedChildRegistry(metrics.EphemeralRegistry, MetricsInboundTraffic+"/") // Registry containing the peer ingress PeerEgressRegistry = metrics.NewPrefixedChildRegistry(metrics.EphemeralRegistry, MetricsOutboundTraffic+"/") // Registry containing the peer egress @@ -124,7 +124,7 @@ func newMeteredConn(conn net.Conn, ingress bool, ip net.IP) net.Conn { } else { egressConnectMeter.Mark(1) } - activePeerCounter.Inc(1) + activePeerGauge.Inc(1) return &meteredConn{ Conn: conn, @@ -200,7 +200,7 @@ func (c *meteredConn) Close() error { IP: c.ip, Elapsed: time.Since(c.connected), }) - activePeerCounter.Dec(1) + activePeerGauge.Dec(1) return err } id := c.id @@ -212,7 +212,7 @@ func (c *meteredConn) Close() error { IP: c.ip, ID: id, }) - activePeerCounter.Dec(1) + activePeerGauge.Dec(1) return err } ingress, egress := uint64(c.ingressMeter.Count()), uint64(c.egressMeter.Count()) @@ -233,6 +233,6 @@ func (c *meteredConn) Close() error { Ingress: ingress, Egress: egress, }) - activePeerCounter.Dec(1) + activePeerGauge.Dec(1) return err } diff --git a/p2p/peer.go b/p2p/peer.go index 372ba8d02..9a9788bc1 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/rlp" @@ -300,6 +301,9 @@ func (p *Peer) handle(msg Msg) error { if err != nil { return fmt.Errorf("msg code out of range: %v", msg.Code) } + if metrics.Enabled { + metrics.GetOrRegisterMeter(fmt.Sprintf("%s/%s/%d/%#02x", MetricsInboundTraffic, proto.Name, proto.Version, msg.Code-proto.offset), nil).Mark(int64(msg.meterSize)) + } select { case proto.in <- msg: return nil @@ -398,7 +402,11 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) { if msg.Code >= rw.Length { return newPeerError(errInvalidMsgCode, "not handled") } + msg.meterCap = rw.cap() + msg.meterCode = msg.Code + msg.Code += rw.offset + select { case <-rw.wstart: err = rw.w.WriteMsg(msg) diff --git a/p2p/protocol.go b/p2p/protocol.go index 9ce4c2020..fa23a087c 100644 --- a/p2p/protocol.go +++ b/p2p/protocol.go @@ -54,6 +54,11 @@ type Protocol struct { // but returns nil, it is assumed that the protocol handshake is still running. PeerInfo func(id enode.ID) interface{} + // DialCandidates, if non-nil, is a way to tell Server about protocol-specific nodes + // that should be dialed. The server continuously reads nodes from the iterator and + // attempts to create connections to them. + DialCandidates enode.Iterator + // Attributes contains protocol specific information for the node record. Attributes []enr.Entry } diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 52e1eb8a4..115021fa9 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -38,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/common/bitutil" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/golang/snappy" "golang.org/x/crypto/sha3" @@ -602,6 +603,10 @@ func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { msg.Payload = bytes.NewReader(payload) msg.Size = uint32(len(payload)) } + msg.meterSize = msg.Size + if metrics.Enabled && msg.meterCap.Name != "" { // don't meter non-subprotocol messages + metrics.GetOrRegisterMeter(fmt.Sprintf("%s/%s/%d/%#02x", MetricsOutboundTraffic, msg.meterCap.Name, msg.meterCap.Version, msg.meterCode), nil).Mark(int64(msg.meterSize)) + } // write header headbuf := make([]byte, 32) fsize := uint32(len(ptype)) + msg.Size @@ -686,6 +691,7 @@ func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) { return msg, err } msg.Size = uint32(content.Len()) + msg.meterSize = msg.Size msg.Payload = content // if snappy is enabled, verify and decompress message diff --git a/p2p/server.go b/p2p/server.go index 692c9eb7d..246148741 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -45,6 +45,11 @@ import ( const ( defaultDialTimeout = 15 * time.Second + // This is the fairness knob for the discovery mixer. When looking for peers, we'll + // wait this long for a single source of candidates before moving on and trying other + // sources. + discmixTimeout = 5 * time.Second + // Connectivity defaults. maxActiveDialTasks = 16 defaultMaxPendingPeers = 50 @@ -167,16 +172,20 @@ type Server struct { lock sync.Mutex // protects running running bool - nodedb *enode.DB - localnode *enode.LocalNode - ntab discoverTable listener net.Listener ourHandshake *protoHandshake - DiscV5 *discv5.Network loopWG sync.WaitGroup // loop, listenLoop peerFeed event.Feed log log.Logger + nodedb *enode.DB + localnode *enode.LocalNode + ntab *discover.UDPv4 + DiscV5 *discv5.Network + discmix *enode.FairMix + + staticNodeResolver nodeResolver + // Channels into the run loop. quit chan struct{} addstatic chan *enode.Node @@ -470,7 +479,7 @@ func (srv *Server) Start() (err error) { } dynPeers := srv.maxDialedConns() - dialer := newDialState(srv.localnode.ID(), srv.ntab, dynPeers, &srv.Config) + dialer := newDialState(srv.localnode.ID(), dynPeers, &srv.Config) srv.loopWG.Add(1) go srv.run(dialer) return nil @@ -521,6 +530,18 @@ func (srv *Server) setupLocalNode() error { } func (srv *Server) setupDiscovery() error { + srv.discmix = enode.NewFairMix(discmixTimeout) + + // Add protocol-specific discovery sources. + added := make(map[string]bool) + for _, proto := range srv.Protocols { + if proto.DialCandidates != nil && !added[proto.Name] { + srv.discmix.AddSource(proto.DialCandidates) + added[proto.Name] = true + } + } + + // Don't listen on UDP endpoint if DHT is disabled. if srv.NoDiscovery && !srv.DiscoveryV5 { return nil } @@ -562,7 +583,10 @@ func (srv *Server) setupDiscovery() error { return err } srv.ntab = ntab + srv.discmix.AddSource(ntab.RandomNodes()) + srv.staticNodeResolver = ntab } + // Discovery V5 if srv.DiscoveryV5 { var ntab *discv5.Network @@ -620,6 +644,7 @@ func (srv *Server) run(dialstate dialer) { srv.log.Info("Started P2P networking", "self", srv.localnode.Node().URLv4()) defer srv.loopWG.Done() defer srv.nodedb.Close() + defer srv.discmix.Close() var ( peers = make(map[enode.ID]*Peer) diff --git a/p2p/server_test.go b/p2p/server_test.go index e8bc627e1..383445c83 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -233,8 +233,8 @@ func TestServerTaskScheduling(t *testing.T) { Config: Config{MaxPeers: 10}, localnode: enode.NewLocalNode(db, newkey()), nodedb: db, + discmix: enode.NewFairMix(0), quit: make(chan struct{}), - ntab: fakeTable{}, running: true, log: log.New(), } @@ -282,9 +282,9 @@ func TestServerManyTasks(t *testing.T) { quit: make(chan struct{}), localnode: enode.NewLocalNode(db, newkey()), nodedb: db, - ntab: fakeTable{}, running: true, log: log.New(), + discmix: enode.NewFairMix(0), } done = make(chan *testTask) start, end = 0, 0 diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index f65ce7b60..850de96a1 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -101,6 +101,11 @@ type NodeConfig struct { // services registered by calling the RegisterService function) Services []string + // Properties are the names of the properties this node should hold + // within running services (e.g. "bootnode", "lightnode" or any custom values) + // These values need to be checked and acted upon by node Services + Properties []string + // Enode node *enode.Node @@ -120,6 +125,7 @@ type nodeConfigJSON struct { PrivateKey string `json:"private_key"` Name string `json:"name"` Services []string `json:"services"` + Properties []string `json:"properties"` EnableMsgEvents bool `json:"enable_msg_events"` Port uint16 `json:"port"` } @@ -131,6 +137,7 @@ func (n *NodeConfig) MarshalJSON() ([]byte, error) { ID: n.ID.String(), Name: n.Name, Services: n.Services, + Properties: n.Properties, Port: n.Port, EnableMsgEvents: n.EnableMsgEvents, } @@ -168,6 +175,7 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { n.Name = confJSON.Name n.Services = confJSON.Services + n.Properties = confJSON.Properties n.Port = confJSON.Port n.EnableMsgEvents = confJSON.EnableMsgEvents diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index ed43c0ed7..84f6ce2a5 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/rand" "net/http/httptest" + "os" "reflect" "sync" "sync/atomic" @@ -38,15 +39,13 @@ import ( "github.com/mattn/go-colorable" ) -var ( - loglevel = flag.Int("loglevel", 2, "verbosity of logs") -) +func TestMain(m *testing.M) { + loglevel := flag.Int("loglevel", 2, "verbosity of logs") -func init() { flag.Parse() - log.PrintOrigins(true) log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(*loglevel), log.StreamHandler(colorable.NewColorableStderr(), log.TerminalFormat(true)))) + os.Exit(m.Run()) } // testService implements the node.Service interface and provides protocols diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index f03c953e8..58fd9a28b 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -56,6 +56,9 @@ type Network struct { Nodes []*Node `json:"nodes"` nodeMap map[enode.ID]int + // Maps a node property string to node indexes of all nodes that hold this property + propertyMap map[string][]int + Conns []*Conn `json:"conns"` connMap map[string]int @@ -71,6 +74,7 @@ func NewNetwork(nodeAdapter adapters.NodeAdapter, conf *NetworkConfig) *Network NetworkConfig: *conf, nodeAdapter: nodeAdapter, nodeMap: make(map[enode.ID]int), + propertyMap: make(map[string][]int), connMap: make(map[string]int), quitc: make(chan struct{}), } @@ -120,9 +124,16 @@ func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) Config: conf, } log.Trace("Node created", "id", conf.ID) - net.nodeMap[conf.ID] = len(net.Nodes) + + nodeIndex := len(net.Nodes) + net.nodeMap[conf.ID] = nodeIndex net.Nodes = append(net.Nodes, node) + // Register any node properties with the network-level propertyMap + for _, property := range conf.Properties { + net.propertyMap[property] = append(net.propertyMap[property], nodeIndex) + } + // emit a "control" event net.events.Send(ControlEvent(node)) @@ -410,7 +421,7 @@ func (net *Network) getNode(id enode.ID) *Node { return net.Nodes[i] } -// GetNode gets the node with the given name, returning nil if the node does +// GetNodeByName gets the node with the given name, returning nil if the node does // not exist func (net *Network) GetNodeByName(name string) *Node { net.lock.RLock() @@ -427,19 +438,104 @@ func (net *Network) getNodeByName(name string) *Node { return nil } -// GetNodes returns the existing nodes -func (net *Network) GetNodes() (nodes []*Node) { +// GetNodeIDs returns the IDs of all existing nodes +// Nodes can optionally be excluded by specifying their enode.ID. +func (net *Network) GetNodeIDs(excludeIDs ...enode.ID) []enode.ID { net.lock.RLock() defer net.lock.RUnlock() - return net.getNodes() + return net.getNodeIDs(excludeIDs) } -func (net *Network) getNodes() (nodes []*Node) { - nodes = append(nodes, net.Nodes...) +func (net *Network) getNodeIDs(excludeIDs []enode.ID) []enode.ID { + // Get all curent nodeIDs + nodeIDs := make([]enode.ID, 0, len(net.nodeMap)) + for id := range net.nodeMap { + nodeIDs = append(nodeIDs, id) + } + + if len(excludeIDs) > 0 { + // Return the difference of nodeIDs and excludeIDs + return filterIDs(nodeIDs, excludeIDs) + } else { + return nodeIDs + } +} + +// GetNodes returns the existing nodes. +// Nodes can optionally be excluded by specifying their enode.ID. +func (net *Network) GetNodes(excludeIDs ...enode.ID) []*Node { + net.lock.RLock() + defer net.lock.RUnlock() + + return net.getNodes(excludeIDs) +} + +func (net *Network) getNodes(excludeIDs []enode.ID) []*Node { + if len(excludeIDs) > 0 { + nodeIDs := net.getNodeIDs(excludeIDs) + return net.getNodesByID(nodeIDs) + } else { + return net.Nodes + } +} + +// GetNodesByID returns existing nodes with the given enode.IDs. +// If a node doesn't exist with a given enode.ID, it is ignored. +func (net *Network) GetNodesByID(nodeIDs []enode.ID) []*Node { + net.lock.RLock() + defer net.lock.RUnlock() + + return net.getNodesByID(nodeIDs) +} + +func (net *Network) getNodesByID(nodeIDs []enode.ID) []*Node { + nodes := make([]*Node, 0, len(nodeIDs)) + for _, id := range nodeIDs { + node := net.getNode(id) + if node != nil { + nodes = append(nodes, node) + } + } + return nodes } +// GetNodesByProperty returns existing nodes that have the given property string registered in their NodeConfig +func (net *Network) GetNodesByProperty(property string) []*Node { + net.lock.RLock() + defer net.lock.RUnlock() + + return net.getNodesByProperty(property) +} + +func (net *Network) getNodesByProperty(property string) []*Node { + nodes := make([]*Node, 0, len(net.propertyMap[property])) + for _, nodeIndex := range net.propertyMap[property] { + nodes = append(nodes, net.Nodes[nodeIndex]) + } + + return nodes +} + +// GetNodeIDsByProperty returns existing node's enode IDs that have the given property string registered in the NodeConfig +func (net *Network) GetNodeIDsByProperty(property string) []enode.ID { + net.lock.RLock() + defer net.lock.RUnlock() + + return net.getNodeIDsByProperty(property) +} + +func (net *Network) getNodeIDsByProperty(property string) []enode.ID { + nodeIDs := make([]enode.ID, 0, len(net.propertyMap[property])) + for _, nodeIndex := range net.propertyMap[property] { + node := net.Nodes[nodeIndex] + nodeIDs = append(nodeIDs, node.ID()) + } + + return nodeIDs +} + // GetRandomUpNode returns a random node on the network, which is running. func (net *Network) GetRandomUpNode(excludeIDs ...enode.ID) *Node { net.lock.RLock() @@ -469,7 +565,7 @@ func (net *Network) GetRandomDownNode(excludeIDs ...enode.ID) *Node { } func (net *Network) getDownNodeIDs() (ids []enode.ID) { - for _, node := range net.getNodes() { + for _, node := range net.Nodes { if !node.Up() { ids = append(ids, node.ID()) } @@ -477,6 +573,13 @@ func (net *Network) getDownNodeIDs() (ids []enode.ID) { return ids } +// GetRandomNode returns a random node on the network, regardless of whether it is running or not +func (net *Network) GetRandomNode(excludeIDs ...enode.ID) *Node { + net.lock.RLock() + defer net.lock.RUnlock() + return net.getRandomNode(net.getNodeIDs(nil), excludeIDs) // no need to exclude twice +} + func (net *Network) getRandomNode(ids []enode.ID, excludeIDs []enode.ID) *Node { filtered := filterIDs(ids, excludeIDs) @@ -616,6 +719,7 @@ func (net *Network) Reset() { //re-initialize the maps net.connMap = make(map[string]int) net.nodeMap = make(map[enode.ID]int) + net.propertyMap = make(map[string][]int) net.Nodes = nil net.Conns = nil @@ -634,12 +738,14 @@ type Node struct { upMu sync.RWMutex } +// Up returns whether the node is currently up (online) func (n *Node) Up() bool { n.upMu.RLock() defer n.upMu.RUnlock() return n.up } +// SetUp sets the up (online) status of the nodes with the given value func (n *Node) SetUp(up bool) { n.upMu.Lock() defer n.upMu.Unlock() diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go index 01cd1000d..f504b9a69 100644 --- a/p2p/simulations/network_test.go +++ b/p2p/simulations/network_test.go @@ -17,6 +17,7 @@ package simulations import ( + "bytes" "context" "encoding/json" "fmt" @@ -393,6 +394,275 @@ func TestNetworkSimulation(t *testing.T) { } } +func createTestNodes(count int, network *Network) (nodes []*Node, err error) { + for i := 0; i < count; i++ { + nodeConf := adapters.RandomNodeConfig() + node, err := network.NewNodeWithConfig(nodeConf) + if err != nil { + return nil, err + } + if err := network.Start(node.ID()); err != nil { + return nil, err + } + + nodes = append(nodes, node) + } + + return nodes, nil +} + +func createTestNodesWithProperty(property string, count int, network *Network) (propertyNodes []*Node, err error) { + for i := 0; i < count; i++ { + nodeConf := adapters.RandomNodeConfig() + nodeConf.Properties = append(nodeConf.Properties, property) + + node, err := network.NewNodeWithConfig(nodeConf) + if err != nil { + return nil, err + } + if err := network.Start(node.ID()); err != nil { + return nil, err + } + + propertyNodes = append(propertyNodes, node) + } + + return propertyNodes, nil +} + +// TestGetNodeIDs creates a set of nodes and attempts to retrieve their IDs,. +// It then tests again whilst excluding a node ID from being returned. +// If a node ID is not returned, or more node IDs than expected are returned, the test fails. +func TestGetNodeIDs(t *testing.T) { + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + + numNodes := 5 + nodes, err := createTestNodes(numNodes, network) + if err != nil { + t.Fatalf("Could not creat test nodes %v", err) + } + + gotNodeIDs := network.GetNodeIDs() + if len(gotNodeIDs) != numNodes { + t.Fatalf("Expected %d nodes, got %d", numNodes, len(gotNodeIDs)) + } + + for _, node1 := range nodes { + match := false + for _, node2ID := range gotNodeIDs { + if bytes.Equal(node1.ID().Bytes(), node2ID.Bytes()) { + match = true + break + } + } + + if !match { + t.Fatalf("A created node was not returned by GetNodes(), ID: %s", node1.ID().String()) + } + } + + excludeNodeID := nodes[3].ID() + gotNodeIDsExcl := network.GetNodeIDs(excludeNodeID) + if len(gotNodeIDsExcl) != numNodes-1 { + t.Fatalf("Expected one less node ID to be returned") + } + for _, nodeID := range gotNodeIDsExcl { + if bytes.Equal(excludeNodeID.Bytes(), nodeID.Bytes()) { + t.Fatalf("GetNodeIDs returned the node ID we excluded, ID: %s", nodeID.String()) + } + } +} + +// TestGetNodes creates a set of nodes and attempts to retrieve them again. +// It then tests again whilst excluding a node from being returned. +// If a node is not returned, or more nodes than expected are returned, the test fails. +func TestGetNodes(t *testing.T) { + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + + numNodes := 5 + nodes, err := createTestNodes(numNodes, network) + if err != nil { + t.Fatalf("Could not creat test nodes %v", err) + } + + gotNodes := network.GetNodes() + if len(gotNodes) != numNodes { + t.Fatalf("Expected %d nodes, got %d", numNodes, len(gotNodes)) + } + + for _, node1 := range nodes { + match := false + for _, node2 := range gotNodes { + if bytes.Equal(node1.ID().Bytes(), node2.ID().Bytes()) { + match = true + break + } + } + + if !match { + t.Fatalf("A created node was not returned by GetNodes(), ID: %s", node1.ID().String()) + } + } + + excludeNodeID := nodes[3].ID() + gotNodesExcl := network.GetNodes(excludeNodeID) + if len(gotNodesExcl) != numNodes-1 { + t.Fatalf("Expected one less node to be returned") + } + for _, node := range gotNodesExcl { + if bytes.Equal(excludeNodeID.Bytes(), node.ID().Bytes()) { + t.Fatalf("GetNodes returned the node we excluded, ID: %s", node.ID().String()) + } + } +} + +// TestGetNodesByID creates a set of nodes and attempts to retrieve a subset of them by ID +// If a node is not returned, or more nodes than expected are returned, the test fails. +func TestGetNodesByID(t *testing.T) { + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + + numNodes := 5 + nodes, err := createTestNodes(numNodes, network) + if err != nil { + t.Fatalf("Could not create test nodes: %v", err) + } + + numSubsetNodes := 2 + subsetNodes := nodes[0:numSubsetNodes] + var subsetNodeIDs []enode.ID + for _, node := range subsetNodes { + subsetNodeIDs = append(subsetNodeIDs, node.ID()) + } + + gotNodesByID := network.GetNodesByID(subsetNodeIDs) + if len(gotNodesByID) != numSubsetNodes { + t.Fatalf("Expected %d nodes, got %d", numSubsetNodes, len(gotNodesByID)) + } + + for _, node1 := range subsetNodes { + match := false + for _, node2 := range gotNodesByID { + if bytes.Equal(node1.ID().Bytes(), node2.ID().Bytes()) { + match = true + break + } + } + + if !match { + t.Fatalf("A created node was not returned by GetNodesByID(), ID: %s", node1.ID().String()) + } + } +} + +// TestGetNodesByProperty creates a subset of nodes with a property assigned. +// GetNodesByProperty is then checked for correctness by comparing the nodes returned to those initially created. +// If a node with a property is not found, or more nodes than expected are returned, the test fails. +func TestGetNodesByProperty(t *testing.T) { + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + + numNodes := 3 + _, err := createTestNodes(numNodes, network) + if err != nil { + t.Fatalf("Failed to create nodes: %v", err) + } + + numPropertyNodes := 3 + propertyTest := "test" + propertyNodes, err := createTestNodesWithProperty(propertyTest, numPropertyNodes, network) + if err != nil { + t.Fatalf("Failed to create nodes with property: %v", err) + } + + gotNodesByProperty := network.GetNodesByProperty(propertyTest) + if len(gotNodesByProperty) != numPropertyNodes { + t.Fatalf("Expected %d nodes with a property, got %d", numPropertyNodes, len(gotNodesByProperty)) + } + + for _, node1 := range propertyNodes { + match := false + for _, node2 := range gotNodesByProperty { + if bytes.Equal(node1.ID().Bytes(), node2.ID().Bytes()) { + match = true + break + } + } + + if !match { + t.Fatalf("A created node with property was not returned by GetNodesByProperty(), ID: %s", node1.ID().String()) + } + } +} + +// TestGetNodeIDsByProperty creates a subset of nodes with a property assigned. +// GetNodeIDsByProperty is then checked for correctness by comparing the node IDs returned to those initially created. +// If a node ID with a property is not found, or more nodes IDs than expected are returned, the test fails. +func TestGetNodeIDsByProperty(t *testing.T) { + adapter := adapters.NewSimAdapter(adapters.Services{ + "test": newTestService, + }) + network := NewNetwork(adapter, &NetworkConfig{ + DefaultService: "test", + }) + defer network.Shutdown() + + numNodes := 3 + _, err := createTestNodes(numNodes, network) + if err != nil { + t.Fatalf("Failed to create nodes: %v", err) + } + + numPropertyNodes := 3 + propertyTest := "test" + propertyNodes, err := createTestNodesWithProperty(propertyTest, numPropertyNodes, network) + if err != nil { + t.Fatalf("Failed to created nodes with property: %v", err) + } + + gotNodeIDsByProperty := network.GetNodeIDsByProperty(propertyTest) + if len(gotNodeIDsByProperty) != numPropertyNodes { + t.Fatalf("Expected %d nodes with a property, got %d", numPropertyNodes, len(gotNodeIDsByProperty)) + } + + for _, node1 := range propertyNodes { + match := false + id1 := node1.ID() + for _, id2 := range gotNodeIDsByProperty { + if bytes.Equal(id1.Bytes(), id2.Bytes()) { + match = true + break + } + } + + if !match { + t.Fatalf("Not all nodes IDs were returned by GetNodeIDsByProperty(), ID: %s", id1.String()) + } + } +} + func triggerChecks(ctx context.Context, ids []enode.ID, trigger chan enode.ID, interval time.Duration) { tick := time.NewTicker(interval) defer tick.Stop() diff --git a/params/bootnodes.go b/params/bootnodes.go index 36f13d178..967cba5bc 100644 --- a/params/bootnodes.go +++ b/params/bootnodes.go @@ -29,13 +29,6 @@ var MainnetBootnodes = []string{ "enode://715171f50508aba88aecd1250af392a45a330af91d7b90701c436b618c86aaa1589c9184561907bebbb56439b8f8787bc01f49a7c77276c58c1b09822d75e8e8@52.231.165.108:30303", // bootnode-azure-koreasouth-001 "enode://5d6d7cd20d6da4bb83a1d28cadb5d409b64edf314c0335df658c1a54e32c7c4a7ab7823d57c39b6a757556e68ff1df17c748b698544a55cb488b52479a92b60f@104.42.217.25:30303", // bootnode-azure-westus-001 - // Ethereum Foundation Go Bootnodes (legacy) - "enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@52.16.188.185:30303", // IE - "enode://3f1d12044546b76342d59d4a05532c14b85aa669704bfe1f864fe079415aa2c02d743e03218e57a33fb94523adb54032871a6c51b2cc5514cb7c7e35b3ed0a99@13.93.211.84:30303", // US-WEST - "enode://78de8a0916848093c73790ead81d1928bec737d565119932b98c6b100d944b7a95e94f847f689fc723399d2e31129d182f7ef3863f2b4c820abbf3ab2722344d@191.235.84.50:30303", // BR - "enode://158f8aab45f6d19c6cbf4a089c2670541a8da11978a2f90dbf6a502a4a3bab80d288afdbeb7ec0ef6d92de563767f3b1ea9e8e334ca711e9f8e2df5a0385e8e6@13.75.154.138:30303", // AU - "enode://1118980bf48b0a3640bdba04e0fe78b1add18e1cd99bf22d53daac1fd9972ad650df52176e7c7d89d1114cfef2bc23a2959aa54998a46afcf7d91809f0855082@52.74.57.123:30303", // SG - // Ethereum Foundation C++ Bootnodes "enode://979b7fa28feeb35a4741660a16076f1943202cb72b6af70d327f053e248bab9ba81760f39d0701ef1d8f89cc1fbd2cacba0710a12cd5314d5e0c9021aa3637f9@5.1.83.226:30303", // DE } diff --git a/params/config.go b/params/config.go index 200add01b..c90de56dc 100644 --- a/params/config.go +++ b/params/config.go @@ -65,16 +65,16 @@ var ( ByzantiumBlock: big.NewInt(4370000), ConstantinopleBlock: big.NewInt(7280000), PetersburgBlock: big.NewInt(7280000), - IstanbulBlock: nil, + IstanbulBlock: big.NewInt(9069000), Ethash: new(EthashConfig), } // MainnetTrustedCheckpoint contains the light client trusted checkpoint for the main network. MainnetTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 253, - SectionHead: common.HexToHash("0xf35fabd036e2030196183bb70ae194f6ce1ea7b58559e3825c168f1df9c0a258"), - CHTRoot: common.HexToHash("0x8992849e2be3390696eaf66312626e484045501cd3ec207922c27a6a80a7bb07"), - BloomRoot: common.HexToHash("0xcc510b51ca4d73fb3fdf43208d73286f8f23817cdc31b8ea9f4de8d645f07df4"), + SectionIndex: 270, + SectionHead: common.HexToHash("0xb67c33d838a60c282c2fb49b188fbbac1ef8565ffb4a1c4909b0a05885e72e40"), + CHTRoot: common.HexToHash("0x781daa4607782300da85d440df3813ba38a1262585231e35e9480726de81dbfc"), + BloomRoot: common.HexToHash("0xfd8951fa6d779cbc981df40dc31056ed1a549db529349d7dfae016f9d96cae72"), } // MainnetCheckpointOracle contains a set of configs for the main network oracle. @@ -103,16 +103,16 @@ var ( ByzantiumBlock: big.NewInt(1700000), ConstantinopleBlock: big.NewInt(4230000), PetersburgBlock: big.NewInt(4939394), - IstanbulBlock: nil, + IstanbulBlock: big.NewInt(6485846), Ethash: new(EthashConfig), } // TestnetTrustedCheckpoint contains the light client trusted checkpoint for the Ropsten test network. TestnetTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 187, - SectionHead: common.HexToHash("0x7d6db64d8ec43303e4392fb726d2346f7231b246decca3d8140dd7e2c0d0b07d"), - CHTRoot: common.HexToHash("0xa5095e1a004a8642fb93ca682eb91e8f20ef5bce151e47404fbb68772d17705b"), - BloomRoot: common.HexToHash("0x90b28050f948ec6fb35b23a91d9aed38ce0c92d3cdd6e1d383c1bddf8b4071cf"), + SectionIndex: 204, + SectionHead: common.HexToHash("0xa39168b51c3205456f30ce6a91f3590a43295b15a1c8c2ab86bb8c06b8ad1808"), + CHTRoot: common.HexToHash("0x9a3654147b79882bfc4e16fbd3421512aa7e4dfadc6c511923980e0877bdf3b4"), + BloomRoot: common.HexToHash("0xe72b979522d94fa45c1331639316da234a9bb85062d64d72e13afe1d3f5c17d5"), } // TestnetCheckpointOracle contains a set of configs for the Ropsten test network oracle. @@ -141,7 +141,7 @@ var ( ByzantiumBlock: big.NewInt(1035301), ConstantinopleBlock: big.NewInt(3660663), PetersburgBlock: big.NewInt(4321234), - IstanbulBlock: nil, + IstanbulBlock: big.NewInt(5435345), Clique: &CliqueConfig{ Period: 15, Epoch: 30000, @@ -150,10 +150,10 @@ var ( // RinkebyTrustedCheckpoint contains the light client trusted checkpoint for the Rinkeby test network. RinkebyTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 148, - SectionHead: common.HexToHash("0x45918f4686732c2a3e80827e1bc39cdb6a27fa362ddfe1fdfb61c69a7f1df1a9"), - CHTRoot: common.HexToHash("0x8ac7046391fec14834a2a0183513937c0b5f696666545991477d24b067008961"), - BloomRoot: common.HexToHash("0xfe4b852517612d7da54bf7e9fc18861a83171a93c72583bb6a61893b74422168"), + SectionIndex: 163, + SectionHead: common.HexToHash("0x36e5deaa46f258bece94b05d8e10f1ef68f422fb62ed47a2b6e616aa26e84997"), + CHTRoot: common.HexToHash("0x829b9feca1c2cdf5a4cf3efac554889e438ee4df8718c2ce3e02555a02d9e9e5"), + BloomRoot: common.HexToHash("0x58c01de24fdae7c082ebbe7665f189d0aa4d90ee10e72086bf56651c63269e54"), } // RinkebyCheckpointOracle contains a set of configs for the Rinkeby test network oracle. @@ -180,7 +180,7 @@ var ( ByzantiumBlock: big.NewInt(0), ConstantinopleBlock: big.NewInt(0), PetersburgBlock: big.NewInt(0), - IstanbulBlock: nil, + IstanbulBlock: big.NewInt(1561651), Clique: &CliqueConfig{ Period: 15, Epoch: 30000, @@ -189,10 +189,10 @@ var ( // GoerliTrustedCheckpoint contains the light client trusted checkpoint for the Görli test network. GoerliTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 32, - SectionHead: common.HexToHash("0x50eaedd8361fa9edd0ac2dec410310b9bdf67b963b60f3b1dce47f84b30670f9"), - CHTRoot: common.HexToHash("0x6504db73139f75ffa9102ae980e41b361cf3d5b66cea06c79cde9f457368820c"), - BloomRoot: common.HexToHash("0x7551ae027bb776252a20ded51ee2ff0cbfbd1d8d57261b9161cc1f2f80237001"), + SectionIndex: 47, + SectionHead: common.HexToHash("0x00c5b54c6c9a73660501fd9273ccdb4c5bbdbe5d7b8b650e28f881ec9d2337f6"), + CHTRoot: common.HexToHash("0xef35caa155fd659f57167e7d507de2f8132cbb31f771526481211d8a977d704c"), + BloomRoot: common.HexToHash("0xbda330402f66008d52e7adc748da28535b1212a7912a21244acd2ba77ff0ff06"), } // GoerliCheckpointOracle contains a set of configs for the Goerli test network oracle. @@ -213,16 +213,16 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil} // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Clique consensus. // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}} + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil} + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, new(EthashConfig), nil} TestRules = TestChainConfig.Rules(new(big.Int)) ) @@ -415,6 +415,42 @@ func (c *ChainConfig) CheckCompatible(newcfg *ChainConfig, height uint64) *Confi return lasterr } +// CheckConfigForkOrder checks that we don't "skip" any forks, geth isn't pluggable enough +// to guarantee that forks +func (c *ChainConfig) CheckConfigForkOrder() error { + type fork struct { + name string + block *big.Int + } + var lastFork fork + for _, cur := range []fork{ + {"homesteadBlock", c.HomesteadBlock}, + {"eip150Block", c.EIP150Block}, + {"eip155Block", c.EIP155Block}, + {"eip158Block", c.EIP158Block}, + {"byzantiumBlock", c.ByzantiumBlock}, + {"constantinopleBlock", c.ConstantinopleBlock}, + {"petersburgBlock", c.PetersburgBlock}, + {"istanbulBlock", c.IstanbulBlock}, + } { + if lastFork.name != "" { + // Next one must be higher number + if lastFork.block == nil && cur.block != nil { + return fmt.Errorf("unsupported fork ordering: %v not enabled, but %v enabled at %v", + lastFork.name, cur.name, cur.block) + } + if lastFork.block != nil && cur.block != nil { + if lastFork.block.Cmp(cur.block) > 0 { + return fmt.Errorf("unsupported fork ordering: %v enabled at %v, but %v enabled at %v", + lastFork.name, lastFork.block, cur.name, cur.block) + } + } + } + lastFork = cur + } + return nil +} + func (c *ChainConfig) checkCompatible(newcfg *ChainConfig, head *big.Int) *ConfigCompatError { if isForkIncompatible(c.HomesteadBlock, newcfg.HomesteadBlock, head) { return newCompatError("Homestead fork block", c.HomesteadBlock, newcfg.HomesteadBlock) diff --git a/params/version.go b/params/version.go index 8d61e4213..4d6e315df 100644 --- a/params/version.go +++ b/params/version.go @@ -23,7 +23,7 @@ import ( const ( VersionMajor = 1 // Major version component of the current release VersionMinor = 9 // Minor version component of the current release - VersionPatch = 4 // Patch version component of the current release + VersionPatch = 7 // Patch version component of the current release VersionMeta = "unstable" // Version metadata to append to the version string ) diff --git a/rlp/decode.go b/rlp/decode.go index 4f29f2fb0..524395915 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -55,81 +55,23 @@ var ( } ) -// Decoder is implemented by types that require custom RLP -// decoding rules or need to decode into private fields. +// Decoder is implemented by types that require custom RLP decoding rules or need to decode +// into private fields. // -// The DecodeRLP method should read one value from the given -// Stream. It is not forbidden to read less or more, but it might -// be confusing. +// The DecodeRLP method should read one value from the given Stream. It is not forbidden to +// read less or more, but it might be confusing. type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result in the -// value pointed to by val. Val must be a non-nil pointer. If r does -// not implement ByteReader, Decode will do its own buffering. +// Decode parses RLP-encoded data from r and stores the result in the value pointed to by +// val. Please see package-level documentation for the decoding rules. Val must be a +// non-nil pointer. // -// Decode uses the following type-dependent decoding rules: +// If r does not implement ByteReader, Decode will do its own buffering. // -// If the type implements the Decoder interface, decode calls -// DecodeRLP. -// -// To decode into a pointer, Decode will decode into the value pointed -// to. If the pointer is nil, a new value of the pointer's element -// type is allocated. If the pointer is non-nil, the existing value -// will be reused. -// -// To decode into a struct, Decode expects the input to be an RLP -// list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. The input list -// must contain an element for each decoded field. Decode returns an -// error if there are too few or too many elements. -// -// The decoding of struct fields honours certain struct tags, "tail", -// "nil" and "-". -// -// The "-" tag ignores fields. -// -// For an explanation of "tail", see the example. -// -// The "nil" tag applies to pointer-typed fields and changes the decoding -// rules for the field such that input values of size zero decode as a nil -// pointer. This tag can be useful when decoding recursive types. -// -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } -// -// To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. For byte slices, -// the input must be an RLP string. Array types decode similarly, with -// the additional restriction that the number of input elements (or -// bytes) must match the array's length. -// -// To decode into a Go string, the input must be an RLP string. The -// input bytes are taken as-is and will not necessarily be valid UTF-8. -// -// To decode into an unsigned integer type, the input must also be an RLP -// string. The bytes are interpreted as a big endian representation of -// the integer. If the RLP string is larger than the bit size of the -// type, Decode will return an error. Decode also supports *big.Int. -// There is no size limit for big integers. -// -// To decode into a boolean, the input must contain an unsigned integer -// of value zero (false) or one (true). -// -// To decode into an interface value, Decode stores one of these -// in the value: -// -// []interface{}, for RLP lists -// []byte, for RLP strings -// -// Non-empty interface types are not supported, nor are signed integers, -// floating point numbers, maps, channels and functions. -// -// Note that Decode does not set an input limit for all readers -// and may be vulnerable to panics cause by huge value sizes. If -// you need an input limit, use +// Note that Decode does not set an input limit for all readers and may be vulnerable to +// panics cause by huge value sizes. If you need an input limit, use // // NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { @@ -140,9 +82,8 @@ func Decode(r io.Reader, val interface{}) error { return stream.Decode(val) } -// DecodeBytes parses RLP data from b into val. -// Please see the documentation of Decode for the decoding rules. -// The input must contain exactly one value and no trailing data. +// DecodeBytes parses RLP data from b into val. Please see package-level documentation for +// the decoding rules. The input must contain exactly one value and no trailing data. func DecodeBytes(b []byte, val interface{}) error { r := bytes.NewReader(b) @@ -211,14 +152,15 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { switch { case typ == rawValueType: return decodeRawValue, nil - case typ.Implements(decoderInterface): return decodeDecoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): - return decodeDecoderNoPtr, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case kind == reflect.Ptr: + return makePtrDecoder(typ, tags) + case reflect.PtrTo(typ).Implements(decoderInterface): + return decodeDecoder, nil case isUint(kind): return decodeUint, nil case kind == reflect.Bool: @@ -229,11 +171,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return makeListDecoder(typ, tags) case kind == reflect.Struct: return makeStructDecoder(typ) - case kind == reflect.Ptr: - if tags.nilOK { - return makeOptionalPtrDecoder(typ) - } - return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil default: @@ -448,6 +385,11 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.decoderErr != nil { + return nil, structFieldError{typ, f.index, f.info.decoderErr} + } + } dec := func(s *Stream, val reflect.Value) (err error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) @@ -465,15 +407,22 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } -// makePtrDecoder creates a decoder that decodes into -// the pointer's element type. -func makePtrDecoder(typ reflect.Type) (decoder, error) { +// makePtrDecoder creates a decoder that decodes into the pointer's element type. +func makePtrDecoder(typ reflect.Type, tag tags) (decoder, error) { etype := typ.Elem() etypeinfo := cachedTypeInfo1(etype, tags{}) - if etypeinfo.decoderErr != nil { + switch { + case etypeinfo.decoderErr != nil: return nil, etypeinfo.decoderErr + case !tag.nilOK: + return makeSimplePtrDecoder(etype, etypeinfo), nil + default: + return makeNilPtrDecoder(etype, etypeinfo, tag.nilKind), nil } - dec := func(s *Stream, val reflect.Value) (err error) { +} + +func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder { + return func(s *Stream, val reflect.Value) (err error) { newval := val if val.IsNil() { newval = reflect.New(etype) @@ -483,30 +432,35 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } -// makeOptionalPtrDecoder creates a decoder that decodes empty values -// as nil. Non-empty values are decoded into a value of the element type, -// just like makePtrDecoder does. +// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty +// values are decoded into a value of the element type, just like makePtrDecoder does. // // This decoder is used for pointer-typed struct fields with struct tag "nil". -func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo := cachedTypeInfo1(etype, tags{}) - if etypeinfo.decoderErr != nil { - return nil, etypeinfo.decoderErr - } - dec := func(s *Stream, val reflect.Value) (err error) { +func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, nilKind Kind) decoder { + typ := reflect.PtrTo(etype) + nilPtr := reflect.Zero(typ) + return func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() - if err != nil || size == 0 && kind != Byte { + if err != nil { + val.Set(nilPtr) + return wrapStreamError(err, typ) + } + // Handle empty values as a nil pointer. + if kind != Byte && size == 0 { + if kind != nilKind { + return &decodeError{ + msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind), + typ: typ, + } + } // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. s.kind = -1 - // set the pointer to nil. - val.Set(reflect.Zero(typ)) - return err + val.Set(nilPtr) + return nil } newval := val if val.IsNil() { @@ -517,7 +471,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } var ifsliceType = reflect.TypeOf([]interface{}{}) @@ -546,21 +499,8 @@ func decodeInterface(s *Stream, val reflect.Value) error { return nil } -// This decoder is used for non-pointer values of types -// that implement the Decoder interface using a pointer receiver. -func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { - return val.Addr().Interface().(Decoder).DecodeRLP(s) -} - func decodeDecoder(s *Stream, val reflect.Value) error { - // Decoder instances are not handled using the pointer rule if the type - // implements Decoder with pointer receiver (i.e. always) - // because it might handle empty values specially. - // We need to allocate one here in this case, like makePtrDecoder does. - if val.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(Decoder).DecodeRLP(s) + return val.Addr().Interface().(Decoder).DecodeRLP(s) } // Kind represents the kind of value contained in an RLP stream. diff --git a/rlp/decode_test.go b/rlp/decode_test.go index fa57182c9..634d1cf3b 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -327,6 +327,10 @@ type recstruct struct { Child *recstruct `rlp:"nil"` } +type invalidNilTag struct { + X []byte `rlp:"nil"` +} + type invalidTail1 struct { A uint `rlp:"tail"` B string @@ -353,6 +357,18 @@ type tailPrivateFields struct { x, y bool } +type nilListUint struct { + X *uint `rlp:"nilList"` +} + +type nilStringSlice struct { + X *[]uint `rlp:"nilString"` +} + +type intField struct { + X int +} + var ( veryBigInt = big.NewInt(0).Add( big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), @@ -485,20 +501,20 @@ var decodeTests = []decodeTest{ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I", }, { - input: "C0", - ptr: new(invalidTail1), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)", - }, - { - input: "C0", - ptr: new(invalidTail2), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)", + input: "C103", + ptr: new(intField), + error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)", }, { input: "C50102C20102", ptr: new(tailUint), error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]", }, + { + input: "C0", + ptr: new(invalidNilTag), + error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`, + }, // struct tag "tail" { @@ -521,6 +537,16 @@ var decodeTests = []decodeTest{ ptr: new(tailPrivateFields), value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, }, + { + input: "C0", + ptr: new(invalidTail1), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`, + }, + { + input: "C0", + ptr: new(invalidTail2), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`, + }, // struct tag "-" { @@ -529,6 +555,43 @@ var decodeTests = []decodeTest{ value: hasIgnoredField{A: 1, C: 2}, }, + // struct tag "nilList" + { + input: "C180", + ptr: new(nilListUint), + error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X", + }, + { + input: "C1C0", + ptr: new(nilListUint), + value: nilListUint{}, + }, + { + input: "C103", + ptr: new(nilListUint), + value: func() interface{} { + v := uint(3) + return nilListUint{X: &v} + }(), + }, + + // struct tag "nilString" + { + input: "C1C0", + ptr: new(nilStringSlice), + error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X", + }, + { + input: "C180", + ptr: new(nilStringSlice), + value: nilStringSlice{}, + }, + { + input: "C2C103", + ptr: new(nilStringSlice), + value: nilStringSlice{X: &[]uint{3}}, + }, + // RawValue {input: "01", ptr: new(RawValue), value: RawValue(unhex("01"))}, {input: "82FFFF", ptr: new(RawValue), value: RawValue(unhex("82FFFF"))}, @@ -672,6 +735,22 @@ func TestDecodeDecoder(t *testing.T) { } } +func TestDecodeDecoderNilPointer(t *testing.T) { + var s struct { + T1 *testDecoder `rlp:"nil"` + T2 *testDecoder + } + if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + if s.T1 != nil { + t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called) + } + if s.T2 == nil || !s.T2.called { + t.Errorf("decoder T2 not allocated/called") + } +} + type byteDecoder byte func (bd *byteDecoder) DecodeRLP(s *Stream) error { diff --git a/rlp/doc.go b/rlp/doc.go index b3a81fe23..7e6ee8520 100644 --- a/rlp/doc.go +++ b/rlp/doc.go @@ -17,17 +17,114 @@ /* Package rlp implements the RLP serialization format. -The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily -nested arrays of binary data, and RLP is the main encoding method used -to serialize objects in Ethereum. The only purpose of RLP is to encode -structure; encoding specific atomic data types (eg. strings, ints, -floats) is left up to higher-order protocols; in Ethereum integers -must be represented in big endian binary form with no leading zeroes -(thus making the integer value zero equivalent to the empty byte -array). +The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of +binary data, and RLP is the main encoding method used to serialize objects in Ethereum. +The only purpose of RLP is to encode structure; encoding specific atomic data types (eg. +strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be +represented in big endian binary form with no leading zeroes (thus making the integer +value zero equivalent to the empty string). -RLP values are distinguished by a type tag. The type tag precedes the -value in the input stream and defines the size and kind of the bytes -that follow. +RLP values are distinguished by a type tag. The type tag precedes the value in the input +stream and defines the size and kind of the bytes that follow. + + +Encoding Rules + +Package rlp uses reflection and encodes RLP based on the Go type of the value. + +If the type implements the Encoder interface, Encode calls EncodeRLP. It does not +call EncodeRLP on nil pointer values. + +To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct +type, slice or array always encodes as an empty RLP list unless the slice or array has +elememt type byte. A nil pointer to any other value encodes as the empty string. + +Struct values are encoded as an RLP list of all their encoded public fields. Recursive +struct types are supported. + +To encode slices and arrays, the elements are encoded as an RLP list of the value's +elements. Note that arrays and slices with element type uint8 or byte are always encoded +as an RLP string. + +A Go string is encoded as an RLP string. + +An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP +string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...) +are not supported and will return an error when encoding. + +Boolean values are encoded as the unsigned integers zero (false) and one (true). + +An interface value encodes as the value contained in the interface. + +Floating point numbers, maps, channels and functions are not supported. + + +Decoding Rules + +Decoding uses the following type-dependent rules: + +If the type implements the Decoder interface, DecodeRLP is called. + +To decode into a pointer, the value will be decoded as the element type of the pointer. If +the pointer is nil, a new value of the pointer's element type is allocated. If the pointer +is non-nil, the existing value will be reused. Note that package rlp never leaves a +pointer-type struct field as nil unless one of the "nil" struct tags is present. + +To decode into a struct, decoding expects the input to be an RLP list. The decoded +elements of the list are assigned to each public field in the order given by the struct's +definition. The input list must contain an element for each decoded field. Decoding +returns an error if there are too few or too many elements for the struct. + +To decode into a slice, the input must be a list and the resulting slice will contain the +input elements in order. For byte slices, the input must be an RLP string. Array types +decode similarly, with the additional restriction that the number of input elements (or +bytes) must match the array's defined length. + +To decode into a Go string, the input must be an RLP string. The input bytes are taken +as-is and will not necessarily be valid UTF-8. + +To decode into an unsigned integer type, the input must also be an RLP string. The bytes +are interpreted as a big endian representation of the integer. If the RLP string is larger +than the bit size of the type, decoding will return an error. Decode also supports +*big.Int. There is no size limit for big integers. + +To decode into a boolean, the input must contain an unsigned integer of value zero (false) +or one (true). + +To decode into an interface value, one of these types is stored in the value: + + []interface{}, for RLP lists + []byte, for RLP strings + +Non-empty interface types are not supported when decoding. +Signed integers, floating point numbers, maps, channels and functions cannot be decoded into. + + +Struct Tags + +Package rlp honours certain struct tags: "-", "tail", "nil", "nilList" and "nilString". + +The "-" tag ignores fields. + +The "tail" tag, which may only be used on the last exported struct field, allows slurping +up any excess list elements into a slice. See examples for more details. + +The "nil" tag applies to pointer-typed fields and changes the decoding rules for the field +such that input values of size zero decode as a nil pointer. This tag can be useful when +decoding recursive types. + + type StructWithOptionalFoo struct { + Foo *[20]byte `rlp:"nil"` + } + +RLP supports two kinds of empty values: empty lists and empty strings. When using the +"nil" tag, the kind of empty value allowed for a type is chosen automatically. A struct +field whose Go type is a pointer to an unsigned integer, string, boolean or byte +array/slice expects an empty RLP string. Any other pointer field type encodes/decodes as +an empty RLP list. + +The choice of null value can be made explicit with the "nilList" and "nilString" struct +tags. Using these tags encodes/decodes a Go nil pointer value as the kind of empty +RLP value defined by the tag. */ package rlp diff --git a/rlp/encode.go b/rlp/encode.go index f255c38a9..9c9e8d706 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -49,36 +49,7 @@ type Encoder interface { // perform many small writes in some cases. Consider making w // buffered. // -// Encode uses the following type-dependent encoding rules: -// -// If the type implements the Encoder interface, Encode calls -// EncodeRLP. This is true even for nil pointers, please see the -// documentation for Encoder. -// -// To encode a pointer, the value being pointed to is encoded. For nil -// pointers, Encode will encode the zero value of the type. A nil -// pointer to a struct type always encodes as an empty RLP list. -// A nil pointer to an array encodes as an empty list (or empty string -// if the array has element type byte). -// -// Struct values are encoded as an RLP list of all their encoded -// public fields. Recursive struct types are supported. -// -// To encode slices and arrays, the elements are encoded as an RLP -// list of the value's elements. Note that arrays and slices with -// element type uint8 or byte are always encoded as an RLP string. -// -// A Go string is encoded as an RLP string. -// -// An unsigned integer value is encoded as an RLP string. Zero always -// encodes as an empty RLP string. Encode also supports *big.Int. -// -// Boolean values are encoded as unsigned integers zero (false) and one (true). -// -// An interface value encodes as the value contained in the interface. -// -// Signed integers are not supported, nor are floating point numbers, maps, -// channels and functions. +// Please see package-level documentation of encoding rules. func Encode(w io.Writer, val interface{}) error { if outer, ok := w.(*encbuf); ok { // Encode was called by some type's EncodeRLP. @@ -95,7 +66,7 @@ func Encode(w io.Writer, val interface{}) error { } // EncodeToBytes returns the RLP encoding of val. -// Please see the documentation of Encode for the encoding rules. +// Please see package-level documentation for the encoding rules. func EncodeToBytes(val interface{}) ([]byte, error) { eb := encbufPool.Get().(*encbuf) defer encbufPool.Put(eb) @@ -349,16 +320,14 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { switch { case typ == rawValueType: return writeRawValue, nil - case typ.Implements(encoderInterface): - return writeEncoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface): - return writeEncoderNoPtr, nil - case kind == reflect.Interface: - return writeInterface, nil case typ.AssignableTo(reflect.PtrTo(bigInt)): return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case kind == reflect.Ptr: + return makePtrWriter(typ, ts) + case reflect.PtrTo(typ).Implements(encoderInterface): + return makeEncoderWriter(typ), nil case isUint(kind): return writeUint, nil case kind == reflect.Bool: @@ -373,8 +342,8 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { return makeSliceWriter(typ, ts) case kind == reflect.Struct: return makeStructWriter(typ) - case kind == reflect.Ptr: - return makePtrWriter(typ) + case kind == reflect.Interface: + return writeInterface, nil default: return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } @@ -470,26 +439,6 @@ func writeString(val reflect.Value, w *encbuf) error { return nil } -func writeEncoder(val reflect.Value, w *encbuf) error { - return val.Interface().(Encoder).EncodeRLP(w) -} - -// writeEncoderNoPtr handles non-pointer values that implement Encoder -// with a pointer receiver. -func writeEncoderNoPtr(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // We can't get the address. It would be possible to make the - // value addressable by creating a shallow copy, but this - // creates other problems so we're not doing it (yet). - // - // package json simply doesn't call MarshalJSON for cases like - // this, but encodes the value as if it didn't implement the - // interface. We don't want to handle it that way. - return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) - } - return val.Addr().Interface().(Encoder).EncodeRLP(w) -} - func writeInterface(val reflect.Value, w *encbuf) error { if val.IsNil() { // Write empty list. This is consistent with the previous RLP @@ -531,6 +480,11 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.writerErr != nil { + return nil, structFieldError{typ, f.index, f.info.writerErr} + } + } writer := func(val reflect.Value, w *encbuf) error { lh := w.list() for _, f := range fields { @@ -544,44 +498,51 @@ func makeStructWriter(typ reflect.Type) (writer, error) { return writer, nil } -func makePtrWriter(typ reflect.Type) (writer, error) { +func makePtrWriter(typ reflect.Type, ts tags) (writer, error) { etypeinfo := cachedTypeInfo1(typ.Elem(), tags{}) if etypeinfo.writerErr != nil { return nil, etypeinfo.writerErr } - - // determine nil pointer handler - var nilfunc func(*encbuf) error - kind := typ.Elem().Kind() - switch { - case kind == reflect.Array && isByte(typ.Elem().Elem()): - nilfunc = func(w *encbuf) error { - w.str = append(w.str, 0x80) - return nil - } - case kind == reflect.Struct || kind == reflect.Array: - nilfunc = func(w *encbuf) error { - // encoding the zero value of a struct/array could trigger - // infinite recursion, avoid that. - w.listEnd(w.list()) - return nil - } - default: - zero := reflect.Zero(typ.Elem()) - nilfunc = func(w *encbuf) error { - return etypeinfo.writer(zero, w) - } + // Determine how to encode nil pointers. + var nilKind Kind + if ts.nilOK { + nilKind = ts.nilKind // use struct tag if provided + } else { + nilKind = defaultNilKind(typ.Elem()) } writer := func(val reflect.Value, w *encbuf) error { if val.IsNil() { - return nilfunc(w) + if nilKind == String { + w.str = append(w.str, 0x80) + } else { + w.listEnd(w.list()) + } + return nil } return etypeinfo.writer(val.Elem(), w) } return writer, nil } +func makeEncoderWriter(typ reflect.Type) writer { + if typ.Implements(encoderInterface) { + return func(val reflect.Value, w *encbuf) error { + return val.Interface().(Encoder).EncodeRLP(w) + } + } + w := func(val reflect.Value, w *encbuf) error { + if !val.CanAddr() { + // package json simply doesn't call MarshalJSON for this case, but encodes the + // value as if it didn't implement the interface. We don't want to handle it that + // way. + return fmt.Errorf("rlp: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) + } + return val.Addr().Interface().(Encoder).EncodeRLP(w) + } + return w +} + // putint writes i to the beginning of b in big endian byte // order, using the least number of bytes needed to represent i. func putint(b []byte, i uint64) (size int) { diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 6e49b89a8..b4b9e5128 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -33,8 +33,9 @@ type testEncoder struct { func (e *testEncoder) EncodeRLP(w io.Writer) error { if e == nil { - w.Write([]byte{0, 0, 0, 0}) - } else if e.err != nil { + panic("EncodeRLP called on nil value") + } + if e.err != nil { return e.err } else { w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) @@ -42,6 +43,13 @@ func (e *testEncoder) EncodeRLP(w io.Writer) error { return nil } +type testEncoderValueMethod struct{} + +func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xFA, 0xFE, 0xF0}) + return nil +} + type byteEncoder byte func (e byteEncoder) EncodeRLP(w io.Writer) error { @@ -52,8 +60,8 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { type undecodableEncoder func() func (f undecodableEncoder) EncodeRLP(w io.Writer) error { - _, err := w.Write(EmptyList) - return err + w.Write([]byte{0xF5, 0xF5, 0xF5}) + return nil } type encodableReader struct { @@ -226,6 +234,7 @@ var encTests = []encTest{ {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"}, // nil {val: (*uint)(nil), output: "80"}, @@ -239,22 +248,66 @@ var encTests = []encTest{ {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // nil struct fields + { + val: struct { + X *[]byte + }{}, + output: "C180", + }, + { + val: struct { + X *[2]byte + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 `rlp:"nilList"` + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 `rlp:"nilString"` + }{}, + output: "C180", + }, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct // Encoder - {val: (*testEncoder)(nil), output: "00000000"}, + {val: (*testEncoder)(nil), output: "C0"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, - // verify that the Encoder interface works for unsupported types like func(). - {val: undecodableEncoder(func() {}), output: "C0"}, - // verify that pointer method testEncoder.EncodeRLP is called for + {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"}, + {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"}, + + // Verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "F5F5F5"}, + + // Verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"}, - // verify the error for non-addressable non-pointer Encoder - {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, - // verify the special case for []byte + + // Verify the error for non-addressable non-pointer Encoder. + {val: testEncoder{}, error: "rlp: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, + + // Verify Encoder takes precedence over []byte. {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"}, } diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go index 1cffa241c..42c1c5c89 100644 --- a/rlp/encoder_example_test.go +++ b/rlp/encoder_example_test.go @@ -28,15 +28,7 @@ type MyCoolType struct { // EncodeRLP writes x as RLP list [a, b] that omits the Name field. func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) { - // Note: the receiver can be a nil pointer. This allows you to - // control the encoding of nil, but it also means that you have to - // check for a nil receiver. - if x == nil { - err = Encode(w, []uint{0, 0}) - } else { - err = Encode(w, []uint{x.a, x.b}) - } - return err + return Encode(w, []uint{x.a, x.b}) } func ExampleEncoder() { @@ -49,6 +41,6 @@ func ExampleEncoder() { fmt.Printf("%v → %X\n", t, bytes) // Output: - // → C28080 + // → C0 // &{foobar 5 6} → C20506 } diff --git a/rlp/typecache.go b/rlp/typecache.go index ab5ee3da7..e9a1e3f9e 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -35,22 +35,28 @@ type typeinfo struct { writerErr error // error from makeWriter } -// represents struct tags +// tags represents struct tags. type tags struct { // rlp:"nil" controls whether empty input results in a nil pointer. nilOK bool + + // This controls whether nil pointers are encoded/decoded as empty strings + // or empty lists. + nilKind Kind + // rlp:"tail" controls whether this field swallows additional list // elements. It can only be set for the last field, which must be // of slice type. tail bool + // rlp:"-" ignores fields. ignored bool } +// typekey is the key of a type in typeCache. It includes the struct tags because +// they might generate a different decoder. type typekey struct { reflect.Type - // the key must include the struct tags because they - // might generate a different decoder. tags } @@ -120,6 +126,25 @@ func structFields(typ reflect.Type) (fields []field, err error) { return fields, nil } +type structFieldError struct { + typ reflect.Type + field int + err error +} + +func (e structFieldError) Error() string { + return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name) +} + +type structTagError struct { + typ reflect.Type + field, tag, err string +} + +func (e structTagError) Error() string { + return fmt.Sprintf("rlp: invalid struct tag %q for %v.%s (%s)", e.tag, e.typ, e.field, e.err) +} + func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) { f := typ.Field(fi) var ts tags @@ -128,15 +153,26 @@ func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) { case "": case "-": ts.ignored = true - case "nil": + case "nil", "nilString", "nilList": ts.nilOK = true + if f.Type.Kind() != reflect.Ptr { + return ts, structTagError{typ, f.Name, t, "field is not a pointer"} + } + switch t { + case "nil": + ts.nilKind = defaultNilKind(f.Type.Elem()) + case "nilString": + ts.nilKind = String + case "nilList": + ts.nilKind = List + } case "tail": ts.tail = true if fi != lastPublic { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) + return ts, structTagError{typ, f.Name, t, "must be on last field"} } if f.Type.Kind() != reflect.Slice { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) + return ts, structTagError{typ, f.Name, t, "field type is not slice"} } default: return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name) @@ -160,6 +196,20 @@ func (i *typeinfo) generate(typ reflect.Type, tags tags) { i.writer, i.writerErr = makeWriter(typ, tags) } +// defaultNilKind determines whether a nil pointer to typ encodes/decodes +// as an empty string or empty list. +func defaultNilKind(typ reflect.Type) Kind { + k := typ.Kind() + if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(typ) { + return String + } + return List +} + func isUint(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uintptr } + +func isByteArray(typ reflect.Type) bool { + return (typ.Kind() == reflect.Slice || typ.Kind() == reflect.Array) && isByte(typ.Elem()) +} diff --git a/rpc/types.go b/rpc/types.go index f31f09a77..e6b9f2a30 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -18,10 +18,12 @@ package rpc import ( "context" + "encoding/json" "fmt" "math" "strings" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" ) @@ -105,3 +107,94 @@ func (bn *BlockNumber) UnmarshalJSON(data []byte) error { func (bn BlockNumber) Int64() int64 { return (int64)(bn) } + +type BlockNumberOrHash struct { + BlockNumber *BlockNumber `json:"blockNumber,omitempty"` + BlockHash *common.Hash `json:"blockHash,omitempty"` + RequireCanonical bool `json:"requireCanonical,omitempty"` +} + +func (bnh *BlockNumberOrHash) UnmarshalJSON(data []byte) error { + type erased BlockNumberOrHash + e := erased{} + err := json.Unmarshal(data, &e) + if err == nil { + if e.BlockNumber != nil && e.BlockHash != nil { + return fmt.Errorf("cannot specify both BlockHash and BlockNumber, choose one or the other") + } + bnh.BlockNumber = e.BlockNumber + bnh.BlockHash = e.BlockHash + bnh.RequireCanonical = e.RequireCanonical + return nil + } + var input string + err = json.Unmarshal(data, &input) + if err != nil { + return err + } + switch input { + case "earliest": + bn := EarliestBlockNumber + bnh.BlockNumber = &bn + return nil + case "latest": + bn := LatestBlockNumber + bnh.BlockNumber = &bn + return nil + case "pending": + bn := PendingBlockNumber + bnh.BlockNumber = &bn + return nil + default: + if len(input) == 66 { + hash := common.Hash{} + err := hash.UnmarshalText([]byte(input)) + if err != nil { + return err + } + bnh.BlockHash = &hash + return nil + } else { + blckNum, err := hexutil.DecodeUint64(input) + if err != nil { + return err + } + if blckNum > math.MaxInt64 { + return fmt.Errorf("blocknumber too high") + } + bn := BlockNumber(blckNum) + bnh.BlockNumber = &bn + return nil + } + } +} + +func (bnh *BlockNumberOrHash) Number() (BlockNumber, bool) { + if bnh.BlockNumber != nil { + return *bnh.BlockNumber, true + } + return BlockNumber(0), false +} + +func (bnh *BlockNumberOrHash) Hash() (common.Hash, bool) { + if bnh.BlockHash != nil { + return *bnh.BlockHash, true + } + return common.Hash{}, false +} + +func BlockNumberOrHashWithNumber(blockNr BlockNumber) BlockNumberOrHash { + return BlockNumberOrHash{ + BlockNumber: &blockNr, + BlockHash: nil, + RequireCanonical: false, + } +} + +func BlockNumberOrHashWithHash(hash common.Hash, canonical bool) BlockNumberOrHash { + return BlockNumberOrHash{ + BlockNumber: nil, + BlockHash: &hash, + RequireCanonical: canonical, + } +} diff --git a/rpc/types_test.go b/rpc/types_test.go index 68b6d3c54..89b0c9171 100644 --- a/rpc/types_test.go +++ b/rpc/types_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" ) @@ -64,3 +65,60 @@ func TestBlockNumberJSONUnmarshal(t *testing.T) { } } } + +func TestBlockNumberOrHash_UnmarshalJSON(t *testing.T) { + tests := []struct { + input string + mustFail bool + expected BlockNumberOrHash + }{ + 0: {`"0x"`, true, BlockNumberOrHash{}}, + 1: {`"0x0"`, false, BlockNumberOrHashWithNumber(0)}, + 2: {`"0X1"`, false, BlockNumberOrHashWithNumber(1)}, + 3: {`"0x00"`, true, BlockNumberOrHash{}}, + 4: {`"0x01"`, true, BlockNumberOrHash{}}, + 5: {`"0x1"`, false, BlockNumberOrHashWithNumber(1)}, + 6: {`"0x12"`, false, BlockNumberOrHashWithNumber(18)}, + 7: {`"0x7fffffffffffffff"`, false, BlockNumberOrHashWithNumber(math.MaxInt64)}, + 8: {`"0x8000000000000000"`, true, BlockNumberOrHash{}}, + 9: {"0", true, BlockNumberOrHash{}}, + 10: {`"ff"`, true, BlockNumberOrHash{}}, + 11: {`"pending"`, false, BlockNumberOrHashWithNumber(PendingBlockNumber)}, + 12: {`"latest"`, false, BlockNumberOrHashWithNumber(LatestBlockNumber)}, + 13: {`"earliest"`, false, BlockNumberOrHashWithNumber(EarliestBlockNumber)}, + 14: {`someString`, true, BlockNumberOrHash{}}, + 15: {`""`, true, BlockNumberOrHash{}}, + 16: {``, true, BlockNumberOrHash{}}, + 17: {`"0x0000000000000000000000000000000000000000000000000000000000000000"`, false, BlockNumberOrHashWithHash(common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000"), false)}, + 18: {`{"blockHash":"0x0000000000000000000000000000000000000000000000000000000000000000"}`, false, BlockNumberOrHashWithHash(common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000"), false)}, + 19: {`{"blockHash":"0x0000000000000000000000000000000000000000000000000000000000000000","requireCanonical":false}`, false, BlockNumberOrHashWithHash(common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000"), false)}, + 20: {`{"blockHash":"0x0000000000000000000000000000000000000000000000000000000000000000","requireCanonical":true}`, false, BlockNumberOrHashWithHash(common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000"), true)}, + 21: {`{"blockNumber":"0x1"}`, false, BlockNumberOrHashWithNumber(1)}, + 22: {`{"blockNumber":"pending"}`, false, BlockNumberOrHashWithNumber(PendingBlockNumber)}, + 23: {`{"blockNumber":"latest"}`, false, BlockNumberOrHashWithNumber(LatestBlockNumber)}, + 24: {`{"blockNumber":"earliest"}`, false, BlockNumberOrHashWithNumber(EarliestBlockNumber)}, + 25: {`{"blockNumber":"0x1", "blockHash":"0x0000000000000000000000000000000000000000000000000000000000000000"}`, true, BlockNumberOrHash{}}, + } + + for i, test := range tests { + var bnh BlockNumberOrHash + err := json.Unmarshal([]byte(test.input), &bnh) + if test.mustFail && err == nil { + t.Errorf("Test %d should fail", i) + continue + } + if !test.mustFail && err != nil { + t.Errorf("Test %d should pass but got err: %v", i, err) + continue + } + hash, hashOk := bnh.Hash() + expectedHash, expectedHashOk := test.expected.Hash() + num, numOk := bnh.Number() + expectedNum, expectedNumOk := test.expected.Number() + if bnh.RequireCanonical != test.expected.RequireCanonical || + hash != expectedHash || hashOk != expectedHashOk || + num != expectedNum || numOk != expectedNumOk { + t.Errorf("Test %d got unexpected value, want %v, got %v", i, test.expected, bnh) + } + } +} diff --git a/statediff/api.go b/statediff/api.go index 52c604f97..06dab7ec7 100644 --- a/statediff/api.go +++ b/statediff/api.go @@ -89,3 +89,8 @@ func (api *PublicStateDiffAPI) Stream(ctx context.Context) (*rpc.Subscription, e return rpcSub, nil } + +// StateDiffAt returns a statediff payload at the specific blockheight +func (api *PublicStateDiffAPI) StateDiffAt(ctx context.Context, blockNumber uint64) (*Payload, error) { + return api.sds.StateDiffAt(blockNumber) +} diff --git a/statediff/builder_test.go b/statediff/builder_test.go index 2c9253de1..1cd5cdf38 100644 --- a/statediff/builder_test.go +++ b/statediff/builder_test.go @@ -41,10 +41,6 @@ var ( burnAddress = common.HexToAddress("0x0") burnLeafKey = testhelpers.AddressToLeafKey(burnAddress) - block0Hash = common.HexToHash("0xd1721cfd0b29c36fd7a68f25c128e86413fb666a6e1d68e89b875bd299262661") - block1Hash = common.HexToHash("0xbbe88de60ba33a3f18c0caa37d827bfb70252e19e40a07cd34041696c35ecb1a") - block2Hash = common.HexToHash("0x34ad0fd9bb2911986b75d518c822641079dea823bc6952343ebf05da1062b6f5") - block3Hash = common.HexToHash("0x9872058136c560a6ebed0c0522b8d3016fc21f4fb0fb6585ddd8fd4c54f9909a") balanceChange10000 = int64(10000) balanceChange1000 = int64(1000) block1BankBalance = int64(99990000) @@ -140,13 +136,13 @@ type arguments struct { } func TestBuilder(t *testing.T) { - _, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) + blockHashes, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) defer chain.Stop() - block0 = blockMap[block0Hash] - block1 = blockMap[block1Hash] - block2 = blockMap[block2Hash] - block3 = blockMap[block3Hash] + block0 = blockMap[blockHashes[3]] + block1 = blockMap[blockHashes[2]] + block2 = blockMap[blockHashes[1]] + block3 = blockMap[blockHashes[0]] config := statediff.Config{ PathsAndProofs: true, IntermediateNodes: false, @@ -164,16 +160,17 @@ func TestBuilder(t *testing.T) { oldStateRoot: block0.Root(), newStateRoot: block0.Root(), blockNumber: block0.Number(), - blockHash: block0Hash, + blockHash: block0.Hash(), }, &statediff.StateDiff{ BlockNumber: block0.Number(), - BlockHash: block0Hash, + BlockHash: block0.Hash(), CreatedAccounts: emptyAccountDiffEventualMap, DeletedAccounts: emptyAccountDiffEventualMap, UpdatedAccounts: emptyAccountDiffIncrementalMap, }, }, + { "testBlock1", //10000 transferred from testBankAddress to account1Addr @@ -181,7 +178,7 @@ func TestBuilder(t *testing.T) { oldStateRoot: block0.Root(), newStateRoot: block1.Root(), blockNumber: block1.Number(), - blockHash: block1Hash, + blockHash: block1.Hash(), }, &statediff.StateDiff{ BlockNumber: block1.Number(), @@ -228,7 +225,7 @@ func TestBuilder(t *testing.T) { oldStateRoot: block1.Root(), newStateRoot: block2.Root(), blockNumber: block2.Number(), - blockHash: block2Hash, + blockHash: block2.Hash(), }, &statediff.StateDiff{ BlockNumber: block2.Number(), @@ -374,13 +371,13 @@ func TestBuilder(t *testing.T) { } func TestBuilderWithWatchedAddressList(t *testing.T) { - _, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) + blockHashes, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) contractLeafKey = testhelpers.AddressToLeafKey(testhelpers.ContractAddr) defer chain.Stop() - block0 = blockMap[block0Hash] - block1 = blockMap[block1Hash] - block2 = blockMap[block2Hash] - block3 = blockMap[block3Hash] + block0 = blockMap[blockHashes[3]] + block1 = blockMap[blockHashes[2]] + block2 = blockMap[blockHashes[1]] + block3 = blockMap[blockHashes[0]] config := statediff.Config{ PathsAndProofs: true, IntermediateNodes: false, @@ -399,11 +396,11 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { oldStateRoot: block0.Root(), newStateRoot: block0.Root(), blockNumber: block0.Number(), - blockHash: block0Hash, + blockHash: block0.Hash(), }, &statediff.StateDiff{ BlockNumber: block0.Number(), - BlockHash: block0Hash, + BlockHash: block0.Hash(), CreatedAccounts: emptyAccountDiffEventualMap, DeletedAccounts: emptyAccountDiffEventualMap, UpdatedAccounts: emptyAccountDiffIncrementalMap, @@ -416,7 +413,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { oldStateRoot: block0.Root(), newStateRoot: block1.Root(), blockNumber: block1.Number(), - blockHash: block1Hash, + blockHash: block1.Hash(), }, &statediff.StateDiff{ BlockNumber: block1.Number(), @@ -444,7 +441,7 @@ func TestBuilderWithWatchedAddressList(t *testing.T) { oldStateRoot: block1.Root(), newStateRoot: block2.Root(), blockNumber: block2.Number(), - blockHash: block2Hash, + blockHash: block2.Hash(), }, &statediff.StateDiff{ BlockNumber: block2.Number(), diff --git a/statediff/service.go b/statediff/service.go index d3eab1065..dbe8aba81 100644 --- a/statediff/service.go +++ b/statediff/service.go @@ -39,6 +39,7 @@ const chainEventChanSize = 20000 type blockChain interface { SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription GetBlockByHash(hash common.Hash) *types.Block + GetBlockByNumber(number uint64) *types.Block AddToStateDiffProcessedCollection(hash common.Hash) GetReceiptsByHash(hash common.Hash) types.Receipts } @@ -53,6 +54,8 @@ type IService interface { Subscribe(id rpc.ID, sub chan<- Payload, quitChan chan<- bool) // Method to unsubscribe from state diff processing Unsubscribe(id rpc.ID) error + // Method to get statediff at specific block + StateDiffAt(blockNumber uint64) (*Payload, error) } // Service is the underlying struct for the state diffing service @@ -132,9 +135,12 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { log.Error(fmt.Sprintf("Parent block is nil, skipping this block (%d)", currentBlock.Number())) continue } - if err := sds.processStateDiff(currentBlock, parentBlock); err != nil { + payload, err := sds.processStateDiff(currentBlock, parentBlock) + if err != nil { log.Error(fmt.Sprintf("Error building statediff for block %d; error: ", currentBlock.Number()) + err.Error()) + continue } + sds.send(*payload) case err := <-errCh: log.Warn("Error from chain event subscription, breaking loop", "error", err) sds.close() @@ -148,14 +154,14 @@ func (sds *Service) Loop(chainEventCh chan core.ChainEvent) { } // processStateDiff method builds the state diff payload from the current and parent block before sending it to listening subscriptions -func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) error { +func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) (*Payload, error) { stateDiff, err := sds.Builder.BuildStateDiff(parentBlock.Root(), currentBlock.Root(), currentBlock.Number(), currentBlock.Hash()) if err != nil { - return err + return nil, err } stateDiffRlp, err := rlp.EncodeToBytes(stateDiff) if err != nil { - return err + return nil, err } payload := Payload{ StateDiffRlp: stateDiffRlp, @@ -163,19 +169,17 @@ func (sds *Service) processStateDiff(currentBlock, parentBlock *types.Block) err if sds.StreamBlock { blockBuff := new(bytes.Buffer) if err = currentBlock.EncodeRLP(blockBuff); err != nil { - return err + return nil, err } payload.BlockRlp = blockBuff.Bytes() receiptBuff := new(bytes.Buffer) receipts := sds.BlockChain.GetReceiptsByHash(currentBlock.Hash()) if err = rlp.Encode(receiptBuff, receipts); err != nil { - return err + return nil, err } payload.ReceiptsRlp = receiptBuff.Bytes() } - - sds.send(payload) - return nil + return &payload, nil } // Subscribe is used by the API to subscribe to the service loop @@ -269,3 +273,12 @@ func (sds *Service) close() { } sds.Unlock() } + +// StateDiffAt returns a statediff payload at the specific blockheight +// This operation cannot be performed back past the point of db pruning; it requires an archival node +func (sds *Service) StateDiffAt(blockNumber uint64) (*Payload, error) { + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) + parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash()) + log.Info(fmt.Sprintf("sending state diff at %d", blockNumber)) + return sds.processStateDiff(currentBlock, parentBlock) +} diff --git a/statediff/service_test.go b/statediff/service_test.go index 6119f6ecb..39bac5823 100644 --- a/statediff/service_test.go +++ b/statediff/service_test.go @@ -93,7 +93,7 @@ func testErrorInChainEventLoop(t *testing.T) { blockMapping := make(map[common.Hash]*types.Block) blockMapping[parentBlock1.Hash()] = parentBlock1 blockMapping[parentBlock2.Hash()] = parentBlock2 - blockChain.SetParentBlocksToReturn(blockMapping) + blockChain.SetBlocksForHashes(blockMapping) blockChain.SetChainEvents([]core.ChainEvent{event1, event2, event3}) blockChain.SetReceiptsForHash(testBlock1.Hash(), testReceipts1) blockChain.SetReceiptsForHash(testBlock2.Hash(), testReceipts2) @@ -149,9 +149,9 @@ func testErrorInChainEventLoop(t *testing.T) { } //look up the parent block from its hash expectedHashes := []common.Hash{testBlock1.ParentHash(), testBlock2.ParentHash()} - if !reflect.DeepEqual(blockChain.ParentHashesLookedUp, expectedHashes) { + if !reflect.DeepEqual(blockChain.HashesLookedUp, expectedHashes) { t.Error("Test failure:", t.Name()) - t.Logf("Actual parent hash does not equal expected.\nactual:%+v\nexpected: %+v", blockChain.ParentHashesLookedUp, expectedHashes) + t.Logf("Actual parent hash does not equal expected.\nactual:%+v\nexpected: %+v", blockChain.HashesLookedUp, expectedHashes) } } @@ -170,7 +170,7 @@ func testErrorInBlockLoop(t *testing.T) { service.Subscribe(rpc.NewID(), payloadChan, quitChan) blockMapping := make(map[common.Hash]*types.Block) blockMapping[parentBlock1.Hash()] = parentBlock1 - blockChain.SetParentBlocksToReturn(blockMapping) + blockChain.SetBlocksForHashes(blockMapping) blockChain.SetChainEvents([]core.ChainEvent{event1, event2}) // Need to have listeners on the channels or the subscription will be closed and the processing halted go func() { @@ -194,3 +194,75 @@ func testErrorInBlockLoop(t *testing.T) { t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.NewStateRoot, testBlock1.Root()) } } + +func TestGetStateDiffAt(t *testing.T) { + testErrorInStateDiffAt(t) +} + +func testErrorInStateDiffAt(t *testing.T) { + mockStateDiff := statediff.StateDiff{ + BlockNumber: testBlock1.Number(), + BlockHash: testBlock1.Hash(), + } + expectedStateDiffRlp, err := rlp.EncodeToBytes(mockStateDiff) + if err != nil { + t.Error(err) + } + expectedReceiptsRlp, err := rlp.EncodeToBytes(testReceipts1) + if err != nil { + t.Error(err) + } + expectedBlockRlp, err := rlp.EncodeToBytes(testBlock1) + if err != nil { + t.Error(err) + } + expectedStateDiffPayload := statediff.Payload{ + StateDiffRlp: expectedStateDiffRlp, + ReceiptsRlp: expectedReceiptsRlp, + BlockRlp: expectedBlockRlp, + } + expectedStateDiffPayloadRlp, err := rlp.EncodeToBytes(expectedStateDiffPayload) + if err != nil { + t.Error(err) + } + builder := mocks.Builder{} + builder.SetStateDiffToBuild(mockStateDiff) + blockChain := mocks.BlockChain{} + blockMapping := make(map[common.Hash]*types.Block) + blockMapping[parentBlock1.Hash()] = parentBlock1 + blockChain.SetBlocksForHashes(blockMapping) + blockChain.SetBlockForNumber(testBlock1, testBlock1.NumberU64()) + blockChain.SetReceiptsForHash(testBlock1.Hash(), testReceipts1) + service := statediff.Service{ + Mutex: sync.Mutex{}, + Builder: &builder, + BlockChain: &blockChain, + QuitChan: make(chan bool), + Subscriptions: make(map[rpc.ID]statediff.Subscription), + StreamBlock: true, + } + stateDiffPayload, err := service.StateDiffAt(testBlock1.NumberU64()) + if err != nil { + t.Error(err) + } + stateDiffPayloadRlp, err := rlp.EncodeToBytes(stateDiffPayload) + if err != nil { + t.Error(err) + } + if !bytes.Equal(builder.BlockHash.Bytes(), testBlock1.Hash().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.BlockHash, testBlock1.Hash()) + } + if !bytes.Equal(builder.OldStateRoot.Bytes(), parentBlock1.Root().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.OldStateRoot, parentBlock1.Root()) + } + if !bytes.Equal(builder.NewStateRoot.Bytes(), testBlock1.Root().Bytes()) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", builder.NewStateRoot, testBlock1.Root()) + } + if !bytes.Equal(expectedStateDiffPayloadRlp, stateDiffPayloadRlp) { + t.Error("Test failure:", t.Name()) + t.Logf("Actual does not equal expected.\nactual:%+v\nexpected: %+v", expectedStateDiffPayload, stateDiffPayload) + } +} diff --git a/statediff/testhelpers/helpers.go b/statediff/testhelpers/helpers.go index 8f52bc8cc..ea41ec7bc 100644 --- a/statediff/testhelpers/helpers.go +++ b/statediff/testhelpers/helpers.go @@ -34,10 +34,6 @@ import ( // reassembly. func MakeChain(n int, parent *types.Block) ([]common.Hash, map[common.Hash]*types.Block, *core.BlockChain) { blocks, _ := core.GenerateChain(params.TestChainConfig, parent, ethash.NewFaker(), Testdb, n, testChainGen) - headers := make([]*types.Header, len(blocks)) - for i, block := range blocks { - headers[i] = block.Header() - } chain, _ := core.NewBlockChain(Testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil) hashes := make([]common.Hash, n+1) diff --git a/statediff/testhelpers/mocks/api.go b/statediff/testhelpers/mocks/api.go index 3b43ab7dd..999d82d54 100644 --- a/statediff/testhelpers/mocks/api.go +++ b/statediff/testhelpers/mocks/api.go @@ -36,6 +36,7 @@ import ( type MockStateDiffService struct { sync.Mutex Builder statediff.Builder + BlockChain *BlockChain ReturnProtocol []p2p.Protocol ReturnAPIs []rpc.API BlockChan chan *types.Block @@ -77,10 +78,12 @@ func (sds *MockStateDiffService) Loop(chan core.ChainEvent) { "current block number", currentBlock.Number()) continue } - if err := sds.process(currentBlock, parentBlock); err != nil { - println(err.Error()) + payload, err := sds.processStateDiff(currentBlock, parentBlock) + if err != nil { log.Error("Error building statediff", "block number", currentBlock.Number(), "error", err) + continue } + sds.send(*payload) case <-sds.QuitChan: log.Debug("Quitting the statediff block channel") sds.close() @@ -89,16 +92,16 @@ func (sds *MockStateDiffService) Loop(chan core.ChainEvent) { } } -// process method builds the state diff payload from the current and parent block and streams it to listening subscriptions -func (sds *MockStateDiffService) process(currentBlock, parentBlock *types.Block) error { +// processStateDiff method builds the state diff payload from the current and parent block and streams it to listening subscriptions +func (sds *MockStateDiffService) processStateDiff(currentBlock, parentBlock *types.Block) (*statediff.Payload, error) { stateDiff, err := sds.Builder.BuildStateDiff(parentBlock.Root(), currentBlock.Root(), currentBlock.Number(), currentBlock.Hash()) if err != nil { - return err + return nil, err } stateDiffRlp, err := rlp.EncodeToBytes(stateDiff) if err != nil { - return err + return nil, err } payload := statediff.Payload{ StateDiffRlp: stateDiffRlp, @@ -106,14 +109,11 @@ func (sds *MockStateDiffService) process(currentBlock, parentBlock *types.Block) if sds.streamBlock { rlpBuff := new(bytes.Buffer) if err = currentBlock.EncodeRLP(rlpBuff); err != nil { - return err + return nil, err } payload.BlockRlp = rlpBuff.Bytes() } - - // If we have any websocket subscription listening in, send the data to them - sds.send(payload) - return nil + return &payload, nil } // Subscribe mock method @@ -185,3 +185,11 @@ func (sds *MockStateDiffService) Stop() error { close(sds.QuitChan) return nil } + +// StateDiffAt mock method +func (sds *MockStateDiffService) StateDiffAt(blockNumber uint64) (*statediff.Payload, error) { + currentBlock := sds.BlockChain.GetBlockByNumber(blockNumber) + parentBlock := sds.BlockChain.GetBlockByHash(currentBlock.ParentHash()) + log.Info(fmt.Sprintf("sending state diff at %d", blockNumber)) + return sds.processStateDiff(currentBlock, parentBlock) +} diff --git a/statediff/testhelpers/mocks/api_test.go b/statediff/testhelpers/mocks/api_test.go index b76ba4328..89e51db08 100644 --- a/statediff/testhelpers/mocks/api_test.go +++ b/statediff/testhelpers/mocks/api_test.go @@ -55,6 +55,10 @@ var bankAccount1, _ = rlp.EncodeToBytes(state.Account{ }) func TestAPI(t *testing.T) { + testSubscriptionAPI(t) + testHTTPAPI(t) +} +func testSubscriptionAPI(t *testing.T) { _, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) defer chain.Stop() block0Hash := common.HexToHash("0xd1721cfd0b29c36fd7a68f25c128e86413fb666a6e1d68e89b875bd299262661") @@ -140,3 +144,75 @@ func TestAPI(t *testing.T) { t.Errorf("channel quit before delivering payload") } } + +func testHTTPAPI(t *testing.T) { + _, blockMap, chain := testhelpers.MakeChain(3, testhelpers.Genesis) + defer chain.Stop() + block0Hash := common.HexToHash("0xd1721cfd0b29c36fd7a68f25c128e86413fb666a6e1d68e89b875bd299262661") + block1Hash := common.HexToHash("0xbbe88de60ba33a3f18c0caa37d827bfb70252e19e40a07cd34041696c35ecb1a") + block0 = blockMap[block0Hash] + block1 = blockMap[block1Hash] + config := statediff.Config{ + PathsAndProofs: true, + IntermediateNodes: false, + } + mockBlockChain := &BlockChain{} + mockBlockChain.SetBlocksForHashes(blockMap) + mockBlockChain.SetBlockForNumber(block1, block1.Number().Uint64()) + mockService := MockStateDiffService{ + Mutex: sync.Mutex{}, + Builder: statediff.NewBuilder(testhelpers.Testdb, chain, config), + BlockChain: mockBlockChain, + streamBlock: true, + } + payload, err := mockService.StateDiffAt(block1.Number().Uint64()) + expectedBlockRlp, _ := rlp.EncodeToBytes(block1) + if !bytes.Equal(payload.BlockRlp, expectedBlockRlp) { + t.Errorf("payload does not have expected block\r\actual block rlp: %v\r\nexpected block rlp: %v", payload.BlockRlp, expectedBlockRlp) + } + expectedStateDiff := statediff.StateDiff{ + BlockNumber: block1.Number(), + BlockHash: block1.Hash(), + CreatedAccounts: []statediff.AccountDiff{ + { + Leaf: true, + Key: burnLeafKey.Bytes(), + Value: burnAccount1, + Proof: [][]byte{{248, 113, 160, 87, 118, 82, 182, 37, 183, 123, 219, 91, 247, 123, 196, 63, 49, 37, 202, 215, 70, 77, 103, 157, 21, 117, 86, 82, 119, 211, 97, 27, 128, 83, 231, 128, 128, 128, 128, 160, 254, 136, 159, 16, 229, 219, 143, 44, 43, 243, 85, 146, 129, 82, 161, 127, 110, 59, 185, 154, 146, 65, 172, 109, 132, 199, 126, 98, 100, 80, 156, 121, 128, 128, 128, 128, 128, 128, 128, 128, 160, 17, 219, 12, 218, 52, 168, 150, 218, 190, 182, 131, 155, 176, 106, 56, 244, 149, 20, 207, 164, 134, 67, 89, 132, 235, 1, 59, 125, 249, 238, 133, 197, 128, 128}, + {248, 113, 160, 51, 128, 199, 183, 174, 129, 165, 142, 185, 141, 156, 120, 222, 74, 31, 215, 253, 149, 53, 252, 149, 62, 210, 190, 96, 45, 170, 164, 23, 103, 49, 42, 184, 78, 248, 76, 128, 136, 27, 193, 109, 103, 78, 200, 0, 0, 160, 86, 232, 31, 23, 27, 204, 85, 166, 255, 131, 69, 230, 146, 192, 248, 110, 91, 72, 224, 27, 153, 108, 173, 192, 1, 98, 47, 181, 227, 99, 180, 33, 160, 197, 210, 70, 1, 134, 247, 35, 60, 146, 126, 125, 178, 220, 199, 3, 192, 229, 0, 182, 83, 202, 130, 39, 59, 123, 250, 216, 4, 93, 133, 164, 112}}, + Path: []byte{5, 3, 8, 0, 12, 7, 11, 7, 10, 14, 8, 1, 10, 5, 8, 14, 11, 9, 8, 13, 9, 12, 7, 8, 13, 14, 4, 10, 1, 15, 13, 7, 15, 13, 9, 5, 3, 5, 15, 12, 9, 5, 3, 14, 13, 2, 11, 14, 6, 0, 2, 13, 10, 10, 10, 4, 1, 7, 6, 7, 3, 1, 2, 10, 16}, + Storage: []statediff.StorageDiff{}, + }, + { + Leaf: true, + Key: testhelpers.Account1LeafKey.Bytes(), + Value: account1, + Proof: [][]byte{{248, 113, 160, 87, 118, 82, 182, 37, 183, 123, 219, 91, 247, 123, 196, 63, 49, 37, 202, 215, 70, 77, 103, 157, 21, 117, 86, 82, 119, 211, 97, 27, 128, 83, 231, 128, 128, 128, 128, 160, 254, 136, 159, 16, 229, 219, 143, 44, 43, 243, 85, 146, 129, 82, 161, 127, 110, 59, 185, 154, 146, 65, 172, 109, 132, 199, 126, 98, 100, 80, 156, 121, 128, 128, 128, 128, 128, 128, 128, 128, 160, 17, 219, 12, 218, 52, 168, 150, 218, 190, 182, 131, 155, 176, 106, 56, 244, 149, 20, 207, 164, 134, 67, 89, 132, 235, 1, 59, 125, 249, 238, 133, 197, 128, 128}, + {248, 107, 160, 57, 38, 219, 105, 170, 206, 213, 24, 233, 185, 240, 244, 52, 164, 115, 231, 23, 65, 9, 201, 67, 84, 139, 184, 242, 59, 228, 28, 167, 109, 154, 210, 184, 72, 248, 70, 128, 130, 39, 16, 160, 86, 232, 31, 23, 27, 204, 85, 166, 255, 131, 69, 230, 146, 192, 248, 110, 91, 72, 224, 27, 153, 108, 173, 192, 1, 98, 47, 181, 227, 99, 180, 33, 160, 197, 210, 70, 1, 134, 247, 35, 60, 146, 126, 125, 178, 220, 199, 3, 192, 229, 0, 182, 83, 202, 130, 39, 59, 123, 250, 216, 4, 93, 133, 164, 112}}, + Path: []byte{14, 9, 2, 6, 13, 11, 6, 9, 10, 10, 12, 14, 13, 5, 1, 8, 14, 9, 11, 9, 15, 0, 15, 4, 3, 4, 10, 4, 7, 3, 14, 7, 1, 7, 4, 1, 0, 9, 12, 9, 4, 3, 5, 4, 8, 11, 11, 8, 15, 2, 3, 11, 14, 4, 1, 12, 10, 7, 6, 13, 9, 10, 13, 2, 16}, + Storage: []statediff.StorageDiff{}, + }, + }, + DeletedAccounts: emptyAccountDiffEventualMap, + UpdatedAccounts: []statediff.AccountDiff{ + { + Leaf: true, + Key: testhelpers.BankLeafKey.Bytes(), + Value: bankAccount1, + Proof: [][]byte{{248, 113, 160, 87, 118, 82, 182, 37, 183, 123, 219, 91, 247, 123, 196, 63, 49, 37, 202, 215, 70, 77, 103, 157, 21, 117, 86, 82, 119, 211, 97, 27, 128, 83, 231, 128, 128, 128, 128, 160, 254, 136, 159, 16, 229, 219, 143, 44, 43, 243, 85, 146, 129, 82, 161, 127, 110, 59, 185, 154, 146, 65, 172, 109, 132, 199, 126, 98, 100, 80, 156, 121, 128, 128, 128, 128, 128, 128, 128, 128, 160, 17, 219, 12, 218, 52, 168, 150, 218, 190, 182, 131, 155, 176, 106, 56, 244, 149, 20, 207, 164, 134, 67, 89, 132, 235, 1, 59, 125, 249, 238, 133, 197, 128, 128}, + {248, 109, 160, 48, 191, 73, 244, 64, 161, 205, 5, 39, 228, 208, 110, 39, 101, 101, 76, 15, 86, 69, 34, 87, 81, 109, 121, 58, 155, 141, 96, 77, 207, 223, 42, 184, 74, 248, 72, 1, 132, 5, 245, 185, 240, 160, 86, 232, 31, 23, 27, 204, 85, 166, 255, 131, 69, 230, 146, 192, 248, 110, 91, 72, 224, 27, 153, 108, 173, 192, 1, 98, 47, 181, 227, 99, 180, 33, 160, 197, 210, 70, 1, 134, 247, 35, 60, 146, 126, 125, 178, 220, 199, 3, 192, 229, 0, 182, 83, 202, 130, 39, 59, 123, 250, 216, 4, 93, 133, 164, 112}}, + Path: []byte{0, 0, 11, 15, 4, 9, 15, 4, 4, 0, 10, 1, 12, 13, 0, 5, 2, 7, 14, 4, 13, 0, 6, 14, 2, 7, 6, 5, 6, 5, 4, 12, 0, 15, 5, 6, 4, 5, 2, 2, 5, 7, 5, 1, 6, 13, 7, 9, 3, 10, 9, 11, 8, 13, 6, 0, 4, 13, 12, 15, 13, 15, 2, 10, 16}, + Storage: []statediff.StorageDiff{}, + }, + }, + } + expectedStateDiffBytes, err := rlp.EncodeToBytes(expectedStateDiff) + if err != nil { + t.Error(err) + } + sort.Slice(payload.StateDiffRlp, func(i, j int) bool { return payload.StateDiffRlp[i] < payload.StateDiffRlp[j] }) + sort.Slice(expectedStateDiffBytes, func(i, j int) bool { return expectedStateDiffBytes[i] < expectedStateDiffBytes[j] }) + if !bytes.Equal(payload.StateDiffRlp, expectedStateDiffBytes) { + t.Errorf("payload does not have expected state diff\r\actual state diff rlp: %v\r\nexpected state diff rlp: %v", payload.StateDiffRlp, expectedStateDiffBytes) + } +} diff --git a/statediff/testhelpers/mocks/blockchain.go b/statediff/testhelpers/mocks/blockchain.go index 508435236..b5eb6ff92 100644 --- a/statediff/testhelpers/mocks/blockchain.go +++ b/statediff/testhelpers/mocks/blockchain.go @@ -29,34 +29,35 @@ import ( // BlockChain is a mock blockchain for testing type BlockChain struct { - ParentHashesLookedUp []common.Hash - parentBlocksToReturn map[common.Hash]*types.Block - callCount int - ChainEvents []core.ChainEvent - Receipts map[common.Hash]types.Receipts + HashesLookedUp []common.Hash + blocksToReturnByHash map[common.Hash]*types.Block + blocksToReturnByNumber map[uint64]*types.Block + callCount int + ChainEvents []core.ChainEvent + Receipts map[common.Hash]types.Receipts } // AddToStateDiffProcessedCollection mock method func (blockChain *BlockChain) AddToStateDiffProcessedCollection(hash common.Hash) {} -// SetParentBlocksToReturn mock method -func (blockChain *BlockChain) SetParentBlocksToReturn(blocks map[common.Hash]*types.Block) { - if blockChain.parentBlocksToReturn == nil { - blockChain.parentBlocksToReturn = make(map[common.Hash]*types.Block) +// SetBlocksForHashes mock method +func (blockChain *BlockChain) SetBlocksForHashes(blocks map[common.Hash]*types.Block) { + if blockChain.blocksToReturnByHash == nil { + blockChain.blocksToReturnByHash = make(map[common.Hash]*types.Block) } - blockChain.parentBlocksToReturn = blocks + blockChain.blocksToReturnByHash = blocks } // GetBlockByHash mock method func (blockChain *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { - blockChain.ParentHashesLookedUp = append(blockChain.ParentHashesLookedUp, hash) + blockChain.HashesLookedUp = append(blockChain.HashesLookedUp, hash) - var parentBlock *types.Block - if len(blockChain.parentBlocksToReturn) > 0 { - parentBlock = blockChain.parentBlocksToReturn[hash] + var block *types.Block + if len(blockChain.blocksToReturnByHash) > 0 { + block = blockChain.blocksToReturnByHash[hash] } - return parentBlock + return block } // SetChainEvents mock method @@ -66,7 +67,7 @@ func (blockChain *BlockChain) SetChainEvents(chainEvents []core.ChainEvent) { // SubscribeChainEvent mock method func (blockChain *BlockChain) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { - subErr := errors.New("Subscription Error") + subErr := errors.New("subscription error") var eventCounter int subscription := event.NewSubscription(func(quit <-chan struct{}) error { @@ -100,3 +101,16 @@ func (blockChain *BlockChain) SetReceiptsForHash(hash common.Hash, receipts type func (blockChain *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { return blockChain.Receipts[hash] } + +// SetBlockForNumber mock method +func (blockChain *BlockChain) SetBlockForNumber(block *types.Block, number uint64) { + if blockChain.blocksToReturnByNumber == nil { + blockChain.blocksToReturnByNumber = make(map[uint64]*types.Block) + } + blockChain.blocksToReturnByNumber[number] = block +} + +// GetBlockByNumber mock method +func (blockChain *BlockChain) GetBlockByNumber(number uint64) *types.Block { + return blockChain.blocksToReturnByNumber[number] +} diff --git a/statediff/testhelpers/test_data.go b/statediff/testhelpers/test_data.go index 2f6088f86..604bd23b7 100644 --- a/statediff/testhelpers/test_data.go +++ b/statediff/testhelpers/test_data.go @@ -23,10 +23,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/statediff" ) // AddressToLeafKey hashes an returns an address @@ -36,67 +33,14 @@ func AddressToLeafKey(address common.Address) common.Hash { // Test variables var ( - BlockNumber = big.NewInt(rand.Int63()) - BlockHash = "0xfa40fbe2d98d98b3363a778d52f2bcd29d6790b9b3f3cab2b167fd12d3550f73" - CodeHash = common.Hex2Bytes("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") - NewNonceValue = rand.Uint64() - NewBalanceValue = rand.Int63() - ContractRoot = common.HexToHash("0x56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - StoragePath = common.HexToHash("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes() - StorageKey = common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001").Bytes() - StorageValue = common.Hex2Bytes("0x03") - storage = []statediff.StorageDiff{{ - Key: StorageKey, - Value: StorageValue, - Path: StoragePath, - Proof: [][]byte{}, - }} - emptyStorage = make([]statediff.StorageDiff, 0) - address = common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476592") - ContractLeafKey = AddressToLeafKey(address) - anotherAddress = common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476593") - AnotherContractLeafKey = AddressToLeafKey(anotherAddress) - testAccount = state.Account{ - Nonce: NewNonceValue, - Balance: big.NewInt(NewBalanceValue), - Root: ContractRoot, - CodeHash: CodeHash, - } - valueBytes, _ = rlp.EncodeToBytes(testAccount) - CreatedAccountDiffs = []statediff.AccountDiff{ - { - Key: ContractLeafKey.Bytes(), - Value: valueBytes, - Storage: storage, - }, - { - Key: AnotherContractLeafKey.Bytes(), - Value: valueBytes, - Storage: emptyStorage, - }, - } - - UpdatedAccountDiffs = []statediff.AccountDiff{{ - Key: ContractLeafKey.Bytes(), - Value: valueBytes, - Storage: storage, - }} - - DeletedAccountDiffs = []statediff.AccountDiff{{ - Key: ContractLeafKey.Bytes(), - Value: valueBytes, - Storage: storage, - }} - - TestStateDiff = statediff.StateDiff{ - BlockNumber: BlockNumber, - BlockHash: common.HexToHash(BlockHash), - CreatedAccounts: CreatedAccountDiffs, - DeletedAccounts: DeletedAccountDiffs, - UpdatedAccounts: UpdatedAccountDiffs, - } - Testdb = rawdb.NewMemoryDatabase() + BlockNumber = big.NewInt(rand.Int63()) + BlockHash = "0xfa40fbe2d98d98b3363a778d52f2bcd29d6790b9b3f3cab2b167fd12d3550f73" + CodeHash = common.Hex2Bytes("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + StoragePath = common.HexToHash("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes() + StorageKey = common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001").Bytes() + StorageValue = common.Hex2Bytes("0x03") + Testdb = rawdb.NewMemoryDatabase() TestBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") TestBankAddress = crypto.PubkeyToAddress(TestBankKey.PublicKey) //0x71562b71999873DB5b286dF957af199Ec94617F7 BankLeafKey = AddressToLeafKey(TestBankAddress) diff --git a/tests/block_test.go b/tests/block_test.go index 33eaed1e1..3a55e4c34 100644 --- a/tests/block_test.go +++ b/tests/block_test.go @@ -25,17 +25,9 @@ func TestBlockchain(t *testing.T) { bt := new(testMatcher) // General state tests are 'exported' as blockchain tests, but we can run them natively. - bt.skipLoad(`^ValidBlocks/bcStateTests/`) - // Skip random failures due to selfish mining test. + bt.skipLoad(`^GeneralStateTests/`) + // Skip random failures due to selfish mining test bt.skipLoad(`.*bcForgedTest/bcForkUncle\.json`) - bt.skipLoad(`.*bcMultiChainTest/(ChainAtoChainB_blockorder|CallContractFromNotBestBlock)`) - bt.skipLoad(`.*bcTotalDifficultyTest/(lotsOfLeafs|lotsOfBranches|sideChainWithMoreTransactions)`) - - // These are not formatted like the rest -- due to the large postState, the postState - // was replaced by a hash, instead of a genesisAlloc map - // See https://github.com/ethereum/tests/pull/616 - bt.skipLoad(`.*bcExploitTest/ShanghaiLove.json`) - bt.skipLoad(`.*bcExploitTest/SuicideIssue.json`) // Slow tests bt.slow(`.*bcExploitTest/DelegateCallSpam.json`) @@ -45,9 +37,20 @@ func TestBlockchain(t *testing.T) { bt.slow(`.*/bcGasPricerTest/RPC_API_Test.json`) bt.slow(`.*/bcWalletTest/`) + // Very slow test + bt.skipLoad(`.*/stTimeConsuming/.*`) + + // test takes a lot for time and goes easily OOM because of sha3 calculation on a huge range, + // using 4.6 TGas + bt.skipLoad(`.*randomStatetest94.json.*`) + bt.walk(t, blockTestDir, func(t *testing.T, name string, test *BlockTest) { if err := bt.checkFailure(t, name, test.Run()); err != nil { t.Error(err) } }) + + // There is also a LegacyTests folder, containing blockchain tests generated + // prior to Istanbul. However, they are all derived from GeneralStateTests, + // which run natively, so there's no reason to run them here. } diff --git a/tests/init_test.go b/tests/init_test.go index 053cbd6fc..622318adb 100644 --- a/tests/init_test.go +++ b/tests/init_test.go @@ -18,6 +18,7 @@ package tests import ( "encoding/json" + "flag" "fmt" "io" "io/ioutil" @@ -33,10 +34,22 @@ import ( "github.com/ethereum/go-ethereum/params" ) +// Command line flags to configure the interpreters. +var ( + testEVM = flag.String("vm.evm", "", "EVM configuration") + testEWASM = flag.String("vm.ewasm", "", "EWASM configuration") +) + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(m.Run()) +} + var ( baseDir = filepath.Join(".", "testdata") blockTestDir = filepath.Join(baseDir, "BlockchainTests") stateTestDir = filepath.Join(baseDir, "GeneralStateTests") + legacyStateTestDir = filepath.Join(baseDir, "LegacyTests", "Constantinople", "GeneralStateTests") transactionTestDir = filepath.Join(baseDir, "TransactionTests") vmTestDir = filepath.Join(baseDir, "VMTests") rlpTestDir = filepath.Join(baseDir, "RLPTests") diff --git a/tests/solidity/bytecode.js b/tests/solidity/bytecode.js new file mode 100644 index 000000000..8796aabfa --- /dev/null +++ b/tests/solidity/bytecode.js @@ -0,0 +1,6 @@ +{ + "linkReferences": {}, + "object": "608060405234801561001057600080fd5b5061001961007a565b604051809103906000f080158015610035573d6000803e3d6000fd5b506000806101000a81548173ffffffffffffffffffffffffffffffffffffffff021916908373ffffffffffffffffffffffffffffffffffffffff16021790555061008a565b60405161015f8061055c83390190565b6104c3806100996000396000f3fe60806040526004361061005c576000357c01000000000000000000000000000000000000000000000000000000009004806355313dea146100615780636d3d141614610078578063b9d1e5aa1461008f578063f8a8fd6d146100a6575b600080fd5b34801561006d57600080fd5b506100766100bd565b005b34801561008457600080fd5b5061008d6100bf565b005b34801561009b57600080fd5b506100a46100c4565b005b3480156100b257600080fd5b506100bb6100c6565b005b005b600080fd5bfe5b600160021a6002f35b60058110156100e3576001810190506100cf565b5060065b60058111156100fb576001810190506100e7565b5060015b6005811215610113576001810190506100ff565b5060065b600581131561012b57600181019050610117565b5060021561013857600051505b60405160208101602060048337505060405160208101602060048339505060405160208101602060048360003c50503660005b81811015610182576002815260018101905061016b565b505060008020506000602060403e6010608060106040610123612710fa506020610123600af05060008060009054906101000a900473ffffffffffffffffffffffffffffffffffffffff169050600060405180807f697353616d654164647265737328616464726573732c61646472657373290000815250601e01905060405180910390209050600033905060405182815281600482015281602482015260648101604052602081604483600088611388f1505060405182815281600482015281602482015260648101604052602081604483600088611388f250506040518281528160048201528160248201526064810160405260208160448387611388f4505060006242004290507f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001026040518082815260200191505060405180910390a07f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001026040518082815260200191505060405180910390a13373ffffffffffffffffffffffffffffffffffffffff166001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001026040518082815260200191505060405180910390a2806001023373ffffffffffffffffffffffffffffffffffffffff166001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001026040518082815260200191505060405180910390a380600102816001023373ffffffffffffffffffffffffffffffffffffffff166001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001027f50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb206001026040518082815260200191505060405180910390a46002fffea165627a7a723058200e51baa2b454b47fdf0ef596fa24aff8ed3a3727b7481ebd25349182ce7152a30029608060405234801561001057600080fd5b5061013f806100206000396000f3fe60806040526004361061003b576000357c010000000000000000000000000000000000000000000000000000000090048063161e715014610040575b600080fd5b34801561004c57600080fd5b506100af6004803603604081101561006357600080fd5b81019080803573ffffffffffffffffffffffffffffffffffffffff169060200190929190803573ffffffffffffffffffffffffffffffffffffffff1690602001909291905050506100c9565b604051808215151515815260200191505060405180910390f35b60008173ffffffffffffffffffffffffffffffffffffffff168373ffffffffffffffffffffffffffffffffffffffff161415610108576001905061010d565b600090505b9291505056fea165627a7a72305820358f67a58c115ea636b0b8e5c4ca7a52b8192d0f3fa98a4434d6ea04596b5d0d0029", + "opcodes": "PUSH1 0x80 PUSH1 0x40 MSTORE CALLVALUE DUP1 ISZERO PUSH2 0x10 JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0x19 PUSH2 0x7A JUMP JUMPDEST PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 PUSH1 0x0 CREATE DUP1 ISZERO DUP1 ISZERO PUSH2 0x35 JUMPI RETURNDATASIZE PUSH1 0x0 DUP1 RETURNDATACOPY RETURNDATASIZE PUSH1 0x0 REVERT JUMPDEST POP PUSH1 0x0 DUP1 PUSH2 0x100 EXP DUP2 SLOAD DUP2 PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF MUL NOT AND SWAP1 DUP4 PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND MUL OR SWAP1 SSTORE POP PUSH2 0x8A JUMP JUMPDEST PUSH1 0x40 MLOAD PUSH2 0x15F DUP1 PUSH2 0x55C DUP4 CODECOPY ADD SWAP1 JUMP JUMPDEST PUSH2 0x4C3 DUP1 PUSH2 0x99 PUSH1 0x0 CODECOPY PUSH1 0x0 RETURN INVALID PUSH1 0x80 PUSH1 0x40 MSTORE PUSH1 0x4 CALLDATASIZE LT PUSH2 0x5C JUMPI PUSH1 0x0 CALLDATALOAD PUSH29 0x100000000000000000000000000000000000000000000000000000000 SWAP1 DIV DUP1 PUSH4 0x55313DEA EQ PUSH2 0x61 JUMPI DUP1 PUSH4 0x6D3D1416 EQ PUSH2 0x78 JUMPI DUP1 PUSH4 0xB9D1E5AA EQ PUSH2 0x8F JUMPI DUP1 PUSH4 0xF8A8FD6D EQ PUSH2 0xA6 JUMPI JUMPDEST PUSH1 0x0 DUP1 REVERT JUMPDEST CALLVALUE DUP1 ISZERO PUSH2 0x6D JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0x76 PUSH2 0xBD JUMP JUMPDEST STOP JUMPDEST CALLVALUE DUP1 ISZERO PUSH2 0x84 JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0x8D PUSH2 0xBF JUMP JUMPDEST STOP JUMPDEST CALLVALUE DUP1 ISZERO PUSH2 0x9B JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0xA4 PUSH2 0xC4 JUMP JUMPDEST STOP JUMPDEST CALLVALUE DUP1 ISZERO PUSH2 0xB2 JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0xBB PUSH2 0xC6 JUMP JUMPDEST STOP JUMPDEST STOP JUMPDEST PUSH1 0x0 DUP1 REVERT JUMPDEST INVALID JUMPDEST PUSH1 0x1 PUSH1 0x2 BYTE PUSH1 0x2 RETURN JUMPDEST PUSH1 0x5 DUP2 LT ISZERO PUSH2 0xE3 JUMPI PUSH1 0x1 DUP2 ADD SWAP1 POP PUSH2 0xCF JUMP JUMPDEST POP PUSH1 0x6 JUMPDEST PUSH1 0x5 DUP2 GT ISZERO PUSH2 0xFB JUMPI PUSH1 0x1 DUP2 ADD SWAP1 POP PUSH2 0xE7 JUMP JUMPDEST POP PUSH1 0x1 JUMPDEST PUSH1 0x5 DUP2 SLT ISZERO PUSH2 0x113 JUMPI PUSH1 0x1 DUP2 ADD SWAP1 POP PUSH2 0xFF JUMP JUMPDEST POP PUSH1 0x6 JUMPDEST PUSH1 0x5 DUP2 SGT ISZERO PUSH2 0x12B JUMPI PUSH1 0x1 DUP2 ADD SWAP1 POP PUSH2 0x117 JUMP JUMPDEST POP PUSH1 0x2 ISZERO PUSH2 0x138 JUMPI PUSH1 0x0 MLOAD POP JUMPDEST PUSH1 0x40 MLOAD PUSH1 0x20 DUP2 ADD PUSH1 0x20 PUSH1 0x4 DUP4 CALLDATACOPY POP POP PUSH1 0x40 MLOAD PUSH1 0x20 DUP2 ADD PUSH1 0x20 PUSH1 0x4 DUP4 CODECOPY POP POP PUSH1 0x40 MLOAD PUSH1 0x20 DUP2 ADD PUSH1 0x20 PUSH1 0x4 DUP4 PUSH1 0x0 EXTCODECOPY POP POP CALLDATASIZE PUSH1 0x0 JUMPDEST DUP2 DUP2 LT ISZERO PUSH2 0x182 JUMPI PUSH1 0x2 DUP2 MSTORE PUSH1 0x1 DUP2 ADD SWAP1 POP PUSH2 0x16B JUMP JUMPDEST POP POP PUSH1 0x0 DUP1 KECCAK256 POP PUSH1 0x0 PUSH1 0x20 PUSH1 0x40 RETURNDATACOPY PUSH1 0x10 PUSH1 0x80 PUSH1 0x10 PUSH1 0x40 PUSH2 0x123 PUSH2 0x2710 STATICCALL POP PUSH1 0x20 PUSH2 0x123 PUSH1 0xA CREATE POP PUSH1 0x0 DUP1 PUSH1 0x0 SWAP1 SLOAD SWAP1 PUSH2 0x100 EXP SWAP1 DIV PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND SWAP1 POP PUSH1 0x0 PUSH1 0x40 MLOAD DUP1 DUP1 PUSH32 0x697353616D654164647265737328616464726573732C61646472657373290000 DUP2 MSTORE POP PUSH1 0x1E ADD SWAP1 POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 KECCAK256 SWAP1 POP PUSH1 0x0 CALLER SWAP1 POP PUSH1 0x40 MLOAD DUP3 DUP2 MSTORE DUP2 PUSH1 0x4 DUP3 ADD MSTORE DUP2 PUSH1 0x24 DUP3 ADD MSTORE PUSH1 0x64 DUP2 ADD PUSH1 0x40 MSTORE PUSH1 0x20 DUP2 PUSH1 0x44 DUP4 PUSH1 0x0 DUP9 PUSH2 0x1388 CALL POP POP PUSH1 0x40 MLOAD DUP3 DUP2 MSTORE DUP2 PUSH1 0x4 DUP3 ADD MSTORE DUP2 PUSH1 0x24 DUP3 ADD MSTORE PUSH1 0x64 DUP2 ADD PUSH1 0x40 MSTORE PUSH1 0x20 DUP2 PUSH1 0x44 DUP4 PUSH1 0x0 DUP9 PUSH2 0x1388 CALLCODE POP POP PUSH1 0x40 MLOAD DUP3 DUP2 MSTORE DUP2 PUSH1 0x4 DUP3 ADD MSTORE DUP2 PUSH1 0x24 DUP3 ADD MSTORE PUSH1 0x64 DUP2 ADD PUSH1 0x40 MSTORE PUSH1 0x20 DUP2 PUSH1 0x44 DUP4 DUP8 PUSH2 0x1388 DELEGATECALL POP POP PUSH1 0x0 PUSH3 0x420042 SWAP1 POP PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH1 0x40 MLOAD DUP1 DUP3 DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 LOG0 PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH1 0x40 MLOAD DUP1 DUP3 DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 LOG1 CALLER PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH1 0x40 MLOAD DUP1 DUP3 DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 LOG2 DUP1 PUSH1 0x1 MUL CALLER PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH1 0x40 MLOAD DUP1 DUP3 DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 LOG3 DUP1 PUSH1 0x1 MUL DUP2 PUSH1 0x1 MUL CALLER PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH32 0x50CB9FE53DAA9737B786AB3646F04D0150DC50EF4E75F59509D83667AD5ADB20 PUSH1 0x1 MUL PUSH1 0x40 MLOAD DUP1 DUP3 DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 LOG4 PUSH1 0x2 SELFDESTRUCT INVALID LOG1 PUSH6 0x627A7A723058 KECCAK256 0xe MLOAD 0xba LOG2 0xb4 SLOAD 0xb4 PUSH32 0xDF0EF596FA24AFF8ED3A3727B7481EBD25349182CE7152A30029608060405234 DUP1 ISZERO PUSH2 0x10 JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0x13F DUP1 PUSH2 0x20 PUSH1 0x0 CODECOPY PUSH1 0x0 RETURN INVALID PUSH1 0x80 PUSH1 0x40 MSTORE PUSH1 0x4 CALLDATASIZE LT PUSH2 0x3B JUMPI PUSH1 0x0 CALLDATALOAD PUSH29 0x100000000000000000000000000000000000000000000000000000000 SWAP1 DIV DUP1 PUSH4 0x161E7150 EQ PUSH2 0x40 JUMPI JUMPDEST PUSH1 0x0 DUP1 REVERT JUMPDEST CALLVALUE DUP1 ISZERO PUSH2 0x4C JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST POP PUSH2 0xAF PUSH1 0x4 DUP1 CALLDATASIZE SUB PUSH1 0x40 DUP2 LT ISZERO PUSH2 0x63 JUMPI PUSH1 0x0 DUP1 REVERT JUMPDEST DUP2 ADD SWAP1 DUP1 DUP1 CALLDATALOAD PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND SWAP1 PUSH1 0x20 ADD SWAP1 SWAP3 SWAP2 SWAP1 DUP1 CALLDATALOAD PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND SWAP1 PUSH1 0x20 ADD SWAP1 SWAP3 SWAP2 SWAP1 POP POP POP PUSH2 0xC9 JUMP JUMPDEST PUSH1 0x40 MLOAD DUP1 DUP3 ISZERO ISZERO ISZERO ISZERO DUP2 MSTORE PUSH1 0x20 ADD SWAP2 POP POP PUSH1 0x40 MLOAD DUP1 SWAP2 SUB SWAP1 RETURN JUMPDEST PUSH1 0x0 DUP2 PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND DUP4 PUSH20 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF AND EQ ISZERO PUSH2 0x108 JUMPI PUSH1 0x1 SWAP1 POP PUSH2 0x10D JUMP JUMPDEST PUSH1 0x0 SWAP1 POP JUMPDEST SWAP3 SWAP2 POP POP JUMP INVALID LOG1 PUSH6 0x627A7A723058 KECCAK256 CALLDATALOAD DUP16 PUSH8 0xA58C115EA636B0B8 0xe5 0xc4 0xca PUSH27 0x52B8192D0F3FA98A4434D6EA04596B5D0D00290000000000000000 ", + "sourceMap": "221:8828:0:-;;;263:110;8:9:-1;5:2;;;30:1;27;20:12;5:2;263:110:0;324:11;;:::i;:::-;;;;;;;;;;;8:9:-1;5:2;;;45:16;42:1;39;24:38;77:16;74:1;67:27;5:2;324:11:0;316:5;;:19;;;;;;;;;;;;;;;;;;221:8828;;;;;;;;;;;;:::o;:::-;;;;;;;" +} diff --git a/tests/solidity/contracts/Migrations.sol b/tests/solidity/contracts/Migrations.sol new file mode 100644 index 000000000..c378ffb02 --- /dev/null +++ b/tests/solidity/contracts/Migrations.sol @@ -0,0 +1,23 @@ +pragma solidity >=0.4.21 <0.6.0; + +contract Migrations { + address public owner; + uint public last_completed_migration; + + constructor() public { + owner = msg.sender; + } + + modifier restricted() { + if (msg.sender == owner) _; + } + + function setCompleted(uint completed) public restricted { + last_completed_migration = completed; + } + + function upgrade(address new_address) public restricted { + Migrations upgraded = Migrations(new_address); + upgraded.setCompleted(last_completed_migration); + } +} diff --git a/tests/solidity/contracts/OpCodes.sol b/tests/solidity/contracts/OpCodes.sol new file mode 100644 index 000000000..9e3a0ebb0 --- /dev/null +++ b/tests/solidity/contracts/OpCodes.sol @@ -0,0 +1,322 @@ +pragma solidity >=0.4.21 <0.6.0; + +contract Test1 { + function isSameAddress(address a, address b) public returns(bool){ //Simply add the two arguments and return + if (a == b) return true; + return false; + } +} + +contract OpCodes { + + Test1 test1; + + constructor() public { //Constructor function + test1 = new Test1(); //Create new "Test1" function + } + + modifier onlyOwner(address _owner) { + require(msg.sender == _owner); + _; + } + // Add a todo to the list + function test() public { + + //simple_instructions + /*assembly { pop(sub(dup1, mul(dup1, dup1))) }*/ + + //keywords + assembly { pop(address) return(2, byte(2,1)) } + + //label_complex + /*assembly { 7 abc: 8 eq jump(abc) jumpi(eq(7, 8), abc) pop } + assembly { pop(jumpi(eq(7, 8), abc)) jump(abc) }*/ + + //functional + /*assembly { let x := 2 add(7, mul(6, x)) mul(7, 8) add =: x }*/ + + //for_statement + assembly { for { let i := 1 } lt(i, 5) { i := add(i, 1) } {} } + assembly { for { let i := 6 } gt(i, 5) { i := add(i, 1) } {} } + assembly { for { let i := 1 } slt(i, 5) { i := add(i, 1) } {} } + assembly { for { let i := 6 } sgt(i, 5) { i := add(i, 1) } {} } + + //no_opcodes_in_strict + assembly { pop(callvalue()) } + + //no_dup_swap_in_strict + /*assembly { swap1() }*/ + + //print_functional + assembly { let x := mul(sload(0x12), 7) } + + //print_if + assembly { if 2 { pop(mload(0)) }} + + //function_definitions_multiple_args + assembly { function f(a, d){ mstore(a, d) } function g(a, d) -> x, y {}} + + //sstore + assembly { function f(a, d){ sstore(a, d) } function g(a, d) -> x, y {}} + + //mstore8 + assembly { function f(a, d){ mstore8(a, d) } function g(a, d) -> x, y {}} + + //calldatacopy + assembly { + let a := mload(0x40) + let b := add(a, 32) + calldatacopy(a, 4, 32) + /*calldatacopy(b, add(4, 32), 32)*/ + /*result := add(mload(a), mload(b))*/ + } + + //codecopy + assembly { + let a := mload(0x40) + let b := add(a, 32) + codecopy(a, 4, 32) + } + + //codecopy + assembly { + let a := mload(0x40) + let b := add(a, 32) + extcodecopy(0, a, 4, 32) + } + + //for_statement + assembly { let x := calldatasize() for { let i := 0} lt(i, x) { i := add(i, 1) } { mstore(i, 2) } } + + //keccak256 + assembly { pop(keccak256(0,0)) } + + //returndatasize + assembly { let r := returndatasize } + + //returndatacopy + assembly { returndatacopy(64, 32, 0) } + //byzantium vs const Constantinople + //staticcall + assembly { pop(staticcall(10000, 0x123, 64, 0x10, 128, 0x10)) } + + /*//create2 Constantinople + assembly { pop(create2(10, 0x123, 32, 64)) }*/ + + //create Constantinople + assembly { pop(create(10, 0x123, 32)) } + + //shift Constantinople + /*assembly { pop(shl(10, 32)) } + assembly { pop(shr(10, 32)) } + assembly { pop(sar(10, 32)) }*/ + + + //not + assembly { pop( not(0x1f)) } + + //exp + assembly { pop( exp(2, 226)) } + + //mod + assembly { pop( mod(3, 9)) } + + //smod + assembly { pop( smod(3, 9)) } + + //div + assembly { pop( div(4, 2)) } + + //sdiv + assembly { pop( sdiv(4, 2)) } + + //iszero + assembly { pop(iszero(1)) } + + //and + assembly { pop(and(2,3)) } + + //or + assembly { pop(or(3,3)) } + + //xor + assembly { pop(xor(3,3)) } + + //addmod + assembly { pop(addmod(3,3,6)) } + + //mulmod + assembly { pop(mulmod(3,3,3)) } + + //signextend + assembly { pop(signextend(1, 10)) } + + //sha3 + assembly { pop(calldataload(0)) } + + //blockhash + assembly { pop(blockhash(sub(number(), 1))) } + + //balance + assembly { pop(balance(0x0)) } + + //caller + assembly { pop(caller()) } + + //codesize + assembly { pop(codesize()) } + + //extcodesize + assembly { pop(extcodesize(0x1)) } + + //origin + assembly { pop(origin()) } + + //gas + assembly { pop(gas())} + + //msize + assembly { pop(msize())} + + //pc + assembly { pop(pc())} + + //gasprice + assembly { pop(gasprice())} + + //coinbase + assembly { pop(coinbase())} + + //timestamp + assembly { pop(timestamp())} + + //number + assembly { pop(number())} + + //difficulty + assembly { pop(difficulty())} + + //gaslimit + assembly { pop(gaslimit())} + + //call + address contractAddr = address(test1); + bytes4 sig = bytes4(keccak256("isSameAddress(address,address)")); //Function signature + address a = msg.sender; + + assembly { + let x := mload(0x40) //Find empty storage location using "free memory pointer" + mstore(x,sig) //Place signature at begining of empty storage + mstore(add(x,0x04),a) // first address parameter. just after signature + mstore(add(x,0x24),a) // 2nd address parameter - first padded. add 32 bytes (not 20 bytes) + mstore(0x40,add(x,0x64)) // this is missing in other examples. Set free pointer before function call. so it is used by called function. + // new free pointer position after the output values of the called function. + + let success := call( + 5000, //5k gas + contractAddr, //To addr + 0, //No wei passed + x, // Inputs are at location x + 0x44, //Inputs size two padded, so 68 bytes + x, //Store output over input + 0x20) //Output is 32 bytes long + } + + //callcode + assembly { + let x := mload(0x40) //Find empty storage location using "free memory pointer" + mstore(x,sig) //Place signature at begining of empty storage + mstore(add(x,0x04),a) // first address parameter. just after signature + mstore(add(x,0x24),a) // 2nd address parameter - first padded. add 32 bytes (not 20 bytes) + mstore(0x40,add(x,0x64)) // this is missing in other examples. Set free pointer before function call. so it is used by called function. + // new free pointer position after the output values of the called function. + + let success := callcode( + 5000, //5k gas + contractAddr, //To addr + 0, //No wei passed + x, // Inputs are at location x + 0x44, //Inputs size two padded, so 68 bytes + x, //Store output over input + 0x20) //Output is 32 bytes long + } + + //delegatecall + assembly { + let x := mload(0x40) //Find empty storage location using "free memory pointer" + mstore(x,sig) //Place signature at begining of empty storage + mstore(add(x,0x04),a) // first address parameter. just after signature + mstore(add(x,0x24),a) // 2nd address parameter - first padded. add 32 bytes (not 20 bytes) + mstore(0x40,add(x,0x64)) // this is missing in other examples. Set free pointer before function call. so it is used by called function. + // new free pointer position after the output values of the called function. + + let success := delegatecall( + 5000, //5k gas + contractAddr, //To addr + x, // Inputs are at location x + 0x44, //Inputs size two padded, so 68 bytes + x, //Store output over input + 0x20) //Output is 32 bytes long + } + + uint256 _id = 0x420042; + + //log0 + log0( + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20) + ); + + //log1 + log1( + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20) + ); + + //log2 + log2( + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(uint256(msg.sender)) + ); + + //log3 + log3( + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(uint256(msg.sender)), + bytes32(_id) + ); + + //log4 + log4( + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(0x50cb9fe53daa9737b786ab3646f04d0150dc50ef4e75f59509d83667ad5adb20), + bytes32(uint256(msg.sender)), + bytes32(_id), + bytes32(_id) + + ); + + //selfdestruct + assembly { selfdestruct(0x02) } + } + + function test_revert() public { + + //revert + assembly{ revert(0, 0) } + } + + function test_invalid() public { + + //revert + assembly{ invalid() } + } + + function test_stop() public { + + //revert + assembly{ stop() } + } + +} diff --git a/tests/solidity/migrations/1_initial_migration.js b/tests/solidity/migrations/1_initial_migration.js new file mode 100644 index 000000000..ee2135d29 --- /dev/null +++ b/tests/solidity/migrations/1_initial_migration.js @@ -0,0 +1,5 @@ +const Migrations = artifacts.require("Migrations"); + +module.exports = function(deployer) { + deployer.deploy(Migrations); +}; diff --git a/tests/solidity/migrations/2_opCodes_migration.js b/tests/solidity/migrations/2_opCodes_migration.js new file mode 100644 index 000000000..65c6b6dc1 --- /dev/null +++ b/tests/solidity/migrations/2_opCodes_migration.js @@ -0,0 +1,5 @@ +var OpCodes = artifacts.require("./OpCodes.sol"); + +module.exports = function(deployer) { + deployer.deploy(OpCodes); +}; diff --git a/tests/solidity/test/opCodes.js b/tests/solidity/test/opCodes.js new file mode 100644 index 000000000..80abacef2 --- /dev/null +++ b/tests/solidity/test/opCodes.js @@ -0,0 +1,34 @@ +const TodoList = artifacts.require('./OpCodes.sol') +const assert = require('assert') +let contractInstance +const Web3 = require('web3'); +const web3 = new Web3(new Web3.providers.HttpProvider('http://localhost:8545')); +// const web3 = new Web3(new Web3.providers.HttpProvider('http://localhost:9545')); + +contract('OpCodes', (accounts) => { + beforeEach(async () => { + contractInstance = await TodoList.deployed() + }) + it('Should run without errors the majorit of opcodes', async () => { + await contractInstance.test() + await contractInstance.test_stop() + + }) + + it('Should throw invalid op code', async () => { + try{ + await contractInstance.test_invalid() + } + catch(error) { + console.error(error); + } + }) + + it('Should revert', async () => { + try{ + await contractInstance.test_revert() } + catch(error) { + console.error(error); + } + }) +}) diff --git a/tests/solidity/truffle-config.js b/tests/solidity/truffle-config.js new file mode 100644 index 000000000..c06d8316f --- /dev/null +++ b/tests/solidity/truffle-config.js @@ -0,0 +1,108 @@ +/** + * Use this file to configure your truffle project. It's seeded with some + * common settings for different networks and features like migrations, + * compilation and testing. Uncomment the ones you need or modify + * them to suit your project as necessary. + * + * More information about configuration can be found at: + * + * truffleframework.com/docs/advanced/configuration + * + * To deploy via Infura you'll need a wallet provider (like truffle-hdwallet-provider) + * to sign your transactions before they're sent to a remote public node. Infura API + * keys are available for free at: infura.io/register + * + * You'll also need a mnemonic - the twelve word phrase the wallet uses to generate + * public/private key pairs. If you're publishing your code to GitHub make sure you load this + * phrase from a file you've .gitignored so it doesn't accidentally become public. + * + */ + +// const HDWalletProvider = require('truffle-hdwallet-provider'); +// const infuraKey = "fj4jll3k....."; +// +// const fs = require('fs'); +// const mnemonic = fs.readFileSync(".secret").toString().trim(); + +// module.exports = { +// /** +// * Networks define how you connect to your ethereum client and let you set the +// * defaults web3 uses to send transactions. If you don't specify one truffle +// * will spin up a development blockchain for you on port 9545 when you +// * run `develop` or `test`. You can ask a truffle command to use a specific +// * network from the command line, e.g +// * +// * $ truffle test --network +// */ +// +// networks: { +// // Useful for testing. The `development` name is special - truffle uses it by default +// // if it's defined here and no other network is specified at the command line. +// // You should run a client (like ganache-cli, geth or parity) in a separate terminal +// // tab if you use this network and you must also set the `host`, `port` and `network_id` +// // options below to some value. +// // +// // development: { +// // host: "127.0.0.1", // Localhost (default: none) +// // port: 8545, // Standard Ethereum port (default: none) +// // network_id: "*", // Any network (default: none) +// // }, +// +// // Another network with more advanced options... +// // advanced: { +// // port: 8777, // Custom port +// // network_id: 1342, // Custom network +// // gas: 8500000, // Gas sent with each transaction (default: ~6700000) +// // gasPrice: 20000000000, // 20 gwei (in wei) (default: 100 gwei) +// // from:
, // Account to send txs from (default: accounts[0]) +// // websockets: true // Enable EventEmitter interface for web3 (default: false) +// // }, +// +// // Useful for deploying to a public network. +// // NB: It's important to wrap the provider as a function. +// // ropsten: { +// // provider: () => new HDWalletProvider(mnemonic, `https://ropsten.infura.io/${infuraKey}`), +// // network_id: 3, // Ropsten's id +// // gas: 5500000, // Ropsten has a lower block limit than mainnet +// // confirmations: 2, // # of confs to wait between deployments. (default: 0) +// // timeoutBlocks: 200, // # of blocks before a deployment times out (minimum/default: 50) +// // skipDryRun: true // Skip dry run before migrations? (default: false for public nets ) +// // }, +// +// // Useful for private networks +// // private: { +// // provider: () => new HDWalletProvider(mnemonic, `https://network.io`), +// // network_id: 2111, // This network is yours, in the cloud. +// // production: true // Treats this network as if it was a public net. (default: false) +// // } +// }, +// +// // Set default mocha options here, use special reporters etc. +// mocha: { +// // timeout: 100000 +// }, +// +// // Configure your compilers +// compilers: { +// solc: { +// // version: "0.5.1", // Fetch exact version from solc-bin (default: truffle's version) +// // docker: true, // Use "0.5.1" you've installed locally with docker (default: false) +// // settings: { // See the solidity docs for advice about optimization and evmVersion +// // optimizer: { +// // enabled: false, +// // runs: 200 +// // }, +// // evmVersion: "byzantium" +// // } +// } +// } +// } +module.exports = { + networks: { + development: { + host: 'localhost', + port: 8545, + network_id: '*' + } + } +} diff --git a/tests/state_test.go b/tests/state_test.go index 0cf124d72..c6a6947bc 100644 --- a/tests/state_test.go +++ b/tests/state_test.go @@ -19,12 +19,10 @@ package tests import ( "bufio" "bytes" - "flag" "fmt" "reflect" "testing" - "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/core/vm" ) @@ -42,9 +40,11 @@ func TestState(t *testing.T) { st.slow(`^stStaticCall/static_Call1MB`) st.slow(`^stSystemOperationsTest/CallRecursiveBomb`) st.slow(`^stTransactionTest/Opcodes_TransactionInit`) + + // Very time consuming + st.skipLoad(`^stTimeConsuming/`) + // Broken tests: - st.skipLoad(`^stTransactionTest/OverflowGasRequire\.json`) // gasLimit > 256 bits - st.skipLoad(`^stTransactionTest/zeroSigTransa[^/]*\.json`) // EIP-86 is not supported yet // Expected failures: //st.fails(`^stRevertTest/RevertPrecompiledTouch(_storage)?\.json/Byzantium/0`, "bug in test") //st.fails(`^stRevertTest/RevertPrecompiledTouch(_storage)?\.json/Byzantium/3`, "bug in test") @@ -66,25 +66,34 @@ func TestState(t *testing.T) { }) } }) + // For Istanbul, older tests were moved into LegacyTests + st.walk(t, legacyStateTestDir, func(t *testing.T, name string, test *StateTest) { + for _, subtest := range test.Subtests() { + subtest := subtest + key := fmt.Sprintf("%s/%d", subtest.Fork, subtest.Index) + name := name + "/" + key + t.Run(key, func(t *testing.T) { + withTrace(t, test.gasLimit(subtest), func(vmconfig vm.Config) error { + _, err := test.Run(subtest, vmconfig) + return st.checkFailure(t, name, err) + }) + }) + } + }) } // Transactions with gasLimit above this value will not get a VM trace on failure. const traceErrorLimit = 400000 -// The VM config for state tests that accepts --vm.* command line arguments. -var testVMConfig = func() vm.Config { - vmconfig := vm.Config{} - flag.StringVar(&vmconfig.EVMInterpreter, utils.EVMInterpreterFlag.Name, utils.EVMInterpreterFlag.Value, utils.EVMInterpreterFlag.Usage) - flag.StringVar(&vmconfig.EWASMInterpreter, utils.EWASMInterpreterFlag.Name, utils.EWASMInterpreterFlag.Value, utils.EWASMInterpreterFlag.Usage) - flag.Parse() - return vmconfig -}() - func withTrace(t *testing.T, gasLimit uint64, test func(vm.Config) error) { - err := test(testVMConfig) + // Use config from command line arguments. + config := vm.Config{EVMInterpreter: *testEVM, EWASMInterpreter: *testEWASM} + err := test(config) if err == nil { return } + + // Test failed, re-run with tracing enabled. t.Error(err) if gasLimit > traceErrorLimit { t.Log("gas limit too high for EVM trace") @@ -93,7 +102,8 @@ func withTrace(t *testing.T, gasLimit uint64, test func(vm.Config) error) { buf := new(bytes.Buffer) w := bufio.NewWriter(buf) tracer := vm.NewJSONLogger(&vm.LogConfig{DisableMemory: true}, w) - err2 := test(vm.Config{Debug: true, Tracer: tracer}) + config.Debug, config.Tracer = true, tracer + err2 := test(config) if !reflect.DeepEqual(err, err2) { t.Errorf("different error for second run: %v", err2) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index c6341e524..59ebcb6e1 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -144,11 +144,29 @@ func (t *StateTest) Subtests() []StateSubtest { return sub } -// Run executes a specific subtest. +// Run executes a specific subtest and verifies the post-state and logs func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateDB, error) { + statedb, root, err := t.RunNoVerify(subtest, vmconfig) + if err != nil { + return statedb, err + } + post := t.json.Post[subtest.Fork][subtest.Index] + // N.B: We need to do this in a two-step process, because the first Commit takes care + // of suicides, and we need to touch the coinbase _after_ it has potentially suicided. + if root != common.Hash(post.Root) { + return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) + } + if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { + return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) + } + return statedb, nil +} + +// RunNoVerify runs a specific subtest and returns the statedb and post-state root +func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config) (*state.StateDB, common.Hash, error) { config, eips, err := getVMConfig(subtest.Fork) if err != nil { - return nil, UnsupportedForkError{subtest.Fork} + return nil, common.Hash{}, UnsupportedForkError{subtest.Fork} } vmconfig.ExtraEips = eips block := t.genesis(config).ToBlock(nil) @@ -157,7 +175,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD post := t.json.Post[subtest.Fork][subtest.Index] msg, err := t.json.Tx.toMessage(post) if err != nil { - return nil, err + return nil, common.Hash{}, err } context := core.NewEVMContext(msg, block.Header(), nil, &t.json.Env.Coinbase) context.GetHash = vmTestBlockHash @@ -179,15 +197,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD statedb.AddBalance(block.Coinbase(), new(big.Int)) // And _now_ get the state root root := statedb.IntermediateRoot(config.IsEIP158(block.Number())) - // N.B: We need to do this in a two-step process, because the first Commit takes care - // of suicides, and we need to touch the coinbase _after_ it has potentially suicided. - if root != common.Hash(post.Root) { - return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) - } - if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { - return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) - } - return statedb, nil + return statedb, root, nil } func (t *StateTest) gasLimit(subtest StateSubtest) uint64 { diff --git a/tests/testdata b/tests/testdata index 553c0ea76..b5eb9900e 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 553c0ea76c739dbb97a8af9fb81c51510bf7493d +Subproject commit b5eb9900ee2147b40d3e681fe86efa4fd693959a diff --git a/trie/sync.go b/trie/sync.go index 6f40b45a1..e5a0c1749 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -57,14 +57,12 @@ type SyncResult struct { // persisted data items. type syncMemBatch struct { batch map[common.Hash][]byte // In-memory membatch of recently completed items - order []common.Hash // Order of completion to prevent out-of-order data loss } // newSyncMemBatch allocates a new memory-buffer for not-yet persisted trie nodes. func newSyncMemBatch() *syncMemBatch { return &syncMemBatch{ batch: make(map[common.Hash][]byte), - order: make([]common.Hash, 0, 256), } } @@ -223,20 +221,18 @@ func (s *Sync) Process(results []SyncResult) (bool, int, error) { } // Commit flushes the data stored in the internal membatch out to persistent -// storage, returning the number of items written and any occurred error. -func (s *Sync) Commit(dbw ethdb.KeyValueWriter) (int, error) { +// storage, returning any occurred error. +func (s *Sync) Commit(dbw ethdb.Batch) error { // Dump the membatch into a database dbw - for i, key := range s.membatch.order { - if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil { - return i, err + for key, value := range s.membatch.batch { + if err := dbw.Put(key[:], value); err != nil { + return err } s.bloom.Add(key[:]) } - written := len(s.membatch.order) // TODO(karalabe): could an order change improve write performance? - // Drop the membatch data and return s.membatch = newSyncMemBatch() - return written, nil + return nil } // Pending returns the number of state entries currently pending for download. @@ -330,7 +326,6 @@ func (s *Sync) children(req *request, object node) ([]*request, error) { func (s *Sync) commit(req *request) (err error) { // Write the node content to the membatch s.membatch.batch[req.hash] = req.data - s.membatch.order = append(s.membatch.order, req.hash) delete(s.requests, req.hash) diff --git a/trie/sync_test.go b/trie/sync_test.go index 0621bb435..6025b87fc 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -105,7 +105,7 @@ func TestEmptySync(t *testing.T) { func TestIterativeSyncIndividual(t *testing.T) { testIterativeSync(t, 1) } func TestIterativeSyncBatched(t *testing.T) { testIterativeSync(t, 100) } -func testIterativeSync(t *testing.T, batch int) { +func testIterativeSync(t *testing.T, count int) { // Create a random trie to copy srcDb, srcTrie, srcData := makeTestTrie() @@ -114,7 +114,7 @@ func testIterativeSync(t *testing.T, batch int) { triedb := NewDatabase(diskdb) sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) - queue := append([]common.Hash{}, sched.Missing(batch)...) + queue := append([]common.Hash{}, sched.Missing(count)...) for len(queue) > 0 { results := make([]SyncResult, len(queue)) for i, hash := range queue { @@ -127,10 +127,12 @@ func testIterativeSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } - queue = append(queue[:0], sched.Missing(batch)...) + batch.Write() + queue = append(queue[:0], sched.Missing(count)...) } // Cross check that the two tries are in sync checkTrieContents(t, triedb, srcTrie.Hash().Bytes(), srcData) @@ -161,9 +163,11 @@ func TestIterativeDelayedSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() queue = append(queue[len(results):], sched.Missing(10000)...) } // Cross check that the two tries are in sync @@ -176,7 +180,7 @@ func TestIterativeDelayedSync(t *testing.T) { func TestIterativeRandomSyncIndividual(t *testing.T) { testIterativeRandomSync(t, 1) } func TestIterativeRandomSyncBatched(t *testing.T) { testIterativeRandomSync(t, 100) } -func testIterativeRandomSync(t *testing.T, batch int) { +func testIterativeRandomSync(t *testing.T, count int) { // Create a random trie to copy srcDb, srcTrie, srcData := makeTestTrie() @@ -186,7 +190,7 @@ func testIterativeRandomSync(t *testing.T, batch int) { sched := NewSync(srcTrie.Hash(), diskdb, nil, NewSyncBloom(1, diskdb)) queue := make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(batch) { + for _, hash := range sched.Missing(count) { queue[hash] = struct{}{} } for len(queue) > 0 { @@ -203,11 +207,13 @@ func testIterativeRandomSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() queue = make(map[common.Hash]struct{}) - for _, hash := range sched.Missing(batch) { + for _, hash := range sched.Missing(count) { queue[hash] = struct{}{} } } @@ -248,9 +254,11 @@ func TestIterativeRandomDelayedSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() for _, result := range results { delete(queue, result.Hash) } @@ -293,9 +301,11 @@ func TestDuplicateAvoidanceSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() queue = append(queue[:0], sched.Missing(0)...) } // Cross check that the two tries are in sync @@ -329,9 +339,11 @@ func TestIncompleteSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(diskdb); err != nil { - t.Fatalf("failed to commit data #%d: %v", index, err) + batch := diskdb.NewBatch() + if err := sched.Commit(batch); err != nil { + t.Fatalf("failed to commit data: %v", err) } + batch.Write() for _, result := range results { added = append(added, result.Hash) } diff --git a/vendor/github.com/cloudflare/cloudflare-go/CODE_OF_CONDUCT.md b/vendor/github.com/cloudflare/cloudflare-go/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..bfbc69d22 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/CODE_OF_CONDUCT.md @@ -0,0 +1,77 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at ggalow@cloudflare.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/vendor/github.com/cloudflare/cloudflare-go/LICENSE b/vendor/github.com/cloudflare/cloudflare-go/LICENSE new file mode 100644 index 000000000..33865c30f --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2015-2019, Cloudflare. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/cloudflare/cloudflare-go/README.md b/vendor/github.com/cloudflare/cloudflare-go/README.md new file mode 100644 index 000000000..8f5d77b74 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/README.md @@ -0,0 +1,107 @@ +# cloudflare-go + +[![GoDoc](https://img.shields.io/badge/godoc-reference-5673AF.svg?style=flat-square)](https://godoc.org/github.com/cloudflare/cloudflare-go) +[![Build Status](https://img.shields.io/travis/cloudflare/cloudflare-go/master.svg?style=flat-square)](https://travis-ci.org/cloudflare/cloudflare-go) +[![Go Report Card](https://goreportcard.com/badge/github.com/cloudflare/cloudflare-go?style=flat-square)](https://goreportcard.com/report/github.com/cloudflare/cloudflare-go) + +> **Note**: This library is under active development as we expand it to cover +> our (expanding!) API. Consider the public API of this package a little +> unstable as we work towards a v1.0. + +A Go library for interacting with +[Cloudflare's API v4](https://api.cloudflare.com/). This library allows you to: + +* Manage and automate changes to your DNS records within Cloudflare +* Manage and automate changes to your zones (domains) on Cloudflare, including + adding new zones to your account +* List and modify the status of WAF (Web Application Firewall) rules for your + zones +* Fetch Cloudflare's IP ranges for automating your firewall whitelisting + +A command-line client, [flarectl](cmd/flarectl), is also available as part of +this project. + +## Features + +The current feature list includes: + +* [x] Cache purging +* [x] Cloudflare IPs +* [x] Custom hostnames +* [x] DNS Records +* [x] Firewall (partial) +* [ ] [Keyless SSL](https://blog.cloudflare.com/keyless-ssl-the-nitty-gritty-technical-details/) +* [x] [Load Balancing](https://blog.cloudflare.com/introducing-load-balancing-intelligent-failover-with-cloudflare/) +* [x] [Logpush Jobs](https://developers.cloudflare.com/logs/logpush/) +* [ ] Organization Administration +* [x] [Origin CA](https://blog.cloudflare.com/universal-ssl-encryption-all-the-way-to-the-origin-for-free/) +* [x] [Railgun](https://www.cloudflare.com/railgun/) administration +* [x] Rate Limiting +* [x] User Administration (partial) +* [x] Virtual DNS Management +* [x] Web Application Firewall (WAF) +* [x] Zone Lockdown and User-Agent Block rules +* [x] Zones + +Pull Requests are welcome, but please open an issue (or comment in an existing +issue) to discuss any non-trivial changes before submitting code. + +## Installation + +You need a working Go environment. + +``` +go get github.com/cloudflare/cloudflare-go +``` + +## Getting Started + +```go +package main + +import ( + "fmt" + "log" + "os" + + "github.com/cloudflare/cloudflare-go" +) + +func main() { + // Construct a new API object + api, err := cloudflare.New(os.Getenv("CF_API_KEY"), os.Getenv("CF_API_EMAIL")) + if err != nil { + log.Fatal(err) + } + + // Fetch user details on the account + u, err := api.UserDetails() + if err != nil { + log.Fatal(err) + } + // Print user details + fmt.Println(u) + + // Fetch the zone ID + id, err := api.ZoneIDByName("example.com") // Assuming example.com exists in your Cloudflare account already + if err != nil { + log.Fatal(err) + } + + // Fetch zone details + zone, err := api.ZoneDetails(id) + if err != nil { + log.Fatal(err) + } + // Print zone details + fmt.Println(zone) +} +``` + +Also refer to the +[API documentation](https://godoc.org/github.com/cloudflare/cloudflare-go) for +how to use this package in-depth. + +# License + +BSD licensed. See the [LICENSE](LICENSE) file for details. diff --git a/vendor/github.com/cloudflare/cloudflare-go/access_application.go b/vendor/github.com/cloudflare/cloudflare-go/access_application.go new file mode 100644 index 000000000..0893c5681 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/access_application.go @@ -0,0 +1,180 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// AccessApplication represents an Access application. +type AccessApplication struct { + ID string `json:"id,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + AUD string `json:"aud,omitempty"` + Name string `json:"name"` + Domain string `json:"domain"` + SessionDuration string `json:"session_duration,omitempty"` +} + +// AccessApplicationListResponse represents the response from the list +// access applications endpoint. +type AccessApplicationListResponse struct { + Result []AccessApplication `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccessApplicationDetailResponse is the API response, containing a single +// access application. +type AccessApplicationDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessApplication `json:"result"` +} + +// AccessApplications returns all applications within a zone. +// +// API reference: https://api.cloudflare.com/#access-applications-list-access-applications +func (api *API) AccessApplications(zoneID string, pageOpts PaginationOptions) ([]AccessApplication, ResultInfo, error) { + v := url.Values{} + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + uri := "/zones/" + zoneID + "/access/apps" + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccessApplication{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accessApplicationListResponse AccessApplicationListResponse + err = json.Unmarshal(res, &accessApplicationListResponse) + if err != nil { + return []AccessApplication{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accessApplicationListResponse.Result, accessApplicationListResponse.ResultInfo, nil +} + +// AccessApplication returns a single application based on the +// application ID. +// +// API reference: https://api.cloudflare.com/#access-applications-access-applications-details +func (api *API) AccessApplication(zoneID, applicationID string) (AccessApplication, error) { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s", + zoneID, + applicationID, + ) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var accessApplicationDetailResponse AccessApplicationDetailResponse + err = json.Unmarshal(res, &accessApplicationDetailResponse) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return accessApplicationDetailResponse.Result, nil +} + +// CreateAccessApplication creates a new access application. +// +// API reference: https://api.cloudflare.com/#access-applications-create-access-application +func (api *API) CreateAccessApplication(zoneID string, accessApplication AccessApplication) (AccessApplication, error) { + uri := "/zones/" + zoneID + "/access/apps" + + res, err := api.makeRequest("POST", uri, accessApplication) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var accessApplicationDetailResponse AccessApplicationDetailResponse + err = json.Unmarshal(res, &accessApplicationDetailResponse) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return accessApplicationDetailResponse.Result, nil +} + +// UpdateAccessApplication updates an existing access application. +// +// API reference: https://api.cloudflare.com/#access-applications-update-access-application +func (api *API) UpdateAccessApplication(zoneID string, accessApplication AccessApplication) (AccessApplication, error) { + if accessApplication.ID == "" { + return AccessApplication{}, errors.Errorf("access application ID cannot be empty") + } + + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s", + zoneID, + accessApplication.ID, + ) + + res, err := api.makeRequest("PUT", uri, accessApplication) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var accessApplicationDetailResponse AccessApplicationDetailResponse + err = json.Unmarshal(res, &accessApplicationDetailResponse) + if err != nil { + return AccessApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return accessApplicationDetailResponse.Result, nil +} + +// DeleteAccessApplication deletes an access application. +// +// API reference: https://api.cloudflare.com/#access-applications-delete-access-application +func (api *API) DeleteAccessApplication(zoneID, applicationID string) error { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s", + zoneID, + applicationID, + ) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} + +// RevokeAccessApplicationTokens revokes tokens associated with an +// access application. +// +// API reference: https://api.cloudflare.com/#access-applications-revoke-access-tokens +func (api *API) RevokeAccessApplicationTokens(zoneID, applicationID string) error { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/revoke-tokens", + zoneID, + applicationID, + ) + + _, err := api.makeRequest("POST", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/access_identity_provider.go b/vendor/github.com/cloudflare/cloudflare-go/access_identity_provider.go new file mode 100644 index 000000000..b41ed8ff0 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/access_identity_provider.go @@ -0,0 +1,331 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + + "github.com/pkg/errors" +) + +// AccessIdentityProvider is the structure of the provider object. +type AccessIdentityProvider struct { + ID string `json:"id,omitemtpy"` + Name string `json:"name"` + Type string `json:"type"` + Config interface{} `json:"config"` +} + +// AccessAzureADConfiguration is the representation of the Azure AD identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/azuread/ +type AccessAzureADConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + DirectoryID string `json:"directory_id"` + SupportGroups bool `json:"support_groups"` +} + +// AccessCentrifyConfiguration is the representation of the Centrify identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/centrify/ +type AccessCentrifyConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + CentrifyAccount string `json:"centrify_account"` + CentrifyAppID string `json:"centrify_app_id"` +} + +// AccessCentrifySAMLConfiguration is the representation of the Centrify +// identity provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/saml-centrify/ +type AccessCentrifySAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessFacebookConfiguration is the representation of the Facebook identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/facebook-login/ +type AccessFacebookConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// AccessGSuiteConfiguration is the representation of the GSuite identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/gsuite/ +type AccessGSuiteConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + AppsDomain string `json:"apps_domain"` +} + +// AccessGenericOIDCConfiguration is the representation of the generic OpenID +// Connect (OIDC) connector. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/generic-oidc/ +type AccessGenericOIDCConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + AuthURL string `json:"auth_url"` + TokenURL string `json:"token_url"` + CertsURL string `json:"certs_url"` +} + +// AccessGitHubConfiguration is the representation of the GitHub identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/github/ +type AccessGitHubConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// AccessGoogleConfiguration is the representation of the Google identity +// provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/google/ +type AccessGoogleConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// AccessJumpCloudSAMLConfiguration is the representation of the Jump Cloud +// identity provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/jumpcloud-saml/ +type AccessJumpCloudSAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessOktaConfiguration is the representation of the Okta identity provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/okta/ +type AccessOktaConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + OktaAccount string `json:"okta_account"` +} + +// AccessOktaSAMLConfiguration is the representation of the Okta identity +// provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/saml-okta/ +type AccessOktaSAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessOneTimePinConfiguration is the representation of the default One Time +// Pin identity provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/one-time-pin/ +type AccessOneTimePinConfiguration struct{} + +// AccessOneLoginOIDCConfiguration is the representation of the OneLogin +// OpenID connector as an identity provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/onelogin-oidc/ +type AccessOneLoginOIDCConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + OneloginAccount string `json:"onelogin_account"` +} + +// AccessOneLoginSAMLConfiguration is the representation of the OneLogin +// identity provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/onelogin-saml/ +type AccessOneLoginSAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessPingSAMLConfiguration is the representation of the Ping identity +// provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/ping-saml/ +type AccessPingSAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessYandexConfiguration is the representation of the Yandex identity provider. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/yandex/ +type AccessYandexConfiguration struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// AccessADSAMLConfiguration is the representation of the Active Directory +// identity provider using SAML. +// +// API reference: https://developers.cloudflare.com/access/configuring-identity-providers/adfs/ +type AccessADSAMLConfiguration struct { + IssuerURL string `json:"issuer_url"` + SsoTargetURL string `json:"sso_target_url"` + Attributes []string `json:"attributes"` + EmailAttributeName string `json:"email_attribute_name"` + SignRequest bool `json:"sign_request"` + IdpPublicCert string `json:"idp_public_cert"` +} + +// AccessIdentityProvidersListResponse is the API response for multiple +// Access Identity Providers. +type AccessIdentityProvidersListResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result []AccessIdentityProvider `json:"result"` +} + +// AccessIdentityProviderListResponse is the API response for a single +// Access Identity Provider. +type AccessIdentityProviderListResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessIdentityProvider `json:"result"` +} + +// AccessIdentityProviders returns all Access Identity Providers for an +// account. +// +// API reference: https://api.cloudflare.com/#access-identity-providers-list-access-identity-providers +func (api *API) AccessIdentityProviders(accountID string) ([]AccessIdentityProvider, error) { + uri := "/accounts/" + accountID + "/access/identity_providers" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccessIdentityProvider{}, errors.Wrap(err, errMakeRequestError) + } + + var accessIdentityProviderResponse AccessIdentityProvidersListResponse + err = json.Unmarshal(res, &accessIdentityProviderResponse) + if err != nil { + return []AccessIdentityProvider{}, errors.Wrap(err, errUnmarshalError) + } + + return accessIdentityProviderResponse.Result, nil +} + +// AccessIdentityProviderDetails returns a single Access Identity +// Provider for an account. +// +// API reference: https://api.cloudflare.com/#access-identity-providers-access-identity-providers-details +func (api *API) AccessIdentityProviderDetails(accountID, identityProviderID string) (AccessIdentityProvider, error) { + uri := fmt.Sprintf( + "/accounts/%s/access/identity_providers/%s", + accountID, + identityProviderID, + ) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errMakeRequestError) + } + + var accessIdentityProviderResponse AccessIdentityProviderListResponse + err = json.Unmarshal(res, &accessIdentityProviderResponse) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errUnmarshalError) + } + + return accessIdentityProviderResponse.Result, nil +} + +// CreateAccessIdentityProvider creates a new Access Identity Provider. +// +// API reference: https://api.cloudflare.com/#access-identity-providers-create-access-identity-provider +func (api *API) CreateAccessIdentityProvider(accountID string, identityProviderConfiguration AccessIdentityProvider) (AccessIdentityProvider, error) { + uri := "/accounts/" + accountID + "/access/identity_providers" + + res, err := api.makeRequest("POST", uri, identityProviderConfiguration) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errMakeRequestError) + } + + var accessIdentityProviderResponse AccessIdentityProviderListResponse + err = json.Unmarshal(res, &accessIdentityProviderResponse) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errUnmarshalError) + } + + return accessIdentityProviderResponse.Result, nil +} + +// UpdateAccessIdentityProvider updates an existing Access Identity +// Provider. +// +// API reference: https://api.cloudflare.com/#access-identity-providers-create-access-identity-provider +func (api *API) UpdateAccessIdentityProvider(accountID, identityProviderUUID string, identityProviderConfiguration AccessIdentityProvider) (AccessIdentityProvider, error) { + uri := fmt.Sprintf( + "/accounts/%s/access/identity_providers/%s", + accountID, + identityProviderUUID, + ) + + res, err := api.makeRequest("PUT", uri, identityProviderConfiguration) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errMakeRequestError) + } + + var accessIdentityProviderResponse AccessIdentityProviderListResponse + err = json.Unmarshal(res, &accessIdentityProviderResponse) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errUnmarshalError) + } + + return accessIdentityProviderResponse.Result, nil +} + +// DeleteAccessIdentityProvider deletes an Access Identity Provider. +// +// API reference: https://api.cloudflare.com/#access-identity-providers-create-access-identity-provider +func (api *API) DeleteAccessIdentityProvider(accountID, identityProviderUUID string) (AccessIdentityProvider, error) { + uri := fmt.Sprintf( + "/accounts/%s/access/identity_providers/%s", + accountID, + identityProviderUUID, + ) + + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errMakeRequestError) + } + + var accessIdentityProviderResponse AccessIdentityProviderListResponse + err = json.Unmarshal(res, &accessIdentityProviderResponse) + if err != nil { + return AccessIdentityProvider{}, errors.Wrap(err, errUnmarshalError) + } + + return accessIdentityProviderResponse.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/access_organization.go b/vendor/github.com/cloudflare/cloudflare-go/access_organization.go new file mode 100644 index 000000000..5bc4a16aa --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/access_organization.go @@ -0,0 +1,101 @@ +package cloudflare + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// AccessOrganization represents an Access organization. +type AccessOrganization struct { + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + Name string `json:"name"` + AuthDomain string `json:"auth_domain"` + LoginDesign AccessOrganizationLoginDesign `json:"login_design"` +} + +// AccessOrganizationLoginDesign represents the login design options. +type AccessOrganizationLoginDesign struct { + BackgroundColor string `json:"background_color"` + TextColor string `json:"text_color"` + LogoPath string `json:"logo_path"` +} + +// AccessOrganizationListResponse represents the response from the list +// access organization endpoint. +type AccessOrganizationListResponse struct { + Result AccessOrganization `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccessOrganizationDetailResponse is the API response, containing a +// single access organization. +type AccessOrganizationDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessOrganization `json:"result"` +} + +// AccessOrganization returns the Access organisation details. +// +// API reference: https://api.cloudflare.com/#access-organizations-access-organization-details +func (api *API) AccessOrganization(accountID string) (AccessOrganization, ResultInfo, error) { + uri := "/accounts/" + accountID + "/access/organizations" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccessOrganization{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accessOrganizationListResponse AccessOrganizationListResponse + err = json.Unmarshal(res, &accessOrganizationListResponse) + if err != nil { + return AccessOrganization{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accessOrganizationListResponse.Result, accessOrganizationListResponse.ResultInfo, nil +} + +// CreateAccessOrganization creates the Access organisation details. +// +// API reference: https://api.cloudflare.com/#access-organizations-create-access-organization +func (api *API) CreateAccessOrganization(accountID string, accessOrganization AccessOrganization) (AccessOrganization, error) { + uri := "/accounts/" + accountID + "/access/organizations" + + res, err := api.makeRequest("POST", uri, accessOrganization) + if err != nil { + return AccessOrganization{}, errors.Wrap(err, errMakeRequestError) + } + + var accessOrganizationDetailResponse AccessOrganizationDetailResponse + err = json.Unmarshal(res, &accessOrganizationDetailResponse) + if err != nil { + return AccessOrganization{}, errors.Wrap(err, errUnmarshalError) + } + + return accessOrganizationDetailResponse.Result, nil +} + +// UpdateAccessOrganization creates the Access organisation details. +// +// API reference: https://api.cloudflare.com/#access-organizations-update-access-organization +func (api *API) UpdateAccessOrganization(accountID string, accessOrganization AccessOrganization) (AccessOrganization, error) { + uri := "/accounts/" + accountID + "/access/organizations" + + res, err := api.makeRequest("PUT", uri, accessOrganization) + if err != nil { + return AccessOrganization{}, errors.Wrap(err, errMakeRequestError) + } + + var accessOrganizationDetailResponse AccessOrganizationDetailResponse + err = json.Unmarshal(res, &accessOrganizationDetailResponse) + if err != nil { + return AccessOrganization{}, errors.Wrap(err, errUnmarshalError) + } + + return accessOrganizationDetailResponse.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/access_policy.go b/vendor/github.com/cloudflare/cloudflare-go/access_policy.go new file mode 100644 index 000000000..dbf63e49f --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/access_policy.go @@ -0,0 +1,221 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// AccessPolicy defines a policy for allowing or disallowing access to +// one or more Access applications. +type AccessPolicy struct { + ID string `json:"id,omitempty"` + Precedence int `json:"precedence"` + Decision string `json:"decision"` + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + Name string `json:"name"` + + // The include policy works like an OR logical operator. The user must + // satisfy one of the rules. + Include []interface{} `json:"include"` + + // The exclude policy works like a NOT logical operator. The user must + // not satisfy all of the rules in exclude. + Exclude []interface{} `json:"exclude"` + + // The require policy works like a AND logical operator. The user must + // satisfy all of the rules in require. + Require []interface{} `json:"require"` +} + +// AccessPolicyEmail is used for managing access based on the email. +// For example, restrict access to users with the email addresses +// `test@example.com` or `someone@example.com`. +type AccessPolicyEmail struct { + Email struct { + Email string `json:"email"` + } `json:"email"` +} + +// AccessPolicyEmailDomain is used for managing access based on an email +// domain domain such as `example.com` instead of individual addresses. +type AccessPolicyEmailDomain struct { + EmailDomain struct { + Domain string `json:"domain"` + } `json:"email_domain"` +} + +// AccessPolicyIP is used for managing access based in the IP. It +// accepts individual IPs or CIDRs. +type AccessPolicyIP struct { + IP struct { + IP string `json:"ip"` + } `json:"ip"` +} + +// AccessPolicyEveryone is used for managing access to everyone. +type AccessPolicyEveryone struct { + Everyone struct{} `json:"everyone"` +} + +// AccessPolicyAccessGroup is used for managing access based on an +// access group. +type AccessPolicyAccessGroup struct { + Group struct { + ID string `json:"id"` + } `json:"group"` +} + +// AccessPolicyListResponse represents the response from the list +// access polciies endpoint. +type AccessPolicyListResponse struct { + Result []AccessPolicy `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccessPolicyDetailResponse is the API response, containing a single +// access policy. +type AccessPolicyDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessPolicy `json:"result"` +} + +// AccessPolicies returns all access policies for an access application. +// +// API reference: https://api.cloudflare.com/#access-policy-list-access-policies +func (api *API) AccessPolicies(zoneID, applicationID string, pageOpts PaginationOptions) ([]AccessPolicy, ResultInfo, error) { + v := url.Values{} + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/policies", + zoneID, + applicationID, + ) + + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccessPolicy{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accessPolicyListResponse AccessPolicyListResponse + err = json.Unmarshal(res, &accessPolicyListResponse) + if err != nil { + return []AccessPolicy{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accessPolicyListResponse.Result, accessPolicyListResponse.ResultInfo, nil +} + +// AccessPolicy returns a single policy based on the policy ID. +// +// API reference: https://api.cloudflare.com/#access-policy-access-policy-details +func (api *API) AccessPolicy(zoneID, applicationID, policyID string) (AccessPolicy, error) { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/policies/%s", + zoneID, + applicationID, + policyID, + ) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errMakeRequestError) + } + + var accessPolicyDetailResponse AccessPolicyDetailResponse + err = json.Unmarshal(res, &accessPolicyDetailResponse) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errUnmarshalError) + } + + return accessPolicyDetailResponse.Result, nil +} + +// CreateAccessPolicy creates a new access policy. +// +// API reference: https://api.cloudflare.com/#access-policy-create-access-policy +func (api *API) CreateAccessPolicy(zoneID, applicationID string, accessPolicy AccessPolicy) (AccessPolicy, error) { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/policies", + zoneID, + applicationID, + ) + + res, err := api.makeRequest("POST", uri, accessPolicy) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errMakeRequestError) + } + + var accessPolicyDetailResponse AccessPolicyDetailResponse + err = json.Unmarshal(res, &accessPolicyDetailResponse) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errUnmarshalError) + } + + return accessPolicyDetailResponse.Result, nil +} + +// UpdateAccessPolicy updates an existing access policy. +// +// API reference: https://api.cloudflare.com/#access-policy-update-access-policy +func (api *API) UpdateAccessPolicy(zoneID, applicationID string, accessPolicy AccessPolicy) (AccessPolicy, error) { + if accessPolicy.ID == "" { + return AccessPolicy{}, errors.Errorf("access policy ID cannot be empty") + } + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/policies/%s", + zoneID, + applicationID, + accessPolicy.ID, + ) + + res, err := api.makeRequest("PUT", uri, accessPolicy) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errMakeRequestError) + } + + var accessPolicyDetailResponse AccessPolicyDetailResponse + err = json.Unmarshal(res, &accessPolicyDetailResponse) + if err != nil { + return AccessPolicy{}, errors.Wrap(err, errUnmarshalError) + } + + return accessPolicyDetailResponse.Result, nil +} + +// DeleteAccessPolicy deletes an access policy. +// +// API reference: https://api.cloudflare.com/#access-policy-update-access-policy +func (api *API) DeleteAccessPolicy(zoneID, applicationID, accessPolicyID string) error { + uri := fmt.Sprintf( + "/zones/%s/access/apps/%s/policies/%s", + zoneID, + applicationID, + accessPolicyID, + ) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/access_service_tokens.go b/vendor/github.com/cloudflare/cloudflare-go/access_service_tokens.go new file mode 100644 index 000000000..66a2bb794 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/access_service_tokens.go @@ -0,0 +1,167 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +// AccessServiceToken represents an Access Service Token. +type AccessServiceToken struct { + ClientID string `json:"client_id"` + CreatedAt *time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at"` + ID string `json:"id"` + Name string `json:"name"` + UpdatedAt *time.Time `json:"updated_at"` +} + +// AccessServiceTokenUpdateResponse represents the response from the API +// when a new Service Token is updated. This base struct is also used in the +// Create as they are very similar responses. +type AccessServiceTokenUpdateResponse struct { + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + ClientID string `json:"client_id"` +} + +// AccessServiceTokenCreateResponse is the same API response as the Update +// operation with the exception that the `ClientSecret` is present in a +// Create operation. +type AccessServiceTokenCreateResponse struct { + CreatedAt *time.Time `json:"created_at"` + UpdatedAt *time.Time `json:"updated_at"` + ID string `json:"id"` + Name string `json:"name"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// AccessServiceTokensListResponse represents the response from the list +// Access Service Tokens endpoint. +type AccessServiceTokensListResponse struct { + Result []AccessServiceToken `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccessServiceTokensDetailResponse is the API response, containing a single +// Access Service Token. +type AccessServiceTokensDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessServiceToken `json:"result"` +} + +// AccessServiceTokensCreationDetailResponse is the API response, containing a +// single Access Service Token. +type AccessServiceTokensCreationDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessServiceTokenCreateResponse `json:"result"` +} + +// AccessServiceTokensUpdateDetailResponse is the API response, containing a +// single Access Service Token. +type AccessServiceTokensUpdateDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccessServiceTokenUpdateResponse `json:"result"` +} + +// AccessServiceTokens returns all Access Service Tokens for an account. +// +// API reference: https://api.cloudflare.com/#access-service-tokens-list-access-service-tokens +func (api *API) AccessServiceTokens(accountID string) ([]AccessServiceToken, ResultInfo, error) { + uri := "/accounts/" + accountID + "/access/service_tokens" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccessServiceToken{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accessServiceTokensListResponse AccessServiceTokensListResponse + err = json.Unmarshal(res, &accessServiceTokensListResponse) + if err != nil { + return []AccessServiceToken{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accessServiceTokensListResponse.Result, accessServiceTokensListResponse.ResultInfo, nil +} + +// CreateAccessServiceToken creates a new Access Service Token for an account. +// +// API reference: https://api.cloudflare.com/#access-service-tokens-create-access-service-token +func (api *API) CreateAccessServiceToken(accountID, name string) (AccessServiceTokenCreateResponse, error) { + uri := "/accounts/" + accountID + "/access/service_tokens" + marshalledName, _ := json.Marshal(struct { + Name string `json:"name"` + }{name}) + + res, err := api.makeRequest("POST", uri, marshalledName) + + if err != nil { + return AccessServiceTokenCreateResponse{}, errors.Wrap(err, errMakeRequestError) + } + + var accessServiceTokenCreation AccessServiceTokensCreationDetailResponse + err = json.Unmarshal(res, &accessServiceTokenCreation) + if err != nil { + return AccessServiceTokenCreateResponse{}, errors.Wrap(err, errUnmarshalError) + } + + return accessServiceTokenCreation.Result, nil +} + +// UpdateAccessServiceToken updates an existing Access Service Token for an +// account. +// +// API reference: https://api.cloudflare.com/#access-service-tokens-update-access-service-token +func (api *API) UpdateAccessServiceToken(accountID, uuid, name string) (AccessServiceTokenUpdateResponse, error) { + uri := fmt.Sprintf("/accounts/%s/access/service_tokens/%s", accountID, uuid) + + marshalledName, _ := json.Marshal(struct { + Name string `json:"name"` + }{name}) + + res, err := api.makeRequest("PUT", uri, marshalledName) + if err != nil { + return AccessServiceTokenUpdateResponse{}, errors.Wrap(err, errMakeRequestError) + } + + var accessServiceTokenUpdate AccessServiceTokensUpdateDetailResponse + err = json.Unmarshal(res, &accessServiceTokenUpdate) + if err != nil { + return AccessServiceTokenUpdateResponse{}, errors.Wrap(err, errUnmarshalError) + } + + return accessServiceTokenUpdate.Result, nil +} + +// DeleteAccessServiceToken removes an existing Access Service Token for an +// account. +// +// API reference: https://api.cloudflare.com/#access-service-tokens-delete-access-service-token +func (api *API) DeleteAccessServiceToken(accountID, uuid string) (AccessServiceTokenUpdateResponse, error) { + uri := fmt.Sprintf("/accounts/%s/access/service_tokens/%s", accountID, uuid) + + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return AccessServiceTokenUpdateResponse{}, errors.Wrap(err, errMakeRequestError) + } + + var accessServiceTokenUpdate AccessServiceTokensUpdateDetailResponse + err = json.Unmarshal(res, &accessServiceTokenUpdate) + if err != nil { + return AccessServiceTokenUpdateResponse{}, errors.Wrap(err, errUnmarshalError) + } + + return accessServiceTokenUpdate.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/account_members.go b/vendor/github.com/cloudflare/cloudflare-go/account_members.go new file mode 100644 index 000000000..42166e922 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/account_members.go @@ -0,0 +1,186 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// AccountMember is the definition of a member of an account. +type AccountMember struct { + ID string `json:"id"` + Code string `json:"code"` + User AccountMemberUserDetails `json:"user"` + Status string `json:"status"` + Roles []AccountRole `json:"roles"` +} + +// AccountMemberUserDetails outlines all the personal information about +// a member. +type AccountMemberUserDetails struct { + ID string `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + TwoFactorAuthenticationEnabled bool +} + +// AccountMembersListResponse represents the response from the list +// account members endpoint. +type AccountMembersListResponse struct { + Result []AccountMember `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccountMemberDetailResponse is the API response, containing a single +// account member. +type AccountMemberDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccountMember `json:"result"` +} + +// AccountMemberInvitation represents the invitation for a new member to +// the account. +type AccountMemberInvitation struct { + Email string `json:"email"` + Roles []string `json:"roles"` +} + +// AccountMembers returns all members of an account. +// +// API reference: https://api.cloudflare.com/#accounts-list-accounts +func (api *API) AccountMembers(accountID string, pageOpts PaginationOptions) ([]AccountMember, ResultInfo, error) { + if accountID == "" { + return []AccountMember{}, ResultInfo{}, errors.New(errMissingAccountID) + } + + v := url.Values{} + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + uri := "/accounts/" + accountID + "/members" + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccountMember{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accountMemberListresponse AccountMembersListResponse + err = json.Unmarshal(res, &accountMemberListresponse) + if err != nil { + return []AccountMember{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accountMemberListresponse.Result, accountMemberListresponse.ResultInfo, nil +} + +// CreateAccountMember invites a new member to join an account. +// +// API reference: https://api.cloudflare.com/#account-members-add-member +func (api *API) CreateAccountMember(accountID string, emailAddress string, roles []string) (AccountMember, error) { + if accountID == "" { + return AccountMember{}, errors.New(errMissingAccountID) + } + + uri := "/accounts/" + accountID + "/members" + + var newMember = AccountMemberInvitation{ + Email: emailAddress, + Roles: roles, + } + res, err := api.makeRequest("POST", uri, newMember) + if err != nil { + return AccountMember{}, errors.Wrap(err, errMakeRequestError) + } + + var accountMemberListResponse AccountMemberDetailResponse + err = json.Unmarshal(res, &accountMemberListResponse) + if err != nil { + return AccountMember{}, errors.Wrap(err, errUnmarshalError) + } + + return accountMemberListResponse.Result, nil +} + +// DeleteAccountMember removes a member from an account. +// +// API reference: https://api.cloudflare.com/#account-members-remove-member +func (api *API) DeleteAccountMember(accountID string, userID string) error { + if accountID == "" { + return errors.New(errMissingAccountID) + } + + uri := fmt.Sprintf("/accounts/%s/members/%s", accountID, userID) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} + +// UpdateAccountMember modifies an existing account member. +// +// API reference: https://api.cloudflare.com/#account-members-update-member +func (api *API) UpdateAccountMember(accountID string, userID string, member AccountMember) (AccountMember, error) { + if accountID == "" { + return AccountMember{}, errors.New(errMissingAccountID) + } + + uri := fmt.Sprintf("/accounts/%s/members/%s", accountID, userID) + + res, err := api.makeRequest("PUT", uri, member) + if err != nil { + return AccountMember{}, errors.Wrap(err, errMakeRequestError) + } + + var accountMemberListResponse AccountMemberDetailResponse + err = json.Unmarshal(res, &accountMemberListResponse) + if err != nil { + return AccountMember{}, errors.Wrap(err, errUnmarshalError) + } + + return accountMemberListResponse.Result, nil +} + +// AccountMember returns details of a single account member. +// +// API reference: https://api.cloudflare.com/#account-members-member-details +func (api *API) AccountMember(accountID string, memberID string) (AccountMember, error) { + if accountID == "" { + return AccountMember{}, errors.New(errMissingAccountID) + } + + uri := fmt.Sprintf( + "/accounts/%s/members/%s", + accountID, + memberID, + ) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccountMember{}, errors.Wrap(err, errMakeRequestError) + } + + var accountMemberResponse AccountMemberDetailResponse + err = json.Unmarshal(res, &accountMemberResponse) + if err != nil { + return AccountMember{}, errors.Wrap(err, errUnmarshalError) + } + + return accountMemberResponse.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/account_roles.go b/vendor/github.com/cloudflare/cloudflare-go/account_roles.go new file mode 100644 index 000000000..3704313b5 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/account_roles.go @@ -0,0 +1,80 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + + "github.com/pkg/errors" +) + +// AccountRole defines the roles that a member can have attached. +type AccountRole struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Permissions map[string]AccountRolePermission `json:"permissions"` +} + +// AccountRolePermission is the shared structure for all permissions +// that can be assigned to a member. +type AccountRolePermission struct { + Read bool `json:"read"` + Edit bool `json:"edit"` +} + +// AccountRolesListResponse represents the list response from the +// account roles. +type AccountRolesListResponse struct { + Result []AccountRole `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccountRoleDetailResponse is the API response, containing a single +// account role. +type AccountRoleDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result AccountRole `json:"result"` +} + +// AccountRoles returns all roles of an account. +// +// API reference: https://api.cloudflare.com/#account-roles-list-roles +func (api *API) AccountRoles(accountID string) ([]AccountRole, error) { + uri := "/accounts/" + accountID + "/roles" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []AccountRole{}, errors.Wrap(err, errMakeRequestError) + } + + var accountRolesListResponse AccountRolesListResponse + err = json.Unmarshal(res, &accountRolesListResponse) + if err != nil { + return []AccountRole{}, errors.Wrap(err, errUnmarshalError) + } + + return accountRolesListResponse.Result, nil +} + +// AccountRole returns the details of a single account role. +// +// API reference: https://api.cloudflare.com/#account-roles-role-details +func (api *API) AccountRole(accountID string, roleID string) (AccountRole, error) { + uri := fmt.Sprintf("/accounts/%s/roles/%s", accountID, roleID) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AccountRole{}, errors.Wrap(err, errMakeRequestError) + } + + var accountRole AccountRoleDetailResponse + err = json.Unmarshal(res, &accountRole) + if err != nil { + return AccountRole{}, errors.Wrap(err, errUnmarshalError) + } + + return accountRole.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/accounts.go b/vendor/github.com/cloudflare/cloudflare-go/accounts.go new file mode 100644 index 000000000..7d34b7b8f --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/accounts.go @@ -0,0 +1,114 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// AccountSettings outlines the available options for an account. +type AccountSettings struct { + EnforceTwoFactor bool `json:"enforce_twofactor"` +} + +// Account represents the root object that owns resources. +type Account struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Settings *AccountSettings `json:"settings"` +} + +// AccountResponse represents the response from the accounts endpoint for a +// single account ID. +type AccountResponse struct { + Result Account `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccountListResponse represents the response from the list accounts endpoint. +type AccountListResponse struct { + Result []Account `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccountDetailResponse is the API response, containing a single Account. +type AccountDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result Account `json:"result"` +} + +// Accounts returns all accounts the logged in user has access to. +// +// API reference: https://api.cloudflare.com/#accounts-list-accounts +func (api *API) Accounts(pageOpts PaginationOptions) ([]Account, ResultInfo, error) { + v := url.Values{} + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + uri := "/accounts" + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []Account{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accListResponse AccountListResponse + err = json.Unmarshal(res, &accListResponse) + if err != nil { + return []Account{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + return accListResponse.Result, accListResponse.ResultInfo, nil +} + +// Account returns a single account based on the ID. +// +// API reference: https://api.cloudflare.com/#accounts-account-details +func (api *API) Account(accountID string) (Account, ResultInfo, error) { + uri := "/accounts/" + accountID + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return Account{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var accResponse AccountResponse + err = json.Unmarshal(res, &accResponse) + if err != nil { + return Account{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + + return accResponse.Result, accResponse.ResultInfo, nil +} + +// UpdateAccount allows management of an account using the account ID. +// +// API reference: https://api.cloudflare.com/#accounts-update-account +func (api *API) UpdateAccount(accountID string, account Account) (Account, error) { + uri := "/accounts/" + accountID + + res, err := api.makeRequest("PUT", uri, account) + if err != nil { + return Account{}, errors.Wrap(err, errMakeRequestError) + } + + var a AccountDetailResponse + err = json.Unmarshal(res, &a) + if err != nil { + return Account{}, errors.Wrap(err, errUnmarshalError) + } + + return a.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/argo.go b/vendor/github.com/cloudflare/cloudflare-go/argo.go new file mode 100644 index 000000000..320c7fc25 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/argo.go @@ -0,0 +1,120 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +var validSettingValues = []string{"on", "off"} + +// ArgoFeatureSetting is the structure of the API object for the +// argo smart routing and tiered caching settings. +type ArgoFeatureSetting struct { + Editable bool `json:"editable,omitempty"` + ID string `json:"id,omitempty"` + ModifiedOn time.Time `json:"modified_on,omitempty"` + Value string `json:"value"` +} + +// ArgoDetailsResponse is the API response for the argo smart routing +// and tiered caching response. +type ArgoDetailsResponse struct { + Result ArgoFeatureSetting `json:"result"` + Response +} + +// ArgoSmartRouting returns the current settings for smart routing. +// +// API reference: https://api.cloudflare.com/#argo-smart-routing-get-argo-smart-routing-setting +func (api *API) ArgoSmartRouting(zoneID string) (ArgoFeatureSetting, error) { + uri := "/zones/" + zoneID + "/argo/smart_routing" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errMakeRequestError) + } + + var argoDetailsResponse ArgoDetailsResponse + err = json.Unmarshal(res, &argoDetailsResponse) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errUnmarshalError) + } + return argoDetailsResponse.Result, nil +} + +// UpdateArgoSmartRouting updates the setting for smart routing. +// +// API reference: https://api.cloudflare.com/#argo-smart-routing-patch-argo-smart-routing-setting +func (api *API) UpdateArgoSmartRouting(zoneID, settingValue string) (ArgoFeatureSetting, error) { + if !contains(validSettingValues, settingValue) { + return ArgoFeatureSetting{}, errors.New(fmt.Sprintf("invalid setting value '%s'. must be 'on' or 'off'", settingValue)) + } + + uri := "/zones/" + zoneID + "/argo/smart_routing" + + res, err := api.makeRequest("PATCH", uri, ArgoFeatureSetting{Value: settingValue}) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errMakeRequestError) + } + + var argoDetailsResponse ArgoDetailsResponse + err = json.Unmarshal(res, &argoDetailsResponse) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errUnmarshalError) + } + return argoDetailsResponse.Result, nil +} + +// ArgoTieredCaching returns the current settings for tiered caching. +// +// API reference: TBA +func (api *API) ArgoTieredCaching(zoneID string) (ArgoFeatureSetting, error) { + uri := "/zones/" + zoneID + "/argo/tiered_caching" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errMakeRequestError) + } + + var argoDetailsResponse ArgoDetailsResponse + err = json.Unmarshal(res, &argoDetailsResponse) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errUnmarshalError) + } + return argoDetailsResponse.Result, nil +} + +// UpdateArgoTieredCaching updates the setting for tiered caching. +// +// API reference: TBA +func (api *API) UpdateArgoTieredCaching(zoneID, settingValue string) (ArgoFeatureSetting, error) { + if !contains(validSettingValues, settingValue) { + return ArgoFeatureSetting{}, errors.New(fmt.Sprintf("invalid setting value '%s'. must be 'on' or 'off'", settingValue)) + } + + uri := "/zones/" + zoneID + "/argo/tiered_caching" + + res, err := api.makeRequest("PATCH", uri, ArgoFeatureSetting{Value: settingValue}) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errMakeRequestError) + } + + var argoDetailsResponse ArgoDetailsResponse + err = json.Unmarshal(res, &argoDetailsResponse) + if err != nil { + return ArgoFeatureSetting{}, errors.Wrap(err, errUnmarshalError) + } + return argoDetailsResponse.Result, nil +} + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/auditlogs.go b/vendor/github.com/cloudflare/cloudflare-go/auditlogs.go new file mode 100644 index 000000000..8cb8eab69 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/auditlogs.go @@ -0,0 +1,143 @@ +package cloudflare + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "time" +) + +// AuditLogAction is a member of AuditLog, the action that was taken. +type AuditLogAction struct { + Result bool `json:"result"` + Type string `json:"type"` +} + +// AuditLogActor is a member of AuditLog, who performed the action. +type AuditLogActor struct { + Email string `json:"email"` + ID string `json:"id"` + IP string `json:"ip"` + Type string `json:"type"` +} + +// AuditLogOwner is a member of AuditLog, who owns this audit log. +type AuditLogOwner struct { + ID string `json:"id"` +} + +// AuditLogResource is a member of AuditLog, what was the action performed on. +type AuditLogResource struct { + ID string `json:"id"` + Type string `json:"type"` +} + +// AuditLog is an resource that represents an update in the cloudflare dash +type AuditLog struct { + Action AuditLogAction `json:"action"` + Actor AuditLogActor `json:"actor"` + ID string `json:"id"` + Metadata map[string]interface{} `json:"metadata"` + NewValue string `json:"newValue"` + OldValue string `json:"oldValue"` + Owner AuditLogOwner `json:"owner"` + Resource AuditLogResource `json:"resource"` + When time.Time `json:"when"` +} + +// AuditLogResponse is the response returned from the cloudflare v4 api +type AuditLogResponse struct { + Response Response + Result []AuditLog `json:"result"` + ResultInfo `json:"result_info"` +} + +// AuditLogFilter is an object for filtering the audit log response from the api. +type AuditLogFilter struct { + ID string + ActorIP string + ActorEmail string + Direction string + ZoneName string + Since string + Before string + PerPage int + Page int +} + +// String turns an audit log filter in to an HTTP Query Param +// list. It will not inclue empty members of the struct in the +// query parameters. +func (a AuditLogFilter) String() string { + params := "?" + if a.ID != "" { + params += "&id=" + a.ID + } + if a.ActorIP != "" { + params += "&actor.ip=" + a.ActorIP + } + if a.ActorEmail != "" { + params += "&actor.email=" + a.ActorEmail + } + if a.ZoneName != "" { + params += "&zone.name=" + a.ZoneName + } + if a.Direction != "" { + params += "&direction=" + a.Direction + } + if a.Since != "" { + params += "&since=" + a.Since + } + if a.Before != "" { + params += "&before=" + a.Before + } + if a.PerPage > 0 { + params += "&per_page=" + fmt.Sprintf("%d", a.PerPage) + } + if a.Page > 0 { + params += "&page=" + fmt.Sprintf("%d", a.Page) + } + return params +} + +// GetOrganizationAuditLogs will return the audit logs of a specific +// organization, based on the ID passed in. The audit logs can be +// filtered based on any argument in the AuditLogFilter +// +// API Reference: https://api.cloudflare.com/#audit-logs-list-organization-audit-logs +func (api *API) GetOrganizationAuditLogs(organizationID string, a AuditLogFilter) (AuditLogResponse, error) { + uri := "/organizations/" + organizationID + "/audit_logs" + fmt.Sprintf("%s", a) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AuditLogResponse{}, err + } + buf, err := base64.RawStdEncoding.DecodeString(string(res)) + if err != nil { + return AuditLogResponse{}, err + } + return unmarshalReturn(buf) +} + +// unmarshalReturn will unmarshal bytes and return an auditlogresponse +func unmarshalReturn(res []byte) (AuditLogResponse, error) { + var auditResponse AuditLogResponse + err := json.Unmarshal(res, &auditResponse) + if err != nil { + return auditResponse, err + } + return auditResponse, nil +} + +// GetUserAuditLogs will return your user's audit logs. The audit logs can be +// filtered based on any argument in the AuditLogFilter +// +// API Reference: https://api.cloudflare.com/#audit-logs-list-user-audit-logs +func (api *API) GetUserAuditLogs(a AuditLogFilter) (AuditLogResponse, error) { + uri := "/user/audit_logs" + fmt.Sprintf("%s", a) + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return AuditLogResponse{}, err + } + return unmarshalReturn(res) +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/cloudflare.go b/vendor/github.com/cloudflare/cloudflare-go/cloudflare.go new file mode 100644 index 000000000..498b9f468 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/cloudflare.go @@ -0,0 +1,435 @@ +// Package cloudflare implements the Cloudflare v4 API. +package cloudflare + +import ( + "bytes" + "context" + "encoding/json" + "io" + "io/ioutil" + "log" + "math" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" + "golang.org/x/time/rate" +) + +const apiURL = "https://api.cloudflare.com/client/v4" + +const ( + // AuthKeyEmail specifies that we should authenticate with API key and email address + AuthKeyEmail = 1 << iota + // AuthUserService specifies that we should authenticate with a User-Service key + AuthUserService + // AuthToken specifies that we should authenticate with an API Token + AuthToken +) + +// API holds the configuration for the current API client. A client should not +// be modified concurrently. +type API struct { + APIKey string + APIEmail string + APIUserServiceKey string + APIToken string + BaseURL string + AccountID string + UserAgent string + headers http.Header + httpClient *http.Client + authType int + rateLimiter *rate.Limiter + retryPolicy RetryPolicy + logger Logger +} + +// newClient provides shared logic for New and NewWithUserServiceKey +func newClient(opts ...Option) (*API, error) { + silentLogger := log.New(ioutil.Discard, "", log.LstdFlags) + + api := &API{ + BaseURL: apiURL, + headers: make(http.Header), + rateLimiter: rate.NewLimiter(rate.Limit(4), 1), // 4rps equates to default api limit (1200 req/5 min) + retryPolicy: RetryPolicy{ + MaxRetries: 3, + MinRetryDelay: time.Duration(1) * time.Second, + MaxRetryDelay: time.Duration(30) * time.Second, + }, + logger: silentLogger, + } + + err := api.parseOptions(opts...) + if err != nil { + return nil, errors.Wrap(err, "options parsing failed") + } + + // Fall back to http.DefaultClient if the package user does not provide + // their own. + if api.httpClient == nil { + api.httpClient = http.DefaultClient + } + + return api, nil +} + +// New creates a new Cloudflare v4 API client. +func New(key, email string, opts ...Option) (*API, error) { + if key == "" || email == "" { + return nil, errors.New(errEmptyCredentials) + } + + api, err := newClient(opts...) + if err != nil { + return nil, err + } + + api.APIKey = key + api.APIEmail = email + api.authType = AuthKeyEmail + + return api, nil +} + +// NewWithAPIToken creates a new Cloudflare v4 API client using API Tokens +func NewWithAPIToken(token string, opts ...Option) (*API, error) { + if token == "" { + return nil, errors.New(errEmptyAPIToken) + } + + api, err := newClient(opts...) + if err != nil { + return nil, err + } + + api.APIToken = token + api.authType = AuthToken + + return api, nil +} + +// NewWithUserServiceKey creates a new Cloudflare v4 API client using service key authentication. +func NewWithUserServiceKey(key string, opts ...Option) (*API, error) { + if key == "" { + return nil, errors.New(errEmptyCredentials) + } + + api, err := newClient(opts...) + if err != nil { + return nil, err + } + + api.APIUserServiceKey = key + api.authType = AuthUserService + + return api, nil +} + +// SetAuthType sets the authentication method (AuthKeyEmail, AuthToken, or AuthUserService). +func (api *API) SetAuthType(authType int) { + api.authType = authType +} + +// ZoneIDByName retrieves a zone's ID from the name. +func (api *API) ZoneIDByName(zoneName string) (string, error) { + res, err := api.ListZonesContext(context.TODO(), WithZoneFilter(zoneName)) + if err != nil { + return "", errors.Wrap(err, "ListZonesContext command failed") + } + + if len(res.Result) > 1 && api.AccountID == "" { + return "", errors.New("ambiguous zone name used without an account ID") + } + + for _, zone := range res.Result { + if api.AccountID != "" { + if zone.Name == zoneName && api.AccountID == zone.Account.ID { + return zone.ID, nil + } + } else { + if zone.Name == zoneName { + return zone.ID, nil + } + } + } + + return "", errors.New("Zone could not be found") +} + +// makeRequest makes a HTTP request and returns the body as a byte slice, +// closing it before returning. params will be serialized to JSON. +func (api *API) makeRequest(method, uri string, params interface{}) ([]byte, error) { + return api.makeRequestWithAuthType(context.TODO(), method, uri, params, api.authType) +} + +func (api *API) makeRequestContext(ctx context.Context, method, uri string, params interface{}) ([]byte, error) { + return api.makeRequestWithAuthType(ctx, method, uri, params, api.authType) +} + +func (api *API) makeRequestWithHeaders(method, uri string, params interface{}, headers http.Header) ([]byte, error) { + return api.makeRequestWithAuthTypeAndHeaders(context.TODO(), method, uri, params, api.authType, headers) +} + +func (api *API) makeRequestWithAuthType(ctx context.Context, method, uri string, params interface{}, authType int) ([]byte, error) { + return api.makeRequestWithAuthTypeAndHeaders(ctx, method, uri, params, authType, nil) +} + +func (api *API) makeRequestWithAuthTypeAndHeaders(ctx context.Context, method, uri string, params interface{}, authType int, headers http.Header) ([]byte, error) { + // Replace nil with a JSON object if needed + var jsonBody []byte + var err error + + if params != nil { + if paramBytes, ok := params.([]byte); ok { + jsonBody = paramBytes + } else { + jsonBody, err = json.Marshal(params) + if err != nil { + return nil, errors.Wrap(err, "error marshalling params to JSON") + } + } + } else { + jsonBody = nil + } + + var resp *http.Response + var respErr error + var reqBody io.Reader + var respBody []byte + for i := 0; i <= api.retryPolicy.MaxRetries; i++ { + if jsonBody != nil { + reqBody = bytes.NewReader(jsonBody) + } + if i > 0 { + // expect the backoff introduced here on errored requests to dominate the effect of rate limiting + // don't need a random component here as the rate limiter should do something similar + // nb time duration could truncate an arbitrary float. Since our inputs are all ints, we should be ok + sleepDuration := time.Duration(math.Pow(2, float64(i-1)) * float64(api.retryPolicy.MinRetryDelay)) + + if sleepDuration > api.retryPolicy.MaxRetryDelay { + sleepDuration = api.retryPolicy.MaxRetryDelay + } + // useful to do some simple logging here, maybe introduce levels later + api.logger.Printf("Sleeping %s before retry attempt number %d for request %s %s", sleepDuration.String(), i, method, uri) + time.Sleep(sleepDuration) + + } + err = api.rateLimiter.Wait(context.TODO()) + if err != nil { + return nil, errors.Wrap(err, "Error caused by request rate limiting") + } + resp, respErr = api.request(ctx, method, uri, reqBody, authType, headers) + + // retry if the server is rate limiting us or if it failed + // assumes server operations are rolled back on failure + if respErr != nil || resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500 { + // if we got a valid http response, try to read body so we can reuse the connection + // see https://golang.org/pkg/net/http/#Client.Do + if respErr == nil { + respBody, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + + respErr = errors.Wrap(err, "could not read response body") + + api.logger.Printf("Request: %s %s got an error response %d: %s\n", method, uri, resp.StatusCode, + strings.Replace(strings.Replace(string(respBody), "\n", "", -1), "\t", "", -1)) + } else { + api.logger.Printf("Error performing request: %s %s : %s \n", method, uri, respErr.Error()) + } + continue + } else { + respBody, err = ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return nil, errors.Wrap(err, "could not read response body") + } + break + } + } + if respErr != nil { + return nil, respErr + } + + switch { + case resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices: + case resp.StatusCode == http.StatusUnauthorized: + return nil, errors.Errorf("HTTP status %d: invalid credentials", resp.StatusCode) + case resp.StatusCode == http.StatusForbidden: + return nil, errors.Errorf("HTTP status %d: insufficient permissions", resp.StatusCode) + case resp.StatusCode == http.StatusServiceUnavailable, + resp.StatusCode == http.StatusBadGateway, + resp.StatusCode == http.StatusGatewayTimeout, + resp.StatusCode == 522, + resp.StatusCode == 523, + resp.StatusCode == 524: + return nil, errors.Errorf("HTTP status %d: service failure", resp.StatusCode) + // This isn't a great solution due to the way the `default` case is + // a catch all and that the `filters/validate-expr` returns a HTTP 400 + // yet the clients need to use the HTTP body as a JSON string. + case resp.StatusCode == 400 && strings.HasSuffix(resp.Request.URL.Path, "/filters/validate-expr"): + return nil, errors.Errorf("%s", respBody) + default: + var s string + if respBody != nil { + s = string(respBody) + } + return nil, errors.Errorf("HTTP status %d: content %q", resp.StatusCode, s) + } + + return respBody, nil +} + +// request makes a HTTP request to the given API endpoint, returning the raw +// *http.Response, or an error if one occurred. The caller is responsible for +// closing the response body. +func (api *API) request(ctx context.Context, method, uri string, reqBody io.Reader, authType int, headers http.Header) (*http.Response, error) { + req, err := http.NewRequest(method, api.BaseURL+uri, reqBody) + if err != nil { + return nil, errors.Wrap(err, "HTTP request creation failed") + } + req.WithContext(ctx) + + combinedHeaders := make(http.Header) + copyHeader(combinedHeaders, api.headers) + copyHeader(combinedHeaders, headers) + req.Header = combinedHeaders + + if authType&AuthKeyEmail != 0 { + req.Header.Set("X-Auth-Key", api.APIKey) + req.Header.Set("X-Auth-Email", api.APIEmail) + } + if authType&AuthUserService != 0 { + req.Header.Set("X-Auth-User-Service-Key", api.APIUserServiceKey) + } + if authType&AuthToken != 0 { + req.Header.Set("Authorization", "Bearer "+api.APIToken) + } + + if api.UserAgent != "" { + req.Header.Set("User-Agent", api.UserAgent) + } + + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := api.httpClient.Do(req) + if err != nil { + return nil, errors.Wrap(err, "HTTP request failed") + } + + return resp, nil +} + +// Returns the base URL to use for API endpoints that exist for accounts. +// If an account option was used when creating the API instance, returns +// the account URL. +// +// accountBase is the base URL for endpoints referring to the current user. +// It exists as a parameter because it is not consistent across APIs. +func (api *API) userBaseURL(accountBase string) string { + if api.AccountID != "" { + return "/accounts/" + api.AccountID + } + return accountBase +} + +// copyHeader copies all headers for `source` and sets them on `target`. +// based on https://godoc.org/github.com/golang/gddo/httputil/header#Copy +func copyHeader(target, source http.Header) { + for k, vs := range source { + target[k] = vs + } +} + +// ResponseInfo contains a code and message returned by the API as errors or +// informational messages inside the response. +type ResponseInfo struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// Response is a template. There will also be a result struct. There will be a +// unique response type for each response, which will include this type. +type Response struct { + Success bool `json:"success"` + Errors []ResponseInfo `json:"errors"` + Messages []ResponseInfo `json:"messages"` +} + +// ResultInfo contains metadata about the Response. +type ResultInfo struct { + Page int `json:"page"` + PerPage int `json:"per_page"` + TotalPages int `json:"total_pages"` + Count int `json:"count"` + Total int `json:"total_count"` +} + +// RawResponse keeps the result as JSON form +type RawResponse struct { + Response + Result json.RawMessage `json:"result"` +} + +// Raw makes a HTTP request with user provided params and returns the +// result as untouched JSON. +func (api *API) Raw(method, endpoint string, data interface{}) (json.RawMessage, error) { + res, err := api.makeRequest(method, endpoint, data) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var r RawResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// PaginationOptions can be passed to a list request to configure paging +// These values will be defaulted if omitted, and PerPage has min/max limits set by resource +type PaginationOptions struct { + Page int `json:"page,omitempty"` + PerPage int `json:"per_page,omitempty"` +} + +// RetryPolicy specifies number of retries and min/max retry delays +// This config is used when the client exponentially backs off after errored requests +type RetryPolicy struct { + MaxRetries int + MinRetryDelay time.Duration + MaxRetryDelay time.Duration +} + +// Logger defines the interface this library needs to use logging +// This is a subset of the methods implemented in the log package +type Logger interface { + Printf(format string, v ...interface{}) +} + +// ReqOption is a functional option for configuring API requests +type ReqOption func(opt *reqOption) +type reqOption struct { + params url.Values +} + +// WithZoneFilter applies a filter based on zone name. +func WithZoneFilter(zone string) ReqOption { + return func(opt *reqOption) { + opt.params.Set("name", zone) + } +} + +// WithPagination configures the pagination for a response. +func WithPagination(opts PaginationOptions) ReqOption { + return func(opt *reqOption) { + opt.params.Set("page", strconv.Itoa(opts.Page)) + opt.params.Set("per_page", strconv.Itoa(opts.PerPage)) + } +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/custom_hostname.go b/vendor/github.com/cloudflare/cloudflare-go/custom_hostname.go new file mode 100644 index 000000000..d982c5b50 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/custom_hostname.go @@ -0,0 +1,161 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// CustomHostnameSSLSettings represents the SSL settings for a custom hostname. +type CustomHostnameSSLSettings struct { + HTTP2 string `json:"http2,omitempty"` + TLS13 string `json:"tls_1_3,omitempty"` + MinTLSVersion string `json:"min_tls_version,omitempty"` + Ciphers []string `json:"ciphers,omitempty"` +} + +// CustomHostnameSSL represents the SSL section in a given custom hostname. +type CustomHostnameSSL struct { + Status string `json:"status,omitempty"` + Method string `json:"method,omitempty"` + Type string `json:"type,omitempty"` + CnameTarget string `json:"cname_target,omitempty"` + CnameName string `json:"cname,omitempty"` + Settings CustomHostnameSSLSettings `json:"settings,omitempty"` +} + +// CustomMetadata defines custom metadata for the hostname. This requires logic to be implemented by Cloudflare to act on the data provided. +type CustomMetadata map[string]interface{} + +// CustomHostname represents a custom hostname in a zone. +type CustomHostname struct { + ID string `json:"id,omitempty"` + Hostname string `json:"hostname,omitempty"` + CustomOriginServer string `json:"custom_origin_server,omitempty"` + SSL CustomHostnameSSL `json:"ssl,omitempty"` + CustomMetadata CustomMetadata `json:"custom_metadata,omitempty"` +} + +// CustomHostnameResponse represents a response from the Custom Hostnames endpoints. +type CustomHostnameResponse struct { + Result CustomHostname `json:"result"` + Response +} + +// CustomHostnameListResponse represents a response from the Custom Hostnames endpoints. +type CustomHostnameListResponse struct { + Result []CustomHostname `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// UpdateCustomHostnameSSL modifies SSL configuration for the given custom +// hostname in the given zone. +// +// API reference: https://api.cloudflare.com/#custom-hostname-for-a-zone-update-custom-hostname-configuration +func (api *API) UpdateCustomHostnameSSL(zoneID string, customHostnameID string, ssl CustomHostnameSSL) (CustomHostname, error) { + return CustomHostname{}, errors.New("Not implemented") +} + +// DeleteCustomHostname deletes a custom hostname (and any issued SSL +// certificates). +// +// API reference: https://api.cloudflare.com/#custom-hostname-for-a-zone-delete-a-custom-hostname-and-any-issued-ssl-certificates- +func (api *API) DeleteCustomHostname(zoneID string, customHostnameID string) error { + uri := "/zones/" + zoneID + "/custom_hostnames/" + customHostnameID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + var response *CustomHostnameResponse + err = json.Unmarshal(res, &response) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + + return nil +} + +// CreateCustomHostname creates a new custom hostname and requests that an SSL certificate be issued for it. +// +// API reference: https://api.cloudflare.com/#custom-hostname-for-a-zone-create-custom-hostname +func (api *API) CreateCustomHostname(zoneID string, ch CustomHostname) (*CustomHostnameResponse, error) { + uri := "/zones/" + zoneID + "/custom_hostnames" + res, err := api.makeRequest("POST", uri, ch) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var response *CustomHostnameResponse + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// CustomHostnames fetches custom hostnames for the given zone, +// by applying filter.Hostname if not empty and scoping the result to page'th 50 items. +// +// The returned ResultInfo can be used to implement pagination. +// +// API reference: https://api.cloudflare.com/#custom-hostname-for-a-zone-list-custom-hostnames +func (api *API) CustomHostnames(zoneID string, page int, filter CustomHostname) ([]CustomHostname, ResultInfo, error) { + v := url.Values{} + v.Set("per_page", "50") + v.Set("page", strconv.Itoa(page)) + if filter.Hostname != "" { + v.Set("hostname", filter.Hostname) + } + query := "?" + v.Encode() + + uri := "/zones/" + zoneID + "/custom_hostnames" + query + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []CustomHostname{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + var customHostnameListResponse CustomHostnameListResponse + err = json.Unmarshal(res, &customHostnameListResponse) + if err != nil { + return []CustomHostname{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + return customHostnameListResponse.Result, customHostnameListResponse.ResultInfo, nil +} + +// CustomHostname inspects the given custom hostname in the given zone. +// +// API reference: https://api.cloudflare.com/#custom-hostname-for-a-zone-custom-hostname-configuration-details +func (api *API) CustomHostname(zoneID string, customHostnameID string) (CustomHostname, error) { + uri := "/zones/" + zoneID + "/custom_hostnames/" + customHostnameID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return CustomHostname{}, errors.Wrap(err, errMakeRequestError) + } + + var response CustomHostnameResponse + err = json.Unmarshal(res, &response) + if err != nil { + return CustomHostname{}, errors.Wrap(err, errUnmarshalError) + } + + return response.Result, nil +} + +// CustomHostnameIDByName retrieves the ID for the given hostname in the given zone. +func (api *API) CustomHostnameIDByName(zoneID string, hostname string) (string, error) { + customHostnames, _, err := api.CustomHostnames(zoneID, 1, CustomHostname{Hostname: hostname}) + if err != nil { + return "", errors.Wrap(err, "CustomHostnames command failed") + } + for _, ch := range customHostnames { + if ch.Hostname == hostname { + return ch.ID, nil + } + } + return "", errors.New("CustomHostname could not be found") +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/custom_pages.go b/vendor/github.com/cloudflare/cloudflare-go/custom_pages.go new file mode 100644 index 000000000..d96788fc8 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/custom_pages.go @@ -0,0 +1,176 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +// CustomPage represents a custom page configuration. +type CustomPage struct { + CreatedOn time.Time `json:"created_on"` + ModifiedOn time.Time `json:"modified_on"` + URL interface{} `json:"url"` + State string `json:"state"` + RequiredTokens []string `json:"required_tokens"` + PreviewTarget string `json:"preview_target"` + Description string `json:"description"` + ID string `json:"id"` +} + +// CustomPageResponse represents the response from the custom pages endpoint. +type CustomPageResponse struct { + Response + Result []CustomPage `json:"result"` +} + +// CustomPageDetailResponse represents the response from the custom page endpoint. +type CustomPageDetailResponse struct { + Response + Result CustomPage `json:"result"` +} + +// CustomPageOptions is used to determine whether or not the operation +// should take place on an account or zone level based on which is +// provided to the function. +// +// A non-empty value denotes desired use. +type CustomPageOptions struct { + AccountID string + ZoneID string +} + +// CustomPageParameters is used to update a particular custom page with +// the values provided. +type CustomPageParameters struct { + URL interface{} `json:"url"` + State string `json:"state"` +} + +// CustomPages lists custom pages for a zone or account. +// +// Zone API reference: https://api.cloudflare.com/#custom-pages-for-a-zone-list-available-custom-pages +// Account API reference: https://api.cloudflare.com/#custom-pages-account--list-custom-pages +func (api *API) CustomPages(options *CustomPageOptions) ([]CustomPage, error) { + var ( + pageType, identifier string + ) + + if options.AccountID == "" && options.ZoneID == "" { + return nil, errors.New("either account ID or zone ID must be provided") + } + + if options.AccountID != "" && options.ZoneID != "" { + return nil, errors.New("account ID and zone ID are mutually exclusive") + } + + // Should the account ID be defined, treat this as an account level operation. + if options.AccountID != "" { + pageType = "accounts" + identifier = options.AccountID + } else { + pageType = "zones" + identifier = options.ZoneID + } + + uri := fmt.Sprintf("/%s/%s/custom_pages", pageType, identifier) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var customPageResponse CustomPageResponse + err = json.Unmarshal(res, &customPageResponse) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return customPageResponse.Result, nil +} + +// CustomPage lists a single custom page based on the ID. +// +// Zone API reference: https://api.cloudflare.com/#custom-pages-for-a-zone-custom-page-details +// Account API reference: https://api.cloudflare.com/#custom-pages-account--custom-page-details +func (api *API) CustomPage(options *CustomPageOptions, customPageID string) (CustomPage, error) { + var ( + pageType, identifier string + ) + + if options.AccountID == "" && options.ZoneID == "" { + return CustomPage{}, errors.New("either account ID or zone ID must be provided") + } + + if options.AccountID != "" && options.ZoneID != "" { + return CustomPage{}, errors.New("account ID and zone ID are mutually exclusive") + } + + // Should the account ID be defined, treat this as an account level operation. + if options.AccountID != "" { + pageType = "accounts" + identifier = options.AccountID + } else { + pageType = "zones" + identifier = options.ZoneID + } + + uri := fmt.Sprintf("/%s/%s/custom_pages/%s", pageType, identifier, customPageID) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return CustomPage{}, errors.Wrap(err, errMakeRequestError) + } + + var customPageResponse CustomPageDetailResponse + err = json.Unmarshal(res, &customPageResponse) + if err != nil { + return CustomPage{}, errors.Wrap(err, errUnmarshalError) + } + + return customPageResponse.Result, nil +} + +// UpdateCustomPage updates a single custom page setting. +// +// Zone API reference: https://api.cloudflare.com/#custom-pages-for-a-zone-update-custom-page-url +// Account API reference: https://api.cloudflare.com/#custom-pages-account--update-custom-page +func (api *API) UpdateCustomPage(options *CustomPageOptions, customPageID string, pageParameters CustomPageParameters) (CustomPage, error) { + var ( + pageType, identifier string + ) + + if options.AccountID == "" && options.ZoneID == "" { + return CustomPage{}, errors.New("either account ID or zone ID must be provided") + } + + if options.AccountID != "" && options.ZoneID != "" { + return CustomPage{}, errors.New("account ID and zone ID are mutually exclusive") + } + + // Should the account ID be defined, treat this as an account level operation. + if options.AccountID != "" { + pageType = "accounts" + identifier = options.AccountID + } else { + pageType = "zones" + identifier = options.ZoneID + } + + uri := fmt.Sprintf("/%s/%s/custom_pages/%s", pageType, identifier, customPageID) + + res, err := api.makeRequest("PUT", uri, pageParameters) + if err != nil { + return CustomPage{}, errors.Wrap(err, errMakeRequestError) + } + + var customPageResponse CustomPageDetailResponse + err = json.Unmarshal(res, &customPageResponse) + if err != nil { + return CustomPage{}, errors.Wrap(err, errUnmarshalError) + } + + return customPageResponse.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/dns.go b/vendor/github.com/cloudflare/cloudflare-go/dns.go new file mode 100644 index 000000000..6bcac2480 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/dns.go @@ -0,0 +1,174 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// DNSRecord represents a DNS record in a zone. +type DNSRecord struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` + Proxiable bool `json:"proxiable,omitempty"` + Proxied bool `json:"proxied"` + TTL int `json:"ttl,omitempty"` + Locked bool `json:"locked,omitempty"` + ZoneID string `json:"zone_id,omitempty"` + ZoneName string `json:"zone_name,omitempty"` + CreatedOn time.Time `json:"created_on,omitempty"` + ModifiedOn time.Time `json:"modified_on,omitempty"` + Data interface{} `json:"data,omitempty"` // data returned by: SRV, LOC + Meta interface{} `json:"meta,omitempty"` + Priority int `json:"priority"` +} + +// DNSRecordResponse represents the response from the DNS endpoint. +type DNSRecordResponse struct { + Result DNSRecord `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// DNSListResponse represents the response from the list DNS records endpoint. +type DNSListResponse struct { + Result []DNSRecord `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// CreateDNSRecord creates a DNS record for the zone identifier. +// +// API reference: https://api.cloudflare.com/#dns-records-for-a-zone-create-dns-record +func (api *API) CreateDNSRecord(zoneID string, rr DNSRecord) (*DNSRecordResponse, error) { + uri := "/zones/" + zoneID + "/dns_records" + res, err := api.makeRequest("POST", uri, rr) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var recordResp *DNSRecordResponse + err = json.Unmarshal(res, &recordResp) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return recordResp, nil +} + +// DNSRecords returns a slice of DNS records for the given zone identifier. +// +// This takes a DNSRecord to allow filtering of the results returned. +// +// API reference: https://api.cloudflare.com/#dns-records-for-a-zone-list-dns-records +func (api *API) DNSRecords(zoneID string, rr DNSRecord) ([]DNSRecord, error) { + // Construct a query string + v := url.Values{} + // Request as many records as possible per page - API max is 50 + v.Set("per_page", "50") + if rr.Name != "" { + v.Set("name", rr.Name) + } + if rr.Type != "" { + v.Set("type", rr.Type) + } + if rr.Content != "" { + v.Set("content", rr.Content) + } + + var query string + var records []DNSRecord + page := 1 + + // Loop over makeRequest until what we've fetched all records + for { + v.Set("page", strconv.Itoa(page)) + query = "?" + v.Encode() + uri := "/zones/" + zoneID + "/dns_records" + query + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []DNSRecord{}, errors.Wrap(err, errMakeRequestError) + } + var r DNSListResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []DNSRecord{}, errors.Wrap(err, errUnmarshalError) + } + records = append(records, r.Result...) + if r.ResultInfo.Page >= r.ResultInfo.TotalPages { + break + } + // Loop around and fetch the next page + page++ + } + return records, nil +} + +// DNSRecord returns a single DNS record for the given zone & record +// identifiers. +// +// API reference: https://api.cloudflare.com/#dns-records-for-a-zone-dns-record-details +func (api *API) DNSRecord(zoneID, recordID string) (DNSRecord, error) { + uri := "/zones/" + zoneID + "/dns_records/" + recordID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return DNSRecord{}, errors.Wrap(err, errMakeRequestError) + } + var r DNSRecordResponse + err = json.Unmarshal(res, &r) + if err != nil { + return DNSRecord{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// UpdateDNSRecord updates a single DNS record for the given zone & record +// identifiers. +// +// API reference: https://api.cloudflare.com/#dns-records-for-a-zone-update-dns-record +func (api *API) UpdateDNSRecord(zoneID, recordID string, rr DNSRecord) error { + rec, err := api.DNSRecord(zoneID, recordID) + if err != nil { + return err + } + // Populate the record name from the existing one if the update didn't + // specify it. + if rr.Name == "" { + rr.Name = rec.Name + } + rr.Type = rec.Type + uri := "/zones/" + zoneID + "/dns_records/" + recordID + res, err := api.makeRequest("PATCH", uri, rr) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r DNSRecordResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} + +// DeleteDNSRecord deletes a single DNS record for the given zone & record +// identifiers. +// +// API reference: https://api.cloudflare.com/#dns-records-for-a-zone-delete-dns-record +func (api *API) DeleteDNSRecord(zoneID, recordID string) error { + uri := "/zones/" + zoneID + "/dns_records/" + recordID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r DNSRecordResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/duration.go b/vendor/github.com/cloudflare/cloudflare-go/duration.go new file mode 100644 index 000000000..ba2418acd --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/duration.go @@ -0,0 +1,40 @@ +package cloudflare + +import ( + "encoding/json" + "time" +) + +// Duration implements json.Marshaler and json.Unmarshaler for time.Duration +// using the fmt.Stringer interface of time.Duration and time.ParseDuration. +type Duration struct { + time.Duration +} + +// MarshalJSON encodes a Duration as a JSON string formatted using String. +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) +} + +// UnmarshalJSON decodes a Duration from a JSON string parsed using time.ParseDuration. +func (d *Duration) UnmarshalJSON(buf []byte) error { + var str string + + err := json.Unmarshal(buf, &str) + if err != nil { + return err + } + + dur, err := time.ParseDuration(str) + if err != nil { + return err + } + + d.Duration = dur + return nil +} + +var ( + _ = json.Marshaler((*Duration)(nil)) + _ = json.Unmarshaler((*Duration)(nil)) +) diff --git a/vendor/github.com/cloudflare/cloudflare-go/errors.go b/vendor/github.com/cloudflare/cloudflare-go/errors.go new file mode 100644 index 000000000..21c38b168 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/errors.go @@ -0,0 +1,50 @@ +package cloudflare + +// Error messages +const ( + errEmptyCredentials = "invalid credentials: key & email must not be empty" + errEmptyAPIToken = "invalid credentials: API Token must not be empty" + errMakeRequestError = "error from makeRequest" + errUnmarshalError = "error unmarshalling the JSON response" + errRequestNotSuccessful = "error reported by API" + errMissingAccountID = "account ID is empty and must be provided" +) + +var _ Error = &UserError{} + +// Error represents an error returned from this library. +type Error interface { + error + // Raised when user credentials or configuration is invalid. + User() bool + // Raised when a parsing error (e.g. JSON) occurs. + Parse() bool + // Raised when a network error occurs. + Network() bool + // Contains the most recent error. +} + +// UserError represents a user-generated error. +type UserError struct { + Err error +} + +// User is a user-caused error. +func (e *UserError) User() bool { + return true +} + +// Network error. +func (e *UserError) Network() bool { + return false +} + +// Parse error. +func (e *UserError) Parse() bool { + return true +} + +// Error wraps the underlying error. +func (e *UserError) Error() string { + return e.Err.Error() +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/filter.go b/vendor/github.com/cloudflare/cloudflare-go/filter.go new file mode 100644 index 000000000..cf3ef1c20 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/filter.go @@ -0,0 +1,241 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// Filter holds the structure of the filter type. +type Filter struct { + ID string `json:"id,omitempty"` + Expression string `json:"expression"` + Paused bool `json:"paused"` + Description string `json:"description"` + + // Property is mentioned in documentation however isn't populated in + // any of the API requests. For now, let's just omit it unless it's + // provided. + Ref string `json:"ref,omitempty"` +} + +// FiltersDetailResponse is the API response that is returned +// for requesting all filters on a zone. +type FiltersDetailResponse struct { + Result []Filter `json:"result"` + ResultInfo `json:"result_info"` + Response +} + +// FilterDetailResponse is the API response that is returned +// for requesting a single filter on a zone. +type FilterDetailResponse struct { + Result Filter `json:"result"` + ResultInfo `json:"result_info"` + Response +} + +// FilterValidateExpression represents the JSON payload for checking +// an expression. +type FilterValidateExpression struct { + Expression string `json:"expression"` +} + +// FilterValidateExpressionResponse represents the API response for +// checking the expression. It conforms to the JSON API approach however +// we don't need all of the fields exposed. +type FilterValidateExpressionResponse struct { + Success bool `json:"success"` + Errors []FilterValidationExpressionMessage `json:"errors"` +} + +// FilterValidationExpressionMessage represents the API error message. +type FilterValidationExpressionMessage struct { + Message string `json:"message"` +} + +// Filter returns a single filter in a zone based on the filter ID. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/get/#get-by-filter-id +func (api *API) Filter(zoneID, filterID string) (Filter, error) { + uri := fmt.Sprintf("/zones/%s/filters/%s", zoneID, filterID) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return Filter{}, errors.Wrap(err, errMakeRequestError) + } + + var filterResponse FilterDetailResponse + err = json.Unmarshal(res, &filterResponse) + if err != nil { + return Filter{}, errors.Wrap(err, errUnmarshalError) + } + + return filterResponse.Result, nil +} + +// Filters returns all filters for a zone. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/get/#get-all-filters +func (api *API) Filters(zoneID string, pageOpts PaginationOptions) ([]Filter, error) { + uri := "/zones/" + zoneID + "/filters" + v := url.Values{} + + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []Filter{}, errors.Wrap(err, errMakeRequestError) + } + + var filtersResponse FiltersDetailResponse + err = json.Unmarshal(res, &filtersResponse) + if err != nil { + return []Filter{}, errors.Wrap(err, errUnmarshalError) + } + + return filtersResponse.Result, nil +} + +// CreateFilters creates new filters. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/post/ +func (api *API) CreateFilters(zoneID string, filters []Filter) ([]Filter, error) { + uri := "/zones/" + zoneID + "/filters" + + res, err := api.makeRequest("POST", uri, filters) + if err != nil { + return []Filter{}, errors.Wrap(err, errMakeRequestError) + } + + var filtersResponse FiltersDetailResponse + err = json.Unmarshal(res, &filtersResponse) + if err != nil { + return []Filter{}, errors.Wrap(err, errUnmarshalError) + } + + return filtersResponse.Result, nil +} + +// UpdateFilter updates a single filter. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/put/#update-a-single-filter +func (api *API) UpdateFilter(zoneID string, filter Filter) (Filter, error) { + if filter.ID == "" { + return Filter{}, errors.Errorf("filter ID cannot be empty") + } + + uri := fmt.Sprintf("/zones/%s/filters/%s", zoneID, filter.ID) + + res, err := api.makeRequest("PUT", uri, filter) + if err != nil { + return Filter{}, errors.Wrap(err, errMakeRequestError) + } + + var filterResponse FilterDetailResponse + err = json.Unmarshal(res, &filterResponse) + if err != nil { + return Filter{}, errors.Wrap(err, errUnmarshalError) + } + + return filterResponse.Result, nil +} + +// UpdateFilters updates many filters at once. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/put/#update-multiple-filters +func (api *API) UpdateFilters(zoneID string, filters []Filter) ([]Filter, error) { + for _, filter := range filters { + if filter.ID == "" { + return []Filter{}, errors.Errorf("filter ID cannot be empty") + } + } + + uri := "/zones/" + zoneID + "/filters" + + res, err := api.makeRequest("PUT", uri, filters) + if err != nil { + return []Filter{}, errors.Wrap(err, errMakeRequestError) + } + + var filtersResponse FiltersDetailResponse + err = json.Unmarshal(res, &filtersResponse) + if err != nil { + return []Filter{}, errors.Wrap(err, errUnmarshalError) + } + + return filtersResponse.Result, nil +} + +// DeleteFilter deletes a single filter. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/delete/#delete-a-single-filter +func (api *API) DeleteFilter(zoneID, filterID string) error { + if filterID == "" { + return errors.Errorf("filter ID cannot be empty") + } + + uri := fmt.Sprintf("/zones/%s/filters/%s", zoneID, filterID) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} + +// DeleteFilters deletes multiple filters. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/delete/#delete-multiple-filters +func (api *API) DeleteFilters(zoneID string, filterIDs []string) error { + ids := strings.Join(filterIDs, ",") + uri := fmt.Sprintf("/zones/%s/filters?id=%s", zoneID, ids) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} + +// ValidateFilterExpression checks correctness of a filter expression. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-filters/validation/ +func (api *API) ValidateFilterExpression(expression string) error { + uri := fmt.Sprintf("/filters/validate-expr") + expressionPayload := FilterValidateExpression{Expression: expression} + + _, err := api.makeRequest("POST", uri, expressionPayload) + if err != nil { + var filterValidationResponse FilterValidateExpressionResponse + + jsonErr := json.Unmarshal([]byte(err.Error()), &filterValidationResponse) + if jsonErr != nil { + return errors.Wrap(jsonErr, errUnmarshalError) + } + + if filterValidationResponse.Success != true { + // Unsure why but the API returns `errors` as an array but it only + // ever shows the issue with one problem at a time ¯\_(ツ)_/¯ + return errors.Errorf(filterValidationResponse.Errors[0].Message) + } + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/firewall.go b/vendor/github.com/cloudflare/cloudflare-go/firewall.go new file mode 100644 index 000000000..4b61a7ca5 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/firewall.go @@ -0,0 +1,280 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// AccessRule represents a firewall access rule. +type AccessRule struct { + ID string `json:"id,omitempty"` + Notes string `json:"notes,omitempty"` + AllowedModes []string `json:"allowed_modes,omitempty"` + Mode string `json:"mode,omitempty"` + Configuration AccessRuleConfiguration `json:"configuration,omitempty"` + Scope AccessRuleScope `json:"scope,omitempty"` + CreatedOn time.Time `json:"created_on,omitempty"` + ModifiedOn time.Time `json:"modified_on,omitempty"` +} + +// AccessRuleConfiguration represents the configuration of a firewall +// access rule. +type AccessRuleConfiguration struct { + Target string `json:"target,omitempty"` + Value string `json:"value,omitempty"` +} + +// AccessRuleScope represents the scope of a firewall access rule. +type AccessRuleScope struct { + ID string `json:"id,omitempty"` + Email string `json:"email,omitempty"` + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` +} + +// AccessRuleResponse represents the response from the firewall access +// rule endpoint. +type AccessRuleResponse struct { + Result AccessRule `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// AccessRuleListResponse represents the response from the list access rules +// endpoint. +type AccessRuleListResponse struct { + Result []AccessRule `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// ListUserAccessRules returns a slice of access rules for the logged-in user. +// +// This takes an AccessRule to allow filtering of the results returned. +// +// API reference: https://api.cloudflare.com/#user-level-firewall-access-rule-list-access-rules +func (api *API) ListUserAccessRules(accessRule AccessRule, page int) (*AccessRuleListResponse, error) { + return api.listAccessRules("/user", accessRule, page) +} + +// CreateUserAccessRule creates a firewall access rule for the logged-in user. +// +// API reference: https://api.cloudflare.com/#user-level-firewall-access-rule-create-access-rule +func (api *API) CreateUserAccessRule(accessRule AccessRule) (*AccessRuleResponse, error) { + return api.createAccessRule("/user", accessRule) +} + +// UserAccessRule returns the details of a user's account access rule. +// +// API reference: https://api.cloudflare.com/#user-level-firewall-access-rule-list-access-rules +func (api *API) UserAccessRule(accessRuleID string) (*AccessRuleResponse, error) { + return api.retrieveAccessRule("/user", accessRuleID) +} + +// UpdateUserAccessRule updates a single access rule for the logged-in user & +// given access rule identifier. +// +// API reference: https://api.cloudflare.com/#user-level-firewall-access-rule-update-access-rule +func (api *API) UpdateUserAccessRule(accessRuleID string, accessRule AccessRule) (*AccessRuleResponse, error) { + return api.updateAccessRule("/user", accessRuleID, accessRule) +} + +// DeleteUserAccessRule deletes a single access rule for the logged-in user and +// access rule identifiers. +// +// API reference: https://api.cloudflare.com/#user-level-firewall-access-rule-update-access-rule +func (api *API) DeleteUserAccessRule(accessRuleID string) (*AccessRuleResponse, error) { + return api.deleteAccessRule("/user", accessRuleID) +} + +// ListZoneAccessRules returns a slice of access rules for the given zone +// identifier. +// +// This takes an AccessRule to allow filtering of the results returned. +// +// API reference: https://api.cloudflare.com/#firewall-access-rule-for-a-zone-list-access-rules +func (api *API) ListZoneAccessRules(zoneID string, accessRule AccessRule, page int) (*AccessRuleListResponse, error) { + return api.listAccessRules("/zones/"+zoneID, accessRule, page) +} + +// CreateZoneAccessRule creates a firewall access rule for the given zone +// identifier. +// +// API reference: https://api.cloudflare.com/#firewall-access-rule-for-a-zone-create-access-rule +func (api *API) CreateZoneAccessRule(zoneID string, accessRule AccessRule) (*AccessRuleResponse, error) { + return api.createAccessRule("/zones/"+zoneID, accessRule) +} + +// ZoneAccessRule returns the details of a zone's access rule. +// +// API reference: https://api.cloudflare.com/#firewall-access-rule-for-a-zone-list-access-rules +func (api *API) ZoneAccessRule(zoneID string, accessRuleID string) (*AccessRuleResponse, error) { + return api.retrieveAccessRule("/zones/"+zoneID, accessRuleID) +} + +// UpdateZoneAccessRule updates a single access rule for the given zone & +// access rule identifiers. +// +// API reference: https://api.cloudflare.com/#firewall-access-rule-for-a-zone-update-access-rule +func (api *API) UpdateZoneAccessRule(zoneID, accessRuleID string, accessRule AccessRule) (*AccessRuleResponse, error) { + return api.updateAccessRule("/zones/"+zoneID, accessRuleID, accessRule) +} + +// DeleteZoneAccessRule deletes a single access rule for the given zone and +// access rule identifiers. +// +// API reference: https://api.cloudflare.com/#firewall-access-rule-for-a-zone-delete-access-rule +func (api *API) DeleteZoneAccessRule(zoneID, accessRuleID string) (*AccessRuleResponse, error) { + return api.deleteAccessRule("/zones/"+zoneID, accessRuleID) +} + +// ListAccountAccessRules returns a slice of access rules for the given +// account identifier. +// +// This takes an AccessRule to allow filtering of the results returned. +// +// API reference: https://api.cloudflare.com/#account-level-firewall-access-rule-list-access-rules +func (api *API) ListAccountAccessRules(accountID string, accessRule AccessRule, page int) (*AccessRuleListResponse, error) { + return api.listAccessRules("/accounts/"+accountID, accessRule, page) +} + +// CreateAccountAccessRule creates a firewall access rule for the given +// account identifier. +// +// API reference: https://api.cloudflare.com/#account-level-firewall-access-rule-create-access-rule +func (api *API) CreateAccountAccessRule(accountID string, accessRule AccessRule) (*AccessRuleResponse, error) { + return api.createAccessRule("/accounts/"+accountID, accessRule) +} + +// AccountAccessRule returns the details of an account's access rule. +// +// API reference: https://api.cloudflare.com/#account-level-firewall-access-rule-access-rule-details +func (api *API) AccountAccessRule(accountID string, accessRuleID string) (*AccessRuleResponse, error) { + return api.retrieveAccessRule("/accounts/"+accountID, accessRuleID) +} + +// UpdateAccountAccessRule updates a single access rule for the given +// account & access rule identifiers. +// +// API reference: https://api.cloudflare.com/#account-level-firewall-access-rule-update-access-rule +func (api *API) UpdateAccountAccessRule(accountID, accessRuleID string, accessRule AccessRule) (*AccessRuleResponse, error) { + return api.updateAccessRule("/accounts/"+accountID, accessRuleID, accessRule) +} + +// DeleteAccountAccessRule deletes a single access rule for the given +// account and access rule identifiers. +// +// API reference: https://api.cloudflare.com/#account-level-firewall-access-rule-delete-access-rule +func (api *API) DeleteAccountAccessRule(accountID, accessRuleID string) (*AccessRuleResponse, error) { + return api.deleteAccessRule("/accounts/"+accountID, accessRuleID) +} + +func (api *API) listAccessRules(prefix string, accessRule AccessRule, page int) (*AccessRuleListResponse, error) { + // Construct a query string + v := url.Values{} + if page <= 0 { + page = 1 + } + v.Set("page", strconv.Itoa(page)) + // Request as many rules as possible per page - API max is 100 + v.Set("per_page", "100") + if accessRule.Notes != "" { + v.Set("notes", accessRule.Notes) + } + if accessRule.Mode != "" { + v.Set("mode", accessRule.Mode) + } + if accessRule.Scope.Type != "" { + v.Set("scope_type", accessRule.Scope.Type) + } + if accessRule.Configuration.Value != "" { + v.Set("configuration_value", accessRule.Configuration.Value) + } + if accessRule.Configuration.Target != "" { + v.Set("configuration_target", accessRule.Configuration.Target) + } + v.Set("page", strconv.Itoa(page)) + query := "?" + v.Encode() + + uri := prefix + "/firewall/access_rules/rules" + query + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &AccessRuleListResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return response, nil +} + +func (api *API) createAccessRule(prefix string, accessRule AccessRule) (*AccessRuleResponse, error) { + uri := prefix + "/firewall/access_rules/rules" + res, err := api.makeRequest("POST", uri, accessRule) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &AccessRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +func (api *API) retrieveAccessRule(prefix, accessRuleID string) (*AccessRuleResponse, error) { + uri := prefix + "/firewall/access_rules/rules/" + accessRuleID + + res, err := api.makeRequest("GET", uri, nil) + + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &AccessRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +func (api *API) updateAccessRule(prefix, accessRuleID string, accessRule AccessRule) (*AccessRuleResponse, error) { + uri := prefix + "/firewall/access_rules/rules/" + accessRuleID + res, err := api.makeRequest("PATCH", uri, accessRule) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &AccessRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return response, nil +} + +func (api *API) deleteAccessRule(prefix, accessRuleID string) (*AccessRuleResponse, error) { + uri := prefix + "/firewall/access_rules/rules/" + accessRuleID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &AccessRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/firewall_rules.go b/vendor/github.com/cloudflare/cloudflare-go/firewall_rules.go new file mode 100644 index 000000000..7a6ce5c70 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/firewall_rules.go @@ -0,0 +1,196 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +// FirewallRule is the struct of the firewall rule. +type FirewallRule struct { + ID string `json:"id,omitempty"` + Paused bool `json:"paused"` + Description string `json:"description"` + Action string `json:"action"` + Priority interface{} `json:"priority"` + Filter Filter `json:"filter"` + CreatedOn time.Time `json:"created_on,omitempty"` + ModifiedOn time.Time `json:"modified_on,omitempty"` +} + +// FirewallRulesDetailResponse is the API response for the firewall +// rules. +type FirewallRulesDetailResponse struct { + Result []FirewallRule `json:"result"` + ResultInfo `json:"result_info"` + Response +} + +// FirewallRuleResponse is the API response that is returned +// for requesting a single firewall rule on a zone. +type FirewallRuleResponse struct { + Result FirewallRule `json:"result"` + ResultInfo `json:"result_info"` + Response +} + +// FirewallRules returns all firewall rules. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/get/#get-all-rules +func (api *API) FirewallRules(zoneID string, pageOpts PaginationOptions) ([]FirewallRule, error) { + uri := fmt.Sprintf("/zones/%s/firewall/rules", zoneID) + v := url.Values{} + + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errMakeRequestError) + } + + var firewallDetailResponse FirewallRulesDetailResponse + err = json.Unmarshal(res, &firewallDetailResponse) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errUnmarshalError) + } + + return firewallDetailResponse.Result, nil +} + +// FirewallRule returns a single firewall rule based on the ID. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/get/#get-by-rule-id +func (api *API) FirewallRule(zoneID, firewallRuleID string) (FirewallRule, error) { + uri := fmt.Sprintf("/zones/%s/firewall/rules/%s", zoneID, firewallRuleID) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return FirewallRule{}, errors.Wrap(err, errMakeRequestError) + } + + var firewallRuleResponse FirewallRuleResponse + err = json.Unmarshal(res, &firewallRuleResponse) + if err != nil { + return FirewallRule{}, errors.Wrap(err, errUnmarshalError) + } + + return firewallRuleResponse.Result, nil +} + +// CreateFirewallRules creates new firewall rules. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/post/ +func (api *API) CreateFirewallRules(zoneID string, firewallRules []FirewallRule) ([]FirewallRule, error) { + uri := fmt.Sprintf("/zones/%s/firewall/rules", zoneID) + + res, err := api.makeRequest("POST", uri, firewallRules) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errMakeRequestError) + } + + var firewallRulesDetailResponse FirewallRulesDetailResponse + err = json.Unmarshal(res, &firewallRulesDetailResponse) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errUnmarshalError) + } + + return firewallRulesDetailResponse.Result, nil +} + +// UpdateFirewallRule updates a single firewall rule. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/put/#update-a-single-rule +func (api *API) UpdateFirewallRule(zoneID string, firewallRule FirewallRule) (FirewallRule, error) { + if firewallRule.ID == "" { + return FirewallRule{}, errors.Errorf("firewall rule ID cannot be empty") + } + + uri := fmt.Sprintf("/zones/%s/firewall/rules/%s", zoneID, firewallRule.ID) + + res, err := api.makeRequest("PUT", uri, firewallRule) + if err != nil { + return FirewallRule{}, errors.Wrap(err, errMakeRequestError) + } + + var firewallRuleResponse FirewallRuleResponse + err = json.Unmarshal(res, &firewallRuleResponse) + if err != nil { + return FirewallRule{}, errors.Wrap(err, errUnmarshalError) + } + + return firewallRuleResponse.Result, nil +} + +// UpdateFirewallRules updates a single firewall rule. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/put/#update-multiple-rules +func (api *API) UpdateFirewallRules(zoneID string, firewallRules []FirewallRule) ([]FirewallRule, error) { + for _, firewallRule := range firewallRules { + if firewallRule.ID == "" { + return []FirewallRule{}, errors.Errorf("firewall ID cannot be empty") + } + } + + uri := fmt.Sprintf("/zones/%s/firewall/rules", zoneID) + + res, err := api.makeRequest("PUT", uri, firewallRules) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errMakeRequestError) + } + + var firewallRulesDetailResponse FirewallRulesDetailResponse + err = json.Unmarshal(res, &firewallRulesDetailResponse) + if err != nil { + return []FirewallRule{}, errors.Wrap(err, errUnmarshalError) + } + + return firewallRulesDetailResponse.Result, nil +} + +// DeleteFirewallRule updates a single firewall rule. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/delete/#delete-a-single-rule +func (api *API) DeleteFirewallRule(zoneID, firewallRuleID string) error { + if firewallRuleID == "" { + return errors.Errorf("firewall rule ID cannot be empty") + } + + uri := fmt.Sprintf("/zones/%s/firewall/rules/%s", zoneID, firewallRuleID) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} + +// DeleteFirewallRules updates a single firewall rule. +// +// API reference: https://developers.cloudflare.com/firewall/api/cf-firewall-rules/delete/#delete-multiple-rules +func (api *API) DeleteFirewallRules(zoneID string, firewallRuleIDs []string) error { + ids := strings.Join(firewallRuleIDs, ",") + uri := fmt.Sprintf("/zones/%s/firewall/rules?id=%s", zoneID, ids) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/go.mod b/vendor/github.com/cloudflare/cloudflare-go/go.mod new file mode 100644 index 000000000..77e922338 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/go.mod @@ -0,0 +1,13 @@ +module github.com/cloudflare/cloudflare-go + +go 1.11 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-runewidth v0.0.4 // indirect + github.com/olekukonko/tablewriter v0.0.1 + github.com/pkg/errors v0.8.1 + github.com/stretchr/testify v1.4.0 + github.com/urfave/cli v1.22.1 + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 +) diff --git a/vendor/github.com/cloudflare/cloudflare-go/go.sum b/vendor/github.com/cloudflare/cloudflare-go/go.sum new file mode 100644 index 000000000..65391c2b1 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/go.sum @@ -0,0 +1,26 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= +github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/olekukonko/tablewriter v0.0.1 h1:b3iUnf1v+ppJiOfNX4yxxqfWKMQPZR5yoh8urCTFX88= +github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/urfave/cli v1.21.0 h1:wYSSj06510qPIzGSua9ZqsncMmWE3Zr55KBERygyrxE= +github.com/urfave/cli v1.21.0/go.mod h1:lxDj6qX9Q6lWQxIrbrT0nwecwUtRnhVZAJjJZrVUZZQ= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/github.com/cloudflare/cloudflare-go/ips.go b/vendor/github.com/cloudflare/cloudflare-go/ips.go new file mode 100644 index 000000000..72b5fcfbc --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/ips.go @@ -0,0 +1,44 @@ +package cloudflare + +import ( + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/pkg/errors" +) + +// IPRanges contains lists of IPv4 and IPv6 CIDRs. +type IPRanges struct { + IPv4CIDRs []string `json:"ipv4_cidrs"` + IPv6CIDRs []string `json:"ipv6_cidrs"` +} + +// IPsResponse is the API response containing a list of IPs. +type IPsResponse struct { + Response + Result IPRanges `json:"result"` +} + +// IPs gets a list of Cloudflare's IP ranges. +// +// This does not require logging in to the API. +// +// API reference: https://api.cloudflare.com/#cloudflare-ips +func IPs() (IPRanges, error) { + resp, err := http.Get(apiURL + "/ips") + if err != nil { + return IPRanges{}, errors.Wrap(err, "HTTP request failed") + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return IPRanges{}, errors.Wrap(err, "Response body could not be read") + } + var r IPsResponse + err = json.Unmarshal(body, &r) + if err != nil { + return IPRanges{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/keyless.go b/vendor/github.com/cloudflare/cloudflare-go/keyless.go new file mode 100644 index 000000000..c5cc83914 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/keyless.go @@ -0,0 +1,52 @@ +package cloudflare + +import "time" + +// KeylessSSL represents Keyless SSL configuration. +type KeylessSSL struct { + ID string `json:"id"` + Name string `json:"name"` + Host string `json:"host"` + Port int `json:"port"` + Status string `json:"success"` + Enabled bool `json:"enabled"` + Permissions []string `json:"permissions"` + CreatedOn time.Time `json:"created_on"` + ModifiedOn time.Time `json:"modifed_on"` +} + +// KeylessSSLResponse represents the response from the Keyless SSL endpoint. +type KeylessSSLResponse struct { + Response + Result []KeylessSSL `json:"result"` +} + +// CreateKeyless creates a new Keyless SSL configuration for the zone. +// +// API reference: https://api.cloudflare.com/#keyless-ssl-for-a-zone-create-a-keyless-ssl-configuration +func (api *API) CreateKeyless() { +} + +// ListKeyless lists Keyless SSL configurations for a zone. +// +// API reference: https://api.cloudflare.com/#keyless-ssl-for-a-zone-list-keyless-ssls +func (api *API) ListKeyless() { +} + +// Keyless provides the configuration for a given Keyless SSL identifier. +// +// API reference: https://api.cloudflare.com/#keyless-ssl-for-a-zone-keyless-ssl-details +func (api *API) Keyless() { +} + +// UpdateKeyless updates an existing Keyless SSL configuration. +// +// API reference: https://api.cloudflare.com/#keyless-ssl-for-a-zone-update-keyless-configuration +func (api *API) UpdateKeyless() { +} + +// DeleteKeyless deletes an existing Keyless SSL configuration. +// +// API reference: https://api.cloudflare.com/#keyless-ssl-for-a-zone-delete-keyless-configuration +func (api *API) DeleteKeyless() { +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/load_balancing.go b/vendor/github.com/cloudflare/cloudflare-go/load_balancing.go new file mode 100644 index 000000000..8b2f89a65 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/load_balancing.go @@ -0,0 +1,387 @@ +package cloudflare + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// LoadBalancerPool represents a load balancer pool's properties. +type LoadBalancerPool struct { + ID string `json:"id,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + ModifiedOn *time.Time `json:"modified_on,omitempty"` + Description string `json:"description"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + MinimumOrigins int `json:"minimum_origins,omitempty"` + Monitor string `json:"monitor,omitempty"` + Origins []LoadBalancerOrigin `json:"origins"` + NotificationEmail string `json:"notification_email,omitempty"` + + // CheckRegions defines the geographic region(s) from where to run health-checks from - e.g. "WNAM", "WEU", "SAF", "SAM". + // Providing a null/empty value means "all regions", which may not be available to all plan types. + CheckRegions []string `json:"check_regions"` +} + +// LoadBalancerOrigin represents a Load Balancer origin's properties. +type LoadBalancerOrigin struct { + Name string `json:"name"` + Address string `json:"address"` + Enabled bool `json:"enabled"` + Weight float64 `json:"weight"` +} + +// LoadBalancerMonitor represents a load balancer monitor's properties. +type LoadBalancerMonitor struct { + ID string `json:"id,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + ModifiedOn *time.Time `json:"modified_on,omitempty"` + Type string `json:"type"` + Description string `json:"description"` + Method string `json:"method"` + Path string `json:"path"` + Header map[string][]string `json:"header"` + Timeout int `json:"timeout"` + Retries int `json:"retries"` + Interval int `json:"interval"` + Port uint16 `json:"port,omitempty"` + ExpectedBody string `json:"expected_body"` + ExpectedCodes string `json:"expected_codes"` + FollowRedirects bool `json:"follow_redirects"` + AllowInsecure bool `json:"allow_insecure"` + ProbeZone string `json:"probe_zone"` +} + +// LoadBalancer represents a load balancer's properties. +type LoadBalancer struct { + ID string `json:"id,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + ModifiedOn *time.Time `json:"modified_on,omitempty"` + Description string `json:"description"` + Name string `json:"name"` + TTL int `json:"ttl,omitempty"` + FallbackPool string `json:"fallback_pool"` + DefaultPools []string `json:"default_pools"` + RegionPools map[string][]string `json:"region_pools"` + PopPools map[string][]string `json:"pop_pools"` + Proxied bool `json:"proxied"` + Enabled *bool `json:"enabled,omitempty"` + Persistence string `json:"session_affinity,omitempty"` + PersistenceTTL int `json:"session_affinity_ttl,omitempty"` + + // SteeringPolicy controls pool selection logic. + // "off" select pools in DefaultPools order + // "geo" select pools based on RegionPools/PopPools + // "dynamic_latency" select pools based on RTT (requires health checks) + // "random" selects pools in a random order + // "" maps to "geo" if RegionPools or PopPools have entries otherwise "off" + SteeringPolicy string `json:"steering_policy,omitempty"` +} + +// LoadBalancerOriginHealth represents the health of the origin. +type LoadBalancerOriginHealth struct { + Healthy bool `json:"healthy,omitempty"` + RTT Duration `json:"rtt,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` + ResponseCode int `json:"response_code,omitempty"` +} + +// LoadBalancerPoolPopHealth represents the health of the pool for given PoP. +type LoadBalancerPoolPopHealth struct { + Healthy bool `json:"healthy,omitempty"` + Origins []map[string]LoadBalancerOriginHealth `json:"origins,omitempty"` +} + +// LoadBalancerPoolHealth represents the healthchecks from different PoPs for a pool. +type LoadBalancerPoolHealth struct { + ID string `json:"pool_id,omitempty"` + PopHealth map[string]LoadBalancerPoolPopHealth `json:"pop_health,omitempty"` +} + +// loadBalancerPoolResponse represents the response from the load balancer pool endpoints. +type loadBalancerPoolResponse struct { + Response + Result LoadBalancerPool `json:"result"` +} + +// loadBalancerPoolListResponse represents the response from the List Pools endpoint. +type loadBalancerPoolListResponse struct { + Response + Result []LoadBalancerPool `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// loadBalancerMonitorResponse represents the response from the load balancer monitor endpoints. +type loadBalancerMonitorResponse struct { + Response + Result LoadBalancerMonitor `json:"result"` +} + +// loadBalancerMonitorListResponse represents the response from the List Monitors endpoint. +type loadBalancerMonitorListResponse struct { + Response + Result []LoadBalancerMonitor `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// loadBalancerResponse represents the response from the load balancer endpoints. +type loadBalancerResponse struct { + Response + Result LoadBalancer `json:"result"` +} + +// loadBalancerListResponse represents the response from the List Load Balancers endpoint. +type loadBalancerListResponse struct { + Response + Result []LoadBalancer `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// loadBalancerPoolHealthResponse represents the response from the Pool Health Details endpoint. +type loadBalancerPoolHealthResponse struct { + Response + Result LoadBalancerPoolHealth `json:"result"` +} + +// CreateLoadBalancerPool creates a new load balancer pool. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-create-a-pool +func (api *API) CreateLoadBalancerPool(pool LoadBalancerPool) (LoadBalancerPool, error) { + uri := api.userBaseURL("/user") + "/load_balancers/pools" + res, err := api.makeRequest("POST", uri, pool) + if err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerPoolResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListLoadBalancerPools lists load balancer pools connected to an account. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-list-pools +func (api *API) ListLoadBalancerPools() ([]LoadBalancerPool, error) { + uri := api.userBaseURL("/user") + "/load_balancers/pools" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerPoolListResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// LoadBalancerPoolDetails returns the details for a load balancer pool. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-pool-details +func (api *API) LoadBalancerPoolDetails(poolID string) (LoadBalancerPool, error) { + uri := api.userBaseURL("/user") + "/load_balancers/pools/" + poolID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerPoolResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// DeleteLoadBalancerPool disables and deletes a load balancer pool. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-delete-a-pool +func (api *API) DeleteLoadBalancerPool(poolID string) error { + uri := api.userBaseURL("/user") + "/load_balancers/pools/" + poolID + if _, err := api.makeRequest("DELETE", uri, nil); err != nil { + return errors.Wrap(err, errMakeRequestError) + } + return nil +} + +// ModifyLoadBalancerPool modifies a configured load balancer pool. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-modify-a-pool +func (api *API) ModifyLoadBalancerPool(pool LoadBalancerPool) (LoadBalancerPool, error) { + uri := api.userBaseURL("/user") + "/load_balancers/pools/" + pool.ID + res, err := api.makeRequest("PUT", uri, pool) + if err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerPoolResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerPool{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// CreateLoadBalancerMonitor creates a new load balancer monitor. +// +// API reference: https://api.cloudflare.com/#load-balancer-monitors-create-a-monitor +func (api *API) CreateLoadBalancerMonitor(monitor LoadBalancerMonitor) (LoadBalancerMonitor, error) { + uri := api.userBaseURL("/user") + "/load_balancers/monitors" + res, err := api.makeRequest("POST", uri, monitor) + if err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerMonitorResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListLoadBalancerMonitors lists load balancer monitors connected to an account. +// +// API reference: https://api.cloudflare.com/#load-balancer-monitors-list-monitors +func (api *API) ListLoadBalancerMonitors() ([]LoadBalancerMonitor, error) { + uri := api.userBaseURL("/user") + "/load_balancers/monitors" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerMonitorListResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// LoadBalancerMonitorDetails returns the details for a load balancer monitor. +// +// API reference: https://api.cloudflare.com/#load-balancer-monitors-monitor-details +func (api *API) LoadBalancerMonitorDetails(monitorID string) (LoadBalancerMonitor, error) { + uri := api.userBaseURL("/user") + "/load_balancers/monitors/" + monitorID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerMonitorResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// DeleteLoadBalancerMonitor disables and deletes a load balancer monitor. +// +// API reference: https://api.cloudflare.com/#load-balancer-monitors-delete-a-monitor +func (api *API) DeleteLoadBalancerMonitor(monitorID string) error { + uri := api.userBaseURL("/user") + "/load_balancers/monitors/" + monitorID + if _, err := api.makeRequest("DELETE", uri, nil); err != nil { + return errors.Wrap(err, errMakeRequestError) + } + return nil +} + +// ModifyLoadBalancerMonitor modifies a configured load balancer monitor. +// +// API reference: https://api.cloudflare.com/#load-balancer-monitors-modify-a-monitor +func (api *API) ModifyLoadBalancerMonitor(monitor LoadBalancerMonitor) (LoadBalancerMonitor, error) { + uri := api.userBaseURL("/user") + "/load_balancers/monitors/" + monitor.ID + res, err := api.makeRequest("PUT", uri, monitor) + if err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerMonitorResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerMonitor{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// CreateLoadBalancer creates a new load balancer. +// +// API reference: https://api.cloudflare.com/#load-balancers-create-a-load-balancer +func (api *API) CreateLoadBalancer(zoneID string, lb LoadBalancer) (LoadBalancer, error) { + uri := "/zones/" + zoneID + "/load_balancers" + res, err := api.makeRequest("POST", uri, lb) + if err != nil { + return LoadBalancer{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancer{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListLoadBalancers lists load balancers configured on a zone. +// +// API reference: https://api.cloudflare.com/#load-balancers-list-load-balancers +func (api *API) ListLoadBalancers(zoneID string) ([]LoadBalancer, error) { + uri := "/zones/" + zoneID + "/load_balancers" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerListResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// LoadBalancerDetails returns the details for a load balancer. +// +// API reference: https://api.cloudflare.com/#load-balancers-load-balancer-details +func (api *API) LoadBalancerDetails(zoneID, lbID string) (LoadBalancer, error) { + uri := "/zones/" + zoneID + "/load_balancers/" + lbID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return LoadBalancer{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancer{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// DeleteLoadBalancer disables and deletes a load balancer. +// +// API reference: https://api.cloudflare.com/#load-balancers-delete-a-load-balancer +func (api *API) DeleteLoadBalancer(zoneID, lbID string) error { + uri := "/zones/" + zoneID + "/load_balancers/" + lbID + if _, err := api.makeRequest("DELETE", uri, nil); err != nil { + return errors.Wrap(err, errMakeRequestError) + } + return nil +} + +// ModifyLoadBalancer modifies a configured load balancer. +// +// API reference: https://api.cloudflare.com/#load-balancers-modify-a-load-balancer +func (api *API) ModifyLoadBalancer(zoneID string, lb LoadBalancer) (LoadBalancer, error) { + uri := "/zones/" + zoneID + "/load_balancers/" + lb.ID + res, err := api.makeRequest("PUT", uri, lb) + if err != nil { + return LoadBalancer{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancer{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// PoolHealthDetails fetches the latest healtcheck details for a single pool. +// +// API reference: https://api.cloudflare.com/#load-balancer-pools-pool-health-details +func (api *API) PoolHealthDetails(poolID string) (LoadBalancerPoolHealth, error) { + uri := api.userBaseURL("/user") + "/load_balancers/pools/" + poolID + "/health" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return LoadBalancerPoolHealth{}, errors.Wrap(err, errMakeRequestError) + } + var r loadBalancerPoolHealthResponse + if err := json.Unmarshal(res, &r); err != nil { + return LoadBalancerPoolHealth{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/lockdown.go b/vendor/github.com/cloudflare/cloudflare-go/lockdown.go new file mode 100644 index 000000000..164129bc5 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/lockdown.go @@ -0,0 +1,151 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// ZoneLockdown represents a Zone Lockdown rule. A rule only permits access to +// the provided URL pattern(s) from the given IP address(es) or subnet(s). +type ZoneLockdown struct { + ID string `json:"id"` + Description string `json:"description"` + URLs []string `json:"urls"` + Configurations []ZoneLockdownConfig `json:"configurations"` + Paused bool `json:"paused"` + Priority int `json:"priority,omitempty"` +} + +// ZoneLockdownConfig represents a Zone Lockdown config, which comprises +// a Target ("ip" or "ip_range") and a Value (an IP address or IP+mask, +// respectively.) +type ZoneLockdownConfig struct { + Target string `json:"target"` + Value string `json:"value"` +} + +// ZoneLockdownResponse represents a response from the Zone Lockdown endpoint. +type ZoneLockdownResponse struct { + Result ZoneLockdown `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// ZoneLockdownListResponse represents a response from the List Zone Lockdown +// endpoint. +type ZoneLockdownListResponse struct { + Result []ZoneLockdown `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// CreateZoneLockdown creates a Zone ZoneLockdown rule for the given zone ID. +// +// API reference: https://api.cloudflare.com/#zone-ZoneLockdown-create-a-ZoneLockdown-rule +func (api *API) CreateZoneLockdown(zoneID string, ld ZoneLockdown) (*ZoneLockdownResponse, error) { + uri := "/zones/" + zoneID + "/firewall/lockdowns" + res, err := api.makeRequest("POST", uri, ld) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneLockdownResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// UpdateZoneLockdown updates a Zone ZoneLockdown rule (based on the ID) for the +// given zone ID. +// +// API reference: https://api.cloudflare.com/#zone-ZoneLockdown-update-ZoneLockdown-rule +func (api *API) UpdateZoneLockdown(zoneID string, id string, ld ZoneLockdown) (*ZoneLockdownResponse, error) { + uri := "/zones/" + zoneID + "/firewall/lockdowns/" + id + res, err := api.makeRequest("PUT", uri, ld) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneLockdownResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// DeleteZoneLockdown deletes a Zone ZoneLockdown rule (based on the ID) for the +// given zone ID. +// +// API reference: https://api.cloudflare.com/#zone-ZoneLockdown-delete-ZoneLockdown-rule +func (api *API) DeleteZoneLockdown(zoneID string, id string) (*ZoneLockdownResponse, error) { + uri := "/zones/" + zoneID + "/firewall/lockdowns/" + id + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneLockdownResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// ZoneLockdown retrieves a Zone ZoneLockdown rule (based on the ID) for the +// given zone ID. +// +// API reference: https://api.cloudflare.com/#zone-ZoneLockdown-ZoneLockdown-rule-details +func (api *API) ZoneLockdown(zoneID string, id string) (*ZoneLockdownResponse, error) { + uri := "/zones/" + zoneID + "/firewall/lockdowns/" + id + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneLockdownResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// ListZoneLockdowns retrieves a list of Zone ZoneLockdown rules for a given +// zone ID by page number. +// +// API reference: https://api.cloudflare.com/#zone-ZoneLockdown-list-ZoneLockdown-rules +func (api *API) ListZoneLockdowns(zoneID string, page int) (*ZoneLockdownListResponse, error) { + v := url.Values{} + if page <= 0 { + page = 1 + } + + v.Set("page", strconv.Itoa(page)) + v.Set("per_page", strconv.Itoa(100)) + query := "?" + v.Encode() + + uri := "/zones/" + zoneID + "/firewall/lockdowns" + query + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneLockdownListResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/logpush.go b/vendor/github.com/cloudflare/cloudflare-go/logpush.go new file mode 100644 index 000000000..a0134aded --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/logpush.go @@ -0,0 +1,224 @@ +package cloudflare + +import ( + "encoding/json" + "strconv" + "time" + + "github.com/pkg/errors" +) + +// LogpushJob describes a Logpush job. +type LogpushJob struct { + ID int `json:"id,omitempty"` + Enabled bool `json:"enabled"` + Name string `json:"name"` + LogpullOptions string `json:"logpull_options"` + DestinationConf string `json:"destination_conf"` + OwnershipChallenge string `json:"ownership_challenge,omitempty"` + LastComplete *time.Time `json:"last_complete,omitempty"` + LastError *time.Time `json:"last_error,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` +} + +// LogpushJobsResponse is the API response, containing an array of Logpush Jobs. +type LogpushJobsResponse struct { + Response + Result []LogpushJob `json:"result"` +} + +// LogpushJobDetailsResponse is the API response, containing a single Logpush Job. +type LogpushJobDetailsResponse struct { + Response + Result LogpushJob `json:"result"` +} + +// LogpushGetOwnershipChallenge describes a ownership validation. +type LogpushGetOwnershipChallenge struct { + Filename string `json:"filename"` + Valid bool `json:"valid"` + Message string `json:"message"` +} + +// LogpushGetOwnershipChallengeResponse is the API response, containing a ownership challenge. +type LogpushGetOwnershipChallengeResponse struct { + Response + Result LogpushGetOwnershipChallenge `json:"result"` +} + +// LogpushGetOwnershipChallengeRequest is the API request for get ownership challenge. +type LogpushGetOwnershipChallengeRequest struct { + DestinationConf string `json:"destination_conf"` +} + +// LogpushOwnershipChallangeValidationResponse is the API response, +// containing a ownership challenge validation result. +type LogpushOwnershipChallangeValidationResponse struct { + Response + Result struct { + Valid bool `json:"valid"` + } +} + +// LogpushValidateOwnershipChallengeRequest is the API request for validate ownership challenge. +type LogpushValidateOwnershipChallengeRequest struct { + DestinationConf string `json:"destination_conf"` + OwnershipChallenge string `json:"ownership_challenge"` +} + +// LogpushDestinationExistsResponse is the API response, +// containing a destination exists check result. +type LogpushDestinationExistsResponse struct { + Response + Result struct { + Exists bool `json:"exists"` + } +} + +// LogpushDestinationExistsRequest is the API request for check destination exists. +type LogpushDestinationExistsRequest struct { + DestinationConf string `json:"destination_conf"` +} + +// CreateLogpushJob creates a new LogpushJob for a zone. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-create-logpush-job +func (api *API) CreateLogpushJob(zoneID string, job LogpushJob) (*LogpushJob, error) { + uri := "/zones/" + zoneID + "/logpush/jobs" + res, err := api.makeRequest("POST", uri, job) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r LogpushJobDetailsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return &r.Result, nil +} + +// LogpushJobs returns all Logpush Jobs for a zone. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-list-logpush-jobs +func (api *API) LogpushJobs(zoneID string) ([]LogpushJob, error) { + uri := "/zones/" + zoneID + "/logpush/jobs" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []LogpushJob{}, errors.Wrap(err, errMakeRequestError) + } + var r LogpushJobsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []LogpushJob{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// LogpushJob fetches detail about one Logpush Job for a zone. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-logpush-job-details +func (api *API) LogpushJob(zoneID string, jobID int) (LogpushJob, error) { + uri := "/zones/" + zoneID + "/logpush/jobs/" + strconv.Itoa(jobID) + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return LogpushJob{}, errors.Wrap(err, errMakeRequestError) + } + var r LogpushJobDetailsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return LogpushJob{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// UpdateLogpushJob lets you update a Logpush Job. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-update-logpush-job +func (api *API) UpdateLogpushJob(zoneID string, jobID int, job LogpushJob) error { + uri := "/zones/" + zoneID + "/logpush/jobs/" + strconv.Itoa(jobID) + res, err := api.makeRequest("PUT", uri, job) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r LogpushJobDetailsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} + +// DeleteLogpushJob deletes a Logpush Job for a zone. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-delete-logpush-job +func (api *API) DeleteLogpushJob(zoneID string, jobID int) error { + uri := "/zones/" + zoneID + "/logpush/jobs/" + strconv.Itoa(jobID) + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r LogpushJobDetailsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} + +// GetLogpushOwnershipChallenge returns ownership challenge. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-get-ownership-challenge +func (api *API) GetLogpushOwnershipChallenge(zoneID, destinationConf string) (*LogpushGetOwnershipChallenge, error) { + uri := "/zones/" + zoneID + "/logpush/ownership" + res, err := api.makeRequest("POST", uri, LogpushGetOwnershipChallengeRequest{ + DestinationConf: destinationConf, + }) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r LogpushGetOwnershipChallengeResponse + err = json.Unmarshal(res, &r) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return &r.Result, nil +} + +// ValidateLogpushOwnershipChallenge returns ownership challenge validation result. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-validate-ownership-challenge +func (api *API) ValidateLogpushOwnershipChallenge(zoneID, destinationConf, ownershipChallenge string) (bool, error) { + uri := "/zones/" + zoneID + "/logpush/ownership/validate" + res, err := api.makeRequest("POST", uri, LogpushValidateOwnershipChallengeRequest{ + DestinationConf: destinationConf, + OwnershipChallenge: ownershipChallenge, + }) + if err != nil { + return false, errors.Wrap(err, errMakeRequestError) + } + var r LogpushGetOwnershipChallengeResponse + err = json.Unmarshal(res, &r) + if err != nil { + return false, errors.Wrap(err, errUnmarshalError) + } + return r.Result.Valid, nil +} + +// CheckLogpushDestinationExists returns destination exists check result. +// +// API reference: https://api.cloudflare.com/#logpush-jobs-check-destination-exists +func (api *API) CheckLogpushDestinationExists(zoneID, destinationConf string) (bool, error) { + uri := "/zones/" + zoneID + "/logpush/validate/destination/exists" + res, err := api.makeRequest("POST", uri, LogpushDestinationExistsRequest{ + DestinationConf: destinationConf, + }) + if err != nil { + return false, errors.Wrap(err, errMakeRequestError) + } + var r LogpushDestinationExistsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return false, errors.Wrap(err, errUnmarshalError) + } + return r.Result.Exists, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/options.go b/vendor/github.com/cloudflare/cloudflare-go/options.go new file mode 100644 index 000000000..1bf4f60bd --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/options.go @@ -0,0 +1,101 @@ +package cloudflare + +import ( + "net/http" + + "time" + + "golang.org/x/time/rate" +) + +// Option is a functional option for configuring the API client. +type Option func(*API) error + +// HTTPClient accepts a custom *http.Client for making API calls. +func HTTPClient(client *http.Client) Option { + return func(api *API) error { + api.httpClient = client + return nil + } +} + +// Headers allows you to set custom HTTP headers when making API calls (e.g. for +// satisfying HTTP proxies, or for debugging). +func Headers(headers http.Header) Option { + return func(api *API) error { + api.headers = headers + return nil + } +} + +// UsingAccount allows you to apply account-level changes (Load Balancing, +// Railguns) to an account instead. +func UsingAccount(accountID string) Option { + return func(api *API) error { + api.AccountID = accountID + return nil + } +} + +// UsingRateLimit applies a non-default rate limit to client API requests +// If not specified the default of 4rps will be applied +func UsingRateLimit(rps float64) Option { + return func(api *API) error { + // because ratelimiter doesnt do any windowing + // setting burst makes it difficult to enforce a fixed rate + // so setting it equal to 1 this effectively disables bursting + // this doesn't check for sensible values, ultimately the api will enforce that the value is ok + api.rateLimiter = rate.NewLimiter(rate.Limit(rps), 1) + return nil + } +} + +// UsingRetryPolicy applies a non-default number of retries and min/max retry delays +// This will be used when the client exponentially backs off after errored requests +func UsingRetryPolicy(maxRetries int, minRetryDelaySecs int, maxRetryDelaySecs int) Option { + // seconds is very granular for a minimum delay - but this is only in case of failure + return func(api *API) error { + api.retryPolicy = RetryPolicy{ + MaxRetries: maxRetries, + MinRetryDelay: time.Duration(minRetryDelaySecs) * time.Second, + MaxRetryDelay: time.Duration(maxRetryDelaySecs) * time.Second, + } + return nil + } +} + +// UsingLogger can be set if you want to get log output from this API instance +// By default no log output is emitted +func UsingLogger(logger Logger) Option { + return func(api *API) error { + api.logger = logger + return nil + } +} + +// UserAgent can be set if you want to send a software name and version for HTTP access logs. +// It is recommended to set it in order to help future Customer Support diagnostics +// and prevent collateral damage by sharing generic User-Agent string with abusive users. +// E.g. "my-software/1.2.3". By default generic Go User-Agent is used. +func UserAgent(userAgent string) Option { + return func(api *API) error { + api.UserAgent = userAgent + return nil + } +} + +// parseOptions parses the supplied options functions and returns a configured +// *API instance. +func (api *API) parseOptions(opts ...Option) error { + // Range over each options function and apply it to our API type to + // configure it. Options functions are applied in order, with any + // conflicting options overriding earlier calls. + for _, option := range opts { + err := option(api) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/origin_ca.go b/vendor/github.com/cloudflare/cloudflare-go/origin_ca.go new file mode 100644 index 000000000..fdd8c4273 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/origin_ca.go @@ -0,0 +1,169 @@ +package cloudflare + +import ( + "context" + "encoding/json" + "net/url" + "time" + + "github.com/pkg/errors" +) + +// OriginCACertificate represents a Cloudflare-issued certificate. +// +// API reference: https://api.cloudflare.com/#cloudflare-ca +type OriginCACertificate struct { + ID string `json:"id"` + Certificate string `json:"certificate"` + Hostnames []string `json:"hostnames"` + ExpiresOn time.Time `json:"expires_on"` + RequestType string `json:"request_type"` + RequestValidity int `json:"requested_validity"` + CSR string `json:"csr"` +} + +// OriginCACertificateListOptions represents the parameters used to list Cloudflare-issued certificates. +type OriginCACertificateListOptions struct { + ZoneID string +} + +// OriginCACertificateID represents the ID of the revoked certificate from the Revoke Certificate endpoint. +type OriginCACertificateID struct { + ID string `json:"id"` +} + +// originCACertificateResponse represents the response from the Create Certificate and the Certificate Details endpoints. +type originCACertificateResponse struct { + Response + Result OriginCACertificate `json:"result"` +} + +// originCACertificateResponseList represents the response from the List Certificates endpoint. +type originCACertificateResponseList struct { + Response + Result []OriginCACertificate `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// originCACertificateResponseRevoke represents the response from the Revoke Certificate endpoint. +type originCACertificateResponseRevoke struct { + Response + Result OriginCACertificateID `json:"result"` +} + +// CreateOriginCertificate creates a Cloudflare-signed certificate. +// +// This function requires api.APIUserServiceKey be set to your Certificates API key. +// +// API reference: https://api.cloudflare.com/#cloudflare-ca-create-certificate +func (api *API) CreateOriginCertificate(certificate OriginCACertificate) (*OriginCACertificate, error) { + uri := "/certificates" + res, err := api.makeRequestWithAuthType(context.TODO(), "POST", uri, certificate, AuthUserService) + + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var originResponse *originCACertificateResponse + + err = json.Unmarshal(res, &originResponse) + + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + if !originResponse.Success { + return nil, errors.New(errRequestNotSuccessful) + } + + return &originResponse.Result, nil +} + +// OriginCertificates lists all Cloudflare-issued certificates. +// +// This function requires api.APIUserServiceKey be set to your Certificates API key. +// +// API reference: https://api.cloudflare.com/#cloudflare-ca-list-certificates +func (api *API) OriginCertificates(options OriginCACertificateListOptions) ([]OriginCACertificate, error) { + v := url.Values{} + if options.ZoneID != "" { + v.Set("zone_id", options.ZoneID) + } + uri := "/certificates" + "?" + v.Encode() + res, err := api.makeRequestWithAuthType(context.TODO(), "GET", uri, nil, AuthUserService) + + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var originResponse *originCACertificateResponseList + + err = json.Unmarshal(res, &originResponse) + + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + if !originResponse.Success { + return nil, errors.New(errRequestNotSuccessful) + } + + return originResponse.Result, nil +} + +// OriginCertificate returns the details for a Cloudflare-issued certificate. +// +// This function requires api.APIUserServiceKey be set to your Certificates API key. +// +// API reference: https://api.cloudflare.com/#cloudflare-ca-certificate-details +func (api *API) OriginCertificate(certificateID string) (*OriginCACertificate, error) { + uri := "/certificates/" + certificateID + res, err := api.makeRequestWithAuthType(context.TODO(), "GET", uri, nil, AuthUserService) + + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var originResponse *originCACertificateResponse + + err = json.Unmarshal(res, &originResponse) + + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + if !originResponse.Success { + return nil, errors.New(errRequestNotSuccessful) + } + + return &originResponse.Result, nil +} + +// RevokeOriginCertificate revokes a created certificate for a zone. +// +// This function requires api.APIUserServiceKey be set to your Certificates API key. +// +// API reference: https://api.cloudflare.com/#cloudflare-ca-revoke-certificate +func (api *API) RevokeOriginCertificate(certificateID string) (*OriginCACertificateID, error) { + uri := "/certificates/" + certificateID + res, err := api.makeRequestWithAuthType(context.TODO(), "DELETE", uri, nil, AuthUserService) + + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + var originResponse *originCACertificateResponseRevoke + + err = json.Unmarshal(res, &originResponse) + + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + if !originResponse.Success { + return nil, errors.New(errRequestNotSuccessful) + } + + return &originResponse.Result, nil + +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/page_rules.go b/vendor/github.com/cloudflare/cloudflare-go/page_rules.go new file mode 100644 index 000000000..36f62e62f --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/page_rules.go @@ -0,0 +1,235 @@ +package cloudflare + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// PageRuleTarget is the target to evaluate on a request. +// +// Currently Target must always be "url" and Operator must be "matches". Value +// is the URL pattern to match against. +type PageRuleTarget struct { + Target string `json:"target"` + Constraint struct { + Operator string `json:"operator"` + Value string `json:"value"` + } `json:"constraint"` +} + +/* +PageRuleAction is the action to take when the target is matched. + +Valid IDs are: + always_online + always_use_https + automatic_https_rewrites + browser_cache_ttl + browser_check + bypass_cache_on_cookie + cache_by_device_type + cache_deception_armor + cache_level + cache_on_cookie + disable_apps + disable_performance + disable_railgun + disable_security + edge_cache_ttl + email_obfuscation + explicit_cache_control + forwarding_url + host_header_override + ip_geolocation + minify + mirage + opportunistic_encryption + origin_error_page_pass_thru + polish + resolve_override + respect_strong_etag + response_buffering + rocket_loader + security_level + server_side_exclude + sort_query_string_for_cache + ssl + true_client_ip_header + waf +*/ +type PageRuleAction struct { + ID string `json:"id"` + Value interface{} `json:"value"` +} + +// PageRuleActions maps API action IDs to human-readable strings. +var PageRuleActions = map[string]string{ + "always_online": "Always Online", // Value of type string + "always_use_https": "Always Use HTTPS", // Value of type interface{} + "automatic_https_rewrites": "Automatic HTTPS Rewrites", // Value of type string + "browser_cache_ttl": "Browser Cache TTL", // Value of type int + "browser_check": "Browser Integrity Check", // Value of type string + "bypass_cache_on_cookie": "Bypass Cache on Cookie", // Value of type string + "cache_by_device_type": "Cache By Device Type", // Value of type string + "cache_deception_armor": "Cache Deception Armor", // Value of type string + "cache_level": "Cache Level", // Value of type string + "cache_on_cookie": "Cache On Cookie", // Value of type string + "disable_apps": "Disable Apps", // Value of type interface{} + "disable_performance": "Disable Performance", // Value of type interface{} + "disable_railgun": "Disable Railgun", // Value of type string + "disable_security": "Disable Security", // Value of type interface{} + "edge_cache_ttl": "Edge Cache TTL", // Value of type int + "email_obfuscation": "Email Obfuscation", // Value of type string + "explicit_cache_control": "Origin Cache Control", // Value of type string + "forwarding_url": "Forwarding URL", // Value of type map[string]interface + "host_header_override": "Host Header Override", // Value of type string + "ip_geolocation": "IP Geolocation Header", // Value of type string + "minify": "Minify", // Value of type map[string]interface + "mirage": "Mirage", // Value of type string + "opportunistic_encryption": "Opportunistic Encryption", // Value of type string + "origin_error_page_pass_thru": "Origin Error Page Pass-thru", // Value of type string + "polish": "Polish", // Value of type string + "resolve_override": "Resolve Override", // Value of type string + "respect_strong_etag": "Respect Strong ETags", // Value of type string + "response_buffering": "Response Buffering", // Value of type string + "rocket_loader": "Rocker Loader", // Value of type string + "security_level": "Security Level", // Value of type string + "server_side_exclude": "Server Side Excludes", // Value of type string + "sort_query_string_for_cache": "Query String Sort", // Value of type string + "ssl": "SSL", // Value of type string + "true_client_ip_header": "True Client IP Header", // Value of type string + "waf": "Web Application Firewall", // Value of type string +} + +// PageRule describes a Page Rule. +type PageRule struct { + ID string `json:"id,omitempty"` + Targets []PageRuleTarget `json:"targets"` + Actions []PageRuleAction `json:"actions"` + Priority int `json:"priority"` + Status string `json:"status"` // can be: active, paused + ModifiedOn time.Time `json:"modified_on,omitempty"` + CreatedOn time.Time `json:"created_on,omitempty"` +} + +// PageRuleDetailResponse is the API response, containing a single PageRule. +type PageRuleDetailResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result PageRule `json:"result"` +} + +// PageRulesResponse is the API response, containing an array of PageRules. +type PageRulesResponse struct { + Success bool `json:"success"` + Errors []string `json:"errors"` + Messages []string `json:"messages"` + Result []PageRule `json:"result"` +} + +// CreatePageRule creates a new Page Rule for a zone. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-create-a-page-rule +func (api *API) CreatePageRule(zoneID string, rule PageRule) (*PageRule, error) { + uri := "/zones/" + zoneID + "/pagerules" + res, err := api.makeRequest("POST", uri, rule) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r PageRuleDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return &r.Result, nil +} + +// ListPageRules returns all Page Rules for a zone. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-list-page-rules +func (api *API) ListPageRules(zoneID string) ([]PageRule, error) { + uri := "/zones/" + zoneID + "/pagerules" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []PageRule{}, errors.Wrap(err, errMakeRequestError) + } + var r PageRulesResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []PageRule{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// PageRule fetches detail about one Page Rule for a zone. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-page-rule-details +func (api *API) PageRule(zoneID, ruleID string) (PageRule, error) { + uri := "/zones/" + zoneID + "/pagerules/" + ruleID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return PageRule{}, errors.Wrap(err, errMakeRequestError) + } + var r PageRuleDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return PageRule{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ChangePageRule lets you change individual settings for a Page Rule. This is +// in contrast to UpdatePageRule which replaces the entire Page Rule. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-change-a-page-rule +func (api *API) ChangePageRule(zoneID, ruleID string, rule PageRule) error { + uri := "/zones/" + zoneID + "/pagerules/" + ruleID + res, err := api.makeRequest("PATCH", uri, rule) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r PageRuleDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} + +// UpdatePageRule lets you replace a Page Rule. This is in contrast to +// ChangePageRule which lets you change individual settings. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-update-a-page-rule +func (api *API) UpdatePageRule(zoneID, ruleID string, rule PageRule) error { + uri := "/zones/" + zoneID + "/pagerules/" + ruleID + res, err := api.makeRequest("PUT", uri, rule) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r PageRuleDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} + +// DeletePageRule deletes a Page Rule for a zone. +// +// API reference: https://api.cloudflare.com/#page-rules-for-a-zone-delete-a-page-rule +func (api *API) DeletePageRule(zoneID, ruleID string) error { + uri := "/zones/" + zoneID + "/pagerules/" + ruleID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r PageRuleDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/railgun.go b/vendor/github.com/cloudflare/cloudflare-go/railgun.go new file mode 100644 index 000000000..72d228691 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/railgun.go @@ -0,0 +1,297 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "time" + + "github.com/pkg/errors" +) + +// Railgun represents a Railgun's properties. +type Railgun struct { + ID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + Enabled bool `json:"enabled"` + ZonesConnected int `json:"zones_connected"` + Build string `json:"build"` + Version string `json:"version"` + Revision string `json:"revision"` + ActivationKey string `json:"activation_key"` + ActivatedOn time.Time `json:"activated_on"` + CreatedOn time.Time `json:"created_on"` + ModifiedOn time.Time `json:"modified_on"` + UpgradeInfo struct { + LatestVersion string `json:"latest_version"` + DownloadLink string `json:"download_link"` + } `json:"upgrade_info"` +} + +// RailgunListOptions represents the parameters used to list railguns. +type RailgunListOptions struct { + Direction string +} + +// railgunResponse represents the response from the Create Railgun and the Railgun Details endpoints. +type railgunResponse struct { + Response + Result Railgun `json:"result"` +} + +// railgunsResponse represents the response from the List Railguns endpoint. +type railgunsResponse struct { + Response + Result []Railgun `json:"result"` +} + +// CreateRailgun creates a new Railgun. +// +// API reference: https://api.cloudflare.com/#railgun-create-railgun +func (api *API) CreateRailgun(name string) (Railgun, error) { + uri := api.userBaseURL("") + "/railguns" + params := struct { + Name string `json:"name"` + }{ + Name: name, + } + res, err := api.makeRequest("POST", uri, params) + if err != nil { + return Railgun{}, errors.Wrap(err, errMakeRequestError) + } + var r railgunResponse + if err := json.Unmarshal(res, &r); err != nil { + return Railgun{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListRailguns lists Railguns connected to an account. +// +// API reference: https://api.cloudflare.com/#railgun-list-railguns +func (api *API) ListRailguns(options RailgunListOptions) ([]Railgun, error) { + v := url.Values{} + if options.Direction != "" { + v.Set("direction", options.Direction) + } + uri := api.userBaseURL("") + "/railguns" + "?" + v.Encode() + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r railgunsResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// RailgunDetails returns the details for a Railgun. +// +// API reference: https://api.cloudflare.com/#railgun-railgun-details +func (api *API) RailgunDetails(railgunID string) (Railgun, error) { + uri := api.userBaseURL("") + "/railguns/" + railgunID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return Railgun{}, errors.Wrap(err, errMakeRequestError) + } + var r railgunResponse + if err := json.Unmarshal(res, &r); err != nil { + return Railgun{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// RailgunZones returns the zones that are currently using a Railgun. +// +// API reference: https://api.cloudflare.com/#railgun-get-zones-connected-to-a-railgun +func (api *API) RailgunZones(railgunID string) ([]Zone, error) { + uri := api.userBaseURL("") + "/railguns/" + railgunID + "/zones" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r ZonesResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// enableRailgun enables (true) or disables (false) a Railgun for all zones connected to it. +// +// API reference: https://api.cloudflare.com/#railgun-enable-or-disable-a-railgun +func (api *API) enableRailgun(railgunID string, enable bool) (Railgun, error) { + uri := api.userBaseURL("") + "/railguns/" + railgunID + params := struct { + Enabled bool `json:"enabled"` + }{ + Enabled: enable, + } + res, err := api.makeRequest("PATCH", uri, params) + if err != nil { + return Railgun{}, errors.Wrap(err, errMakeRequestError) + } + var r railgunResponse + if err := json.Unmarshal(res, &r); err != nil { + return Railgun{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// EnableRailgun enables a Railgun for all zones connected to it. +// +// API reference: https://api.cloudflare.com/#railgun-enable-or-disable-a-railgun +func (api *API) EnableRailgun(railgunID string) (Railgun, error) { + return api.enableRailgun(railgunID, true) +} + +// DisableRailgun enables a Railgun for all zones connected to it. +// +// API reference: https://api.cloudflare.com/#railgun-enable-or-disable-a-railgun +func (api *API) DisableRailgun(railgunID string) (Railgun, error) { + return api.enableRailgun(railgunID, false) +} + +// DeleteRailgun disables and deletes a Railgun. +// +// API reference: https://api.cloudflare.com/#railgun-delete-railgun +func (api *API) DeleteRailgun(railgunID string) error { + uri := api.userBaseURL("") + "/railguns/" + railgunID + if _, err := api.makeRequest("DELETE", uri, nil); err != nil { + return errors.Wrap(err, errMakeRequestError) + } + return nil +} + +// ZoneRailgun represents the status of a Railgun on a zone. +type ZoneRailgun struct { + ID string `json:"id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Connected bool `json:"connected"` +} + +// zoneRailgunResponse represents the response from the Zone Railgun Details endpoint. +type zoneRailgunResponse struct { + Response + Result ZoneRailgun `json:"result"` +} + +// zoneRailgunsResponse represents the response from the Zone Railgun endpoint. +type zoneRailgunsResponse struct { + Response + Result []ZoneRailgun `json:"result"` +} + +// RailgunDiagnosis represents the test results from testing railgun connections +// to a zone. +type RailgunDiagnosis struct { + Method string `json:"method"` + HostName string `json:"host_name"` + HTTPStatus int `json:"http_status"` + Railgun string `json:"railgun"` + URL string `json:"url"` + ResponseStatus string `json:"response_status"` + Protocol string `json:"protocol"` + ElapsedTime string `json:"elapsed_time"` + BodySize string `json:"body_size"` + BodyHash string `json:"body_hash"` + MissingHeaders string `json:"missing_headers"` + ConnectionClose bool `json:"connection_close"` + Cloudflare string `json:"cloudflare"` + CFRay string `json:"cf-ray"` + // NOTE: Cloudflare's online API documentation does not yet have definitions + // for the following fields. See: https://api.cloudflare.com/#railgun-connections-for-a-zone-test-railgun-connection/ + CFWANError string `json:"cf-wan-error"` + CFCacheStatus string `json:"cf-cache-status"` +} + +// railgunDiagnosisResponse represents the response from the Test Railgun Connection enpoint. +type railgunDiagnosisResponse struct { + Response + Result RailgunDiagnosis `json:"result"` +} + +// ZoneRailguns returns the available Railguns for a zone. +// +// API reference: https://api.cloudflare.com/#railguns-for-a-zone-get-available-railguns +func (api *API) ZoneRailguns(zoneID string) ([]ZoneRailgun, error) { + uri := "/zones/" + zoneID + "/railguns" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r zoneRailgunsResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ZoneRailgunDetails returns the configuration for a given Railgun. +// +// API reference: https://api.cloudflare.com/#railguns-for-a-zone-get-railgun-details +func (api *API) ZoneRailgunDetails(zoneID, railgunID string) (ZoneRailgun, error) { + uri := "/zones/" + zoneID + "/railguns/" + railgunID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ZoneRailgun{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneRailgunResponse + if err := json.Unmarshal(res, &r); err != nil { + return ZoneRailgun{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// TestRailgunConnection tests a Railgun connection for a given zone. +// +// API reference: https://api.cloudflare.com/#railgun-connections-for-a-zone-test-railgun-connection +func (api *API) TestRailgunConnection(zoneID, railgunID string) (RailgunDiagnosis, error) { + uri := "/zones/" + zoneID + "/railguns/" + railgunID + "/diagnose" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return RailgunDiagnosis{}, errors.Wrap(err, errMakeRequestError) + } + var r railgunDiagnosisResponse + if err := json.Unmarshal(res, &r); err != nil { + return RailgunDiagnosis{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// connectZoneRailgun connects (true) or disconnects (false) a Railgun for a given zone. +// +// API reference: https://api.cloudflare.com/#railguns-for-a-zone-connect-or-disconnect-a-railgun +func (api *API) connectZoneRailgun(zoneID, railgunID string, connect bool) (ZoneRailgun, error) { + uri := "/zones/" + zoneID + "/railguns/" + railgunID + params := struct { + Connected bool `json:"connected"` + }{ + Connected: connect, + } + res, err := api.makeRequest("PATCH", uri, params) + if err != nil { + return ZoneRailgun{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneRailgunResponse + if err := json.Unmarshal(res, &r); err != nil { + return ZoneRailgun{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ConnectZoneRailgun connects a Railgun for a given zone. +// +// API reference: https://api.cloudflare.com/#railguns-for-a-zone-connect-or-disconnect-a-railgun +func (api *API) ConnectZoneRailgun(zoneID, railgunID string) (ZoneRailgun, error) { + return api.connectZoneRailgun(zoneID, railgunID, true) +} + +// DisconnectZoneRailgun disconnects a Railgun for a given zone. +// +// API reference: https://api.cloudflare.com/#railguns-for-a-zone-connect-or-disconnect-a-railgun +func (api *API) DisconnectZoneRailgun(zoneID, railgunID string) (ZoneRailgun, error) { + return api.connectZoneRailgun(zoneID, railgunID, false) +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/rate_limiting.go b/vendor/github.com/cloudflare/cloudflare-go/rate_limiting.go new file mode 100644 index 000000000..e3eb3e2e7 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/rate_limiting.go @@ -0,0 +1,210 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// RateLimit is a policy than can be applied to limit traffic within a customer domain +type RateLimit struct { + ID string `json:"id,omitempty"` + Disabled bool `json:"disabled,omitempty"` + Description string `json:"description,omitempty"` + Match RateLimitTrafficMatcher `json:"match"` + Bypass []RateLimitKeyValue `json:"bypass,omitempty"` + Threshold int `json:"threshold"` + Period int `json:"period"` + Action RateLimitAction `json:"action"` + Correlate *RateLimitCorrelate `json:"correlate,omitempty"` +} + +// RateLimitTrafficMatcher contains the rules that will be used to apply a rate limit to traffic +type RateLimitTrafficMatcher struct { + Request RateLimitRequestMatcher `json:"request"` + Response RateLimitResponseMatcher `json:"response"` +} + +// RateLimitRequestMatcher contains the matching rules pertaining to requests +type RateLimitRequestMatcher struct { + Methods []string `json:"methods,omitempty"` + Schemes []string `json:"schemes,omitempty"` + URLPattern string `json:"url,omitempty"` +} + +// RateLimitResponseMatcher contains the matching rules pertaining to responses +type RateLimitResponseMatcher struct { + Statuses []int `json:"status,omitempty"` + OriginTraffic *bool `json:"origin_traffic,omitempty"` // api defaults to true so we need an explicit empty value + Headers []RateLimitResponseMatcherHeader `json:"headers,omitempty"` +} + +// RateLimitResponseMatcherHeader contains the structure of the origin +// HTTP headers used in request matcher checks. +type RateLimitResponseMatcherHeader struct { + Name string `json:"name"` + Op string `json:"op"` + Value string `json:"value"` +} + +// RateLimitKeyValue is k-v formatted as expected in the rate limit description +type RateLimitKeyValue struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// RateLimitAction is the action that will be taken when the rate limit threshold is reached +type RateLimitAction struct { + Mode string `json:"mode"` + Timeout int `json:"timeout"` + Response *RateLimitActionResponse `json:"response"` +} + +// RateLimitActionResponse is the response that will be returned when rate limit action is triggered +type RateLimitActionResponse struct { + ContentType string `json:"content_type"` + Body string `json:"body"` +} + +// RateLimitCorrelate pertainings to NAT support +type RateLimitCorrelate struct { + By string `json:"by"` +} + +type rateLimitResponse struct { + Response + Result RateLimit `json:"result"` +} + +type rateLimitListResponse struct { + Response + Result []RateLimit `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// CreateRateLimit creates a new rate limit for a zone. +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-create-a-ratelimit +func (api *API) CreateRateLimit(zoneID string, limit RateLimit) (RateLimit, error) { + uri := "/zones/" + zoneID + "/rate_limits" + res, err := api.makeRequest("POST", uri, limit) + if err != nil { + return RateLimit{}, errors.Wrap(err, errMakeRequestError) + } + var r rateLimitResponse + if err := json.Unmarshal(res, &r); err != nil { + return RateLimit{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListRateLimits returns Rate Limits for a zone, paginated according to the provided options +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-list-rate-limits +func (api *API) ListRateLimits(zoneID string, pageOpts PaginationOptions) ([]RateLimit, ResultInfo, error) { + v := url.Values{} + if pageOpts.PerPage > 0 { + v.Set("per_page", strconv.Itoa(pageOpts.PerPage)) + } + if pageOpts.Page > 0 { + v.Set("page", strconv.Itoa(pageOpts.Page)) + } + + uri := "/zones/" + zoneID + "/rate_limits" + if len(v) > 0 { + uri = uri + "?" + v.Encode() + } + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []RateLimit{}, ResultInfo{}, errors.Wrap(err, errMakeRequestError) + } + + var r rateLimitListResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []RateLimit{}, ResultInfo{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, r.ResultInfo, nil +} + +// ListAllRateLimits returns all Rate Limits for a zone. +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-list-rate-limits +func (api *API) ListAllRateLimits(zoneID string) ([]RateLimit, error) { + pageOpts := PaginationOptions{ + PerPage: 100, // this is the max page size allowed + Page: 1, + } + + allRateLimits := make([]RateLimit, 0) + for { + rateLimits, resultInfo, err := api.ListRateLimits(zoneID, pageOpts) + if err != nil { + return []RateLimit{}, err + } + allRateLimits = append(allRateLimits, rateLimits...) + // total pages is not returned on this call + // if number of records is less than the max, this must be the last page + // in case TotalCount % PerPage = 0, the last request will return an empty list + if resultInfo.Count < resultInfo.PerPage { + break + } + // continue with the next page + pageOpts.Page = pageOpts.Page + 1 + } + + return allRateLimits, nil +} + +// RateLimit fetches detail about one Rate Limit for a zone. +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-rate-limit-details +func (api *API) RateLimit(zoneID, limitID string) (RateLimit, error) { + uri := "/zones/" + zoneID + "/rate_limits/" + limitID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return RateLimit{}, errors.Wrap(err, errMakeRequestError) + } + var r rateLimitResponse + err = json.Unmarshal(res, &r) + if err != nil { + return RateLimit{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// UpdateRateLimit lets you replace a Rate Limit for a zone. +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-update-rate-limit +func (api *API) UpdateRateLimit(zoneID, limitID string, limit RateLimit) (RateLimit, error) { + uri := "/zones/" + zoneID + "/rate_limits/" + limitID + res, err := api.makeRequest("PUT", uri, limit) + if err != nil { + return RateLimit{}, errors.Wrap(err, errMakeRequestError) + } + var r rateLimitResponse + if err := json.Unmarshal(res, &r); err != nil { + return RateLimit{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// DeleteRateLimit deletes a Rate Limit for a zone. +// +// API reference: https://api.cloudflare.com/#rate-limits-for-a-zone-delete-rate-limit +func (api *API) DeleteRateLimit(zoneID, limitID string) error { + uri := "/zones/" + zoneID + "/rate_limits/" + limitID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + var r rateLimitResponse + err = json.Unmarshal(res, &r) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/registrar.go b/vendor/github.com/cloudflare/cloudflare-go/registrar.go new file mode 100644 index 000000000..51eacf173 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/registrar.go @@ -0,0 +1,175 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +// RegistrarDomain is the structure of the API response for a new +// Cloudflare Registrar domain. +type RegistrarDomain struct { + ID string `json:"id"` + Available bool `json:"available"` + SupportedTLD bool `json:"supported_tld"` + CanRegister bool `json:"can_register"` + TransferIn RegistrarTransferIn `json:"transfer_in"` + CurrentRegistrar string `json:"current_registrar"` + ExpiresAt time.Time `json:"expires_at"` + RegistryStatuses string `json:"registry_statuses"` + Locked bool `json:"locked"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + RegistrantContact RegistrantContact `json:"registrant_contact"` +} + +// RegistrarTransferIn contains the structure for a domain transfer in +// request. +type RegistrarTransferIn struct { + UnlockDomain string `json:"unlock_domain"` + DisablePrivacy string `json:"disable_privacy"` + EnterAuthCode string `json:"enter_auth_code"` + ApproveTransfer string `json:"approve_transfer"` + AcceptFoa string `json:"accept_foa"` + CanCancelTransfer bool `json:"can_cancel_transfer"` +} + +// RegistrantContact is the contact details for the domain registration. +type RegistrantContact struct { + ID string `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Organization string `json:"organization"` + Address string `json:"address"` + Address2 string `json:"address2"` + City string `json:"city"` + State string `json:"state"` + Zip string `json:"zip"` + Country string `json:"country"` + Phone string `json:"phone"` + Email string `json:"email"` + Fax string `json:"fax"` +} + +// RegistrarDomainConfiguration is the structure for making updates to +// and existing domain. +type RegistrarDomainConfiguration struct { + NameServers []string `json:"name_servers"` + Privacy bool `json:"privacy"` + Locked bool `json:"locked"` + AutoRenew bool `json:"auto_renew"` +} + +// RegistrarDomainDetailResponse is the structure of the detailed +// response from the API for a single domain. +type RegistrarDomainDetailResponse struct { + Response + Result RegistrarDomain `json:"result"` +} + +// RegistrarDomainsDetailResponse is the structure of the detailed +// response from the API. +type RegistrarDomainsDetailResponse struct { + Response + Result []RegistrarDomain `json:"result"` +} + +// RegistrarDomain returns a single domain based on the account ID and +// domain name. +// +// API reference: https://api.cloudflare.com/#registrar-domains-get-domain +func (api *API) RegistrarDomain(accountID, domainName string) (RegistrarDomain, error) { + uri := fmt.Sprintf("/accounts/%s/registrar/domains/%s", accountID, domainName) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return RegistrarDomain{}, errors.Wrap(err, errMakeRequestError) + } + + var r RegistrarDomainDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return RegistrarDomain{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// RegistrarDomains returns all registrar domains based on the account +// ID. +// +// API reference: https://api.cloudflare.com/#registrar-domains-list-domains +func (api *API) RegistrarDomains(accountID string) ([]RegistrarDomain, error) { + uri := "/accounts/" + accountID + "/registrar/domains" + + res, err := api.makeRequest("POST", uri, nil) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errMakeRequestError) + } + + var r RegistrarDomainsDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// TransferRegistrarDomain initiates the transfer from another registrar +// to Cloudflare Registrar. +// +// API reference: https://api.cloudflare.com/#registrar-domains-transfer-domain +func (api *API) TransferRegistrarDomain(accountID, domainName string) ([]RegistrarDomain, error) { + uri := fmt.Sprintf("/accounts/%s/registrar/domains/%s/transfer", accountID, domainName) + + res, err := api.makeRequest("POST", uri, nil) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errMakeRequestError) + } + + var r RegistrarDomainsDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// CancelRegistrarDomainTransfer cancels a pending domain transfer. +// +// API reference: https://api.cloudflare.com/#registrar-domains-cancel-transfer +func (api *API) CancelRegistrarDomainTransfer(accountID, domainName string) ([]RegistrarDomain, error) { + uri := fmt.Sprintf("/accounts/%s/registrar/domains/%s/cancel_transfer", accountID, domainName) + + res, err := api.makeRequest("POST", uri, nil) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errMakeRequestError) + } + + var r RegistrarDomainsDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []RegistrarDomain{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// UpdateRegistrarDomain updates an existing Registrar Domain configuration. +// +// API reference: https://api.cloudflare.com/#registrar-domains-update-domain +func (api *API) UpdateRegistrarDomain(accountID, domainName string, domainConfiguration RegistrarDomainConfiguration) (RegistrarDomain, error) { + uri := fmt.Sprintf("/accounts/%s/registrar/domains/%s", accountID, domainName) + + res, err := api.makeRequest("PUT", uri, domainConfiguration) + if err != nil { + return RegistrarDomain{}, errors.Wrap(err, errMakeRequestError) + } + + var r RegistrarDomainDetailResponse + err = json.Unmarshal(res, &r) + if err != nil { + return RegistrarDomain{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/renovate.json b/vendor/github.com/cloudflare/cloudflare-go/renovate.json new file mode 100644 index 000000000..f45d8f110 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/renovate.json @@ -0,0 +1,5 @@ +{ + "extends": [ + "config:base" + ] +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/spectrum.go b/vendor/github.com/cloudflare/cloudflare-go/spectrum.go new file mode 100644 index 000000000..a95a2cd7f --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/spectrum.go @@ -0,0 +1,158 @@ +package cloudflare + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/pkg/errors" +) + +// SpectrumApplication defines a single Spectrum Application. +type SpectrumApplication struct { + ID string `json:"id,omitempty"` + Protocol string `json:"protocol,omitempty"` + IPv4 bool `json:"ipv4,omitempty"` + DNS SpectrumApplicationDNS `json:"dns,omitempty"` + OriginDirect []string `json:"origin_direct,omitempty"` + OriginPort int `json:"origin_port,omitempty"` + OriginDNS *SpectrumApplicationOriginDNS `json:"origin_dns,omitempty"` + IPFirewall bool `json:"ip_firewall,omitempty"` + ProxyProtocol bool `json:"proxy_protocol,omitempty"` + TLS string `json:"tls,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + ModifiedOn *time.Time `json:"modified_on,omitempty"` +} + +// SpectrumApplicationDNS holds the external DNS configuration for a Spectrum +// Application. +type SpectrumApplicationDNS struct { + Type string `json:"type"` + Name string `json:"name"` +} + +// SpectrumApplicationOriginDNS holds the origin DNS configuration for a Spectrum +// Application. +type SpectrumApplicationOriginDNS struct { + Name string `json:"name"` +} + +// SpectrumApplicationDetailResponse is the structure of the detailed response +// from the API. +type SpectrumApplicationDetailResponse struct { + Response + Result SpectrumApplication `json:"result"` +} + +// SpectrumApplicationsDetailResponse is the structure of the detailed response +// from the API. +type SpectrumApplicationsDetailResponse struct { + Response + Result []SpectrumApplication `json:"result"` +} + +// SpectrumApplications fetches all of the Spectrum applications for a zone. +// +// API reference: https://developers.cloudflare.com/spectrum/api-reference/#list-spectrum-applications +func (api *API) SpectrumApplications(zoneID string) ([]SpectrumApplication, error) { + uri := "/zones/" + zoneID + "/spectrum/apps" + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []SpectrumApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var spectrumApplications SpectrumApplicationsDetailResponse + err = json.Unmarshal(res, &spectrumApplications) + if err != nil { + return []SpectrumApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return spectrumApplications.Result, nil +} + +// SpectrumApplication fetches a single Spectrum application based on the ID. +// +// API reference: https://developers.cloudflare.com/spectrum/api-reference/#list-spectrum-applications +func (api *API) SpectrumApplication(zoneID string, applicationID string) (SpectrumApplication, error) { + uri := fmt.Sprintf( + "/zones/%s/spectrum/apps/%s", + zoneID, + applicationID, + ) + + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var spectrumApplication SpectrumApplicationDetailResponse + err = json.Unmarshal(res, &spectrumApplication) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return spectrumApplication.Result, nil +} + +// CreateSpectrumApplication creates a new Spectrum application. +// +// API reference: https://developers.cloudflare.com/spectrum/api-reference/#create-a-spectrum-application +func (api *API) CreateSpectrumApplication(zoneID string, appDetails SpectrumApplication) (SpectrumApplication, error) { + uri := "/zones/" + zoneID + "/spectrum/apps" + + res, err := api.makeRequest("POST", uri, appDetails) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var spectrumApplication SpectrumApplicationDetailResponse + err = json.Unmarshal(res, &spectrumApplication) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return spectrumApplication.Result, nil +} + +// UpdateSpectrumApplication updates an existing Spectrum application. +// +// API reference: https://developers.cloudflare.com/spectrum/api-reference/#update-a-spectrum-application +func (api *API) UpdateSpectrumApplication(zoneID, appID string, appDetails SpectrumApplication) (SpectrumApplication, error) { + uri := fmt.Sprintf( + "/zones/%s/spectrum/apps/%s", + zoneID, + appID, + ) + + res, err := api.makeRequest("PUT", uri, appDetails) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errMakeRequestError) + } + + var spectrumApplication SpectrumApplicationDetailResponse + err = json.Unmarshal(res, &spectrumApplication) + if err != nil { + return SpectrumApplication{}, errors.Wrap(err, errUnmarshalError) + } + + return spectrumApplication.Result, nil +} + +// DeleteSpectrumApplication removes a Spectrum application based on the ID. +// +// API reference: https://developers.cloudflare.com/spectrum/api-reference/#delete-a-spectrum-application +func (api *API) DeleteSpectrumApplication(zoneID string, applicationID string) error { + uri := fmt.Sprintf( + "/zones/%s/spectrum/apps/%s", + zoneID, + applicationID, + ) + + _, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/ssl.go b/vendor/github.com/cloudflare/cloudflare-go/ssl.go new file mode 100644 index 000000000..505dfa650 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/ssl.go @@ -0,0 +1,157 @@ +package cloudflare + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// ZoneCustomSSL represents custom SSL certificate metadata. +type ZoneCustomSSL struct { + ID string `json:"id"` + Hosts []string `json:"hosts"` + Issuer string `json:"issuer"` + Signature string `json:"signature"` + Status string `json:"status"` + BundleMethod string `json:"bundle_method"` + GeoRestrictions ZoneCustomSSLGeoRestrictions `json:"geo_restrictions"` + ZoneID string `json:"zone_id"` + UploadedOn time.Time `json:"uploaded_on"` + ModifiedOn time.Time `json:"modified_on"` + ExpiresOn time.Time `json:"expires_on"` + Priority int `json:"priority"` + KeylessServer KeylessSSL `json:"keyless_server"` +} + +// ZoneCustomSSLGeoRestrictions represents the parameter to create or update +// geographic restrictions on a custom ssl certificate. +type ZoneCustomSSLGeoRestrictions struct { + Label string `json:"label"` +} + +// zoneCustomSSLResponse represents the response from the zone SSL details endpoint. +type zoneCustomSSLResponse struct { + Response + Result ZoneCustomSSL `json:"result"` +} + +// zoneCustomSSLsResponse represents the response from the zone SSL list endpoint. +type zoneCustomSSLsResponse struct { + Response + Result []ZoneCustomSSL `json:"result"` +} + +// ZoneCustomSSLOptions represents the parameters to create or update an existing +// custom SSL configuration. +type ZoneCustomSSLOptions struct { + Certificate string `json:"certificate"` + PrivateKey string `json:"private_key"` + BundleMethod string `json:"bundle_method,omitempty"` + GeoRestrictions ZoneCustomSSLGeoRestrictions `json:"geo_restrictions,omitempty"` + Type string `json:"type,omitempty"` +} + +// ZoneCustomSSLPriority represents a certificate's ID and priority. It is a +// subset of ZoneCustomSSL used for patch requests. +type ZoneCustomSSLPriority struct { + ID string `json:"ID"` + Priority int `json:"priority"` +} + +// CreateSSL allows you to add a custom SSL certificate to the given zone. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-create-ssl-configuration +func (api *API) CreateSSL(zoneID string, options ZoneCustomSSLOptions) (ZoneCustomSSL, error) { + uri := "/zones/" + zoneID + "/custom_certificates" + res, err := api.makeRequest("POST", uri, options) + if err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneCustomSSLResponse + if err := json.Unmarshal(res, &r); err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListSSL lists the custom certificates for the given zone. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-list-ssl-configurations +func (api *API) ListSSL(zoneID string) ([]ZoneCustomSSL, error) { + uri := "/zones/" + zoneID + "/custom_certificates" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r zoneCustomSSLsResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// SSLDetails returns the configuration details for a custom SSL certificate. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-ssl-configuration-details +func (api *API) SSLDetails(zoneID, certificateID string) (ZoneCustomSSL, error) { + uri := "/zones/" + zoneID + "/custom_certificates/" + certificateID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneCustomSSLResponse + if err := json.Unmarshal(res, &r); err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// UpdateSSL updates (replaces) a custom SSL certificate. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-update-ssl-configuration +func (api *API) UpdateSSL(zoneID, certificateID string, options ZoneCustomSSLOptions) (ZoneCustomSSL, error) { + uri := "/zones/" + zoneID + "/custom_certificates/" + certificateID + res, err := api.makeRequest("PATCH", uri, options) + if err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneCustomSSLResponse + if err := json.Unmarshal(res, &r); err != nil { + return ZoneCustomSSL{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ReprioritizeSSL allows you to change the priority (which is served for a given +// request) of custom SSL certificates associated with the given zone. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-re-prioritize-ssl-certificates +func (api *API) ReprioritizeSSL(zoneID string, p []ZoneCustomSSLPriority) ([]ZoneCustomSSL, error) { + uri := "/zones/" + zoneID + "/custom_certificates/prioritize" + params := struct { + Certificates []ZoneCustomSSLPriority `json:"certificates"` + }{ + Certificates: p, + } + res, err := api.makeRequest("PUT", uri, params) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r zoneCustomSSLsResponse + if err := json.Unmarshal(res, &r); err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// DeleteSSL deletes a custom SSL certificate from the given zone. +// +// API reference: https://api.cloudflare.com/#custom-ssl-for-a-zone-delete-an-ssl-certificate +func (api *API) DeleteSSL(zoneID, certificateID string) error { + uri := "/zones/" + zoneID + "/custom_certificates/" + certificateID + if _, err := api.makeRequest("DELETE", uri, nil); err != nil { + return errors.Wrap(err, errMakeRequestError) + } + return nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/universal_ssl.go b/vendor/github.com/cloudflare/cloudflare-go/universal_ssl.go new file mode 100644 index 000000000..4bf8ffde7 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/universal_ssl.go @@ -0,0 +1,88 @@ +package cloudflare + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +// UniversalSSLSetting represents a universal ssl setting's properties. +type UniversalSSLSetting struct { + Enabled bool `json:"enabled"` +} + +type universalSSLSettingResponse struct { + Response + Result UniversalSSLSetting `json:"result"` +} + +// UniversalSSLVerificationDetails represents a universal ssl verifcation's properties. +type UniversalSSLVerificationDetails struct { + CertificateStatus string `json:"certificate_status"` + VerificationType string `json:"verification_type"` + ValidationMethod string `json:"validation_method"` + CertPackUUID string `json:"cert_pack_uuid"` + VerificationStatus bool `json:"verification_status"` + BrandCheck bool `json:"brand_check"` + VerificationInfo UniversalSSLVerificationInfo `json:"verification_info"` +} + +// UniversalSSLVerificationInfo represents DCV record. +type UniversalSSLVerificationInfo struct { + RecordName string `json:"record_name"` + RecordTarget string `json:"record_target"` +} + +type universalSSLVerificationResponse struct { + Response + Result []UniversalSSLVerificationDetails `json:"result"` +} + +// UniversalSSLSettingDetails returns the details for a universal ssl setting +// +// API reference: https://api.cloudflare.com/#universal-ssl-settings-for-a-zone-universal-ssl-settings-details +func (api *API) UniversalSSLSettingDetails(zoneID string) (UniversalSSLSetting, error) { + uri := "/zones/" + zoneID + "/ssl/universal/settings" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return UniversalSSLSetting{}, errors.Wrap(err, errMakeRequestError) + } + var r universalSSLSettingResponse + if err := json.Unmarshal(res, &r); err != nil { + return UniversalSSLSetting{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// EditUniversalSSLSetting edits the uniersal ssl setting for a zone +// +// API reference: https://api.cloudflare.com/#universal-ssl-settings-for-a-zone-edit-universal-ssl-settings +func (api *API) EditUniversalSSLSetting(zoneID string, setting UniversalSSLSetting) (UniversalSSLSetting, error) { + uri := "/zones/" + zoneID + "/ssl/universal/settings" + res, err := api.makeRequest("PATCH", uri, setting) + if err != nil { + return UniversalSSLSetting{}, errors.Wrap(err, errMakeRequestError) + } + var r universalSSLSettingResponse + if err := json.Unmarshal(res, &r); err != nil { + return UniversalSSLSetting{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil + +} + +// UniversalSSLVerificationDetails returns the details for a universal ssl verifcation +// +// API reference: https://api.cloudflare.com/#ssl-verification-ssl-verification-details +func (api *API) UniversalSSLVerificationDetails(zoneID string) ([]UniversalSSLVerificationDetails, error) { + uri := "/zones/" + zoneID + "/ssl/verification" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []UniversalSSLVerificationDetails{}, errors.Wrap(err, errMakeRequestError) + } + var r universalSSLVerificationResponse + if err := json.Unmarshal(res, &r); err != nil { + return []UniversalSSLVerificationDetails{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/user.go b/vendor/github.com/cloudflare/cloudflare-go/user.go new file mode 100644 index 000000000..bf2f47a57 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/user.go @@ -0,0 +1,113 @@ +package cloudflare + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +// User describes a user account. +type User struct { + ID string `json:"id,omitempty"` + Email string `json:"email,omitempty"` + FirstName string `json:"first_name,omitempty"` + LastName string `json:"last_name,omitempty"` + Username string `json:"username,omitempty"` + Telephone string `json:"telephone,omitempty"` + Country string `json:"country,omitempty"` + Zipcode string `json:"zipcode,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + ModifiedOn *time.Time `json:"modified_on,omitempty"` + APIKey string `json:"api_key,omitempty"` + TwoFA bool `json:"two_factor_authentication_enabled,omitempty"` + Betas []string `json:"betas,omitempty"` + Accounts []Account `json:"organizations,omitempty"` +} + +// UserResponse wraps a response containing User accounts. +type UserResponse struct { + Response + Result User `json:"result"` +} + +// userBillingProfileResponse wraps a response containing Billing Profile information. +type userBillingProfileResponse struct { + Response + Result UserBillingProfile +} + +// UserBillingProfile contains Billing Profile information. +type UserBillingProfile struct { + ID string `json:"id,omitempty"` + FirstName string `json:"first_name,omitempty"` + LastName string `json:"last_name,omitempty"` + Address string `json:"address,omitempty"` + Address2 string `json:"address2,omitempty"` + Company string `json:"company,omitempty"` + City string `json:"city,omitempty"` + State string `json:"state,omitempty"` + ZipCode string `json:"zipcode,omitempty"` + Country string `json:"country,omitempty"` + Telephone string `json:"telephone,omitempty"` + CardNumber string `json:"card_number,omitempty"` + CardExpiryYear int `json:"card_expiry_year,omitempty"` + CardExpiryMonth int `json:"card_expiry_month,omitempty"` + VAT string `json:"vat,omitempty"` + CreatedOn *time.Time `json:"created_on,omitempty"` + EditedOn *time.Time `json:"edited_on,omitempty"` +} + +// UserDetails provides information about the logged-in user. +// +// API reference: https://api.cloudflare.com/#user-user-details +func (api *API) UserDetails() (User, error) { + var r UserResponse + res, err := api.makeRequest("GET", "/user", nil) + if err != nil { + return User{}, errors.Wrap(err, errMakeRequestError) + } + + err = json.Unmarshal(res, &r) + if err != nil { + return User{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UpdateUser updates the properties of the given user. +// +// API reference: https://api.cloudflare.com/#user-update-user +func (api *API) UpdateUser(user *User) (User, error) { + var r UserResponse + res, err := api.makeRequest("PATCH", "/user", user) + if err != nil { + return User{}, errors.Wrap(err, errMakeRequestError) + } + + err = json.Unmarshal(res, &r) + if err != nil { + return User{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UserBillingProfile returns the billing profile of the user. +// +// API reference: https://api.cloudflare.com/#user-billing-profile +func (api *API) UserBillingProfile() (UserBillingProfile, error) { + var r userBillingProfileResponse + res, err := api.makeRequest("GET", "/user/billing/profile", nil) + if err != nil { + return UserBillingProfile{}, errors.Wrap(err, errMakeRequestError) + } + + err = json.Unmarshal(res, &r) + if err != nil { + return UserBillingProfile{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/user_agent.go b/vendor/github.com/cloudflare/cloudflare-go/user_agent.go new file mode 100644 index 000000000..6d75f3a1d --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/user_agent.go @@ -0,0 +1,149 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// UserAgentRule represents a User-Agent Block. These rules can be used to +// challenge, block or whitelist specific User-Agents for a given zone. +type UserAgentRule struct { + ID string `json:"id"` + Description string `json:"description"` + Mode string `json:"mode"` + Configuration UserAgentRuleConfig `json:"configuration"` + Paused bool `json:"paused"` +} + +// UserAgentRuleConfig represents a Zone Lockdown config, which comprises +// a Target ("ip" or "ip_range") and a Value (an IP address or IP+mask, +// respectively.) +type UserAgentRuleConfig ZoneLockdownConfig + +// UserAgentRuleResponse represents a response from the Zone Lockdown endpoint. +type UserAgentRuleResponse struct { + Result UserAgentRule `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// UserAgentRuleListResponse represents a response from the List Zone Lockdown endpoint. +type UserAgentRuleListResponse struct { + Result []UserAgentRule `json:"result"` + Response + ResultInfo `json:"result_info"` +} + +// CreateUserAgentRule creates a User-Agent Block rule for the given zone ID. +// +// API reference: https://api.cloudflare.com/#user-agent-blocking-rules-create-a-useragent-rule +func (api *API) CreateUserAgentRule(zoneID string, ld UserAgentRule) (*UserAgentRuleResponse, error) { + switch ld.Mode { + case "block", "challenge", "js_challenge", "whitelist": + break + default: + return nil, errors.New(`the User-Agent Block rule mode must be one of "block", "challenge", "js_challenge", "whitelist"`) + } + + uri := "/zones/" + zoneID + "/firewall/ua_rules" + res, err := api.makeRequest("POST", uri, ld) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &UserAgentRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// UpdateUserAgentRule updates a User-Agent Block rule (based on the ID) for the given zone ID. +// +// API reference: https://api.cloudflare.com/#user-agent-blocking-rules-update-useragent-rule +func (api *API) UpdateUserAgentRule(zoneID string, id string, ld UserAgentRule) (*UserAgentRuleResponse, error) { + uri := "/zones/" + zoneID + "/firewall/ua_rules/" + id + res, err := api.makeRequest("PUT", uri, ld) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &UserAgentRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// DeleteUserAgentRule deletes a User-Agent Block rule (based on the ID) for the given zone ID. +// +// API reference: https://api.cloudflare.com/#user-agent-blocking-rules-delete-useragent-rule +func (api *API) DeleteUserAgentRule(zoneID string, id string) (*UserAgentRuleResponse, error) { + uri := "/zones/" + zoneID + "/firewall/ua_rules/" + id + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &UserAgentRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// UserAgentRule retrieves a User-Agent Block rule (based on the ID) for the given zone ID. +// +// API reference: https://api.cloudflare.com/#user-agent-blocking-rules-useragent-rule-details +func (api *API) UserAgentRule(zoneID string, id string) (*UserAgentRuleResponse, error) { + uri := "/zones/" + zoneID + "/firewall/ua_rules/" + id + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &UserAgentRuleResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// ListUserAgentRules retrieves a list of User-Agent Block rules for a given zone ID by page number. +// +// API reference: https://api.cloudflare.com/#user-agent-blocking-rules-list-useragent-rules +func (api *API) ListUserAgentRules(zoneID string, page int) (*UserAgentRuleListResponse, error) { + v := url.Values{} + if page <= 0 { + page = 1 + } + + v.Set("page", strconv.Itoa(page)) + v.Set("per_page", strconv.Itoa(100)) + query := "?" + v.Encode() + + uri := "/zones/" + zoneID + "/firewall/ua_rules" + query + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &UserAgentRuleListResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/virtualdns.go b/vendor/github.com/cloudflare/cloudflare-go/virtualdns.go new file mode 100644 index 000000000..f8082e1f0 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/virtualdns.go @@ -0,0 +1,192 @@ +package cloudflare + +import ( + "encoding/json" + "net/url" + "strings" + "time" + + "github.com/pkg/errors" +) + +// VirtualDNS represents a Virtual DNS configuration. +type VirtualDNS struct { + ID string `json:"id"` + Name string `json:"name"` + OriginIPs []string `json:"origin_ips"` + VirtualDNSIPs []string `json:"virtual_dns_ips"` + MinimumCacheTTL uint `json:"minimum_cache_ttl"` + MaximumCacheTTL uint `json:"maximum_cache_ttl"` + DeprecateAnyRequests bool `json:"deprecate_any_requests"` + ModifiedOn string `json:"modified_on"` +} + +// VirtualDNSAnalyticsMetrics respresents a group of aggregated Virtual DNS metrics. +type VirtualDNSAnalyticsMetrics struct { + QueryCount *int64 `json:"queryCount"` + UncachedCount *int64 `json:"uncachedCount"` + StaleCount *int64 `json:"staleCount"` + ResponseTimeAvg *float64 `json:"responseTimeAvg"` + ResponseTimeMedian *float64 `json:"responseTimeMedian"` + ResponseTime90th *float64 `json:"responseTime90th"` + ResponseTime99th *float64 `json:"responseTime99th"` +} + +// VirtualDNSAnalytics represents a set of aggregated Virtual DNS metrics. +// TODO: Add the queried data and not only the aggregated values. +type VirtualDNSAnalytics struct { + Totals VirtualDNSAnalyticsMetrics `json:"totals"` + Min VirtualDNSAnalyticsMetrics `json:"min"` + Max VirtualDNSAnalyticsMetrics `json:"max"` +} + +// VirtualDNSUserAnalyticsOptions represents range and dimension selection on analytics endpoint +type VirtualDNSUserAnalyticsOptions struct { + Metrics []string + Since *time.Time + Until *time.Time +} + +// VirtualDNSResponse represents a Virtual DNS response. +type VirtualDNSResponse struct { + Response + Result *VirtualDNS `json:"result"` +} + +// VirtualDNSListResponse represents an array of Virtual DNS responses. +type VirtualDNSListResponse struct { + Response + Result []*VirtualDNS `json:"result"` +} + +// VirtualDNSAnalyticsResponse represents a Virtual DNS analytics response. +type VirtualDNSAnalyticsResponse struct { + Response + Result VirtualDNSAnalytics `json:"result"` +} + +// CreateVirtualDNS creates a new Virtual DNS cluster. +// +// API reference: https://api.cloudflare.com/#virtual-dns-users--create-a-virtual-dns-cluster +func (api *API) CreateVirtualDNS(v *VirtualDNS) (*VirtualDNS, error) { + res, err := api.makeRequest("POST", "/user/virtual_dns", v) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &VirtualDNSResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response.Result, nil +} + +// VirtualDNS fetches a single virtual DNS cluster. +// +// API reference: https://api.cloudflare.com/#virtual-dns-users--get-a-virtual-dns-cluster +func (api *API) VirtualDNS(virtualDNSID string) (*VirtualDNS, error) { + uri := "/user/virtual_dns/" + virtualDNSID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &VirtualDNSResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response.Result, nil +} + +// ListVirtualDNS lists the virtual DNS clusters associated with an account. +// +// API reference: https://api.cloudflare.com/#virtual-dns-users--get-virtual-dns-clusters +func (api *API) ListVirtualDNS() ([]*VirtualDNS, error) { + res, err := api.makeRequest("GET", "/user/virtual_dns", nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &VirtualDNSListResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response.Result, nil +} + +// UpdateVirtualDNS updates a Virtual DNS cluster. +// +// API reference: https://api.cloudflare.com/#virtual-dns-users--modify-a-virtual-dns-cluster +func (api *API) UpdateVirtualDNS(virtualDNSID string, vv VirtualDNS) error { + uri := "/user/virtual_dns/" + virtualDNSID + res, err := api.makeRequest("PUT", uri, vv) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + response := &VirtualDNSResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + + return nil +} + +// DeleteVirtualDNS deletes a Virtual DNS cluster. Note that this cannot be +// undone, and will stop all traffic to that cluster. +// +// API reference: https://api.cloudflare.com/#virtual-dns-users--delete-a-virtual-dns-cluster +func (api *API) DeleteVirtualDNS(virtualDNSID string) error { + uri := "/user/virtual_dns/" + virtualDNSID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return errors.Wrap(err, errMakeRequestError) + } + + response := &VirtualDNSResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return errors.Wrap(err, errUnmarshalError) + } + + return nil +} + +// encode encodes non-nil fields into URL encoded form. +func (o VirtualDNSUserAnalyticsOptions) encode() string { + v := url.Values{} + if o.Since != nil { + v.Set("since", (*o.Since).UTC().Format(time.RFC3339)) + } + if o.Until != nil { + v.Set("until", (*o.Until).UTC().Format(time.RFC3339)) + } + if o.Metrics != nil { + v.Set("metrics", strings.Join(o.Metrics, ",")) + } + return v.Encode() +} + +// VirtualDNSUserAnalytics retrieves analytics report for a specified dimension and time range +func (api *API) VirtualDNSUserAnalytics(virtualDNSID string, o VirtualDNSUserAnalyticsOptions) (VirtualDNSAnalytics, error) { + uri := "/user/virtual_dns/" + virtualDNSID + "/dns_analytics/report?" + o.encode() + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return VirtualDNSAnalytics{}, errors.Wrap(err, errMakeRequestError) + } + + response := VirtualDNSAnalyticsResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return VirtualDNSAnalytics{}, errors.Wrap(err, errUnmarshalError) + } + + return response.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/waf.go b/vendor/github.com/cloudflare/cloudflare-go/waf.go new file mode 100644 index 000000000..9b67f79a7 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/waf.go @@ -0,0 +1,300 @@ +package cloudflare + +import ( + "encoding/json" + + "github.com/pkg/errors" +) + +// WAFPackage represents a WAF package configuration. +type WAFPackage struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + ZoneID string `json:"zone_id"` + DetectionMode string `json:"detection_mode"` + Sensitivity string `json:"sensitivity"` + ActionMode string `json:"action_mode"` +} + +// WAFPackagesResponse represents the response from the WAF packages endpoint. +type WAFPackagesResponse struct { + Response + Result []WAFPackage `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFPackageResponse represents the response from the WAF package endpoint. +type WAFPackageResponse struct { + Response + Result WAFPackage `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFPackageOptions represents options to edit a WAF package. +type WAFPackageOptions struct { + Sensitivity string `json:"sensitivity,omitempty"` + ActionMode string `json:"action_mode,omitempty"` +} + +// WAFGroup represents a WAF rule group. +type WAFGroup struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + RulesCount int `json:"rules_count"` + ModifiedRulesCount int `json:"modified_rules_count"` + PackageID string `json:"package_id"` + Mode string `json:"mode"` + AllowedModes []string `json:"allowed_modes"` +} + +// WAFGroupsResponse represents the response from the WAF groups endpoint. +type WAFGroupsResponse struct { + Response + Result []WAFGroup `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFGroupResponse represents the response from the WAF group endpoint. +type WAFGroupResponse struct { + Response + Result WAFGroup `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFRule represents a WAF rule. +type WAFRule struct { + ID string `json:"id"` + Description string `json:"description"` + Priority string `json:"priority"` + PackageID string `json:"package_id"` + Group struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"group"` + Mode string `json:"mode"` + DefaultMode string `json:"default_mode"` + AllowedModes []string `json:"allowed_modes"` +} + +// WAFRulesResponse represents the response from the WAF rules endpoint. +type WAFRulesResponse struct { + Response + Result []WAFRule `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFRuleResponse represents the response from the WAF rule endpoint. +type WAFRuleResponse struct { + Response + Result WAFRule `json:"result"` + ResultInfo ResultInfo `json:"result_info"` +} + +// WAFRuleOptions is a subset of WAFRule, for editable options. +type WAFRuleOptions struct { + Mode string `json:"mode"` +} + +// ListWAFPackages returns a slice of the WAF packages for the given zone. +// +// API Reference: https://api.cloudflare.com/#waf-rule-packages-list-firewall-packages +func (api *API) ListWAFPackages(zoneID string) ([]WAFPackage, error) { + var p WAFPackagesResponse + var packages []WAFPackage + var res []byte + var err error + uri := "/zones/" + zoneID + "/firewall/waf/packages" + res, err = api.makeRequest("GET", uri, nil) + if err != nil { + return []WAFPackage{}, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &p) + if err != nil { + return []WAFPackage{}, errors.Wrap(err, errUnmarshalError) + } + if !p.Success { + // TODO: Provide an actual error message instead of always returning nil + return []WAFPackage{}, err + } + for pi := range p.Result { + packages = append(packages, p.Result[pi]) + } + return packages, nil +} + +// WAFPackage returns a WAF package for the given zone. +// +// API Reference: https://api.cloudflare.com/#waf-rule-packages-firewall-package-details +func (api *API) WAFPackage(zoneID, packageID string) (WAFPackage, error) { + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return WAFPackage{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFPackageResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFPackage{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UpdateWAFPackage lets you update the a WAF Package. +// +// API Reference: https://api.cloudflare.com/#waf-rule-packages-edit-firewall-package +func (api *API) UpdateWAFPackage(zoneID, packageID string, opts WAFPackageOptions) (WAFPackage, error) { + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + res, err := api.makeRequest("PATCH", uri, opts) + if err != nil { + return WAFPackage{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFPackageResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFPackage{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListWAFGroups returns a slice of the WAF groups for the given WAF package. +// +// API Reference: https://api.cloudflare.com/#waf-rule-groups-list-rule-groups +func (api *API) ListWAFGroups(zoneID, packageID string) ([]WAFGroup, error) { + var groups []WAFGroup + var res []byte + var err error + + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/groups" + res, err = api.makeRequest("GET", uri, nil) + if err != nil { + return []WAFGroup{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFGroupsResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []WAFGroup{}, errors.Wrap(err, errUnmarshalError) + } + + if !r.Success { + // TODO: Provide an actual error message instead of always returning nil + return []WAFGroup{}, err + } + + for gi := range r.Result { + groups = append(groups, r.Result[gi]) + } + return groups, nil +} + +// WAFGroup returns a WAF rule group from the given WAF package. +// +// API Reference: https://api.cloudflare.com/#waf-rule-groups-rule-group-details +func (api *API) WAFGroup(zoneID, packageID, groupID string) (WAFGroup, error) { + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/groups/" + groupID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return WAFGroup{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFGroupResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFGroup{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UpdateWAFGroup lets you update the mode of a WAF Group. +// +// API Reference: https://api.cloudflare.com/#waf-rule-groups-edit-rule-group +func (api *API) UpdateWAFGroup(zoneID, packageID, groupID, mode string) (WAFGroup, error) { + opts := WAFRuleOptions{Mode: mode} + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/groups/" + groupID + res, err := api.makeRequest("PATCH", uri, opts) + if err != nil { + return WAFGroup{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFGroupResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFGroup{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ListWAFRules returns a slice of the WAF rules for the given WAF package. +// +// API Reference: https://api.cloudflare.com/#waf-rules-list-rules +func (api *API) ListWAFRules(zoneID, packageID string) ([]WAFRule, error) { + var rules []WAFRule + var res []byte + var err error + + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/rules" + res, err = api.makeRequest("GET", uri, nil) + if err != nil { + return []WAFRule{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFRulesResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []WAFRule{}, errors.Wrap(err, errUnmarshalError) + } + + if !r.Success { + // TODO: Provide an actual error message instead of always returning nil + return []WAFRule{}, err + } + + for ri := range r.Result { + rules = append(rules, r.Result[ri]) + } + return rules, nil +} + +// WAFRule returns a WAF rule from the given WAF package. +// +// API Reference: https://api.cloudflare.com/#waf-rules-rule-details +func (api *API) WAFRule(zoneID, packageID, ruleID string) (WAFRule, error) { + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/rules/" + ruleID + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return WAFRule{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFRuleResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFRule{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UpdateWAFRule lets you update the mode of a WAF Rule. +// +// API Reference: https://api.cloudflare.com/#waf-rules-edit-rule +func (api *API) UpdateWAFRule(zoneID, packageID, ruleID, mode string) (WAFRule, error) { + opts := WAFRuleOptions{Mode: mode} + uri := "/zones/" + zoneID + "/firewall/waf/packages/" + packageID + "/rules/" + ruleID + res, err := api.makeRequest("PATCH", uri, opts) + if err != nil { + return WAFRule{}, errors.Wrap(err, errMakeRequestError) + } + + var r WAFRuleResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WAFRule{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/workers.go b/vendor/github.com/cloudflare/cloudflare-go/workers.go new file mode 100644 index 000000000..1ab795ec4 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/workers.go @@ -0,0 +1,314 @@ +package cloudflare + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/pkg/errors" +) + +// WorkerRequestParams provides parameters for worker requests for both enterprise and standard requests +type WorkerRequestParams struct { + ZoneID string + ScriptName string +} + +// WorkerRoute aka filters are patterns used to enable or disable workers that match requests. +// +// API reference: https://api.cloudflare.com/#worker-filters-properties +type WorkerRoute struct { + ID string `json:"id,omitempty"` + Pattern string `json:"pattern"` + Enabled bool `json:"enabled"` + Script string `json:"script,omitempty"` +} + +// WorkerRoutesResponse embeds Response struct and slice of WorkerRoutes +type WorkerRoutesResponse struct { + Response + Routes []WorkerRoute `json:"result"` +} + +// WorkerRouteResponse embeds Response struct and a single WorkerRoute +type WorkerRouteResponse struct { + Response + WorkerRoute `json:"result"` +} + +// WorkerScript Cloudflare Worker struct with metadata +type WorkerScript struct { + WorkerMetaData + Script string `json:"script"` +} + +// WorkerMetaData contains worker script information such as size, creation & modification dates +type WorkerMetaData struct { + ID string `json:"id,omitempty"` + ETAG string `json:"etag,omitempty"` + Size int `json:"size,omitempty"` + CreatedOn time.Time `json:"created_on,omitempty"` + ModifiedOn time.Time `json:"modified_on,omitempty"` +} + +// WorkerListResponse wrapper struct for API response to worker script list API call +type WorkerListResponse struct { + Response + WorkerList []WorkerMetaData `json:"result"` +} + +// WorkerScriptResponse wrapper struct for API response to worker script calls +type WorkerScriptResponse struct { + Response + WorkerScript `json:"result"` +} + +// DeleteWorker deletes worker for a zone. +// +// API reference: https://api.cloudflare.com/#worker-script-delete-worker +func (api *API) DeleteWorker(requestParams *WorkerRequestParams) (WorkerScriptResponse, error) { + // if ScriptName is provided we will treat as org request + if requestParams.ScriptName != "" { + return api.deleteWorkerWithName(requestParams.ScriptName) + } + uri := "/zones/" + requestParams.ZoneID + "/workers/script" + res, err := api.makeRequest("DELETE", uri, nil) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return r, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// DeleteWorkerWithName deletes worker for a zone. +// This is an enterprise only feature https://developers.cloudflare.com/workers/api/config-api-for-enterprise +// account must be specified as api option https://godoc.org/github.com/cloudflare/cloudflare-go#UsingAccount +// +// API reference: https://api.cloudflare.com/#worker-script-delete-worker +func (api *API) deleteWorkerWithName(scriptName string) (WorkerScriptResponse, error) { + if api.AccountID == "" { + return WorkerScriptResponse{}, errors.New("account ID required for enterprise only request") + } + uri := "/accounts/" + api.AccountID + "/workers/scripts/" + scriptName + res, err := api.makeRequest("DELETE", uri, nil) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return r, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// DownloadWorker fetch raw script content for your worker returns []byte containing worker code js +// +// API reference: https://api.cloudflare.com/#worker-script-download-worker +func (api *API) DownloadWorker(requestParams *WorkerRequestParams) (WorkerScriptResponse, error) { + if requestParams.ScriptName != "" { + return api.downloadWorkerWithName(requestParams.ScriptName) + } + uri := "/zones/" + requestParams.ZoneID + "/workers/script" + res, err := api.makeRequest("GET", uri, nil) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + r.Script = string(res) + r.Success = true + return r, nil +} + +// DownloadWorkerWithName fetch raw script content for your worker returns string containing worker code js +// This is an enterprise only feature https://developers.cloudflare.com/workers/api/config-api-for-enterprise/ +// +// API reference: https://api.cloudflare.com/#worker-script-download-worker +func (api *API) downloadWorkerWithName(scriptName string) (WorkerScriptResponse, error) { + if api.AccountID == "" { + return WorkerScriptResponse{}, errors.New("account ID required for enterprise only request") + } + uri := "/accounts/" + api.AccountID + "/workers/scripts/" + scriptName + res, err := api.makeRequest("GET", uri, nil) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + r.Script = string(res) + r.Success = true + return r, nil +} + +// ListWorkerScripts returns list of worker scripts for given account. +// +// This is an enterprise only feature https://developers.cloudflare.com/workers/api/config-api-for-enterprise +// +// API reference: https://developers.cloudflare.com/workers/api/config-api-for-enterprise/ +func (api *API) ListWorkerScripts() (WorkerListResponse, error) { + if api.AccountID == "" { + return WorkerListResponse{}, errors.New("account ID required for enterprise only request") + } + uri := "/accounts/" + api.AccountID + "/workers/scripts" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return WorkerListResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r WorkerListResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WorkerListResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// UploadWorker push raw script content for your worker. +// +// API reference: https://api.cloudflare.com/#worker-script-upload-worker +func (api *API) UploadWorker(requestParams *WorkerRequestParams, data string) (WorkerScriptResponse, error) { + if requestParams.ScriptName != "" { + return api.uploadWorkerWithName(requestParams.ScriptName, data) + } + uri := "/zones/" + requestParams.ZoneID + "/workers/script" + headers := make(http.Header) + headers.Set("Content-Type", "application/javascript") + res, err := api.makeRequestWithHeaders("PUT", uri, []byte(data), headers) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return r, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// UploadWorkerWithName push raw script content for your worker. +// +// This is an enterprise only feature https://developers.cloudflare.com/workers/api/config-api-for-enterprise/ +// +// API reference: https://api.cloudflare.com/#worker-script-upload-worker +func (api *API) uploadWorkerWithName(scriptName string, data string) (WorkerScriptResponse, error) { + if api.AccountID == "" { + return WorkerScriptResponse{}, errors.New("account ID required for enterprise only request") + } + uri := "/accounts/" + api.AccountID + "/workers/scripts/" + scriptName + headers := make(http.Header) + headers.Set("Content-Type", "application/javascript") + res, err := api.makeRequestWithHeaders("PUT", uri, []byte(data), headers) + var r WorkerScriptResponse + if err != nil { + return r, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return r, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// CreateWorkerRoute creates worker route for a zone +// +// API reference: https://api.cloudflare.com/#worker-filters-create-filter +func (api *API) CreateWorkerRoute(zoneID string, route WorkerRoute) (WorkerRouteResponse, error) { + // Check whether a script name is defined in order to determine whether + // to use the single-script or multi-script endpoint. + pathComponent := "filters" + if route.Script != "" { + if api.AccountID == "" { + return WorkerRouteResponse{}, errors.New("account ID required for enterprise only request") + } + pathComponent = "routes" + } + + uri := "/zones/" + zoneID + "/workers/" + pathComponent + res, err := api.makeRequest("POST", uri, route) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r WorkerRouteResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// DeleteWorkerRoute deletes worker route for a zone +// +// API reference: https://api.cloudflare.com/#worker-filters-delete-filter +func (api *API) DeleteWorkerRoute(zoneID string, routeID string) (WorkerRouteResponse, error) { + // For deleting a route, it doesn't matter whether we use the + // single-script or multi-script endpoint + uri := "/zones/" + zoneID + "/workers/filters/" + routeID + res, err := api.makeRequest("DELETE", uri, nil) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r WorkerRouteResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// ListWorkerRoutes returns list of worker routes +// +// API reference: https://api.cloudflare.com/#worker-filters-list-filters +func (api *API) ListWorkerRoutes(zoneID string) (WorkerRoutesResponse, error) { + pathComponent := "filters" + if api.AccountID != "" { + pathComponent = "routes" + } + uri := "/zones/" + zoneID + "/workers/" + pathComponent + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return WorkerRoutesResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r WorkerRoutesResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WorkerRoutesResponse{}, errors.Wrap(err, errUnmarshalError) + } + for i := range r.Routes { + route := &r.Routes[i] + // The Enabled flag will not be set in the multi-script API response + // so we manually set it to true if the script name is not empty + // in case any multi-script customers rely on the Enabled field + if route.Script != "" { + route.Enabled = true + } + } + return r, nil +} + +// UpdateWorkerRoute updates worker route for a zone. +// +// API reference: https://api.cloudflare.com/#worker-filters-update-filter +func (api *API) UpdateWorkerRoute(zoneID string, routeID string, route WorkerRoute) (WorkerRouteResponse, error) { + // Check whether a script name is defined in order to determine whether + // to use the single-script or multi-script endpoint. + pathComponent := "filters" + if route.Script != "" { + if api.AccountID == "" { + return WorkerRouteResponse{}, errors.New("account ID required for enterprise only request") + } + pathComponent = "routes" + } + uri := "/zones/" + zoneID + "/workers/" + pathComponent + "/" + routeID + res, err := api.makeRequest("PUT", uri, route) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r WorkerRouteResponse + err = json.Unmarshal(res, &r) + if err != nil { + return WorkerRouteResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/workers_kv.go b/vendor/github.com/cloudflare/cloudflare-go/workers_kv.go new file mode 100644 index 000000000..92197af08 --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/workers_kv.go @@ -0,0 +1,192 @@ +package cloudflare + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/pkg/errors" +) + +// WorkersKVNamespaceRequest provides parameters for creating and updating storage namespaces +type WorkersKVNamespaceRequest struct { + Title string `json:"title"` +} + +// WorkersKVNamespaceResponse is the response received when creating storage namespaces +type WorkersKVNamespaceResponse struct { + Response + Result WorkersKVNamespace `json:"result"` +} + +// WorkersKVNamespace contains the unique identifier and title of a storage namespace +type WorkersKVNamespace struct { + ID string `json:"id"` + Title string `json:"title"` +} + +// ListWorkersKVNamespacesResponse contains a slice of storage namespaces associated with an +// account, pagination information, and an embedded response struct +type ListWorkersKVNamespacesResponse struct { + Response + Result []WorkersKVNamespace `json:"result"` + ResultInfo `json:"result_info"` +} + +// StorageKey is a key name used to identify a storage value +type StorageKey struct { + Name string `json:"name"` +} + +// ListStorageKeysResponse contains a slice of keys belonging to a storage namespace, +// pagination information, and an embedded response struct +type ListStorageKeysResponse struct { + Response + Result []StorageKey `json:"result"` + ResultInfo `json:"result_info"` +} + +// CreateWorkersKVNamespace creates a namespace under the given title. +// A 400 is returned if the account already owns a namespace with this title. +// A namespace must be explicitly deleted to be replaced. +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-create-a-namespace +func (api *API) CreateWorkersKVNamespace(ctx context.Context, req *WorkersKVNamespaceRequest) (WorkersKVNamespaceResponse, error) { + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces", api.AccountID) + res, err := api.makeRequestContext(ctx, http.MethodPost, uri, req) + if err != nil { + return WorkersKVNamespaceResponse{}, errors.Wrap(err, errMakeRequestError) + } + + result := WorkersKVNamespaceResponse{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + + return result, err +} + +// ListWorkersKVNamespaces lists storage namespaces +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-list-namespaces +func (api *API) ListWorkersKVNamespaces(ctx context.Context) (ListWorkersKVNamespacesResponse, error) { + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces", api.AccountID) + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return ListWorkersKVNamespacesResponse{}, errors.Wrap(err, errMakeRequestError) + } + + result := ListWorkersKVNamespacesResponse{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + + return result, err +} + +// DeleteWorkersKVNamespace deletes the namespace corresponding to the given ID +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-remove-a-namespace +func (api *API) DeleteWorkersKVNamespace(ctx context.Context, namespaceID string) (Response, error) { + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s", api.AccountID, namespaceID) + res, err := api.makeRequestContext(ctx, http.MethodDelete, uri, nil) + if err != nil { + return Response{}, errors.Wrap(err, errMakeRequestError) + } + + result := Response{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + + return result, err +} + +// UpdateWorkersKVNamespace modifies a namespace's title +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-rename-a-namespace +func (api *API) UpdateWorkersKVNamespace(ctx context.Context, namespaceID string, req *WorkersKVNamespaceRequest) (Response, error) { + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s", api.AccountID, namespaceID) + res, err := api.makeRequestContext(ctx, http.MethodPut, uri, req) + if err != nil { + return Response{}, errors.Wrap(err, errMakeRequestError) + } + + result := Response{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + + return result, err +} + +// WriteWorkersKV writes a value identified by a key. +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-write-key-value-pair +func (api *API) WriteWorkersKV(ctx context.Context, namespaceID, key string, value []byte) (Response, error) { + key = url.PathEscape(key) + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s/values/%s", api.AccountID, namespaceID, key) + res, err := api.makeRequestWithAuthTypeAndHeaders( + ctx, http.MethodPut, uri, value, api.authType, http.Header{"Content-Type": []string{"application/octet-stream"}}, + ) + if err != nil { + return Response{}, errors.Wrap(err, errMakeRequestError) + } + + result := Response{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + + return result, err +} + +// ReadWorkersKV returns the value associated with the given key in the given namespace +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-read-key-value-pair +func (api API) ReadWorkersKV(ctx context.Context, namespaceID, key string) ([]byte, error) { + key = url.PathEscape(key) + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s/values/%s", api.AccountID, namespaceID, key) + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + return res, nil +} + +// DeleteWorkersKV deletes a key and value for a provided storage namespace +// +// API reference: https://api.cloudflare.com/#workers-kv-namespace-delete-key-value-pair +func (api API) DeleteWorkersKV(ctx context.Context, namespaceID, key string) (Response, error) { + key = url.PathEscape(key) + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s/values/%s", api.AccountID, namespaceID, key) + res, err := api.makeRequestContext(ctx, http.MethodDelete, uri, nil) + if err != nil { + return Response{}, errors.Wrap(err, errMakeRequestError) + } + + result := Response{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + return result, err +} + +// ListWorkersKVs lists a namespace's keys +// +// API Reference: https://api.cloudflare.com/#workers-kv-namespace-list-a-namespace-s-keys +func (api API) ListWorkersKVs(ctx context.Context, namespaceID string) (ListStorageKeysResponse, error) { + uri := fmt.Sprintf("/accounts/%s/storage/kv/namespaces/%s/keys", api.AccountID, namespaceID) + res, err := api.makeRequestContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return ListStorageKeysResponse{}, errors.Wrap(err, errMakeRequestError) + } + + result := ListStorageKeysResponse{} + if err := json.Unmarshal(res, &result); err != nil { + return result, errors.Wrap(err, errUnmarshalError) + } + return result, err +} diff --git a/vendor/github.com/cloudflare/cloudflare-go/zone.go b/vendor/github.com/cloudflare/cloudflare-go/zone.go new file mode 100644 index 000000000..28f54a5da --- /dev/null +++ b/vendor/github.com/cloudflare/cloudflare-go/zone.go @@ -0,0 +1,740 @@ +package cloudflare + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "sync" + "time" + + "github.com/pkg/errors" +) + +// Owner describes the resource owner. +type Owner struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + OwnerType string `json:"type"` +} + +// Zone describes a Cloudflare zone. +type Zone struct { + ID string `json:"id"` + Name string `json:"name"` + // DevMode contains the time in seconds until development expires (if + // positive) or since it expired (if negative). It will be 0 if never used. + DevMode int `json:"development_mode"` + OriginalNS []string `json:"original_name_servers"` + OriginalRegistrar string `json:"original_registrar"` + OriginalDNSHost string `json:"original_dnshost"` + CreatedOn time.Time `json:"created_on"` + ModifiedOn time.Time `json:"modified_on"` + NameServers []string `json:"name_servers"` + Owner Owner `json:"owner"` + Permissions []string `json:"permissions"` + Plan ZonePlan `json:"plan"` + PlanPending ZonePlan `json:"plan_pending,omitempty"` + Status string `json:"status"` + Paused bool `json:"paused"` + Type string `json:"type"` + Host struct { + Name string + Website string + } `json:"host"` + VanityNS []string `json:"vanity_name_servers"` + Betas []string `json:"betas"` + DeactReason string `json:"deactivation_reason"` + Meta ZoneMeta `json:"meta"` + Account Account `json:"account"` +} + +// ZoneMeta describes metadata about a zone. +type ZoneMeta struct { + // custom_certificate_quota is broken - sometimes it's a string, sometimes a number! + // CustCertQuota int `json:"custom_certificate_quota"` + PageRuleQuota int `json:"page_rule_quota"` + WildcardProxiable bool `json:"wildcard_proxiable"` + PhishingDetected bool `json:"phishing_detected"` +} + +// ZonePlan contains the plan information for a zone. +type ZonePlan struct { + ZonePlanCommon + IsSubscribed bool `json:"is_subscribed"` + CanSubscribe bool `json:"can_subscribe"` + LegacyID string `json:"legacy_id"` + LegacyDiscount bool `json:"legacy_discount"` + ExternallyManaged bool `json:"externally_managed"` +} + +// ZoneRatePlan contains the plan information for a zone. +type ZoneRatePlan struct { + ZonePlanCommon + Components []zoneRatePlanComponents `json:"components,omitempty"` +} + +// ZonePlanCommon contains fields used by various Plan endpoints +type ZonePlanCommon struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Price int `json:"price,omitempty"` + Currency string `json:"currency,omitempty"` + Frequency string `json:"frequency,omitempty"` +} + +type zoneRatePlanComponents struct { + Name string `json:"name"` + Default int `json:"Default"` + UnitPrice int `json:"unit_price"` +} + +// ZoneID contains only the zone ID. +type ZoneID struct { + ID string `json:"id"` +} + +// ZoneResponse represents the response from the Zone endpoint containing a single zone. +type ZoneResponse struct { + Response + Result Zone `json:"result"` +} + +// ZonesResponse represents the response from the Zone endpoint containing an array of zones. +type ZonesResponse struct { + Response + Result []Zone `json:"result"` + ResultInfo `json:"result_info"` +} + +// ZoneIDResponse represents the response from the Zone endpoint, containing only a zone ID. +type ZoneIDResponse struct { + Response + Result ZoneID `json:"result"` +} + +// AvailableZoneRatePlansResponse represents the response from the Available Rate Plans endpoint. +type AvailableZoneRatePlansResponse struct { + Response + Result []ZoneRatePlan `json:"result"` + ResultInfo `json:"result_info"` +} + +// AvailableZonePlansResponse represents the response from the Available Plans endpoint. +type AvailableZonePlansResponse struct { + Response + Result []ZonePlan `json:"result"` + ResultInfo +} + +// ZoneRatePlanResponse represents the response from the Plan Details endpoint. +type ZoneRatePlanResponse struct { + Response + Result ZoneRatePlan `json:"result"` +} + +// ZoneSetting contains settings for a zone. +type ZoneSetting struct { + ID string `json:"id"` + Editable bool `json:"editable"` + ModifiedOn string `json:"modified_on"` + Value interface{} `json:"value"` + TimeRemaining int `json:"time_remaining"` +} + +// ZoneSettingResponse represents the response from the Zone Setting endpoint. +type ZoneSettingResponse struct { + Response + Result []ZoneSetting `json:"result"` +} + +// ZoneSSLSetting contains ssl setting for a zone. +type ZoneSSLSetting struct { + ID string `json:"id"` + Editable bool `json:"editable"` + ModifiedOn string `json:"modified_on"` + Value string `json:"value"` + CertificateStatus string `json:"certificate_status"` +} + +// ZoneSSLSettingResponse represents the response from the Zone SSL Setting +// endpoint. +type ZoneSSLSettingResponse struct { + Response + Result ZoneSSLSetting `json:"result"` +} + +// ZoneAnalyticsData contains totals and timeseries analytics data for a zone. +type ZoneAnalyticsData struct { + Totals ZoneAnalytics `json:"totals"` + Timeseries []ZoneAnalytics `json:"timeseries"` +} + +// zoneAnalyticsDataResponse represents the response from the Zone Analytics Dashboard endpoint. +type zoneAnalyticsDataResponse struct { + Response + Result ZoneAnalyticsData `json:"result"` +} + +// ZoneAnalyticsColocation contains analytics data by datacenter. +type ZoneAnalyticsColocation struct { + ColocationID string `json:"colo_id"` + Timeseries []ZoneAnalytics `json:"timeseries"` +} + +// zoneAnalyticsColocationResponse represents the response from the Zone Analytics By Co-location endpoint. +type zoneAnalyticsColocationResponse struct { + Response + Result []ZoneAnalyticsColocation `json:"result"` +} + +// ZoneAnalytics contains analytics data for a zone. +type ZoneAnalytics struct { + Since time.Time `json:"since"` + Until time.Time `json:"until"` + Requests struct { + All int `json:"all"` + Cached int `json:"cached"` + Uncached int `json:"uncached"` + ContentType map[string]int `json:"content_type"` + Country map[string]int `json:"country"` + SSL struct { + Encrypted int `json:"encrypted"` + Unencrypted int `json:"unencrypted"` + } `json:"ssl"` + HTTPStatus map[string]int `json:"http_status"` + } `json:"requests"` + Bandwidth struct { + All int `json:"all"` + Cached int `json:"cached"` + Uncached int `json:"uncached"` + ContentType map[string]int `json:"content_type"` + Country map[string]int `json:"country"` + SSL struct { + Encrypted int `json:"encrypted"` + Unencrypted int `json:"unencrypted"` + } `json:"ssl"` + } `json:"bandwidth"` + Threats struct { + All int `json:"all"` + Country map[string]int `json:"country"` + Type map[string]int `json:"type"` + } `json:"threats"` + Pageviews struct { + All int `json:"all"` + SearchEngines map[string]int `json:"search_engines"` + } `json:"pageviews"` + Uniques struct { + All int `json:"all"` + } +} + +// ZoneAnalyticsOptions represents the optional parameters in Zone Analytics +// endpoint requests. +type ZoneAnalyticsOptions struct { + Since *time.Time + Until *time.Time + Continuous *bool +} + +// PurgeCacheRequest represents the request format made to the purge endpoint. +type PurgeCacheRequest struct { + Everything bool `json:"purge_everything,omitempty"` + // Purge by filepath (exact match). Limit of 30 + Files []string `json:"files,omitempty"` + // Purge by Tag (Enterprise only): + // https://support.cloudflare.com/hc/en-us/articles/206596608-How-to-Purge-Cache-Using-Cache-Tags-Enterprise-only- + Tags []string `json:"tags,omitempty"` + // Purge by hostname - e.g. "assets.example.com" + Hosts []string `json:"hosts,omitempty"` +} + +// PurgeCacheResponse represents the response from the purge endpoint. +type PurgeCacheResponse struct { + Response + Result struct { + ID string `json:"id"` + } `json:"result"` +} + +// newZone describes a new zone. +type newZone struct { + Name string `json:"name"` + JumpStart bool `json:"jump_start"` + Type string `json:"type"` + // We use a pointer to get a nil type when the field is empty. + // This allows us to completely omit this with json.Marshal(). + Account *Account `json:"organization,omitempty"` +} + +// FallbackOrigin describes a fallback origin +type FallbackOrigin struct { + Value string `json:"value"` + ID string `json:"id,omitempty"` +} + +// FallbackOriginResponse represents the response from the fallback_origin endpoint +type FallbackOriginResponse struct { + Response + Result FallbackOrigin `json:"result"` +} + +// CreateZone creates a zone on an account. +// +// Setting jumpstart to true will attempt to automatically scan for existing +// DNS records. Setting this to false will create the zone with no DNS records. +// +// If account is non-empty, it must have at least the ID field populated. +// This will add the new zone to the specified multi-user account. +// +// API reference: https://api.cloudflare.com/#zone-create-a-zone +func (api *API) CreateZone(name string, jumpstart bool, account Account, zoneType string) (Zone, error) { + var newzone newZone + newzone.Name = name + newzone.JumpStart = jumpstart + if account.ID != "" { + newzone.Account = &account + } + + if zoneType == "partial" { + newzone.Type = "partial" + } else { + newzone.Type = "full" + } + + res, err := api.makeRequest("POST", "/zones", newzone) + if err != nil { + return Zone{}, errors.Wrap(err, errMakeRequestError) + } + + var r ZoneResponse + err = json.Unmarshal(res, &r) + if err != nil { + return Zone{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ZoneActivationCheck initiates another zone activation check for newly-created zones. +// +// API reference: https://api.cloudflare.com/#zone-initiate-another-zone-activation-check +func (api *API) ZoneActivationCheck(zoneID string) (Response, error) { + res, err := api.makeRequest("PUT", "/zones/"+zoneID+"/activation_check", nil) + if err != nil { + return Response{}, errors.Wrap(err, errMakeRequestError) + } + var r Response + err = json.Unmarshal(res, &r) + if err != nil { + return Response{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// ListZones lists zones on an account. Optionally takes a list of zone names +// to filter against. +// +// API reference: https://api.cloudflare.com/#zone-list-zones +func (api *API) ListZones(z ...string) ([]Zone, error) { + v := url.Values{} + var res []byte + var r ZonesResponse + var zones []Zone + var err error + if len(z) > 0 { + for _, zone := range z { + v.Set("name", zone) + res, err = api.makeRequest("GET", "/zones?"+v.Encode(), nil) + if err != nil { + return []Zone{}, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return []Zone{}, errors.Wrap(err, errUnmarshalError) + } + if !r.Success { + // TODO: Provide an actual error message instead of always returning nil + return []Zone{}, err + } + for zi := range r.Result { + zones = append(zones, r.Result[zi]) + } + } + } else { + res, err = api.makeRequest("GET", "/zones?per_page=50", nil) + if err != nil { + return []Zone{}, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return []Zone{}, errors.Wrap(err, errUnmarshalError) + } + + totalPageCount := r.TotalPages + var wg sync.WaitGroup + wg.Add(totalPageCount) + errc := make(chan error) + + for i := 1; i <= totalPageCount; i++ { + go func(pageNumber int) error { + res, err = api.makeRequest("GET", fmt.Sprintf("/zones?per_page=50&page=%d", pageNumber), nil) + if err != nil { + errc <- err + } + + err = json.Unmarshal(res, &r) + if err != nil { + errc <- err + } + + for _, zone := range r.Result { + zones = append(zones, zone) + } + + select { + case err := <-errc: + return err + default: + wg.Done() + } + + return nil + }(i) + } + + wg.Wait() + } + + return zones, nil +} + +// ListZonesContext lists zones on an account. Optionally takes a list of ReqOptions. +func (api *API) ListZonesContext(ctx context.Context, opts ...ReqOption) (r ZonesResponse, err error) { + var res []byte + opt := reqOption{ + params: url.Values{}, + } + for _, of := range opts { + of(&opt) + } + + res, err = api.makeRequestContext(ctx, "GET", "/zones?"+opt.params.Encode(), nil) + if err != nil { + return ZonesResponse{}, errors.Wrap(err, errMakeRequestError) + } + err = json.Unmarshal(res, &r) + if err != nil { + return ZonesResponse{}, errors.Wrap(err, errUnmarshalError) + } + + return r, nil +} + +// ZoneDetails fetches information about a zone. +// +// API reference: https://api.cloudflare.com/#zone-zone-details +func (api *API) ZoneDetails(zoneID string) (Zone, error) { + res, err := api.makeRequest("GET", "/zones/"+zoneID, nil) + if err != nil { + return Zone{}, errors.Wrap(err, errMakeRequestError) + } + var r ZoneResponse + err = json.Unmarshal(res, &r) + if err != nil { + return Zone{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ZoneOptions is a subset of Zone, for editable options. +type ZoneOptions struct { + Paused *bool `json:"paused,omitempty"` + VanityNS []string `json:"vanity_name_servers,omitempty"` + Plan *ZonePlan `json:"plan,omitempty"` +} + +// ZoneSetPaused pauses Cloudflare service for the entire zone, sending all +// traffic direct to the origin. +func (api *API) ZoneSetPaused(zoneID string, paused bool) (Zone, error) { + zoneopts := ZoneOptions{Paused: &paused} + zone, err := api.EditZone(zoneID, zoneopts) + if err != nil { + return Zone{}, err + } + + return zone, nil +} + +// ZoneSetVanityNS sets custom nameservers for the zone. +// These names must be within the same zone. +func (api *API) ZoneSetVanityNS(zoneID string, ns []string) (Zone, error) { + zoneopts := ZoneOptions{VanityNS: ns} + zone, err := api.EditZone(zoneID, zoneopts) + if err != nil { + return Zone{}, err + } + + return zone, nil +} + +// ZoneSetPlan changes the zone plan. +func (api *API) ZoneSetPlan(zoneID string, plan ZonePlan) (Zone, error) { + zoneopts := ZoneOptions{Plan: &plan} + zone, err := api.EditZone(zoneID, zoneopts) + if err != nil { + return Zone{}, err + } + + return zone, nil +} + +// EditZone edits the given zone. +// +// This is usually called by ZoneSetPaused, ZoneSetVanityNS or ZoneSetPlan. +// +// API reference: https://api.cloudflare.com/#zone-edit-zone-properties +func (api *API) EditZone(zoneID string, zoneOpts ZoneOptions) (Zone, error) { + res, err := api.makeRequest("PATCH", "/zones/"+zoneID, zoneOpts) + if err != nil { + return Zone{}, errors.Wrap(err, errMakeRequestError) + } + var r ZoneResponse + err = json.Unmarshal(res, &r) + if err != nil { + return Zone{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// PurgeEverything purges the cache for the given zone. +// +// Note: this will substantially increase load on the origin server for that +// zone if there is a high cached vs. uncached request ratio. +// +// API reference: https://api.cloudflare.com/#zone-purge-all-files +func (api *API) PurgeEverything(zoneID string) (PurgeCacheResponse, error) { + uri := "/zones/" + zoneID + "/purge_cache" + res, err := api.makeRequest("POST", uri, PurgeCacheRequest{true, nil, nil, nil}) + if err != nil { + return PurgeCacheResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r PurgeCacheResponse + err = json.Unmarshal(res, &r) + if err != nil { + return PurgeCacheResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// PurgeCache purges the cache using the given PurgeCacheRequest (zone/url/tag). +// +// API reference: https://api.cloudflare.com/#zone-purge-individual-files-by-url-and-cache-tags +func (api *API) PurgeCache(zoneID string, pcr PurgeCacheRequest) (PurgeCacheResponse, error) { + uri := "/zones/" + zoneID + "/purge_cache" + res, err := api.makeRequest("POST", uri, pcr) + if err != nil { + return PurgeCacheResponse{}, errors.Wrap(err, errMakeRequestError) + } + var r PurgeCacheResponse + err = json.Unmarshal(res, &r) + if err != nil { + return PurgeCacheResponse{}, errors.Wrap(err, errUnmarshalError) + } + return r, nil +} + +// DeleteZone deletes the given zone. +// +// API reference: https://api.cloudflare.com/#zone-delete-a-zone +func (api *API) DeleteZone(zoneID string) (ZoneID, error) { + res, err := api.makeRequest("DELETE", "/zones/"+zoneID, nil) + if err != nil { + return ZoneID{}, errors.Wrap(err, errMakeRequestError) + } + var r ZoneIDResponse + err = json.Unmarshal(res, &r) + if err != nil { + return ZoneID{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// AvailableZoneRatePlans returns information about all plans available to the specified zone. +// +// API reference: https://api.cloudflare.com/#zone-plan-available-plans +func (api *API) AvailableZoneRatePlans(zoneID string) ([]ZoneRatePlan, error) { + uri := "/zones/" + zoneID + "/available_rate_plans" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []ZoneRatePlan{}, errors.Wrap(err, errMakeRequestError) + } + var r AvailableZoneRatePlansResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []ZoneRatePlan{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// AvailableZonePlans returns information about all plans available to the specified zone. +// +// API reference: https://api.cloudflare.com/#zone-rate-plan-list-available-plans +func (api *API) AvailableZonePlans(zoneID string) ([]ZonePlan, error) { + uri := "/zones/" + zoneID + "/available_plans" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return []ZonePlan{}, errors.Wrap(err, errMakeRequestError) + } + var r AvailableZonePlansResponse + err = json.Unmarshal(res, &r) + if err != nil { + return []ZonePlan{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// encode encodes non-nil fields into URL encoded form. +func (o ZoneAnalyticsOptions) encode() string { + v := url.Values{} + if o.Since != nil { + v.Set("since", (*o.Since).Format(time.RFC3339)) + } + if o.Until != nil { + v.Set("until", (*o.Until).Format(time.RFC3339)) + } + if o.Continuous != nil { + v.Set("continuous", fmt.Sprintf("%t", *o.Continuous)) + } + return v.Encode() +} + +// ZoneAnalyticsDashboard returns zone analytics information. +// +// API reference: https://api.cloudflare.com/#zone-analytics-dashboard +func (api *API) ZoneAnalyticsDashboard(zoneID string, options ZoneAnalyticsOptions) (ZoneAnalyticsData, error) { + uri := "/zones/" + zoneID + "/analytics/dashboard" + "?" + options.encode() + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ZoneAnalyticsData{}, errors.Wrap(err, errMakeRequestError) + } + var r zoneAnalyticsDataResponse + err = json.Unmarshal(res, &r) + if err != nil { + return ZoneAnalyticsData{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ZoneAnalyticsByColocation returns zone analytics information by datacenter. +// +// API reference: https://api.cloudflare.com/#zone-analytics-analytics-by-co-locations +func (api *API) ZoneAnalyticsByColocation(zoneID string, options ZoneAnalyticsOptions) ([]ZoneAnalyticsColocation, error) { + uri := "/zones/" + zoneID + "/analytics/colos" + "?" + options.encode() + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + var r zoneAnalyticsColocationResponse + err = json.Unmarshal(res, &r) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// ZoneSettings returns all of the settings for a given zone. +// +// API reference: https://api.cloudflare.com/#zone-settings-get-all-zone-settings +func (api *API) ZoneSettings(zoneID string) (*ZoneSettingResponse, error) { + uri := "/zones/" + zoneID + "/settings" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneSettingResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// UpdateZoneSettings updates the settings for a given zone. +// +// API reference: https://api.cloudflare.com/#zone-settings-edit-zone-settings-info +func (api *API) UpdateZoneSettings(zoneID string, settings []ZoneSetting) (*ZoneSettingResponse, error) { + uri := "/zones/" + zoneID + "/settings" + res, err := api.makeRequest("PATCH", uri, struct { + Items []ZoneSetting `json:"items"` + }{settings}) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &ZoneSettingResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} + +// ZoneSSLSettings returns information about SSL setting to the specified zone. +// +// API reference: https://api.cloudflare.com/#zone-settings-get-ssl-setting +func (api *API) ZoneSSLSettings(zoneID string) (ZoneSSLSetting, error) { + uri := "/zones/" + zoneID + "/settings/ssl" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return ZoneSSLSetting{}, errors.Wrap(err, errMakeRequestError) + } + var r ZoneSSLSettingResponse + err = json.Unmarshal(res, &r) + if err != nil { + return ZoneSSLSetting{}, errors.Wrap(err, errUnmarshalError) + } + return r.Result, nil +} + +// FallbackOrigin returns information about the fallback origin for the specified zone. +// +// API reference: https://developers.cloudflare.com/ssl/ssl-for-saas/api-calls/#fallback-origin-configuration +func (api *API) FallbackOrigin(zoneID string) (FallbackOrigin, error) { + uri := "/zones/" + zoneID + "/fallback_origin" + res, err := api.makeRequest("GET", uri, nil) + if err != nil { + return FallbackOrigin{}, errors.Wrap(err, errMakeRequestError) + } + + var r FallbackOriginResponse + err = json.Unmarshal(res, &r) + if err != nil { + return FallbackOrigin{}, errors.Wrap(err, errUnmarshalError) + } + + return r.Result, nil +} + +// UpdateFallbackOrigin updates the fallback origin for a given zone. +// +// API reference: https://developers.cloudflare.com/ssl/ssl-for-saas/api-calls/#4-example-patch-to-change-fallback-origin +func (api *API) UpdateFallbackOrigin(zoneID string, fbo FallbackOrigin) (*FallbackOriginResponse, error) { + uri := "/zones/" + zoneID + "/fallback_origin" + res, err := api.makeRequest("PATCH", uri, fbo) + if err != nil { + return nil, errors.Wrap(err, errMakeRequestError) + } + + response := &FallbackOriginResponse{} + err = json.Unmarshal(res, &response) + if err != nil { + return nil, errors.Wrap(err, errUnmarshalError) + } + + return response, nil +} diff --git a/vendor/github.com/karalabe/usb/appveyor.yml b/vendor/github.com/karalabe/usb/appveyor.yml index 1d921ae51..73a9664ae 100644 --- a/vendor/github.com/karalabe/usb/appveyor.yml +++ b/vendor/github.com/karalabe/usb/appveyor.yml @@ -22,8 +22,8 @@ environment: install: - rmdir C:\go /s /q - - appveyor DownloadFile https://storage.googleapis.com/golang/go1.12.6.windows-%GOARCH%.zip - - 7z x go1.12.6.windows-%GOARCH%.zip -y -oC:\ > NUL + - appveyor DownloadFile https://storage.googleapis.com/golang/go1.12.9.windows-%GOARCH%.zip + - 7z x go1.12.9.windows-%GOARCH%.zip -y -oC:\ > NUL - go version - gcc --version diff --git a/vendor/github.com/karalabe/usb/hidapi/windows/hid.c b/vendor/github.com/karalabe/usb/hidapi/windows/hid.c index 4e92cc8bc..60da64608 100644 --- a/vendor/github.com/karalabe/usb/hidapi/windows/hid.c +++ b/vendor/github.com/karalabe/usb/hidapi/windows/hid.c @@ -74,6 +74,8 @@ extern "C" { #pragma warning(disable:4996) #endif +#pragma GCC diagnostic ignored "-Wstringop-overflow" + #ifdef __cplusplus extern "C" { #endif @@ -428,7 +430,7 @@ struct hid_device_info HID_API_EXPORT * HID_API_CALL hid_enumerate(unsigned shor if (str) { len = strlen(str); cur_dev->path = (char*) calloc(len+1, sizeof(char)); - strncpy(cur_dev->path, str, sizeof(cur_dev->path)); + strncpy(cur_dev->path, str, len+1); cur_dev->path[len] = '\0'; } else diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/batch.go b/vendor/github.com/syndtr/goleveldb/leveldb/batch.go index 225920002..823be93f9 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/batch.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/batch.go @@ -238,6 +238,11 @@ func newBatch() interface{} { return &Batch{} } +// MakeBatch returns empty batch with preallocated buffer. +func MakeBatch(n int) *Batch { + return &Batch{data: make([]byte, 0, n)} +} + func decodeBatch(data []byte, fn func(i int, index batchIndex) error) error { var index batchIndex for i, o := 0, 0; o < len(data); i++ { diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db.go b/vendor/github.com/syndtr/goleveldb/leveldb/db.go index 0de5ffe8d..74e982695 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/db.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db.go @@ -38,6 +38,12 @@ type DB struct { inWritePaused int32 // The indicator whether write operation is paused by compaction aliveSnaps, aliveIters int32 + // Compaction statistic + memComp uint32 // The cumulative number of memory compaction + level0Comp uint32 // The cumulative number of level0 compaction + nonLevel0Comp uint32 // The cumulative number of non-level0 compaction + seekComp uint32 // The cumulative number of seek compaction + // Session. s *session @@ -978,6 +984,8 @@ func (db *DB) GetProperty(name string) (value string, err error) { value += fmt.Sprintf(" Total | %10d | %13.5f | %13.5f | %13.5f | %13.5f\n", totalTables, float64(totalSize)/1048576.0, totalDuration.Seconds(), float64(totalRead)/1048576.0, float64(totalWrite)/1048576.0) + case p == "compcount": + value = fmt.Sprintf("MemComp:%d Level0Comp:%d NonLevel0Comp:%d SeekComp:%d", atomic.LoadUint32(&db.memComp), atomic.LoadUint32(&db.level0Comp), atomic.LoadUint32(&db.nonLevel0Comp), atomic.LoadUint32(&db.seekComp)) case p == "iostats": value = fmt.Sprintf("Read(MB):%.5f Write(MB):%.5f", float64(db.s.stor.reads())/1048576.0, @@ -1034,6 +1042,11 @@ type DBStats struct { LevelRead Sizes LevelWrite Sizes LevelDurations []time.Duration + + MemComp uint32 + Level0Comp uint32 + NonLevel0Comp uint32 + SeekComp uint32 } // Stats populates s with database statistics. @@ -1070,16 +1083,17 @@ func (db *DB) Stats(s *DBStats) error { for level, tables := range v.levels { duration, read, write := db.compStats.getStat(level) - if len(tables) == 0 && duration == 0 { - continue - } + s.LevelDurations = append(s.LevelDurations, duration) s.LevelRead = append(s.LevelRead, read) s.LevelWrite = append(s.LevelWrite, write) s.LevelSizes = append(s.LevelSizes, tables.size()) s.LevelTablesCounts = append(s.LevelTablesCounts, len(tables)) } - + s.MemComp = atomic.LoadUint32(&db.memComp) + s.Level0Comp = atomic.LoadUint32(&db.level0Comp) + s.NonLevel0Comp = atomic.LoadUint32(&db.nonLevel0Comp) + s.SeekComp = atomic.LoadUint32(&db.seekComp) return nil } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go index 56f3632a7..6b70eb2c9 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_compaction.go @@ -8,6 +8,7 @@ package leveldb import ( "sync" + "sync/atomic" "time" "github.com/syndtr/goleveldb/leveldb/errors" @@ -324,10 +325,12 @@ func (db *DB) memCompaction() { db.logf("memdb@flush committed F·%d T·%v", len(rec.addedTables), stats.duration) + // Save compaction stats for _, r := range rec.addedTables { stats.write += r.size } db.compStats.addStat(flushLevel, stats) + atomic.AddUint32(&db.memComp, 1) // Drop frozen memdb. db.dropFrozenMem() @@ -588,6 +591,14 @@ func (db *DB) tableCompaction(c *compaction, noTrivial bool) { for i := range stats { db.compStats.addStat(c.sourceLevel+1, &stats[i]) } + switch c.typ { + case level0Compaction: + atomic.AddUint32(&db.level0Comp, 1) + case nonLevel0Compaction: + atomic.AddUint32(&db.nonLevel0Comp, 1) + case seekCompaction: + atomic.AddUint32(&db.seekComp, 1) + } } func (db *DB) tableRangeCompaction(level int, umin, umax []byte) error { diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go index 03c24cdab..e6e8ca59d 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_iter.go @@ -78,13 +78,17 @@ func (db *DB) newIterator(auxm *memDB, auxt tFiles, seq uint64, slice *util.Rang } rawIter := db.newRawIterator(auxm, auxt, islice, ro) iter := &dbIter{ - db: db, - icmp: db.s.icmp, - iter: rawIter, - seq: seq, - strict: opt.GetStrict(db.s.o.Options, ro, opt.StrictReader), - key: make([]byte, 0), - value: make([]byte, 0), + db: db, + icmp: db.s.icmp, + iter: rawIter, + seq: seq, + strict: opt.GetStrict(db.s.o.Options, ro, opt.StrictReader), + disableSampling: db.s.o.GetDisableSeeksCompaction() || db.s.o.GetIteratorSamplingRate() <= 0, + key: make([]byte, 0), + value: make([]byte, 0), + } + if !iter.disableSampling { + iter.samplingGap = db.iterSamplingRate() } atomic.AddInt32(&db.aliveIters, 1) runtime.SetFinalizer(iter, (*dbIter).Release) @@ -107,13 +111,14 @@ const ( // dbIter represent an interator states over a database session. type dbIter struct { - db *DB - icmp *iComparer - iter iterator.Iterator - seq uint64 - strict bool + db *DB + icmp *iComparer + iter iterator.Iterator + seq uint64 + strict bool + disableSampling bool - smaplingGap int + samplingGap int dir dir key []byte value []byte @@ -122,10 +127,14 @@ type dbIter struct { } func (i *dbIter) sampleSeek() { + if i.disableSampling { + return + } + ikey := i.iter.Key() - i.smaplingGap -= len(ikey) + len(i.iter.Value()) - for i.smaplingGap < 0 { - i.smaplingGap += i.db.iterSamplingRate() + i.samplingGap -= len(ikey) + len(i.iter.Value()) + for i.samplingGap < 0 { + i.samplingGap += i.db.iterSamplingRate() i.db.sampleSeek(ikey) } } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go index f145b64fb..21d1e512f 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/db_transaction.go @@ -69,6 +69,9 @@ func (tr *Transaction) Has(key []byte, ro *opt.ReadOptions) (bool, error) { // DB. And a nil Range.Limit is treated as a key after all keys in // the DB. // +// The returned iterator has locks on its own resources, so it can live beyond +// the lifetime of the transaction who creates them. +// // WARNING: Any slice returned by interator (e.g. slice returned by calling // Iterator.Key() or Iterator.Key() methods), its content should not be modified // unless noted otherwise. @@ -252,13 +255,14 @@ func (tr *Transaction) discard() { // Discard transaction. for _, t := range tr.tables { tr.db.logf("transaction@discard @%d", t.fd.Num) - if err1 := tr.db.s.stor.Remove(t.fd); err1 == nil { - tr.db.s.reuseFileNum(t.fd.Num) - } + // Iterator may still use the table, so we use tOps.remove here. + tr.db.s.tops.remove(t.fd) } } // Discard discards the transaction. +// This method is noop if transaction is already closed (either committed or +// discarded) // // Other methods should not be called after transaction has been discarded. func (tr *Transaction) Discard() { @@ -282,8 +286,10 @@ func (db *DB) waitCompaction() error { // until in-flight transaction is committed or discarded. // The returned transaction handle is safe for concurrent use. // -// Transaction is expensive and can overwhelm compaction, especially if +// Transaction is very expensive and can overwhelm compaction, especially if // transaction size is small. Use with caution. +// The rule of thumb is if you need to merge at least same amount of +// `Options.WriteBuffer` worth of data then use transaction, otherwise don't. // // The transaction must be closed once done, either by committing or discarding // the transaction. diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go index bab0e9970..56ccbfbec 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/filter/bloom.go @@ -16,7 +16,7 @@ func bloomHash(key []byte) uint32 { type bloomFilter int -// The bloom filter serializes its parameters and is backward compatible +// Name: The bloom filter serializes its parameters and is backward compatible // with respect to them. Therefor, its parameters are not added to its // name. func (bloomFilter) Name() string { diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go index b661c08a9..824e47f5f 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/memdb/memdb.go @@ -397,6 +397,10 @@ func (p *DB) Find(key []byte) (rkey, value []byte, err error) { // DB. And a nil Range.Limit is treated as a key after all keys in // the DB. // +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// // The iterator must be released after use, by calling Release method. // // Also read Iterator documentation of the leveldb/iterator package. diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go b/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go index 528b16423..c02c1e978 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/opt/options.go @@ -278,6 +278,14 @@ type Options struct { // The default is false. DisableLargeBatchTransaction bool + // DisableSeeksCompaction allows disabling 'seeks triggered compaction'. + // The purpose of 'seeks triggered compaction' is to optimize database so + // that 'level seeks' can be minimized, however this might generate many + // small compaction which may not preferable. + // + // The default is false. + DisableSeeksCompaction bool + // ErrorIfExist defines whether an error should returned if the DB already // exist. // @@ -309,6 +317,8 @@ type Options struct { // IteratorSamplingRate defines approximate gap (in bytes) between read // sampling of an iterator. The samples will be used to determine when // compaction should be triggered. + // Use negative value to disable iterator sampling. + // The iterator sampling is disabled if DisableSeeksCompaction is true. // // The default is 1MiB. IteratorSamplingRate int @@ -526,6 +536,13 @@ func (o *Options) GetDisableLargeBatchTransaction() bool { return o.DisableLargeBatchTransaction } +func (o *Options) GetDisableSeeksCompaction() bool { + if o == nil { + return false + } + return o.DisableSeeksCompaction +} + func (o *Options) GetErrorIfExist() bool { if o == nil { return false @@ -548,8 +565,10 @@ func (o *Options) GetFilter() filter.Filter { } func (o *Options) GetIteratorSamplingRate() int { - if o == nil || o.IteratorSamplingRate <= 0 { + if o == nil || o.IteratorSamplingRate == 0 { return DefaultIteratorSamplingRate + } else if o.IteratorSamplingRate < 0 { + return 0 } return o.IteratorSamplingRate } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go b/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go index f6030022d..4c1d336be 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session_compaction.go @@ -14,6 +14,13 @@ import ( "github.com/syndtr/goleveldb/leveldb/opt" ) +const ( + undefinedCompaction = iota + level0Compaction + nonLevel0Compaction + seekCompaction +) + func (s *session) pickMemdbLevel(umin, umax []byte, maxLevel int) int { v := s.version() defer v.release() @@ -50,6 +57,7 @@ func (s *session) pickCompaction() *compaction { var sourceLevel int var t0 tFiles + var typ int if v.cScore >= 1 { sourceLevel = v.cLevel cptr := s.getCompPtr(sourceLevel) @@ -63,18 +71,24 @@ func (s *session) pickCompaction() *compaction { if len(t0) == 0 { t0 = append(t0, tables[0]) } + if sourceLevel == 0 { + typ = level0Compaction + } else { + typ = nonLevel0Compaction + } } else { if p := atomic.LoadPointer(&v.cSeek); p != nil { ts := (*tSet)(p) sourceLevel = ts.level t0 = append(t0, ts.table) + typ = seekCompaction } else { v.release() return nil } } - return newCompaction(s, v, sourceLevel, t0) + return newCompaction(s, v, sourceLevel, t0, typ) } // Create compaction from given level and range; need external synchronization. @@ -109,13 +123,18 @@ func (s *session) getCompactionRange(sourceLevel int, umin, umax []byte, noLimit } } - return newCompaction(s, v, sourceLevel, t0) + typ := level0Compaction + if sourceLevel != 0 { + typ = nonLevel0Compaction + } + return newCompaction(s, v, sourceLevel, t0, typ) } -func newCompaction(s *session, v *version, sourceLevel int, t0 tFiles) *compaction { +func newCompaction(s *session, v *version, sourceLevel int, t0 tFiles, typ int) *compaction { c := &compaction{ s: s, v: v, + typ: typ, sourceLevel: sourceLevel, levels: [2]tFiles{t0, nil}, maxGPOverlaps: int64(s.o.GetCompactionGPOverlaps(sourceLevel)), @@ -131,6 +150,7 @@ type compaction struct { s *session v *version + typ int sourceLevel int levels [2]tFiles maxGPOverlaps int64 diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go b/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go index 67a6df979..fc56b631f 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/session_util.go @@ -308,7 +308,7 @@ func (s *session) setNextFileNum(num int64) { func (s *session) markFileNum(num int64) { nextFileNum := num + 1 for { - old, x := s.stNextFileNum, nextFileNum + old, x := atomic.LoadInt64(&s.stNextFileNum), nextFileNum if old > x { x = old } @@ -326,7 +326,7 @@ func (s *session) allocFileNum() int64 { // Reuse given file number. func (s *session) reuseFileNum(num int64) { for { - old, x := s.stNextFileNum, num + old, x := atomic.LoadInt64(&s.stNextFileNum), num if old != x+1 { x = old } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table.go b/vendor/github.com/syndtr/goleveldb/leveldb/table.go index 5ad1f8054..b7759b2f5 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/table.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table.go @@ -493,6 +493,8 @@ func (t *tOps) remove(fd storage.FileDesc) { if t.evictRemoved && t.bcache != nil { t.bcache.EvictNS(uint64(fd.Num)) } + // Try to reuse file num, useful for discarded transaction. + t.s.reuseFileNum(fd.Num) }) } diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go b/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go index 16cfbaa00..496feb6fb 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/table/reader.go @@ -787,6 +787,10 @@ func (r *Reader) getDataIterErr(dataBH blockHandle, slice *util.Range, verifyChe // table. And a nil Range.Limit is treated as a key after all keys in // the table. // +// WARNING: Any slice returned by interator (e.g. slice returned by calling +// Iterator.Key() or Iterator.Key() methods), its content should not be modified +// unless noted otherwise. +// // The returned iterator is not safe for concurrent use and should be released // after use. // diff --git a/vendor/github.com/syndtr/goleveldb/leveldb/version.go b/vendor/github.com/syndtr/goleveldb/leveldb/version.go index 2664560e1..9535e3591 100644 --- a/vendor/github.com/syndtr/goleveldb/leveldb/version.go +++ b/vendor/github.com/syndtr/goleveldb/leveldb/version.go @@ -144,6 +144,7 @@ func (v *version) get(aux tFiles, ikey internalKey, ro *opt.ReadOptions, noValue } ukey := ikey.ukey() + sampleSeeks := !v.s.o.GetDisableSeeksCompaction() var ( tset *tSet @@ -161,7 +162,7 @@ func (v *version) get(aux tFiles, ikey internalKey, ro *opt.ReadOptions, noValue // Since entries never hop across level, finding key/value // in smaller level make later levels irrelevant. v.walkOverlapping(aux, ikey, func(level int, t *tFile) bool { - if level >= 0 && !tseek { + if sampleSeeks && level >= 0 && !tseek { if tset == nil { tset = &tSet{level, t} } else { diff --git a/vendor/golang.org/x/time/LICENSE b/vendor/golang.org/x/time/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/time/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/time/PATENTS b/vendor/golang.org/x/time/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/time/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/time/rate/rate.go b/vendor/golang.org/x/time/rate/rate.go new file mode 100644 index 000000000..ae93e2471 --- /dev/null +++ b/vendor/golang.org/x/time/rate/rate.go @@ -0,0 +1,374 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "context" + "fmt" + "math" + "sync" + "time" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit allows no events. +type Limit float64 + +// Inf is the infinite rate limit; it allows all events (even if burst is zero). +const Inf = Limit(math.MaxFloat64) + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + return Inf + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a "token bucket" of size b, initially full and refilled +// at rate r tokens per second. +// Informally, in any large enough time interval, the Limiter limits the +// rate to r tokens per second, with a maximum burst size of b events. +// As a special case, if r == Inf (the infinite rate), b is ignored. +// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets. +// +// The zero value is a valid Limiter, but it will reject all events. +// Use NewLimiter to create non-zero Limiters. +// +// Limiter has three main methods, Allow, Reserve, and Wait. +// Most callers should use Wait. +// +// Each of the three methods consumes a single token. +// They differ in their behavior when no token is available. +// If no token is available, Allow returns false. +// If no token is available, Reserve returns a reservation for a future token +// and the amount of time the caller must wait before using it. +// If no token is available, Wait blocks until one can be obtained +// or its associated context.Context is canceled. +// +// The methods AllowN, ReserveN, and WaitN consume n tokens. +type Limiter struct { + limit Limit + burst int + + mu sync.Mutex + tokens float64 + // last is the last time the limiter's tokens field was updated + last time.Time + // lastEvent is the latest time of a rate-limited event (past or future) + lastEvent time.Time +} + +// Limit returns the maximum overall event rate. +func (lim *Limiter) Limit() Limit { + lim.mu.Lock() + defer lim.mu.Unlock() + return lim.limit +} + +// Burst returns the maximum burst size. Burst is the maximum number of tokens +// that can be consumed in a single call to Allow, Reserve, or Wait, so higher +// Burst values allow more events to happen at once. +// A zero Burst allows no events, unless limit == Inf. +func (lim *Limiter) Burst() int { + return lim.burst +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + return &Limiter{ + limit: r, + burst: b, + } +} + +// Allow is shorthand for AllowN(time.Now(), 1). +func (lim *Limiter) Allow() bool { + return lim.AllowN(time.Now(), 1) +} + +// AllowN reports whether n events may happen at time now. +// Use this method if you intend to drop / skip events that exceed the rate limit. +// Otherwise use Reserve or Wait. +func (lim *Limiter) AllowN(now time.Time, n int) bool { + return lim.reserveN(now, n, 0).ok +} + +// A Reservation holds information about events that are permitted by a Limiter to happen after a delay. +// A Reservation may be canceled, which may enable the Limiter to permit additional events. +type Reservation struct { + ok bool + lim *Limiter + tokens int + timeToAct time.Time + // This is the Limit at reservation time, it can change later. + limit Limit +} + +// OK returns whether the limiter can provide the requested number of tokens +// within the maximum wait time. If OK is false, Delay returns InfDuration, and +// Cancel does nothing. +func (r *Reservation) OK() bool { + return r.ok +} + +// Delay is shorthand for DelayFrom(time.Now()). +func (r *Reservation) Delay() time.Duration { + return r.DelayFrom(time.Now()) +} + +// InfDuration is the duration returned by Delay when a Reservation is not OK. +const InfDuration = time.Duration(1<<63 - 1) + +// DelayFrom returns the duration for which the reservation holder must wait +// before taking the reserved action. Zero duration means act immediately. +// InfDuration means the limiter cannot grant the tokens requested in this +// Reservation within the maximum wait time. +func (r *Reservation) DelayFrom(now time.Time) time.Duration { + if !r.ok { + return InfDuration + } + delay := r.timeToAct.Sub(now) + if delay < 0 { + return 0 + } + return delay +} + +// Cancel is shorthand for CancelAt(time.Now()). +func (r *Reservation) Cancel() { + r.CancelAt(time.Now()) + return +} + +// CancelAt indicates that the reservation holder will not perform the reserved action +// and reverses the effects of this Reservation on the rate limit as much as possible, +// considering that other reservations may have already been made. +func (r *Reservation) CancelAt(now time.Time) { + if !r.ok { + return + } + + r.lim.mu.Lock() + defer r.lim.mu.Unlock() + + if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) { + return + } + + // calculate tokens to restore + // The duration between lim.lastEvent and r.timeToAct tells us how many tokens were reserved + // after r was obtained. These tokens should not be restored. + restoreTokens := float64(r.tokens) - r.limit.tokensFromDuration(r.lim.lastEvent.Sub(r.timeToAct)) + if restoreTokens <= 0 { + return + } + // advance time to now + now, _, tokens := r.lim.advance(now) + // calculate new number of tokens + tokens += restoreTokens + if burst := float64(r.lim.burst); tokens > burst { + tokens = burst + } + // update state + r.lim.last = now + r.lim.tokens = tokens + if r.timeToAct == r.lim.lastEvent { + prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens))) + if !prevEvent.Before(now) { + r.lim.lastEvent = prevEvent + } + } + + return +} + +// Reserve is shorthand for ReserveN(time.Now(), 1). +func (lim *Limiter) Reserve() *Reservation { + return lim.ReserveN(time.Now(), 1) +} + +// ReserveN returns a Reservation that indicates how long the caller must wait before n events happen. +// The Limiter takes this Reservation into account when allowing future events. +// ReserveN returns false if n exceeds the Limiter's burst size. +// Usage example: +// r := lim.ReserveN(time.Now(), 1) +// if !r.OK() { +// // Not allowed to act! Did you remember to set lim.burst to be > 0 ? +// return +// } +// time.Sleep(r.Delay()) +// Act() +// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. +// If you need to respect a deadline or cancel the delay, use Wait instead. +// To drop or skip events exceeding rate limit, use Allow instead. +func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation { + r := lim.reserveN(now, n, InfDuration) + return &r +} + +// Wait is shorthand for WaitN(ctx, 1). +func (lim *Limiter) Wait(ctx context.Context) (err error) { + return lim.WaitN(ctx, 1) +} + +// WaitN blocks until lim permits n events to happen. +// It returns an error if n exceeds the Limiter's burst size, the Context is +// canceled, or the expected wait time exceeds the Context's Deadline. +// The burst limit is ignored if the rate limit is Inf. +func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) { + if n > lim.burst && lim.limit != Inf { + return fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.burst) + } + // Check if ctx is already cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + // Determine wait limit + now := time.Now() + waitLimit := InfDuration + if deadline, ok := ctx.Deadline(); ok { + waitLimit = deadline.Sub(now) + } + // Reserve + r := lim.reserveN(now, n, waitLimit) + if !r.ok { + return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) + } + // Wait if necessary + delay := r.DelayFrom(now) + if delay == 0 { + return nil + } + t := time.NewTimer(delay) + defer t.Stop() + select { + case <-t.C: + // We can proceed. + return nil + case <-ctx.Done(): + // Context was canceled before we could proceed. Cancel the + // reservation, which may permit other events to proceed sooner. + r.Cancel() + return ctx.Err() + } +} + +// SetLimit is shorthand for SetLimitAt(time.Now(), newLimit). +func (lim *Limiter) SetLimit(newLimit Limit) { + lim.SetLimitAt(time.Now(), newLimit) +} + +// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated +// or underutilized by those which reserved (using Reserve or Wait) but did not yet act +// before SetLimitAt was called. +func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) { + lim.mu.Lock() + defer lim.mu.Unlock() + + now, _, tokens := lim.advance(now) + + lim.last = now + lim.tokens = tokens + lim.limit = newLimit +} + +// reserveN is a helper method for AllowN, ReserveN, and WaitN. +// maxFutureReserve specifies the maximum reservation wait duration allowed. +// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. +func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation { + lim.mu.Lock() + + if lim.limit == Inf { + lim.mu.Unlock() + return Reservation{ + ok: true, + lim: lim, + tokens: n, + timeToAct: now, + } + } + + now, last, tokens := lim.advance(now) + + // Calculate the remaining number of tokens resulting from the request. + tokens -= float64(n) + + // Calculate the wait duration + var waitDuration time.Duration + if tokens < 0 { + waitDuration = lim.limit.durationFromTokens(-tokens) + } + + // Decide result + ok := n <= lim.burst && waitDuration <= maxFutureReserve + + // Prepare reservation + r := Reservation{ + ok: ok, + lim: lim, + limit: lim.limit, + } + if ok { + r.tokens = n + r.timeToAct = now.Add(waitDuration) + } + + // Update state + if ok { + lim.last = now + lim.tokens = tokens + lim.lastEvent = r.timeToAct + } else { + lim.last = last + } + + lim.mu.Unlock() + return r +} + +// advance calculates and returns an updated state for lim resulting from the passage of time. +// lim is not changed. +func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) { + last := lim.last + if now.Before(last) { + last = now + } + + // Avoid making delta overflow below when last is very old. + maxElapsed := lim.limit.durationFromTokens(float64(lim.burst) - lim.tokens) + elapsed := now.Sub(last) + if elapsed > maxElapsed { + elapsed = maxElapsed + } + + // Calculate the new number of tokens, due to time that passed. + delta := lim.limit.tokensFromDuration(elapsed) + tokens := lim.tokens + delta + if burst := float64(lim.burst); tokens > burst { + tokens = burst + } + + return now, last, tokens +} + +// durationFromTokens is a unit conversion function from the number of tokens to the duration +// of time it takes to accumulate them at a rate of limit tokens per second. +func (limit Limit) durationFromTokens(tokens float64) time.Duration { + seconds := tokens / float64(limit) + return time.Nanosecond * time.Duration(1e9*seconds) +} + +// tokensFromDuration is a unit conversion function from a time duration to the number of tokens +// which could be accumulated during that duration at a rate of limit tokens per second. +func (limit Limit) tokensFromDuration(d time.Duration) float64 { + return d.Seconds() * float64(limit) +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 33aab0e11..572d37401 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -56,6 +56,12 @@ "revision": "165db2f241fd235aec29ba6d9b1ccd5f1c14637c", "revisionTime": "2015-01-22T07:26:53Z" }, + { + "checksumSHA1": "WILMZlCPSNbyMzYRNo/RkDcUH2M=", + "path": "github.com/cloudflare/cloudflare-go", + "revision": "a80f83b9add9d67ca4098ccbf42cd865ebb36ffb", + "revisionTime": "2019-09-16T15:18:08Z" + }, { "checksumSHA1": "dvabztWVQX8f6oMLRyv4dLH+TGY=", "path": "github.com/davecgh/go-spew/spew", @@ -249,10 +255,10 @@ "revisionTime": "2017-04-30T22:20:11Z" }, { - "checksumSHA1": "X7ZY5gt+qBd/lafKNbPbouL819w=", + "checksumSHA1": "AkW2LisC8HZAFIthaamcxOVl3RU=", "path": "github.com/karalabe/usb", - "revision": "6a7de9d893feb2324aaef49331e923ce279c7973", - "revisionTime": "2019-07-03T09:51:11Z", + "revision": "51dc0efba3568b598359930901dc6647e9b2c6a1", + "revisionTime": "2019-09-19T08:00:40Z", "tree": true }, { @@ -449,76 +455,76 @@ "revisionTime": "2017-07-05T02:17:15Z" }, { - "checksumSHA1": "4NTmfUj7H5J59M2wCnp3/8FWt1I=", + "checksumSHA1": "Bl4KYAyUkgJSjcdEyv3VhHQ8PVs=", "path": "github.com/syndtr/goleveldb/leveldb", - "revision": "c3a204f8e96543bb0cc090385c001078f184fc46", - "revisionTime": "2019-03-18T03:00:20Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "mPNraL2edpk/2FYq26rSXfMHbJg=", "path": "github.com/syndtr/goleveldb/leveldb/cache", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "UA+PKDKWlDnE2OZblh23W6wZwbY=", "path": "github.com/syndtr/goleveldb/leveldb/comparer", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "1DRAxdlWzS4U0xKN/yQ/fdNN7f0=", "path": "github.com/syndtr/goleveldb/leveldb/errors", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { - "checksumSHA1": "eqKeD6DS7eNCtxVYZEHHRKkyZrw=", + "checksumSHA1": "iBorxU3FBbau81WSyVa8KwcutzA=", "path": "github.com/syndtr/goleveldb/leveldb/filter", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "hPyFsMiqZ1OB7MX+6wIAA6nsdtc=", "path": "github.com/syndtr/goleveldb/leveldb/iterator", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "gJY7bRpELtO0PJpZXgPQ2BYFJ88=", "path": "github.com/syndtr/goleveldb/leveldb/journal", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { - "checksumSHA1": "MtYY1b2234y/MlS+djL8tXVAcQs=", + "checksumSHA1": "2ncG38FDk2thSlrHd7JFmiuvnxA=", "path": "github.com/syndtr/goleveldb/leveldb/memdb", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { - "checksumSHA1": "o2TorI3z+vc+EBMJ8XeFoUmXBtU=", + "checksumSHA1": "LC+WnyNq4O2J9SHuVfWL19wZH48=", "path": "github.com/syndtr/goleveldb/leveldb/opt", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "ZnyuciM+R19NG8L5YS3TIJdo1e8=", "path": "github.com/syndtr/goleveldb/leveldb/storage", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { - "checksumSHA1": "gWFPMz8OQeul0t54RM66yMTX49g=", + "checksumSHA1": "DS0i9KReIeZn3T1Bpu31xPMtzio=", "path": "github.com/syndtr/goleveldb/leveldb/table", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "V/Dh7NV0/fy/5jX1KaAjmGcNbzI=", "path": "github.com/syndtr/goleveldb/leveldb/util", - "revision": "b001fa50d6b27f3f0bb175a87d0cb55426d0a0ae", - "revisionTime": "2018-11-28T10:09:59Z" + "revision": "758128399b1df3a87e92df6c26c1d2063da8fabe", + "revisionTime": "2019-09-23T12:57:48Z" }, { "checksumSHA1": "SsMMqb3xn7hg1ZX5ugwZz5rzpx0=", @@ -844,6 +850,12 @@ "revision": "31e7599a6c37728c25ca34167be099d072ad335d", "revisionTime": "2019-04-05T05:38:27Z" }, + { + "checksumSHA1": "7Ev/X4Xe8P3961myez/hBKO05ig=", + "path": "golang.org/x/time/rate", + "revision": "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef", + "revisionTime": "2019-02-15T22:48:40Z" + }, { "checksumSHA1": "CEFTYXtWmgSh+3Ik1NmDaJcz4E0=", "path": "gopkg.in/check.v1",