Update dependencies

- uses newer version of go-ethereum required for go1.11
This commit is contained in:
Rob Mulholand 2018-09-05 10:36:14 -05:00
parent 939ead0c82
commit 560305f601
2356 changed files with 331681 additions and 128329 deletions

168
Gopkg.lock generated
View File

@ -5,13 +5,19 @@
branch = "master"
name = "github.com/aristanetworks/goarista"
packages = ["monotime"]
revision = "8d0e8f607a4080e7df3532e645440ed0900c64a4"
revision = "ff33da284e760fcdb03c33d37a719e5ed30ba844"
[[projects]]
branch = "master"
name = "github.com/btcsuite/btcd"
packages = ["btcec"]
revision = "2e60448ffcc6bf78332d1fe590260095f554dd78"
revision = "cff30e1d23fc9e800b2b5b4b41ef1817dda07e9f"
[[projects]]
name = "github.com/deckarep/golang-set"
packages = ["."]
revision = "1d4478f51bed434f1dadf96dcd9b43aabac66795"
version = "v1.7"
[[projects]]
name = "github.com/ethereum/go-ethereum"
@ -25,18 +31,10 @@
"common/hexutil",
"common/math",
"common/mclock",
"consensus",
"consensus/misc",
"core",
"core/state",
"core/rawdb",
"core/types",
"core/vm",
"crypto",
"crypto/bn256",
"crypto/bn256/cloudflare",
"crypto/bn256/google",
"crypto/ecies",
"crypto/randentropy",
"crypto/secp256k1",
"crypto/sha3",
"ethclient",
@ -54,8 +52,8 @@
"rpc",
"trie"
]
revision = "b8b9f7f4476a30a0aaf6077daade6ae77f969960"
version = "v1.8.2"
revision = "89451f7c382ad2185987ee369f16416f89c28a7d"
version = "v1.8.15"
[[projects]]
name = "github.com/fsnotify/fsnotify"
@ -66,37 +64,34 @@
[[projects]]
name = "github.com/go-stack/stack"
packages = ["."]
revision = "259ab82a6cad3992b4e21ff5cac294ccb06474bc"
version = "v1.7.0"
revision = "2fee6af1a9795aafbe0253a0cfbdf668e1fb8a9a"
version = "v1.8.0"
[[projects]]
branch = "master"
name = "github.com/golang/protobuf"
packages = ["proto"]
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5"
version = "v1.2.0"
[[projects]]
branch = "master"
name = "github.com/golang/snappy"
packages = ["."]
revision = "553a641470496b2327abcac10b36396bd98e45c9"
revision = "2e65f85255dbc3072edf28d6b5b8efc472979f5a"
[[projects]]
branch = "master"
name = "github.com/hashicorp/golang-lru"
packages = [
".",
"simplelru"
]
revision = "0fb14efe8c47ae851c0034ed7a448854d3d34cf3"
name = "github.com/google/uuid"
packages = ["."]
revision = "d460ce9f8df2e77fb1ba55ca87fafed96c607494"
version = "v1.0.0"
[[projects]]
branch = "master"
name = "github.com/hashicorp/hcl"
packages = [
".",
"hcl/ast",
"hcl/parser",
"hcl/printer",
"hcl/scanner",
"hcl/strconv",
"hcl/token",
@ -104,7 +99,20 @@
"json/scanner",
"json/token"
]
revision = "23c074d0eceb2b8a5bfdbb271ab780cde70f05a8"
revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241"
version = "v1.0.0"
[[projects]]
name = "github.com/hpcloud/tail"
packages = [
".",
"ratelimiter",
"util",
"watch",
"winfile"
]
revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5"
version = "v1.0.0"
[[projects]]
branch = "master"
@ -118,7 +126,7 @@
"soap",
"ssdp"
]
revision = "dceda08e705b2acee36aab47d765ed801f64cfc7"
revision = "1395d1447324cbea88d249fbfcfd70ea878fdfca"
[[projects]]
name = "github.com/inconshreveable/mousetrap"
@ -139,34 +147,34 @@
".",
"reflectx"
]
revision = "99f3ad6d85ae53d0fecf788ab62d0e9734b3c117"
revision = "0dae4fefe7c0e190f7b5a78dac28a1c82cc8d849"
[[projects]]
branch = "master"
name = "github.com/lib/pq"
packages = [
".",
"oid"
]
revision = "83612a56d3dd153a94a629cd64925371c9adad78"
revision = "4ded0e9383f75c197b3a2aaa6d590ac52df6fd79"
version = "v1.0.0"
[[projects]]
name = "github.com/magiconair/properties"
packages = ["."]
revision = "d419a98cdbed11a922bf76f257b7c4be79b50e73"
version = "v1.7.4"
revision = "c2353362d570a7bfa228149c62842019201cfb71"
version = "v1.8.0"
[[projects]]
branch = "master"
name = "github.com/mitchellh/go-homedir"
packages = ["."]
revision = "b8bc1bf767474819792c23f32d8286a45736f1c6"
revision = "ae18d6b8b3205b561c79e8e5f69bff09736185f4"
version = "v1.0.0"
[[projects]]
branch = "master"
name = "github.com/mitchellh/mapstructure"
packages = ["."]
revision = "b4575eea38cca1123ec2dc90c26529b5c5acfcff"
revision = "fa473d140ef3c6adf42d6b391fe76707f1f243c8"
version = "v1.0.0"
[[projects]]
name = "github.com/onsi/ginkgo"
@ -190,8 +198,8 @@
"reporters/stenographer/support/go-isatty",
"types"
]
revision = "9eda700730cba42af70d53180f9dcce9266bc2bc"
version = "v1.4.0"
revision = "3774a09d95489ccaa16032e0770d08ea77ba6184"
version = "v1.6.0"
[[projects]]
name = "github.com/onsi/gomega"
@ -210,20 +218,20 @@
"matchers/support/goraph/util",
"types"
]
revision = "c893efa28eb45626cdaa76c9f653b62488858837"
version = "v1.2.0"
revision = "7615b9433f86a8bdf29709bf288bc4fd0636a369"
version = "v1.4.2"
[[projects]]
name = "github.com/pborman/uuid"
packages = ["."]
revision = "e790cca94e6cc75c7064b1332e63811d4aae1a53"
version = "v1.1"
revision = "adf5a7427709b9deb95d29d3fa8a2bf9cfd388f1"
version = "v1.2"
[[projects]]
name = "github.com/pelletier/go-toml"
packages = ["."]
revision = "acdc4509485b587f5e675510c4f2c63e90ff68a8"
version = "v1.1.0"
revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194"
version = "v1.2.0"
[[projects]]
name = "github.com/philhofer/fwd"
@ -234,14 +242,14 @@
[[projects]]
name = "github.com/rjeczalik/notify"
packages = ["."]
revision = "52ae50d8490436622a8941bd70c3dbe0acdd4bbf"
version = "v0.9.0"
revision = "0f065fa99b48b842c3fd3e2c8b194c6f2b69f6b8"
version = "v0.9.1"
[[projects]]
name = "github.com/rs/cors"
packages = ["."]
revision = "7af7a1e09ba336d2ea14b1ce73bf693c6837dbf6"
version = "v1.2"
revision = "3fb1b69b103a84de38a19c3c6ec073dd6caa4d3f"
version = "v1.5.0"
[[projects]]
name = "github.com/spf13/afero"
@ -249,38 +257,38 @@
".",
"mem"
]
revision = "bb8f1927f2a9d3ab41c9340aa034f6b803f4359c"
version = "v1.0.2"
revision = "d40851caa0d747393da1ffb28f7f9d8b4eeffebd"
version = "v1.1.2"
[[projects]]
name = "github.com/spf13/cast"
packages = ["."]
revision = "acbeb36b902d72a7a4c18e8f3241075e7ab763e4"
version = "v1.1.0"
revision = "8965335b8c7107321228e3e3702cab9832751bac"
version = "v1.2.0"
[[projects]]
name = "github.com/spf13/cobra"
packages = ["."]
revision = "7b2c5ac9fc04fc5efafb60700713d4fa609b777b"
version = "v0.0.1"
revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385"
version = "v0.0.3"
[[projects]]
branch = "master"
name = "github.com/spf13/jwalterweatherman"
packages = ["."]
revision = "7c0cea34c8ece3fbeb2b27ab9b59511d360fb394"
revision = "4a4406e478ca629068e7768fc33f3f044173c0a6"
version = "v1.0.0"
[[projects]]
name = "github.com/spf13/pflag"
packages = ["."]
revision = "e57e3eeb33f795204c1ca35f56c44f83227c6e66"
version = "v1.0.0"
revision = "9a97c102cda95a86cec2345a6f09f55a939babf5"
version = "v1.0.2"
[[projects]]
name = "github.com/spf13/viper"
packages = ["."]
revision = "25b30aa063fc18e48662b86996252eabdcf2f0c7"
version = "v1.0.0"
revision = "8fb642006536c8d3760c99d4fa2389f5e2205631"
version = "v1.2.0"
[[projects]]
branch = "master"
@ -299,7 +307,7 @@
"leveldb/table",
"leveldb/util"
]
revision = "adf24ef3f94bd13ec4163060b21a5678f22b429b"
revision = "ae2bd5eed72d46b28834ec3f60db3a3ebedd8dbd"
[[projects]]
name = "github.com/tinylib/msgp"
@ -312,10 +320,9 @@
name = "golang.org/x/crypto"
packages = [
"pbkdf2",
"ripemd160",
"scrypt"
]
revision = "613d6eafa307c6881a737a3c35c0e312e8d3a8c5"
revision = "0e37d006457bf46f9e6692014ba72ef82c33022c"
[[projects]]
branch = "master"
@ -327,7 +334,7 @@
"html/charset",
"websocket"
]
revision = "faacc1b5e36e3ff02cbec9661c69ac63dd5a83ad"
revision = "26e67e76b6c3f6ce91f7c52def5af501b4e0f3a2"
[[projects]]
branch = "master"
@ -342,10 +349,9 @@
"unix",
"windows"
]
revision = "a0f4589a76f1f83070cb9e5613809e1d07b97c13"
revision = "d0be0721c37eeb5299f245a996a483160fc36940"
[[projects]]
branch = "master"
name = "golang.org/x/text"
packages = [
"encoding",
@ -369,7 +375,8 @@
"unicode/cldr",
"unicode/norm"
]
revision = "be25de41fadfae372d6470bda81ca6beb55ef551"
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
[[projects]]
branch = "master"
@ -379,7 +386,7 @@
"imports",
"internal/fastwalk"
]
revision = "8cc4e8a6f4841aa92a8683fca47bc5d64b58875b"
revision = "677d2ff680c188ddb7dcd2bfa6bc7d3f2f2f75b2"
[[projects]]
name = "gopkg.in/DataDog/dd-trace-go.v1"
@ -389,14 +396,15 @@
"ddtrace/internal",
"ddtrace/tracer"
]
revision = "8efc9a798f2db99a9e00c7e57f45fc13611214e0"
version = "v1.2.3"
revision = "bcd20367df871708a36549e7fe36183ee5b4fc55"
version = "v1.3.0"
[[projects]]
name = "gopkg.in/fatih/set.v0"
name = "gopkg.in/fsnotify.v1"
packages = ["."]
revision = "57907de300222151a123d29255ed17f5ed43fad3"
version = "v0.1.0"
revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9"
source = "gopkg.in/fsnotify/fsnotify.v1"
version = "v1.4.7"
[[projects]]
branch = "v2"
@ -411,14 +419,20 @@
revision = "c1b8fa8bdccecb0b8db834ee0b92fdbcfa606dd6"
[[projects]]
branch = "v2"
branch = "v1"
name = "gopkg.in/tomb.v1"
packages = ["."]
revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8"
[[projects]]
name = "gopkg.in/yaml.v2"
packages = ["."]
revision = "287cf08546ab5e7e37d55a84f7ed3fd1db036de5"
revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183"
version = "v2.2.1"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "7a913c984013e026536456baa75bd95e261bbb0d294b7de77785819ac182b465"
inputs-digest = "d1f001d06a295f55fcb3dd1f38744ca0d088b9c4bf88f251ef3e043c8852fe76"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -21,6 +21,10 @@
# version = "2.4.0"
[[override]]
name = "gopkg.in/fsnotify.v1"
source = "gopkg.in/fsnotify/fsnotify.v1"
[[constraint]]
name = "github.com/onsi/ginkgo"
version = "1.4.0"
@ -39,4 +43,4 @@
[[constraint]]
name = "github.com/ethereum/go-ethereum"
version = "1.8"
version = "1.8.15"

View File

@ -36,5 +36,5 @@ func (l LevelDatabase) GetBlockReceipts(blockHash []byte, blockNumber int64) typ
func (l LevelDatabase) GetHeadBlockNumber() int64 {
h := l.reader.GetHeadBlockHash()
n := l.reader.GetBlockNumber(h)
return int64(n)
return int64(*n)
}

View File

@ -2,42 +2,42 @@ package level
import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
)
type Reader interface {
GetBlock(hash common.Hash, number uint64) *types.Block
GetBlockNumber(hash common.Hash) uint64
GetBlockNumber(hash common.Hash) *uint64
GetBlockReceipts(hash common.Hash, number uint64) types.Receipts
GetCanonicalHash(number uint64) common.Hash
GetHeadBlockHash() common.Hash
}
type LevelDatabaseReader struct {
reader core.DatabaseReader
reader rawdb.DatabaseReader
}
func NewLevelDatabaseReader(reader core.DatabaseReader) *LevelDatabaseReader {
func NewLevelDatabaseReader(reader rawdb.DatabaseReader) *LevelDatabaseReader {
return &LevelDatabaseReader{reader: reader}
}
func (ldbr *LevelDatabaseReader) GetBlock(hash common.Hash, number uint64) *types.Block {
return core.GetBlock(ldbr.reader, hash, number)
return rawdb.ReadBlock(ldbr.reader, hash, number)
}
func (ldbr *LevelDatabaseReader) GetBlockNumber(hash common.Hash) uint64 {
return core.GetBlockNumber(ldbr.reader, hash)
func (ldbr *LevelDatabaseReader) GetBlockNumber(hash common.Hash) *uint64 {
return rawdb.ReadHeaderNumber(ldbr.reader, hash)
}
func (ldbr *LevelDatabaseReader) GetBlockReceipts(hash common.Hash, number uint64) types.Receipts {
return core.GetBlockReceipts(ldbr.reader, hash, number)
return rawdb.ReadReceipts(ldbr.reader, hash, number)
}
func (ldbr *LevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash {
return core.GetCanonicalHash(ldbr.reader, number)
return rawdb.ReadCanonicalHash(ldbr.reader, number)
}
func (ldbr *LevelDatabaseReader) GetHeadBlockHash() common.Hash {
return core.GetHeadBlockHash(ldbr.reader)
return rawdb.ReadHeadBlockHash(ldbr.reader)
}

View File

@ -82,10 +82,10 @@ func (mldr *MockLevelDatabaseReader) GetBlockReceipts(hash common.Hash, number u
return mldr.returnReceipts
}
func (mldr *MockLevelDatabaseReader) GetBlockNumber(hash common.Hash) uint64 {
func (mldr *MockLevelDatabaseReader) GetBlockNumber(hash common.Hash) *uint64 {
mldr.getBlockNumberCalled = true
mldr.getBlockNumberPassedHash = hash
return mldr.returnBlockNumber
return &mldr.returnBlockNumber
}
func (mldr *MockLevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash {

View File

@ -218,7 +218,7 @@ var _ = Describe("Conversion of GethBlock to core.Block", func() {
CumulativeGasUsed: uint64(7996119),
GasUsed: uint64(21000),
Logs: []*types.Log{},
Status: uint(1),
Status: uint64(1),
TxHash: gethTransaction.Hash(),
}

View File

@ -54,7 +54,7 @@ var _ = Describe("Conversion of GethReceipt to core.Receipt", func() {
CumulativeGasUsed: uint64(7996119),
GasUsed: uint64(21000),
Logs: []*types.Log{},
Status: uint(1),
Status: uint64(1),
TxHash: common.HexToHash("0xe340558980f89d5f86045ac11e5cc34e4bcec20f9f1e2a427aa39d87114e8223"),
}

View File

@ -37,7 +37,7 @@ var (
)
var DentLog = types.Log{
Address: common.StringToAddress(shared.FlipperContractAddress),
Address: common.HexToAddress(shared.FlipperContractAddress),
Topics: []common.Hash{
common.HexToHash("0x5ff3a38200000000000000000000000000000000000000000000000000000000"),
common.HexToHash("0x00000000000000000000000064d922894153be9eef7b7218dc565d1d0ce2a092"),

View File

@ -38,7 +38,7 @@ var (
)
var TendLogNote = types.Log{
Address: common.StringToAddress(shared.FlipperContractAddress),
Address: common.HexToAddress(shared.FlipperContractAddress),
Topics: []common.Hash{
common.HexToHash("0x4b43ed1200000000000000000000000000000000000000000000000000000000"), //abbreviated tend function signature
common.HexToHash("0x0000000000000000000000007d7bee5fcfd8028cf7b00876c5b1421c800561a6"), //msg caller address

View File

@ -1,7 +1,8 @@
language: go
go:
- 1.9
- tip
- 1.10.x
- 1.x
- master
before_install:
- go get -v github.com/golang/lint/golint
- go get -v -t -d ./...

View File

@ -3,7 +3,7 @@
# that can be found in the COPYING file.
# TODO: move this to cmd/ockafka (https://github.com/docker/hub-feedback/issues/292)
FROM golang:1.7.3
FROM golang:1.10.3
RUN mkdir -p /go/src/github.com/aristanetworks/goarista/cmd
WORKDIR /go/src/github.com/aristanetworks/goarista

View File

@ -25,19 +25,20 @@ class of service flags to use for incoming connections. Requires `go1.9`.
## key
Provides a common type used across various Arista projects, named `key.Key`,
which is used to work around the fact that Go can't let one
use a non-hashable type as a key to a `map`, and we sometimes need to use
a `map[string]interface{}` (or something containing one) as a key to maps.
As a result, we frequently use `map[key.Key]interface{}` instead of just
`map[interface{}]interface{}` when we need a generic key-value collection.
Provides common types used across various Arista projects. The type `key.Key`
is used to work around the fact that Go can't let one use a non-hashable type
as a key to a `map`, and we sometimes need to use a `map[string]interface{}`
(or something containing one) as a key to maps. As a result, we frequently use
`map[key.Key]interface{}` instead of just `map[interface{}]interface{}` when we
need a generic key-value collection. The type `key.Path` is the representation
of a path broken down into individual elements, where each element is a `key.Key`.
The type `key.Pointer` represents a pointer to a `key.Path`.
## path
Provides a common type used across various Arista projects, named `path.Path`,
which is the representation of a path broken down into individual elements.
Each element is a `key.Key`. The type `path.Map` may be used for mapping paths
to values. It allows for some fuzzy matching.
Provides functions that can be used to manipulate `key.Path` objects. The type
`path.Map` may be used for mapping paths to values. It allows for some fuzzy
matching for paths containing `path.Wildcard` keys.
## lanz
A client for [LANZ](https://eos.arista.com/latency-analyzer-lanz-architectures-and-configuration/)
@ -55,9 +56,10 @@ A library to help expose monitoring metrics on top of the
`netns.Do(namespace, cb)` provides a handy mechanism to execute the given
callback `cb` in the given [network namespace](https://lwn.net/Articles/580893/).
## pathmap
## influxlib
DEPRECATED; use`path.Map` instead.
This is a influxdb library that provides easy methods of connecting to, writing to,
and reading from the service.
## test

View File

@ -16,10 +16,20 @@ under [GOPATH](https://golang.org/doc/code.html#GOPATH).
# Usage
```
$ gnmi [OPTIONS] [OPERATION]
```
When running on the switch in a non-default VRF:
```
$ ip netns exec ns-<VRF> gnmi [OPTIONS] [OPERATION]
```
## Options
* `-addr ADDR:PORT`
Address of the gNMI endpoint (REQUIRED)
* `-addr [<VRF-NAME>/]ADDR:PORT`
Address of the gNMI endpoint (REQUIRED) with VRF name (OPTIONAL)
* `-username USERNAME`
Username to authenticate with
* `-password PASSWORD`
@ -92,7 +102,7 @@ $ gnmi [OPTIONS] delete '/network-instances/network-instance[name=default]/proto
```
`update` and `replace` both take a path and a value in JSON
format. See
format. The JSON data may be provided in a file. See
[here](https://github.com/openconfig/reference/blob/master/rpc/gnmi/gnmi-specification.md#344-modes-of-update-replace-versus-update)
for documentation on the differences between `update` and `replace`.
@ -108,15 +118,46 @@ Replace the BGP global configuration:
gnmi [OPTIONS] replace '/network-instances/network-instance[name=default]/protocols/protocol[name=BGP][identifier=BGP]/bgp/global' '{"config":{"as": 1234, "router-id": "1.2.3.4"}}'
```
Note: String values must be quoted. For example, setting the hostname to `"tor13"`:
Note: String values need to be quoted if they look like JSON. For example, setting the login banner to `tor[13]`:
```
gnmi [OPTIONS] update '/system/config/hostname' '"tor13"'
gnmi [OPTIONS] update '/system/config/login-banner '"tor[13]"'
```
#### JSON in a file
The value argument to `update` and `replace` may be a file. The
content of the file is used to make the request.
Example:
File `path/to/subintf100.json` contains the following:
```
{
"subinterface": [
{
"config": {
"enabled": true,
"index": 100
},
"index": 100
}
]
}
```
Add subinterface 100 to interfaces Ethernet4/1/1 and Ethernet4/2/1 in
one transaction:
```
gnmi [OPTIONS] update '/interfaces/interface[name=Ethernet4/1/1]/subinterfaces' path/to/subintf100.json \
update '/interfaces/interface[name=Ethernet4/2/1]/subinterfaces' path/to/subintf100.json
```
### CLI requests
`gnmi` offers the ability to send CLI text inside an `update` or
`replace` operation. This is achieved by doing an `update` or
`replace` and using `"cli"` as the path and a set of configure-mode
`replace` and specifying `"origin=cli"` along with an empty path and a set of configure-mode
CLI commands separated by `\n`.
Example:
@ -127,6 +168,18 @@ gnmi [OPTIONS] update 'cli' 'management ssh
idle-timeout 300'
```
### P4 Config
`gnmi` offers the ability to send p4 config files inside a `replace` operation.
This is achieved by doing a `replace` and specifying `"origin=p4_config"`
along with the path of the p4 config file to send.
Example:
Send the config.p4 file
```
gnmi [OPTIONS] replace 'origin=p4_config' 'config.p4'
```
## Paths
Paths in `gnmi` use a simplified xpath style. Path elements are

View File

@ -9,6 +9,8 @@ import (
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/aristanetworks/goarista/gnmi"
@ -18,22 +20,24 @@ import (
// TODO: Make this more clear
var help = `Usage of gnmi:
gnmi -addr ADDRESS:PORT [options...]
gnmi -addr [<VRF-NAME>/]ADDRESS:PORT [options...]
capabilities
get PATH+
subscribe PATH+
((update|replace PATH JSON)|(delete PATH))+
((update|replace (origin=ORIGIN) PATH JSON|FILE)|(delete (origin=ORIGIN) PATH))+
`
func exitWithError(s string) {
func usageAndExit(s string) {
flag.Usage()
fmt.Fprintln(os.Stderr, s)
if s != "" {
fmt.Fprintln(os.Stderr, s)
}
os.Exit(1)
}
func main() {
cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "", "Address of gNMI gRPC server")
flag.StringVar(&cfg.Addr, "addr", "", "Address of gNMI gRPC server with optional VRF name")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
@ -41,26 +45,54 @@ func main() {
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
subscribeOptions := &gnmi.SubscribeOptions{}
flag.StringVar(&subscribeOptions.Prefix, "prefix", "", "Subscribe prefix path")
flag.BoolVar(&subscribeOptions.UpdatesOnly, "updates_only", false,
"Subscribe to updates only (false | true)")
flag.StringVar(&subscribeOptions.Mode, "mode", "stream",
"Subscribe mode (stream | once | poll)")
flag.StringVar(&subscribeOptions.StreamMode, "stream_mode", "target_defined",
"Subscribe stream mode, only applies for stream subscriptions "+
"(target_defined | on_change | sample)")
sampleIntervalStr := flag.String("sample_interval", "0", "Subscribe sample interval, "+
"only applies for sample subscriptions (400ms, 2.5s, 1m, etc.)")
heartbeatIntervalStr := flag.String("heartbeat_interval", "0", "Subscribe heartbeat "+
"interval, only applies for on-change subscriptions (400ms, 2.5s, 1m, etc.)")
flag.Usage = func() {
fmt.Fprintln(os.Stderr, help)
flag.PrintDefaults()
}
flag.Parse()
if cfg.Addr == "" {
exitWithError("error: address not specified")
usageAndExit("error: address not specified")
}
var sampleInterval, heartbeatInterval time.Duration
var err error
if sampleInterval, err = time.ParseDuration(*sampleIntervalStr); err != nil {
usageAndExit(fmt.Sprintf("error: sample interval (%s) invalid", *sampleIntervalStr))
}
subscribeOptions.SampleInterval = uint64(sampleInterval)
if heartbeatInterval, err = time.ParseDuration(*heartbeatIntervalStr); err != nil {
usageAndExit(fmt.Sprintf("error: heartbeat interval (%s) invalid", *heartbeatIntervalStr))
}
subscribeOptions.HeartbeatInterval = uint64(heartbeatInterval)
args := flag.Args()
ctx := gnmi.NewContext(context.Background(), cfg)
client := gnmi.Dial(cfg)
client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
var setOps []*gnmi.Operation
for i := 0; i < len(args); i++ {
switch args[i] {
case "capabilities":
if len(setOps) != 0 {
exitWithError("error: 'capabilities' not allowed after 'merge|replace|delete'")
usageAndExit("error: 'capabilities' not allowed after 'merge|replace|delete'")
}
err := gnmi.Capabilities(ctx, client)
if err != nil {
@ -69,7 +101,7 @@ func main() {
return
case "get":
if len(setOps) != 0 {
exitWithError("error: 'get' not allowed after 'merge|replace|delete'")
usageAndExit("error: 'get' not allowed after 'merge|replace|delete'")
}
err := gnmi.Get(ctx, client, gnmi.SplitPaths(args[i+1:]))
if err != nil {
@ -78,49 +110,55 @@ func main() {
return
case "subscribe":
if len(setOps) != 0 {
exitWithError("error: 'subscribe' not allowed after 'merge|replace|delete'")
usageAndExit("error: 'subscribe' not allowed after 'merge|replace|delete'")
}
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
defer close(respChan)
defer close(errChan)
go gnmi.Subscribe(ctx, client, gnmi.SplitPaths(args[i+1:]), respChan, errChan)
subscribeOptions.Paths = gnmi.SplitPaths(args[i+1:])
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for {
select {
case resp := <-respChan:
case resp, open := <-respChan:
if !open {
return
}
if err := gnmi.LogSubscribeResponse(resp); err != nil {
exitWithError(err.Error())
glog.Fatal(err)
}
case err := <-errChan:
exitWithError(err.Error())
glog.Fatal(err)
}
}
case "update", "replace", "delete":
if len(args) == i+1 {
exitWithError("error: missing path")
usageAndExit("error: missing path")
}
op := &gnmi.Operation{
Type: args[i],
}
i++
if strings.HasPrefix(args[i], "origin=") {
op.Origin = strings.TrimPrefix(args[i], "origin=")
i++
}
op.Path = gnmi.SplitPath(args[i])
if op.Type != "delete" {
if len(args) == i+1 {
exitWithError("error: missing JSON")
usageAndExit("error: missing JSON or FILEPATH to data")
}
i++
op.Val = args[i]
}
setOps = append(setOps, op)
default:
exitWithError(fmt.Sprintf("error: unknown operation %q", args[i]))
usageAndExit(fmt.Sprintf("error: unknown operation %q", args[i]))
}
}
if len(setOps) == 0 {
flag.Usage()
os.Exit(1)
usageAndExit("")
}
err := gnmi.Set(ctx, client, setOps)
err = gnmi.Set(ctx, client, setOps)
if err != nil {
glog.Fatal(err)
}

View File

@ -0,0 +1,120 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// json2test reformats 'go test -json' output as text as if the -json
// flag were not passed to go test. It is useful if you want to
// analyze go test -json output, but still want a human readable test
// log.
//
// Usage:
//
// go test -json > out.txt; <analysis program> out.txt; cat out.txt | json2test
//
package main
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"time"
)
var errTestFailure = errors.New("testfailure")
func main() {
err := writeTestOutput(os.Stdin, os.Stdout)
if err == errTestFailure {
os.Exit(1)
} else if err != nil {
log.Fatal(err)
}
}
type testEvent struct {
Time time.Time // encodes as an RFC3339-format string
Action string
Package string
Test string
Elapsed float64 // seconds
Output string
}
type test struct {
pkg string
test string
}
type outputBuffer struct {
output []string
}
func (o *outputBuffer) push(s string) {
o.output = append(o.output, s)
}
type testFailure struct {
t test
o outputBuffer
}
func writeTestOutput(in io.Reader, out io.Writer) error {
testOutputBuffer := map[test]*outputBuffer{}
var failures []testFailure
d := json.NewDecoder(in)
buf := bufio.NewWriter(out)
defer buf.Flush()
for {
var e testEvent
if err := d.Decode(&e); err != nil {
break
}
switch e.Action {
default:
continue
case "run":
testOutputBuffer[test{pkg: e.Package, test: e.Test}] = new(outputBuffer)
case "pass":
// Don't hold onto text for passing
delete(testOutputBuffer, test{pkg: e.Package, test: e.Test})
case "fail":
// fail may be for a package, which won't have an entry in
// testOutputBuffer because packages don't have a "run"
// action.
t := test{pkg: e.Package, test: e.Test}
if o, ok := testOutputBuffer[t]; ok {
f := testFailure{t: t, o: *o}
delete(testOutputBuffer, t)
failures = append(failures, f)
}
case "output":
buf.WriteString(e.Output)
// output may be for a package, which won't have an entry
// in testOutputBuffer because packages don't have a "run"
// action.
if o, ok := testOutputBuffer[test{pkg: e.Package, test: e.Test}]; ok {
o.push(e.Output)
}
}
}
if len(failures) == 0 {
return nil
}
buf.WriteString("\nTest failures:\n")
for i, f := range failures {
fmt.Fprintf(buf, "[%d] %s.%s\n", i+1, f.t.pkg, f.t.test)
for _, s := range f.o.output {
buf.WriteString(s)
}
if i < len(failures)-1 {
buf.WriteByte('\n')
}
}
return errTestFailure
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"testing"
)
func TestWriteTestOutput(t *testing.T) {
input, err := os.Open("testdata/input.txt")
if err != nil {
t.Fatal(err)
}
var out bytes.Buffer
if err := writeTestOutput(input, &out); err != errTestFailure {
t.Error("expected test failure")
}
gold, err := os.Open("testdata/gold.txt")
if err != nil {
t.Fatal(err)
}
expected, err := ioutil.ReadAll(gold)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(out.Bytes(), expected) {
t.Error("output does not match gold.txt")
fmt.Println("Expected:")
fmt.Println(string(expected))
fmt.Println("Got:")
fmt.Println(out.String())
}
}

View File

@ -0,0 +1,16 @@
? pkg/skipped [no test files]
=== RUN TestPass
--- PASS: TestPass (0.00s)
PASS
ok pkg/passed 0.013s
panic
FAIL pkg/panic 600.029s
--- FAIL: TestFail (0.18s)
Test failures:
[1] pkg/panic.TestPanic
panic
FAIL pkg/panic 600.029s
[2] pkg/failed.TestFail
--- FAIL: TestFail (0.18s)

View File

@ -0,0 +1,17 @@
{"Time":"2018-03-08T10:33:12.002692769-08:00","Action":"output","Package":"pkg/skipped","Output":"? \tpkg/skipped\t[no test files]\n"}
{"Time":"2018-03-08T10:33:12.003199228-08:00","Action":"skip","Package":"pkg/skipped","Elapsed":0.001}
{"Time":"2018-03-08T10:33:12.343866281-08:00","Action":"run","Package":"pkg/passed","Test":"TestPass"}
{"Time":"2018-03-08T10:33:12.34406622-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"=== RUN TestPass\n"}
{"Time":"2018-03-08T10:33:12.344139342-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"--- PASS: TestPass (0.00s)\n"}
{"Time":"2018-03-08T10:33:12.344165231-08:00","Action":"pass","Package":"pkg/passed","Test":"TestPass","Elapsed":0}
{"Time":"2018-03-08T10:33:12.344297059-08:00","Action":"output","Package":"pkg/passed","Output":"PASS\n"}
{"Time":"2018-03-08T10:33:12.345217622-08:00","Action":"output","Package":"pkg/passed","Output":"ok \tpkg/passed\t0.013s\n"}
{"Time":"2018-03-08T10:33:12.34533033-08:00","Action":"pass","Package":"pkg/passed","Elapsed":0.013}
{"Time":"2018-03-08T10:33:20.243866281-08:00","Action":"run","Package":"pkg/panic","Test":"TestPanic"}
{"Time":"2018-03-08T10:33:20.27231537-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"panic\n"}
{"Time":"2018-03-08T10:33:20.272414481-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"FAIL\tpkg/panic\t600.029s\n"}
{"Time":"2018-03-08T10:33:20.272440286-08:00","Action":"fail","Package":"pkg/panic","Test":"TestPanic","Elapsed":600.029}
{"Time":"2018-03-08T10:33:26.143866281-08:00","Action":"run","Package":"pkg/failed","Test":"TestFail"}
{"Time":"2018-03-08T10:33:27.158776469-08:00","Action":"output","Package":"pkg/failed","Test":"TestFail","Output":"--- FAIL: TestFail (0.18s)\n"}
{"Time":"2018-03-08T10:33:27.158860934-08:00","Action":"fail","Package":"pkg/failed","Test":"TestFail","Elapsed":0.18}
{"Time":"2018-03-08T10:33:27.161302093-08:00","Action":"fail","Package":"pkg/failed","Elapsed":0.204}

View File

@ -26,13 +26,12 @@ var keysFlag = flag.String("kafkakeys", "",
"Keys for kafka messages (comma-separated, default: the value of -addrs")
func newProducer(addresses []string, topic, key, dataset string) (producer.Producer, error) {
glog.Infof("Connected to Kafka brokers at %s", addresses)
encodedKey := sarama.StringEncoder(key)
p, err := producer.New(openconfig.NewEncoder(topic, encodedKey, dataset), addresses, nil)
if err != nil {
return nil, fmt.Errorf("Failed to create Kafka producer: %s", err)
return nil, fmt.Errorf("Failed to create Kafka brokers: %s", err)
}
glog.Infof("Connected to Kafka brokers at %s", addresses)
return p, nil
}

View File

@ -33,5 +33,5 @@ Prometheus 2.0 will probably support timestamps.
See the `-help` output, but here's an example to push all the metrics defined
in the sample config file:
```
ocprometheus -addrs <switch-hostname>:6042 -config sampleconfig.json
ocprometheus -addr <switch-hostname>:6042 -config sampleconfig.json
```

View File

@ -11,7 +11,7 @@ import (
"github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto"
"github.com/openconfig/reference/rpc/openconfig"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/prometheus/client_golang/prometheus"
)
@ -24,8 +24,10 @@ type source struct {
// Since the labels are fixed per-path and per-device we can cache them here,
// to avoid recomputing them.
type labelledMetric struct {
metric prometheus.Metric
labels []string
metric prometheus.Metric
labels []string
defaultValue float64
stringMetric bool
}
type collector struct {
@ -43,13 +45,14 @@ func newCollector(config *Config) *collector {
}
}
// Process a notfication and update or create the corresponding metrics.
// Process a notification and update or create the corresponding metrics.
func (c *collector) update(addr string, message proto.Message) {
resp, ok := message.(*openconfig.SubscribeResponse)
resp, ok := message.(*pb.SubscribeResponse)
if !ok {
glog.Errorf("Unexpected type of message: %T", message)
return
}
notif := resp.GetUpdate()
if notif == nil {
return
@ -57,7 +60,6 @@ func (c *collector) update(addr string, message proto.Message) {
device := strings.Split(addr, ":")[0]
prefix := "/" + strings.Join(notif.Prefix.Element, "/")
// Process deletes first
for _, del := range notif.Delete {
path := prefix + "/" + strings.Join(del.Element, "/")
@ -70,7 +72,7 @@ func (c *collector) update(addr string, message proto.Message) {
// Process updates next
for _, update := range notif.Update {
// We only use JSON encoded values
if update.Value == nil || update.Value.Type != openconfig.Type_JSON {
if update.Value == nil || update.Value.Type != pb.Encoding_JSON {
glog.V(9).Infof("Ignoring incompatible update value in %s", update)
continue
}
@ -80,40 +82,81 @@ func (c *collector) update(addr string, message proto.Message) {
if !ok {
continue
}
var strUpdate bool
var floatVal float64
var strVal string
switch v := value.(type) {
case float64:
strUpdate = false
floatVal = v
case string:
strUpdate = true
strVal = v
}
if suffix != "" {
path += "/" + suffix
}
src := source{addr: device, path: path}
c.m.Lock()
// Use the cached labels and descriptor if available
if m, ok := c.metrics[src]; ok {
m.metric = prometheus.MustNewConstMetric(m.metric.Desc(), prometheus.GaugeValue, value,
m.labels...)
if strUpdate {
// Skip string updates for non string metrics
if !m.stringMetric {
c.m.Unlock()
continue
}
// Display a default value and replace the value label with the string value
floatVal = m.defaultValue
m.labels[len(m.labels)-1] = strVal
}
m.metric = prometheus.MustNewConstMetric(m.metric.Desc(), prometheus.GaugeValue,
floatVal, m.labels...)
c.m.Unlock()
continue
}
c.m.Unlock()
c.m.Unlock()
// Get the descriptor and labels for this source
desc, labelValues := c.config.getDescAndLabels(src)
if desc == nil {
metric := c.config.getMetricValues(src)
if metric == nil || metric.desc == nil {
glog.V(8).Infof("Ignoring unmatched update at %s:%s: %+v", device, path, update.Value)
continue
}
c.m.Lock()
if strUpdate {
if !metric.stringMetric {
// Skip string updates for non string metrics
continue
}
// Display a default value and replace the value label with the string value
floatVal = metric.defaultValue
metric.labels[len(metric.labels)-1] = strVal
}
// Save the metric and labels in the cache
metric := prometheus.MustNewConstMetric(desc, prometheus.GaugeValue, value, labelValues...)
c.m.Lock()
lm := prometheus.MustNewConstMetric(metric.desc, prometheus.GaugeValue,
floatVal, metric.labels...)
c.metrics[src] = &labelledMetric{
metric: metric,
labels: labelValues,
metric: lm,
labels: metric.labels,
defaultValue: metric.defaultValue,
stringMetric: metric.stringMetric,
}
c.m.Unlock()
}
}
func parseValue(update *openconfig.Update) (float64, string, bool) {
// All metrics in Prometheus are floats, so only try to unmarshal as float64.
// ParseValue takes in an update and parses a value and suffix
// Returns an interface that contains either a string or a float64 as well as a suffix
// Unparseable updates return (0, empty string, false)
func parseValue(update *pb.Update) (interface{}, string, bool) {
var intf interface{}
if err := json.Unmarshal(update.Value.Value, &intf); err != nil {
glog.Errorf("Can't parse value in update %v: %v", update, err)
@ -129,13 +172,16 @@ func parseValue(update *openconfig.Update) (float64, string, bool) {
return val, "value", true
}
}
// float64 or string expected as the return value
case bool:
if value {
return 1, "", true
return float64(1), "", true
}
return 0, "", true
return float64(0), "", true
case string:
return value, "", true
default:
glog.V(9).Infof("Ignorig non-numeric update: %v", update)
glog.V(9).Infof("Ignoring update with unexpected type: %T", value)
}
return 0, "", false

View File

@ -5,33 +5,80 @@
package main
import (
"fmt"
"strings"
"testing"
"github.com/aristanetworks/goarista/test"
"github.com/openconfig/reference/rpc/openconfig"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/prometheus/client_golang/prometheus"
)
func makeMetrics(cfg *Config, expValues map[source]float64) map[source]*labelledMetric {
func makeMetrics(cfg *Config, expValues map[source]float64, notification *pb.Notification,
prevMetrics map[source]*labelledMetric) map[source]*labelledMetric {
expMetrics := map[source]*labelledMetric{}
for k, v := range expValues {
desc, labels := cfg.getDescAndLabels(k)
if desc == nil || labels == nil {
panic("cfg.getDescAndLabels returned nil")
if prevMetrics != nil {
expMetrics = prevMetrics
}
for src, v := range expValues {
metric := cfg.getMetricValues(src)
if metric == nil || metric.desc == nil || metric.labels == nil {
panic("cfg.getMetricValues returned nil")
}
expMetrics[k] = &labelledMetric{
metric: prometheus.MustNewConstMetric(desc, prometheus.GaugeValue, v, labels...),
labels: labels,
// Preserve current value of labels
labels := metric.labels
if _, ok := expMetrics[src]; ok && expMetrics[src] != nil {
labels = expMetrics[src].labels
}
// Handle string updates
if notification.Update != nil {
if update, err := findUpdate(notification, src.path); err == nil {
val, _, ok := parseValue(update)
if !ok {
continue
}
if strVal, ok := val.(string); ok {
if !metric.stringMetric {
continue
}
v = metric.defaultValue
labels[len(labels)-1] = strVal
}
}
}
expMetrics[src] = &labelledMetric{
metric: prometheus.MustNewConstMetric(metric.desc, prometheus.GaugeValue, v,
labels...),
labels: labels,
defaultValue: metric.defaultValue,
stringMetric: metric.stringMetric,
}
}
// Handle deletion
for key := range expMetrics {
if _, ok := expValues[key]; !ok {
delete(expMetrics, key)
}
}
return expMetrics
}
func makeResponse(notif *openconfig.Notification) *openconfig.SubscribeResponse {
return &openconfig.SubscribeResponse{
Response: &openconfig.SubscribeResponse_Update{Update: notif},
func findUpdate(notif *pb.Notification, path string) (*pb.Update, error) {
prefix := notif.Prefix.Element
for _, v := range notif.Update {
fullPath := "/" + strings.Join(append(prefix, v.Path.Element...), "/")
if strings.Contains(path, fullPath) || path == fullPath {
return v, nil
}
}
return nil, fmt.Errorf("Failed to find matching update for path %v", path)
}
func makeResponse(notif *pb.Notification) *pb.SubscribeResponse {
return &pb.SubscribeResponse{
Response: &pb.SubscribeResponse_Update{Update: notif},
}
}
@ -49,6 +96,11 @@ subscriptions:
- /Sysdb/environment/power/status
- /Sysdb/bridging/igmpsnooping/forwarding/forwarding/status
metrics:
- name: fanName
path: /Sysdb/environment/cooling/status/fan/name
help: Fan Name
valuelabel: name
defaultvalue: 2.5
- name: intfCounter
path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter
help: Per-Interface Bytes/Errors/Discards Counters
@ -64,79 +116,101 @@ metrics:
}
coll := newCollector(cfg)
notif := &openconfig.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{
notif := &pb.Notification{
Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*pb.Update{
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("42"),
},
},
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "speed"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("{\"value\": 45}"),
},
},
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"igmpsnooping", "vlanStatus", "2050", "ethGroup",
"01:00:5e:01:01:01", "intf", "Cpu"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("true"),
},
},
{
Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "name"},
},
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("\"Fan1.1\""),
},
},
},
}
expValues := map[source]float64{
source{
{
addr: "10.1.1.1",
path: "/Sysdb/lag/intfCounterDir/Ethernet1/intfCounter",
}: 42,
source{
{
addr: "10.1.1.1",
path: "/Sysdb/environment/cooling/status/fan/speed/value",
}: 45,
source{
{
addr: "10.1.1.1",
path: "/Sysdb/igmpsnooping/vlanStatus/2050/ethGroup/01:00:5e:01:01:01/intf/Cpu",
}: 1,
{
addr: "10.1.1.1",
path: "/Sysdb/environment/cooling/status/fan/name",
}: 2.5,
}
coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics := makeMetrics(cfg, expValues)
expMetrics := makeMetrics(cfg, expValues, notif, nil)
if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
}
// Update one value, and one path which is not a metric
notif = &openconfig.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{
// Update two values, and one path which is not a metric
notif = &pb.Notification{
Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*pb.Update{
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("52"),
},
},
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "name"},
},
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("\"Fan2.1\""),
},
},
{
Path: &pb.Path{
Element: []string{"environment", "doesntexist", "status"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("{\"value\": 45}"),
},
},
@ -149,21 +223,21 @@ metrics:
expValues[src] = 52
coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues)
expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
}
// Same path, different device
notif = &openconfig.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{
notif = &pb.Notification{
Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*pb.Update{
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("42"),
},
},
@ -173,15 +247,15 @@ metrics:
expValues[src] = 42
coll.update("10.1.1.2:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues)
expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
}
// Delete a path
notif = &openconfig.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}},
Delete: []*openconfig.Path{
notif = &pb.Notification{
Prefix: &pb.Path{Element: []string{"Sysdb"}},
Delete: []*pb.Path{
{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
},
@ -191,21 +265,21 @@ metrics:
delete(expValues, src)
coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues)
expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
}
// Non-numeric update
notif = &openconfig.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{
// Non-numeric update to path without value label
notif = &pb.Notification{
Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*pb.Update{
{
Path: &openconfig.Path{
Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
},
Value: &openconfig.Value{
Type: openconfig.Type_JSON,
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("\"test\""),
},
},
@ -213,6 +287,7 @@ metrics:
}
coll.update("10.1.1.1:6042", makeResponse(notif))
// Don't make new metrics as it should have no effect
if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
}

View File

@ -33,7 +33,7 @@ type MetricDef struct {
Path string
// Path compiled as a regexp.
re *regexp.Regexp
re *regexp.Regexp `deepequal:"ignore"`
// Metric name.
Name string
@ -41,6 +41,15 @@ type MetricDef struct {
// Metric help string.
Help string
// Label to store string values
ValueLabel string
// Default value to display for string values
DefaultValue float64
// Does the metric store a string value
stringMetric bool
// This map contains the metric descriptors for this metric for each device.
devDesc map[string]*prometheus.Desc
@ -48,6 +57,14 @@ type MetricDef struct {
desc *prometheus.Desc
}
// metricValues contains the values used in updating a metric
type metricValues struct {
desc *prometheus.Desc
labels []string
defaultValue float64
stringMetric bool
}
// Parses the config and creates the descriptors for each path and device.
func parseConfig(cfg []byte) (*Config, error) {
config := &Config{
@ -56,10 +73,8 @@ func parseConfig(cfg []byte) (*Config, error) {
if err := yaml.Unmarshal(cfg, config); err != nil {
return nil, fmt.Errorf("Failed to parse config: %v", err)
}
for _, def := range config.Metrics {
def.re = regexp.MustCompile(def.Path)
// Extract label names
reNames := def.re.SubexpNames()[1:]
labelNames := make([]string, len(reNames))
@ -69,7 +84,10 @@ func parseConfig(cfg []byte) (*Config, error) {
labelNames[i] = "unnamedLabel" + strconv.Itoa(i+1)
}
}
if def.ValueLabel != "" {
labelNames = append(labelNames, def.ValueLabel)
def.stringMetric = true
}
// Create a default descriptor only if there aren't any per-device labels,
// or if it's explicitly declared
if len(config.DeviceLabels) == 0 || len(config.DeviceLabels["*"]) > 0 {
@ -88,20 +106,25 @@ func parseConfig(cfg []byte) (*Config, error) {
return config, nil
}
// Returns the descriptor corresponding to the device and path, and labels extracted from the path.
// Returns a struct containing the descriptor corresponding to the device and path, labels
// extracted from the path, the default value for the metric and if it accepts string values.
// If the device and path doesn't match any metrics, returns nil.
func (c *Config) getDescAndLabels(s source) (*prometheus.Desc, []string) {
func (c *Config) getMetricValues(s source) *metricValues {
for _, def := range c.Metrics {
if groups := def.re.FindStringSubmatch(s.path); groups != nil {
if desc, ok := def.devDesc[s.addr]; ok {
return desc, groups[1:]
if def.ValueLabel != "" {
groups = append(groups, def.ValueLabel)
}
return def.desc, groups[1:]
desc, ok := def.devDesc[s.addr]
if !ok {
desc = def.desc
}
return &metricValues{desc: desc, labels: groups[1:], defaultValue: def.DefaultValue,
stringMetric: def.stringMetric}
}
}
return nil, nil
return nil
}
// Sends all the descriptors to the channel.

View File

@ -31,6 +31,11 @@ subscriptions:
- /Sysdb/environment/cooling/status
- /Sysdb/environment/power/status
metrics:
- name: fanName
path: /Sysdb/environment/cooling/status/fan/name
help: Fan Name
valuelabel: name
defaultvalue: 25
- name: intfCounter
path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter
help: Per-Interface Bytes/Errors/Discards Counters
@ -53,6 +58,26 @@ metrics:
"/Sysdb/environment/power/status",
},
Metrics: []*MetricDef{
{
Path: "/Sysdb/environment/cooling/status/fan/name",
re: regexp.MustCompile(
"/Sysdb/environment/cooling/status/fan/name"),
Name: "fanName",
Help: "Fan Name",
ValueLabel: "name",
DefaultValue: 25,
stringMetric: true,
devDesc: map[string]*prometheus.Desc{
"10.1.1.1": prometheus.NewDesc("fanName",
"Fan Name",
[]string{"name"},
prometheus.Labels{"lab1": "val1", "lab2": "val2"}),
},
desc: prometheus.NewDesc("fanName",
"Fan Name",
[]string{"name"},
prometheus.Labels{"lab1": "val3", "lab2": "val4"}),
},
{
Path: "/Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter",
re: regexp.MustCompile(
@ -126,8 +151,9 @@ metrics:
prometheus.Labels{"lab1": "val3", "lab2": "val4"}),
},
{
Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"),
Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed",
Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{},
@ -180,7 +206,8 @@ metrics:
},
{
Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"),
re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed",
Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{
@ -222,8 +249,9 @@ metrics:
[]string{"intf"}, prometheus.Labels{}),
},
{
Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"),
Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed",
Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{},
@ -247,7 +275,7 @@ metrics:
}
}
func TestGetDescAndLabels(t *testing.T) {
func TestGetMetricValues(t *testing.T) {
config := []byte(`
devicelabels:
10.1.1.1:
@ -317,12 +345,16 @@ metrics:
}
for i, c := range tCases {
desc, labels := cfg.getDescAndLabels(c.src)
if !test.DeepEqual(desc, c.desc) {
t.Errorf("Test case %d: desc mismatch %v", i+1, test.Diff(desc, c.desc))
metric := cfg.getMetricValues(c.src)
if metric == nil {
// Avoids error from trying to access metric.desc when metric is nil
metric = &metricValues{}
}
if !test.DeepEqual(labels, c.labels) {
t.Errorf("Test case %d: labels mismatch %v", i+1, test.Diff(labels, c.labels))
if !test.DeepEqual(metric.desc, c.desc) {
t.Errorf("Test case %d: desc mismatch %v", i+1, test.Diff(metric.desc, c.desc))
}
if !test.DeepEqual(metric.labels, c.labels) {
t.Errorf("Test case %d: labels mismatch %v", i+1, test.Diff(metric.labels, c.labels))
}
}
}

View File

@ -6,25 +6,39 @@
package main
import (
"context"
"flag"
"io/ioutil"
"net/http"
"sync"
"github.com/aristanetworks/goarista/openconfig/client"
"strings"
"github.com/aristanetworks/glog"
"github.com/aristanetworks/goarista/gnmi"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func main() {
// gNMI options
gNMIcfg := &gnmi.Config{}
flag.StringVar(&gNMIcfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&gNMIcfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&gNMIcfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&gNMIcfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&gNMIcfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&gNMIcfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&gNMIcfg.TLS, "tls", false, "Enable TLS")
subscribePaths := flag.String("subscribe", "/", "Comma-separated list of paths to subscribe to")
// program options
listenaddr := flag.String("listenaddr", ":8080", "Address on which to expose the metrics")
url := flag.String("url", "/metrics", "URL where to expose the metrics")
configFlag := flag.String("config", "",
"Config to turn OpenConfig telemetry into Prometheus metrics")
username, password, subscriptions, addrs, opts := client.ParseFlags()
flag.Parse()
subscriptions := strings.Split(*subscribePaths, ",")
if *configFlag == "" {
glog.Fatal("You need specify a config file using -config flag")
}
@ -39,7 +53,7 @@ func main() {
// Ignore the default "subscribe-to-everything" subscription of the
// -subscribe flag.
if subscriptions[0] == "" {
if subscriptions[0] == "/" {
subscriptions = subscriptions[1:]
}
// Add the subscriptions from the config file.
@ -47,14 +61,33 @@ func main() {
coll := newCollector(config)
prometheus.MustRegister(coll)
wg := new(sync.WaitGroup)
for _, addr := range addrs {
wg.Add(1)
c := client.New(username, password, addr, opts)
go c.Subscribe(wg, subscriptions, coll.update)
ctx := gnmi.NewContext(context.Background(), gNMIcfg)
client, err := gnmi.Dial(gNMIcfg)
if err != nil {
glog.Fatal(err)
}
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
go handleSubscription(respChan, errChan, coll, gNMIcfg.Addr)
http.Handle(*url, promhttp.Handler())
glog.Fatal(http.ListenAndServe(*listenaddr, nil))
}
func handleSubscription(respChan chan *pb.SubscribeResponse,
errChan chan error, coll *collector, addr string) {
for {
select {
case resp := <-respChan:
coll.update(addr, resp)
case err := <-errChan:
glog.Fatal(err)
}
}
}

View File

@ -0,0 +1,80 @@
# Per-device labels. Optional
# Exactly the same set of labels must be specified for each device.
# If device address is *, the labels apply to all devices not listed explicitly.
# If any explicit device if listed below, then you need to specify all devices you're subscribed to,
# or have a wildcard entry. Otherwise, updates from non-listed devices will be ignored.
#deviceLabels:
# 10.1.1.1:
# lab1: val1
# lab2: val2
# '*':
# lab1: val3
# lab2: val4
# Subscriptions to OpenConfig paths.
subscriptions:
- /Smash/counters/ethIntf
- /Smash/interface/counter/lag/current/counter
- /Sysdb/environment/archer/cooling/status
- /Sysdb/environment/archer/power/status
- /Sysdb/environment/archer/temperature/status
- /Sysdb/hardware/archer/xcvr/status
- /Sysdb/interface/config/eth
# Prometheus metrics configuration.
# If you use named capture groups in the path, they will be extracted into labels with the same name.
# All fields are mandatory.
metrics:
- name: interfaceDescription
path: /Sysdb/interface/config/eth/phy/slice/1/intfConfig/(?P<interface>Ethernet.)/description
help: Description
valuelabel: description
defaultvalue: 15
- name: intfCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)
help: Per-Interface Bytes/Errors/Discards Counters
- name: intfLagCounter
path: /Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)
help: Per-Lag Bytes/Errors/Discards Counters
- name: intfPktCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)
help: Per-Interface Unicast/Multicast/Broadcast Packer Counters
- name: intfLagPktCounter
path: /Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)
help: Per-Lag Unicast/Multicast/Broadcast Packer Counters
- name: intfPfcClassCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/ethStatistics/(?P<direction>(?:in|out))(PfcClassFrames)
help: Per-Interface Input/Output PFC Frames Counters
- name: tempSensor
path: /Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/((?:maxT|t)emperature)
help: Temperature and Maximum Temperature
- name: tempSensorAlert
path: /Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/(alertRaisedCount)
help: Temperature Alerts Counter
- name: currentSensor
path: /Sysdb/(environment)/archer/power/status/currentSensor/(?P<sensor>.+)/(current)
help: Current Levels
- name: powerSensor
path: /Sysdb/(environment)/archer/(power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power
help: Input/Output Power Levels
- name: voltageSensor
path: /Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(voltage)
help: Voltage Levels
- name: railCurrentSensor
path: /Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(current)
help: Rail Current Levels
- name: fanSpeed
path: /Sysdb/(environment)/archer/(cooling)/status/(?P<fan>.+)/speed
help: Fan Speed
- name: qsfpModularRxPower
path: /Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)
help: qsfpModularRxPower
- name: qsfpFixedRxPower
path: /Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)
help: qsfpFixedRxPower
- name: sfpModularTemperature
path: /Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/lastDomUpdateTime/(temperature)
help: sfpModularTemperature
- name: sfpFixedTemperature
path: /Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/lastDomUpdateTime/(temperature)
help: sfpFixedTemperature

View File

@ -18,11 +18,18 @@ subscriptions:
- /Sysdb/environment/temperature/status
- /Sysdb/interface/counter/eth/lag
- /Sysdb/interface/counter/eth/slice/phy
- /Sysdb/interface/config
- /Sysdb/interface/config/eth/phy/slice/1/intfConfig
# Prometheus metrics configuration.
# If you use named capture groups in the path, they will be extracted into labels with the same name.
# All fields are mandatory.
metrics:
- name: interfaceDescription
path: Sysdb/interface/config/eth/phy/slice/1/intfConfig/(?P<interface>Ethernet.)/description
help: Description
valuelabel: description
defaultvalue: 15
- name: intfCounter
path: /Sysdb/interface/counter/eth/(?:lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter/current/statistics/(?P<direction>(?:in|out))(?P<type>(Octets|Errors|Discards))
help: Per-Interface Bytes/Errors/Discards Counters

View File

@ -17,5 +17,5 @@ See the `-help` output, but here's an example to push all the temperature
sensors into Redis. You can also not pass any `-subscribe` flag to push
_everything_ into Redis.
```
ocredis -subscribe /Sysdb/environment/temperature -addrs <switch-hostname>:6042 -redis <redis-hostname>:6379
ocredis -subscribe /Sysdb/environment/temperature -addr <switch-hostname>:6042 -redis <redis-hostname>:6379
```

View File

@ -8,16 +8,15 @@
package main
import (
"context"
"encoding/json"
"flag"
"strings"
"sync"
occlient "github.com/aristanetworks/goarista/openconfig/client"
"github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto"
"github.com/openconfig/reference/rpc/openconfig"
pb "github.com/openconfig/gnmi/proto/gnmi"
redis "gopkg.in/redis.v4"
)
@ -42,11 +41,23 @@ type baseClient interface {
var client baseClient
func main() {
username, password, subscriptions, hostAddrs, opts := occlient.ParseFlags()
// gNMI options
cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&cfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
subscribePaths := flag.String("subscribe", "/", "Comma-separated list of paths to subscribe to")
flag.Parse()
if *redisFlag == "" {
glog.Fatal("Specify the address of the Redis server to write to with -redis")
}
subscriptions := strings.Split(*subscribePaths, ",")
redisAddrs := strings.Split(*redisFlag, ",")
if !*clusterMode && len(redisAddrs) > 1 {
glog.Fatal("Please pass only 1 redis address in noncluster mode or enable cluster mode")
@ -72,25 +83,27 @@ func main() {
if err != nil {
glog.Fatal("Failed to connect to client: ", err)
}
ocPublish := func(addr string, message proto.Message) {
resp, ok := message.(*openconfig.SubscribeResponse)
if !ok {
glog.Errorf("Unexpected type of message: %T", message)
return
}
if notif := resp.GetUpdate(); notif != nil {
bufferToRedis(addr, notif)
ctx := gnmi.NewContext(context.Background(), cfg)
client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for {
select {
case resp := <-respChan:
bufferToRedis(cfg.Addr, resp.GetUpdate())
case err := <-errChan:
glog.Fatal(err)
}
}
wg := new(sync.WaitGroup)
for _, hostAddr := range hostAddrs {
wg.Add(1)
c := occlient.New(username, password, hostAddr, opts)
go c.Subscribe(wg, subscriptions, ocPublish)
}
wg.Wait()
}
type redisData struct {
@ -100,7 +113,12 @@ type redisData struct {
pub map[string]interface{}
}
func bufferToRedis(addr string, notif *openconfig.Notification) {
func bufferToRedis(addr string, notif *pb.Notification) {
if notif == nil {
// possible that this should be ignored silently
glog.Error("Nil notification ignored")
return
}
path := addr + "/" + joinPath(notif.Prefix)
data := &redisData{key: path}
@ -167,20 +185,21 @@ func redisPublish(path, kind string, payload interface{}) {
}
}
func joinPath(path *openconfig.Path) string {
func joinPath(path *pb.Path) string {
// path.Elem is empty for some reason so using path.Element instead
return strings.Join(path.Element, "/")
}
func convertUpdate(update *openconfig.Update) interface{} {
func convertUpdate(update *pb.Update) interface{} {
switch update.Value.Type {
case openconfig.Type_JSON:
case pb.Encoding_JSON:
var value interface{}
err := json.Unmarshal(update.Value.Value, &value)
if err != nil {
glog.Fatalf("Malformed JSON update %q in %s", update.Value.Value, update)
}
return value
case openconfig.Type_BYTES:
case pb.Encoding_BYTES:
return update.Value.Value
default:
glog.Fatalf("Unhandled type of value %v in %s", update.Value.Type, update)

View File

@ -14,9 +14,10 @@ import (
"strings"
"time"
"github.com/aristanetworks/glog"
"github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/splunk-hec-go"
"github.com/fuyufjh/splunk-hec-go"
pb "github.com/openconfig/gnmi/proto/gnmi"
)
@ -49,7 +50,10 @@ func main() {
ctx := gnmi.NewContext(context.Background(), cfg)
// Store the address without the port so it can be used as the host in the Splunk event.
addr := cfg.Addr
client := gnmi.Dial(cfg)
client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
// Splunk connection
urls := strings.Split(*splunkURLs, ",")
@ -67,10 +71,14 @@ func main() {
// gNMI subscription
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
defer close(respChan)
defer close(errChan)
paths := strings.Split(*subscribePaths, ",")
go gnmi.Subscribe(ctx, client, gnmi.SplitPaths(paths), respChan, errChan)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(paths),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
// Forward subscribe responses to Splunk
for {

View File

@ -6,28 +6,51 @@ dropped.
This tool requires a config file to specify how to map the path of the
notificatons coming out of the OpenConfig gRPC interface onto OpenTSDB
metric names, and how to extract tags from the path. For example, the
following rule, excerpt from `sampleconfig.json`:
metric names, and how to extract tags from the path.
## Getting Started
To begin, a list of subscriptions is required (excerpt from `sampleconfig.json`):
```json
"metrics": {
"tempSensor": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)/value"
},
...
"subscriptions": [
"/Sysdb/interface/counter/eth/lag",
"/Sysdb/interface/counter/eth/slice/phy",
"/Sysdb/environment/temperature/status",
"/Sysdb/environment/cooling/status",
"/Sysdb/environment/power/status",
"/Sysdb/hardware/xcvr/status/all/xcvrStatus"
],
...
```
Applied to an update for the path
`/Sysdb/environment/temperature/status/tempSensor/TempSensor1/temperature/value`
will lead to the metric name `environment.temperature` and tags `sensor=TempSensor1`.
Note that subscriptions should not end with a trailing `/` as that will cause
the subscription to fail.
Basically, un-named groups are used to make up the metric name, and named
groups are used to extract (optional) tags.
Afterwards, the metrics are defined (excerpt from `sampleconfig.json`):
```json
"metrics": {
"tempSensor": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)"
},
...
}
```
In the metrics path, unnamed matched groups are used to make up the metric name, and named matched groups
are used to extract optional tags. Note that unnamed groups are required, otherwise the metric
name will be empty and the update will be silently dropped.
For example, using the above metrics path applied to an update for the path
`/Sysdb/environment/temperature/status/tempSensor/TempSensor1/temperature`
will lead to the metric name `environment.temperature` and tags `sensor=TempSensor1`.
## Usage
See the `-help` output, but here's an example to push all the metrics defined
in the sample config file:
```
octsdb -addrs <switch-hostname>:6042 -config sampleconfig.json -text | nc <tsd-hostname> 4242
octsdb -addr <switch-hostname>:6042 -config sampleconfig.json -text | nc <tsd-hostname> 4242
```

View File

@ -16,6 +16,9 @@ func TestConfig(t *testing.T) {
t.Fatal("Managed to load a nonexistent config!")
}
cfg, err = loadConfig("sampleconfig.json")
if err != nil {
t.Fatal("Failed to load config:", err)
}
testcases := []struct {
path string

View File

@ -7,22 +7,35 @@ package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/aristanetworks/goarista/openconfig/client"
"github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto"
"github.com/openconfig/reference/rpc/openconfig"
pb "github.com/openconfig/gnmi/proto/gnmi"
)
func main() {
// gNMI options
cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&cfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
// Program options
subscribePaths := flag.String("paths", "/", "Comma-separated list of paths to subscribe to")
tsdbFlag := flag.String("tsdb", "",
"Address of the OpenTSDB server where to push telemetry to")
textFlag := flag.Bool("text", false,
@ -38,8 +51,8 @@ func main() {
" Clients and servers should have the same number.")
udpTimeoutFlag := flag.Duration("udptimeout", 2*time.Second,
"Timeout for each")
username, password, subscriptions, addrs, opts := client.ParseFlags()
flag.Parse()
if !(*tsdbFlag != "" || *textFlag || *udpAddrFlag != "") {
glog.Fatal("Specify the address of the OpenTSDB server to write to with -tsdb")
} else if *configFlag == "" {
@ -52,6 +65,7 @@ func main() {
}
// Ignore the default "subscribe-to-everything" subscription of the
// -subscribe flag.
subscriptions := strings.Split(*subscribePaths, ",")
if subscriptions[0] == "" {
subscriptions = subscriptions[1:]
}
@ -79,33 +93,37 @@ func main() {
// TODO: support HTTP(S).
c = newTelnetClient(*tsdbFlag)
}
wg := new(sync.WaitGroup)
for _, addr := range addrs {
wg.Add(1)
publish := func(addr string, message proto.Message) {
resp, ok := message.(*openconfig.SubscribeResponse)
if !ok {
glog.Errorf("Unexpected type of message: %T", message)
return
}
if notif := resp.GetUpdate(); notif != nil {
pushToOpenTSDB(addr, c, config, notif)
}
}
c := client.New(username, password, addr, opts)
go c.Subscribe(wg, subscriptions, publish)
ctx := gnmi.NewContext(context.Background(), cfg)
client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for {
select {
case resp := <-respChan:
pushToOpenTSDB(cfg.Addr, c, config, resp.GetUpdate())
case err := <-errChan:
glog.Fatal(err)
}
}
wg.Wait()
}
func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
notif *openconfig.Notification) {
func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config, notif *pb.Notification) {
if notif == nil {
glog.Error("Nil notification ignored")
return
}
if notif.Timestamp <= 0 {
glog.Fatalf("Invalid timestamp %d in %s", notif.Timestamp, notif)
}
host := addr[:strings.IndexRune(addr, ':')]
if host == "localhost" {
// TODO: On Linux this reads /proc/sys/kernel/hostname each time,
@ -118,18 +136,15 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
}
}
prefix := "/" + strings.Join(notif.Prefix.Element, "/")
for _, update := range notif.Update {
if update.Value == nil || update.Value.Type != openconfig.Type_JSON {
if update.Value == nil || update.Value.Type != pb.Encoding_JSON {
glog.V(9).Infof("Ignoring incompatible update value in %s", update)
continue
}
value := parseValue(update)
if value == nil {
continue
}
path := prefix + "/" + strings.Join(update.Path.Element, "/")
metricName, tags := config.Match(path)
if metricName == "" {
@ -137,7 +152,6 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
continue
}
tags["host"] = host
for i, v := range value {
if len(value) > 1 {
tags["index"] = strconv.Itoa(i)
@ -158,7 +172,7 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
// parseValue returns either an integer/floating point value of the given update, or if
// the value is a slice of integers/floating point values. If the value is neither of these
// or if any element in the slice is non numerical, parseValue returns nil.
func parseValue(update *openconfig.Update) []interface{} {
func parseValue(update *pb.Update) []interface{} {
var value interface{}
decoder := json.NewDecoder(bytes.NewReader(update.Value.Value))
@ -196,7 +210,7 @@ func parseValue(update *openconfig.Update) []interface{} {
}
// Convert our json.Number to either an int64, uint64, or float64.
func parseNumber(num json.Number, update *openconfig.Update) interface{} {
func parseNumber(num json.Number, update *pb.Update) interface{} {
var value interface{}
var err error
if value, err = num.Int64(); err != nil {

View File

@ -9,8 +9,7 @@ import (
"testing"
"github.com/aristanetworks/goarista/test"
"github.com/openconfig/reference/rpc/openconfig"
pb "github.com/openconfig/gnmi/proto/gnmi"
)
func TestParseValue(t *testing.T) { // Because parsing JSON sucks.
@ -35,8 +34,8 @@ func TestParseValue(t *testing.T) { // Because parsing JSON sucks.
},
}
for i, tcase := range testcases {
actual := parseValue(&openconfig.Update{
Value: &openconfig.Value{
actual := parseValue(&pb.Update{
Value: &pb.Value{
Value: []byte(tcase.input),
},
})

View File

@ -1,11 +1,14 @@
{
"comment": "This is a sample configuration",
"comment": "This is a sample configuration for EOS versions below 4.20",
"subscriptions": [
"/Sysdb/interface/counter/eth/lag",
"/Sysdb/interface/counter/eth/slice/phy",
"/Sysdb/environment/temperature/status",
"/Sysdb/environment/cooling/status",
"/Sysdb/environment/power/status",
"/Sysdb/environment/temperature/status",
"/Sysdb/interface/counter/eth/lag",
"/Sysdb/interface/counter/eth/slice/phy"
"/Sysdb/hardware/xcvr/status/all/xcvrStatus"
],
"metricPrefix": "eos",
"metrics": {
@ -20,25 +23,32 @@
},
"tempSensor": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)/value"
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)"
},
"tempSensorAlert": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/(alertRaisedCount)"
},
"currentSensor": {
"path": "/Sysdb/(environment)/power/status/currentSensor/(?P<sensor>.+)/(current)/value"
"path": "/Sysdb/(environment)/power/status/currentSensor/(?P<sensor>.+)/(current)"
},
"powerSensor": {
"path": "/Sysdb/(environment/power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power/value"
"path": "/Sysdb/(environment/power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power"
},
"voltageSensor": {
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(voltage)/value"
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(voltage)"
},
"railCurrentSensor": {
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(current)/value"
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(current)"
},
"fanSpeed": {
"path": "/Sysdb/(environment)/cooling/status/(fan)/(?P<fan>.+)/(speed)/value"
"path": "/Sysdb/(environment)/cooling/status/(fan)/(?P<fan>.+)/(speed)"
},
"qsfpRxPower": {
"path": "/Sysdb/hardware/(xcvr)/status/all/xcvrStatus/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"sfpTemperature": {
"path": "/Sysdb/hardware/(xcvr)/status/all/xcvrStatus/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
}
}
}

View File

@ -0,0 +1,66 @@
{
"comment": "This is a sample configuration for EOS versions above 4.20",
"subscriptions": [
"/Smash/counters/ethIntf",
"/Smash/interface/counter/lag/current/counter",
"/Sysdb/environment/archer/cooling/status",
"/Sysdb/environment/archer/power/status",
"/Sysdb/environment/archer/temperature/status",
"/Sysdb/hardware/archer/xcvr/status"
],
"metricPrefix": "eos",
"metrics": {
"intfCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)"
},
"intfLagCounter": {
"path": "/Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)"
},
"intfPktCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)"
},
"intfLagPktCounter": {
"path": "/Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)"
},
"intfPfcClassCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/ethStatistics/(?P<direction>(?:in|out))(PfcClassFrames)"
},
"tempSensor": {
"path": "/Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/((?:maxT|t)emperature)"
},
"tempSensorAlert": {
"path": "/Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/(alertRaisedCount)"
},
"currentSensor": {
"path": "/Sysdb/(environment)/archer/power/status/currentSensor/(?P<sensor>.+)/(current)"
},
"powerSensor": {
"path": "/Sysdb/(environment)/archer/(power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power"
},
"voltageSensor": {
"path": "/Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(voltage)"
},
"railCurrentSensor": {
"path": "/Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(current)"
},
"fanSpeed": {
"path": "/Sysdb/(environment)/archer/(cooling)/status/(?P<fan>.+)/speed"
},
"qsfpModularRxPower": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"qsfpFixedRxPower": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"sfpModularTemperature": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
},
"sfpFixedTemperature": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
}
}
}

View File

@ -0,0 +1,286 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// test2influxdb writes results from 'go test -json' to an influxdb
// database.
//
// Example usage:
//
// go test -json | test2influxdb [options...]
//
// Points are written to influxdb with tags:
//
// package
// type "package" for a package result; "test" for a test result
// Additional tags set by -tags flag
//
// And fields:
//
// test string // "NONE" for whole package results
// elapsed float64 // in seconds
// pass float64 // 1 for PASS, 0 for FAIL
// Additional fields set by -fields flag
//
// "test" is a field instead of a tag to reduce cardinality of data.
//
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/aristanetworks/glog"
client "github.com/influxdata/influxdb/client/v2"
)
type tag struct {
key string
value string
}
type tags []tag
func (ts *tags) String() string {
s := make([]string, len(*ts))
for i, t := range *ts {
s[i] = t.key + "=" + t.value
}
return strings.Join(s, ",")
}
func (ts *tags) Set(s string) error {
for _, fieldString := range strings.Split(s, ",") {
kv := strings.Split(fieldString, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid tag, expecting one '=': %q", fieldString)
}
key := strings.TrimSpace(kv[0])
if key == "" {
return fmt.Errorf("invalid tag key %q in %q", key, fieldString)
}
val := strings.TrimSpace(kv[1])
if val == "" {
return fmt.Errorf("invalid tag value %q in %q", val, fieldString)
}
*ts = append(*ts, tag{key: key, value: val})
}
return nil
}
type field struct {
key string
value interface{}
}
type fields []field
func (fs *fields) String() string {
s := make([]string, len(*fs))
for i, f := range *fs {
var valString string
switch v := f.value.(type) {
case bool:
valString = strconv.FormatBool(v)
case float64:
valString = strconv.FormatFloat(v, 'f', -1, 64)
case int64:
valString = strconv.FormatInt(v, 10) + "i"
case string:
valString = v
}
s[i] = f.key + "=" + valString
}
return strings.Join(s, ",")
}
func (fs *fields) Set(s string) error {
for _, fieldString := range strings.Split(s, ",") {
kv := strings.Split(fieldString, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid field, expecting one '=': %q", fieldString)
}
key := strings.TrimSpace(kv[0])
if key == "" {
return fmt.Errorf("invalid field key %q in %q", key, fieldString)
}
val := strings.TrimSpace(kv[1])
if val == "" {
return fmt.Errorf("invalid field value %q in %q", val, fieldString)
}
var value interface{}
var err error
if value, err = strconv.ParseBool(val); err == nil {
// It's a bool
} else if value, err = strconv.ParseFloat(val, 64); err == nil {
// It's a float64
} else if value, err = strconv.ParseInt(val[:len(val)-1], 0, 64); err == nil &&
val[len(val)-1] == 'i' {
// ints are suffixed with an "i"
} else {
value = val
}
*fs = append(*fs, field{key: key, value: value})
}
return nil
}
var (
flagAddr = flag.String("addr", "http://localhost:8086", "adddress of influxdb database")
flagDB = flag.String("db", "gotest", "use `database` in influxdb")
flagMeasurement = flag.String("m", "result", "`measurement` used in influxdb database")
flagTags tags
flagFields fields
)
func init() {
flag.Var(&flagTags, "tags", "set additional `tags`. Ex: name=alice,food=pasta")
flag.Var(&flagFields, "fields", "set additional `fields`. Ex: id=1234i,long=34.123,lat=72.234")
}
func main() {
flag.Parse()
c, err := client.NewHTTPClient(client.HTTPConfig{
Addr: *flagAddr,
})
if err != nil {
glog.Fatal(err)
}
batch, err := client.NewBatchPoints(client.BatchPointsConfig{Database: *flagDB})
if err != nil {
glog.Fatal(err)
}
if err := parseTestOutput(os.Stdin, batch); err != nil {
glog.Fatal(err)
}
if err := c.Write(batch); err != nil {
glog.Fatal(err)
}
}
// See https://golang.org/cmd/test2json/ for a description of 'go test
// -json' output
type testEvent struct {
Time time.Time // encodes as an RFC3339-format string
Action string
Package string
Test string
Elapsed float64 // seconds
Output string
}
func createTags(e *testEvent) map[string]string {
tags := make(map[string]string, len(flagTags)+2)
for _, t := range flagTags {
tags[t.key] = t.value
}
resultType := "test"
if e.Test == "" {
resultType = "package"
}
tags["package"] = e.Package
tags["type"] = resultType
return tags
}
func createFields(e *testEvent) map[string]interface{} {
fields := make(map[string]interface{}, len(flagFields)+3)
for _, f := range flagFields {
fields[f.key] = f.value
}
// Use a float64 instead of a bool to be able to SUM test
// successes in influxdb.
var pass float64
if e.Action == "pass" {
pass = 1
}
fields["pass"] = pass
fields["elapsed"] = e.Elapsed
if e.Test != "" {
fields["test"] = e.Test
}
return fields
}
func parseTestOutput(r io.Reader, batch client.BatchPoints) error {
// pkgs holds packages seen in r. Unfortunately, if a test panics,
// then there is no "fail" result from a package. To detect these
// kind of failures, keep track of all the packages that never had
// a "pass" or "fail".
//
// The last seen timestamp is stored with the package, so that
// package result measurement written to influxdb can be later
// than any test result for that package.
pkgs := make(map[string]time.Time)
d := json.NewDecoder(r)
for {
e := &testEvent{}
if err := d.Decode(e); err != nil {
if err != io.EOF {
return err
}
break
}
switch e.Action {
case "pass", "fail":
default:
continue
}
if e.Test == "" {
// A package has completed.
delete(pkgs, e.Package)
} else {
pkgs[e.Package] = e.Time
}
point, err := client.NewPoint(
*flagMeasurement,
createTags(e),
createFields(e),
e.Time,
)
if err != nil {
return err
}
batch.AddPoint(point)
}
for pkg, t := range pkgs {
pkgFail := &testEvent{
Action: "fail",
Package: pkg,
}
point, err := client.NewPoint(
*flagMeasurement,
createTags(pkgFail),
createFields(pkgFail),
// Fake a timestamp that is later than anything that
// occurred for this package
t.Add(time.Millisecond),
)
if err != nil {
return err
}
batch.AddPoint(point)
}
return nil
}

View File

@ -0,0 +1,151 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package main
import (
"os"
"testing"
"time"
"github.com/aristanetworks/goarista/test"
"github.com/influxdata/influxdb/client/v2"
)
func newPoint(t *testing.T, measurement string, tags map[string]string,
fields map[string]interface{}, timeString string) *client.Point {
t.Helper()
timestamp, err := time.Parse(time.RFC3339Nano, timeString)
if err != nil {
t.Fatal(err)
}
p, err := client.NewPoint(measurement, tags, fields, timestamp)
if err != nil {
t.Fatal(err)
}
return p
}
func TestParseTestOutput(t *testing.T) {
// Verify tags and fields set by flags are set in records
flagTags.Set("tag=foo")
flagFields.Set("field=true")
defer func() {
flagTags = nil
flagFields = nil
}()
f, err := os.Open("testdata/output.txt")
if err != nil {
t.Fatal(err)
}
makeTags := func(pkg, resultType string) map[string]string {
return map[string]string{"package": pkg, "type": resultType, "tag": "foo"}
}
makeFields := func(pass, elapsed float64, test string) map[string]interface{} {
m := map[string]interface{}{"pass": pass, "elapsed": elapsed, "field": true}
if test != "" {
m["test"] = test
}
return m
}
expected := []*client.Point{
newPoint(t,
"result",
makeTags("pkg/passed", "test"),
makeFields(1, 0, "TestPass"),
"2018-03-08T10:33:12.344165231-08:00",
),
newPoint(t,
"result",
makeTags("pkg/passed", "package"),
makeFields(1, 0.013, ""),
"2018-03-08T10:33:12.34533033-08:00",
),
newPoint(t,
"result",
makeTags("pkg/panic", "test"),
makeFields(0, 600.029, "TestPanic"),
"2018-03-08T10:33:20.272440286-08:00",
),
newPoint(t,
"result",
makeTags("pkg/failed", "test"),
makeFields(0, 0.18, "TestFail"),
"2018-03-08T10:33:27.158860934-08:00",
),
newPoint(t,
"result",
makeTags("pkg/failed", "package"),
makeFields(0, 0.204, ""),
"2018-03-08T10:33:27.161302093-08:00",
),
newPoint(t,
"result",
makeTags("pkg/panic", "package"),
makeFields(0, 0, ""),
"2018-03-08T10:33:20.273440286-08:00",
),
}
batch, err := client.NewBatchPoints(client.BatchPointsConfig{})
if err != nil {
t.Fatal(err)
}
if err := parseTestOutput(f, batch); err != nil {
t.Fatal(err)
}
if diff := test.Diff(expected, batch.Points()); diff != "" {
t.Errorf("unexpected diff: %s", diff)
}
}
func TestTagsFlag(t *testing.T) {
for tc, expected := range map[string]tags{
"abc=def": tags{tag{key: "abc", value: "def"}},
"abc=def,ghi=klm": tags{tag{key: "abc", value: "def"}, tag{key: "ghi", value: "klm"}},
} {
t.Run(tc, func(t *testing.T) {
var ts tags
ts.Set(tc)
if diff := test.Diff(expected, ts); diff != "" {
t.Errorf("unexpected diff from Set: %s", diff)
}
if s := ts.String(); s != tc {
t.Errorf("unexpected diff from String: %q vs. %q", tc, s)
}
})
}
}
func TestFieldsFlag(t *testing.T) {
for tc, expected := range map[string]fields{
"str=abc": fields{field{key: "str", value: "abc"}},
"bool=true": fields{field{key: "bool", value: true}},
"bool=false": fields{field{key: "bool", value: false}},
"float64=42": fields{field{key: "float64", value: float64(42)}},
"float64=42.123": fields{field{key: "float64", value: float64(42.123)}},
"int64=42i": fields{field{key: "int64", value: int64(42)}},
"str=abc,bool=true,float64=42,int64=42i": fields{field{key: "str", value: "abc"},
field{key: "bool", value: true},
field{key: "float64", value: float64(42)},
field{key: "int64", value: int64(42)}},
} {
t.Run(tc, func(t *testing.T) {
var fs fields
fs.Set(tc)
if diff := test.Diff(expected, fs); diff != "" {
t.Errorf("unexpected diff from Set: %s", diff)
}
if s := fs.String(); s != tc {
t.Errorf("unexpected diff from String: %q vs. %q", tc, s)
}
})
}
}

View File

@ -0,0 +1,15 @@
{"Time":"2018-03-08T10:33:12.002692769-08:00","Action":"output","Package":"pkg/skipped","Output":"? \tpkg/skipped\t[no test files]\n"}
{"Time":"2018-03-08T10:33:12.003199228-08:00","Action":"skip","Package":"pkg/skipped","Elapsed":0.001}
{"Time":"2018-03-08T10:33:12.343866281-08:00","Action":"run","Package":"pkg/passed","Test":"TestPass"}
{"Time":"2018-03-08T10:33:12.34406622-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"=== RUN TestPass\n"}
{"Time":"2018-03-08T10:33:12.344139342-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"--- PASS: TestPass (0.00s)\n"}
{"Time":"2018-03-08T10:33:12.344165231-08:00","Action":"pass","Package":"pkg/passed","Test":"TestPass","Elapsed":0}
{"Time":"2018-03-08T10:33:12.344297059-08:00","Action":"output","Package":"pkg/passed","Output":"PASS\n"}
{"Time":"2018-03-08T10:33:12.345217622-08:00","Action":"output","Package":"pkg/passed","Output":"ok \tpkg/passed\t0.013s\n"}
{"Time":"2018-03-08T10:33:12.34533033-08:00","Action":"pass","Package":"pkg/passed","Elapsed":0.013}
{"Time":"2018-03-08T10:33:20.27231537-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"panic\n"}
{"Time":"2018-03-08T10:33:20.272414481-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"FAIL\tpkg/panic\t600.029s\n"}
{"Time":"2018-03-08T10:33:20.272440286-08:00","Action":"fail","Package":"pkg/panic","Test":"TestPanic","Elapsed":600.029}
{"Time":"2018-03-08T10:33:27.158776469-08:00","Action":"output","Package":"pkg/failed","Test":"TestFail","Output":"--- FAIL: TestFail (0.18s)\n"}
{"Time":"2018-03-08T10:33:27.158860934-08:00","Action":"fail","Package":"pkg/failed","Test":"TestFail","Elapsed":0.18}
{"Time":"2018-03-08T10:33:27.161302093-08:00","Action":"fail","Package":"pkg/failed","Elapsed":0.204}

View File

@ -8,7 +8,6 @@ package dscp
import (
"fmt"
"net"
"reflect"
"time"
)
@ -20,8 +19,7 @@ func DialTCPWithTOS(laddr, raddr *net.TCPAddr, tos byte) (*net.TCPConn, error) {
if err != nil {
return nil, err
}
value := reflect.ValueOf(conn)
if err = setTOS(raddr.IP, value, tos); err != nil {
if err = setTOS(raddr.IP, conn, tos); err != nil {
conn.Close()
return nil, err
}
@ -54,7 +52,24 @@ func DialTimeoutWithTOS(network, address string, timeout time.Duration, tos byte
conn.Close()
return nil, fmt.Errorf("DialTimeoutWithTOS: cannot set TOS on a %s socket", network)
}
if err = setTOS(ip, reflect.ValueOf(conn), tos); err != nil {
if err = setTOS(ip, conn, tos); err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// DialTCPTimeoutWithTOS is same as DialTimeoutWithTOS except for enforcing "tcp" and
// providing an option to specify local address (source)
func DialTCPTimeoutWithTOS(laddr, raddr *net.TCPAddr, tos byte, timeout time.Duration) (net.Conn,
error) {
d := net.Dialer{Timeout: timeout, LocalAddr: laddr}
conn, err := d.Dial("tcp", raddr.String())
if err != nil {
return nil, err
}
if err = setTOS(raddr.IP, conn, tos); err != nil {
conn.Close()
return nil, err
}

View File

@ -5,8 +5,11 @@
package dscp_test
import (
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/aristanetworks/goarista/dscp"
)
@ -51,3 +54,50 @@ func TestDialTCPWithTOS(t *testing.T) {
conn.Write(buf)
<-done
}
func TestDialTCPTimeoutWithTOS(t *testing.T) {
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
for name, td := range map[string]*net.TCPAddr{
"ipNoPort": &net.TCPAddr{
IP: net.ParseIP("127.0.0.42"), Port: 0,
},
"ipWithPort": &net.TCPAddr{
IP: net.ParseIP("127.0.0.42"), Port: 10001,
},
} {
t.Run(name, func(t *testing.T) {
l, err := net.ListenTCP("tcp", raddr)
if err != nil {
t.Fatal(err)
}
defer l.Close()
var srcAddr net.Addr
done := make(chan struct{})
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
srcAddr = conn.RemoteAddr()
close(done)
}()
conn, err := dscp.DialTCPTimeoutWithTOS(td, l.Addr().(*net.TCPAddr), 40, 5*time.Second)
if err != nil {
t.Fatal("Connection failed:", err)
}
defer conn.Close()
pfx := td.IP.String() + ":"
if td.Port > 0 {
pfx = fmt.Sprintf("%s%d", pfx, td.Port)
}
<-done
if !strings.HasPrefix(srcAddr.String(), pfx) {
t.Fatalf("DialTCPTimeoutWithTOS wrong address: %q instead of %q", srcAddr, pfx)
}
})
}
}

View File

@ -7,7 +7,6 @@ package dscp
import (
"net"
"reflect"
)
// ListenTCPWithTOS is similar to net.ListenTCP but with the socket configured
@ -18,8 +17,7 @@ func ListenTCPWithTOS(address *net.TCPAddr, tos byte) (*net.TCPListener, error)
if err != nil {
return nil, err
}
value := reflect.ValueOf(lsnr)
if err = setTOS(address.IP, value, tos); err != nil {
if err = setTOS(address.IP, lsnr, tos); err != nil {
lsnr.Close()
return nil, err
}

View File

@ -5,19 +5,18 @@
package dscp
import (
"fmt"
"net"
"os"
"reflect"
"syscall"
"golang.org/x/sys/unix"
)
// This works for the UNIX implementation of netFD, i.e. not on Windows and Plan9.
// This kludge is needed until https://github.com/golang/go/issues/9661 is fixed.
// value can be the reflection of a connection or a dialer.
func setTOS(ip net.IP, value reflect.Value, tos byte) error {
netFD := value.Elem().FieldByName("fd").Elem()
fd := int(netFD.FieldByName("pfd").FieldByName("Sysfd").Int())
// conn must either implement syscall.Conn or be a TCPListener.
func setTOS(ip net.IP, conn interface{}, tos byte) error {
var proto, optname int
if ip.To4() != nil {
proto = unix.IPPROTO_IP
@ -26,8 +25,42 @@ func setTOS(ip net.IP, value reflect.Value, tos byte) error {
proto = unix.IPPROTO_IPV6
optname = unix.IPV6_TCLASS
}
switch c := conn.(type) {
case syscall.Conn:
return setTOSWithSyscallConn(proto, optname, c, tos)
case *net.TCPListener:
// This code is needed to support go1.9. In go1.10
// *net.TCPListener implements syscall.Conn.
return setTOSWithTCPListener(proto, optname, c, tos)
}
return fmt.Errorf("unsupported connection type: %T", conn)
}
func setTOSWithTCPListener(proto, optname int, conn *net.TCPListener, tos byte) error {
// A kludge for pre-go1.10 to get the fd of a net.TCPListener
value := reflect.ValueOf(conn)
netFD := value.Elem().FieldByName("fd").Elem()
fd := int(netFD.FieldByName("pfd").FieldByName("Sysfd").Int())
if err := unix.SetsockoptInt(fd, proto, optname, int(tos)); err != nil {
return os.NewSyscallError("setsockopt", err)
}
return nil
}
func setTOSWithSyscallConn(proto, optname int, conn syscall.Conn, tos byte) error {
syscallConn, err := conn.SyscallConn()
if err != nil {
return err
}
var setsockoptErr error
err = syscallConn.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), proto, optname, int(tos)); err != nil {
setsockoptErr = os.NewSyscallError("setsockopt", err)
}
})
if setsockoptErr != nil {
return setsockoptErr
}
return err
}

View File

@ -8,10 +8,15 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"math"
"net"
"time"
"io/ioutil"
"strings"
"github.com/aristanetworks/glog"
"github.com/aristanetworks/goarista/netns"
pb "github.com/openconfig/gnmi/proto/gnmi"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@ -19,7 +24,7 @@ import (
)
const (
defaultPort = "6042"
defaultPort = "6030"
)
// Config is the gnmi.Client config
@ -33,19 +38,30 @@ type Config struct {
TLS bool
}
// SubscribeOptions is the gNMI subscription request options
type SubscribeOptions struct {
UpdatesOnly bool
Prefix string
Mode string
StreamMode string
SampleInterval uint64
HeartbeatInterval uint64
Paths [][]string
}
// Dial connects to a gnmi service and returns a client
func Dial(cfg *Config) pb.GNMIClient {
func Dial(cfg *Config) (pb.GNMIClient, error) {
var opts []grpc.DialOption
if cfg.TLS || cfg.CAFile != "" || cfg.CertFile != "" {
tlsConfig := &tls.Config{}
if cfg.CAFile != "" {
b, err := ioutil.ReadFile(cfg.CAFile)
if err != nil {
glog.Fatal(err)
return nil, err
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
glog.Fatalf("credentials: failed to append certificates")
return nil, fmt.Errorf("credentials: failed to append certificates")
}
tlsConfig.RootCAs = cp
} else {
@ -53,11 +69,11 @@ func Dial(cfg *Config) pb.GNMIClient {
}
if cfg.CertFile != "" {
if cfg.KeyFile == "" {
glog.Fatalf("Please provide both -certfile and -keyfile")
return nil, fmt.Errorf("please provide both -certfile and -keyfile")
}
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
glog.Fatal(err)
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
@ -69,12 +85,33 @@ func Dial(cfg *Config) pb.GNMIClient {
if !strings.ContainsRune(cfg.Addr, ':') {
cfg.Addr += ":" + defaultPort
}
conn, err := grpc.Dial(cfg.Addr, opts...)
if err != nil {
glog.Fatalf("Failed to dial: %s", err)
dial := func(addrIn string, time time.Duration) (net.Conn, error) {
var conn net.Conn
nsName, addr, err := netns.ParseAddress(addrIn)
if err != nil {
return nil, err
}
err = netns.Do(nsName, func() error {
var err error
conn, err = net.Dial("tcp", addr)
return err
})
return conn, err
}
return pb.NewGNMIClient(conn)
opts = append(opts,
grpc.WithDialer(dial),
// Allows received protobuf messages to be larger than 4MB
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
)
grpcconn, err := grpc.Dial(cfg.Addr, opts...)
if err != nil {
return nil, fmt.Errorf("failed to dial: %s", err)
}
return pb.NewGNMIClient(grpcconn), nil
}
// NewContext returns a new context with username and password
@ -104,17 +141,53 @@ func NewGetRequest(paths [][]string) (*pb.GetRequest, error) {
}
// NewSubscribeRequest returns a SubscribeRequest for the given paths
func NewSubscribeRequest(paths [][]string) (*pb.SubscribeRequest, error) {
subList := &pb.SubscriptionList{
Subscription: make([]*pb.Subscription, len(paths)),
func NewSubscribeRequest(subscribeOptions *SubscribeOptions) (*pb.SubscribeRequest, error) {
var mode pb.SubscriptionList_Mode
switch subscribeOptions.Mode {
case "once":
mode = pb.SubscriptionList_ONCE
case "poll":
mode = pb.SubscriptionList_POLL
case "stream":
mode = pb.SubscriptionList_STREAM
default:
return nil, fmt.Errorf("subscribe mode (%s) invalid", subscribeOptions.Mode)
}
for i, p := range paths {
var streamMode pb.SubscriptionMode
switch subscribeOptions.StreamMode {
case "on_change":
streamMode = pb.SubscriptionMode_ON_CHANGE
case "sample":
streamMode = pb.SubscriptionMode_SAMPLE
case "target_defined":
streamMode = pb.SubscriptionMode_TARGET_DEFINED
default:
return nil, fmt.Errorf("subscribe stream mode (%s) invalid", subscribeOptions.StreamMode)
}
prefixPath, err := ParseGNMIElements(SplitPath(subscribeOptions.Prefix))
if err != nil {
return nil, err
}
subList := &pb.SubscriptionList{
Subscription: make([]*pb.Subscription, len(subscribeOptions.Paths)),
Mode: mode,
UpdatesOnly: subscribeOptions.UpdatesOnly,
Prefix: prefixPath,
}
for i, p := range subscribeOptions.Paths {
gnmiPath, err := ParseGNMIElements(p)
if err != nil {
return nil, err
}
subList.Subscription[i] = &pb.Subscription{Path: gnmiPath}
subList.Subscription[i] = &pb.Subscription{
Path: gnmiPath,
Mode: streamMode,
SampleInterval: subscribeOptions.SampleInterval,
HeartbeatInterval: subscribeOptions.HeartbeatInterval,
}
}
return &pb.SubscribeRequest{
Request: &pb.SubscribeRequest_Subscribe{Subscribe: subList}}, nil
return &pb.SubscribeRequest{Request: &pb.SubscribeRequest_Subscribe{
Subscribe: subList}}, nil
}

View File

@ -17,7 +17,7 @@ func NotificationToMap(notif *gnmi.Notification) (map[string]interface{}, error)
updates := make(map[string]interface{}, len(notif.Update))
var err error
for _, update := range notif.Update {
updates[StrPath(update.Path)] = strUpdateVal(update)
updates[StrPath(update.Path)] = StrUpdateVal(update)
if err != nil {
return nil, err
}

View File

@ -5,14 +5,22 @@
package gnmi
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"strconv"
"strings"
"time"
"github.com/aristanetworks/glog"
pb "github.com/openconfig/gnmi/proto/gnmi"
"google.golang.org/grpc/codes"
)
@ -28,9 +36,10 @@ func Get(ctx context.Context, client pb.GNMIClient, paths [][]string) error {
return err
}
for _, notif := range resp.Notification {
prefix := StrPath(notif.Prefix)
for _, update := range notif.Update {
fmt.Printf("%s:\n", StrPath(update.Path))
fmt.Println(strUpdateVal(update))
fmt.Printf("%s:\n", path.Join(prefix, StrPath(update.Path)))
fmt.Println(StrUpdateVal(update))
}
}
return nil
@ -55,49 +64,104 @@ func Capabilities(ctx context.Context, client pb.GNMIClient) error {
// val may be a path to a file or it may be json. First see if it is a
// file, if so return its contents, otherwise return val
func extractJSON(val string) []byte {
jsonBytes, err := ioutil.ReadFile(val)
if err != nil {
jsonBytes = []byte(val)
if jsonBytes, err := ioutil.ReadFile(val); err == nil {
return jsonBytes
}
return jsonBytes
// Best effort check if the value might a string literal, in which
// case wrap it in quotes. This is to allow a user to do:
// gnmi update ../hostname host1234
// gnmi update ../description 'This is a description'
// instead of forcing them to quote the string:
// gnmi update ../hostname '"host1234"'
// gnmi update ../description '"This is a description"'
maybeUnquotedStringLiteral := func(s string) bool {
if s == "true" || s == "false" || s == "null" || // JSON reserved words
strings.ContainsAny(s, `"'{}[]`) { // Already quoted or is a JSON object or array
return false
} else if _, err := strconv.ParseInt(s, 0, 32); err == nil {
// Integer. Using byte size of 32 because larger integer
// types are supposed to be sent as strings in JSON.
return false
} else if _, err := strconv.ParseFloat(s, 64); err == nil {
// Float
return false
}
return true
}
if maybeUnquotedStringLiteral(val) {
out := make([]byte, len(val)+2)
out[0] = '"'
copy(out[1:], val)
out[len(out)-1] = '"'
return out
}
return []byte(val)
}
// strUpdateVal will return a string representing the value within the supplied update
func strUpdateVal(u *pb.Update) string {
// StrUpdateVal will return a string representing the value within the supplied update
func StrUpdateVal(u *pb.Update) string {
if u.Value != nil {
return string(u.Value.Value) // Backwards compatibility with pre-v0.4 gnmi
// Backwards compatibility with pre-v0.4 gnmi
switch u.Value.Type {
case pb.Encoding_JSON, pb.Encoding_JSON_IETF:
return strJSON(u.Value.Value)
case pb.Encoding_BYTES, pb.Encoding_PROTO:
return base64.StdEncoding.EncodeToString(u.Value.Value)
case pb.Encoding_ASCII:
return string(u.Value.Value)
default:
return string(u.Value.Value)
}
}
return strVal(u.Val)
return StrVal(u.Val)
}
// strVal will return a string representing the supplied value
func strVal(val *pb.TypedValue) string {
// StrVal will return a string representing the supplied value
func StrVal(val *pb.TypedValue) string {
switch v := val.GetValue().(type) {
case *pb.TypedValue_StringVal:
return v.StringVal
case *pb.TypedValue_JsonIetfVal:
return string(v.JsonIetfVal)
return strJSON(v.JsonIetfVal)
case *pb.TypedValue_JsonVal:
return strJSON(v.JsonVal)
case *pb.TypedValue_IntVal:
return fmt.Sprintf("%v", v.IntVal)
return strconv.FormatInt(v.IntVal, 10)
case *pb.TypedValue_UintVal:
return fmt.Sprintf("%v", v.UintVal)
return strconv.FormatUint(v.UintVal, 10)
case *pb.TypedValue_BoolVal:
return fmt.Sprintf("%v", v.BoolVal)
return strconv.FormatBool(v.BoolVal)
case *pb.TypedValue_BytesVal:
return string(v.BytesVal)
return base64.StdEncoding.EncodeToString(v.BytesVal)
case *pb.TypedValue_DecimalVal:
return strDecimal64(v.DecimalVal)
case *pb.TypedValue_FloatVal:
return strconv.FormatFloat(float64(v.FloatVal), 'g', -1, 32)
case *pb.TypedValue_LeaflistVal:
return strLeaflist(v.LeaflistVal)
case *pb.TypedValue_AsciiVal:
return v.AsciiVal
case *pb.TypedValue_AnyVal:
return v.AnyVal.String()
default:
panic(v)
}
}
func strJSON(inJSON []byte) string {
var out bytes.Buffer
err := json.Indent(&out, inJSON, "", " ")
if err != nil {
return fmt.Sprintf("(error unmarshalling json: %s)\n", err) + string(inJSON)
}
return out.String()
}
func strDecimal64(d *pb.Decimal64) string {
var i, frac uint64
var i, frac int64
if d.Precision > 0 {
div := uint64(10)
div := int64(10)
it := d.Precision - 1
for it > 0 {
div *= 10
@ -108,32 +172,25 @@ func strDecimal64(d *pb.Decimal64) string {
} else {
i = d.Digits
}
if frac < 0 {
frac = -frac
}
return fmt.Sprintf("%d.%d", i, frac)
}
// strLeafList builds a human-readable form of a leaf-list. e.g. [1,2,3] or [a,b,c]
// strLeafList builds a human-readable form of a leaf-list. e.g. [1, 2, 3] or [a, b, c]
func strLeaflist(v *pb.ScalarArray) string {
s := make([]string, 0, len(v.Element))
sz := 2 // []
var buf bytes.Buffer
buf.WriteByte('[')
// convert arbitrary TypedValues to string form
for _, elm := range v.Element {
str := strVal(elm)
s = append(s, str)
sz += len(str) + 1 // %v + ,
}
b := make([]byte, sz)
buf := bytes.NewBuffer(b)
buf.WriteRune('[')
for i := range v.Element {
buf.WriteString(s[i])
for i, elm := range v.Element {
buf.WriteString(StrVal(elm))
if i < len(v.Element)-1 {
buf.WriteRune(',')
buf.WriteString(", ")
}
}
buf.WriteRune(']')
buf.WriteByte(']')
return buf.String()
}
@ -143,9 +200,16 @@ func update(p *pb.Path, val string) *pb.Update {
case "":
v = &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: extractJSON(val)}}
case "cli":
case "cli", "test-regen-cli":
v = &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: val}}
case "p4_config":
b, err := ioutil.ReadFile(val)
if err != nil {
glog.Fatalf("Cannot read p4 file: %s", err)
}
v = &pb.TypedValue{
Value: &pb.TypedValue_ProtoBytes{ProtoBytes: b}}
default:
panic(fmt.Errorf("unexpected origin: %q", p.Origin))
}
@ -155,9 +219,10 @@ func update(p *pb.Path, val string) *pb.Update {
// Operation describes an gNMI operation.
type Operation struct {
Type string
Path []string
Val string
Type string
Origin string
Path []string
Val string
}
func newSetRequest(setOps []*Operation) (*pb.SetRequest, error) {
@ -167,6 +232,7 @@ func newSetRequest(setOps []*Operation) (*pb.SetRequest, error) {
if err != nil {
return nil, err
}
p.Origin = op.Origin
switch op.Type {
case "delete":
@ -199,16 +265,18 @@ func Set(ctx context.Context, client pb.GNMIClient, setOps []*Operation) error {
}
// Subscribe sends a SubscribeRequest to the given client.
func Subscribe(ctx context.Context, client pb.GNMIClient, paths [][]string,
func Subscribe(ctx context.Context, client pb.GNMIClient, subscribeOptions *SubscribeOptions,
respChan chan<- *pb.SubscribeResponse, errChan chan<- error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer close(respChan)
stream, err := client.Subscribe(ctx)
if err != nil {
errChan <- err
return
}
req, err := NewSubscribeRequest(paths)
req, err := NewSubscribeRequest(subscribeOptions)
if err != nil {
errChan <- err
return
@ -228,6 +296,26 @@ func Subscribe(ctx context.Context, client pb.GNMIClient, paths [][]string,
return
}
respChan <- resp
// For POLL subscriptions, initiate a poll request by pressing ENTER
if subscribeOptions.Mode == "poll" {
switch resp.Response.(type) {
case *pb.SubscribeResponse_SyncResponse:
fmt.Print("Press ENTER to send a poll request: ")
reader := bufio.NewReader(os.Stdin)
reader.ReadString('\n')
pollReq := &pb.SubscribeRequest{
Request: &pb.SubscribeRequest_Poll{
Poll: &pb.Poll{},
},
}
if err := stream.Send(pollReq); err != nil {
errChan <- err
return
}
}
}
}
}
@ -241,10 +329,16 @@ func LogSubscribeResponse(response *pb.SubscribeResponse) error {
return errors.New("initial sync failed")
}
case *pb.SubscribeResponse_Update:
t := time.Unix(0, resp.Update.Timestamp).UTC()
prefix := StrPath(resp.Update.Prefix)
for _, update := range resp.Update.Update {
fmt.Printf("%s = %s\n", path.Join(prefix, StrPath(update.Path)),
strUpdateVal(update))
fmt.Printf("[%s] %s = %s\n", t.Format(time.RFC3339Nano),
path.Join(prefix, StrPath(update.Path)),
StrUpdateVal(update))
}
for _, del := range resp.Update.Delete {
fmt.Printf("[%s] Deleted %s\n", t.Format(time.RFC3339Nano),
path.Join(prefix, StrPath(del)))
}
}
return nil

View File

@ -5,9 +5,14 @@
package gnmi
import (
"bytes"
"io/ioutil"
"os"
"testing"
"github.com/aristanetworks/goarista/test"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/any"
pb "github.com/openconfig/gnmi/proto/gnmi"
)
@ -15,24 +20,41 @@ import (
func TestNewSetRequest(t *testing.T) {
pathFoo := &pb.Path{
Element: []string{"foo"},
Elem: []*pb.PathElem{&pb.PathElem{Name: "foo"}},
Elem: []*pb.PathElem{{Name: "foo"}},
}
pathCli := &pb.Path{
Origin: "cli",
}
pathP4 := &pb.Path{
Origin: "p4_config",
}
p4FileContent := "p4_config test"
p4TestFile, err := ioutil.TempFile("", "p4TestFile")
if err != nil {
t.Errorf("cannot create test file for p4_config")
}
p4Filename := p4TestFile.Name()
defer os.Remove(p4Filename)
if _, err := p4TestFile.WriteString(p4FileContent); err != nil {
t.Errorf("cannot write test file for p4_config")
}
p4TestFile.Close()
testCases := map[string]struct {
setOps []*Operation
exp pb.SetRequest
}{
"delete": {
setOps: []*Operation{&Operation{Type: "delete", Path: []string{"foo"}}},
setOps: []*Operation{{Type: "delete", Path: []string{"foo"}}},
exp: pb.SetRequest{Delete: []*pb.Path{pathFoo}},
},
"update": {
setOps: []*Operation{&Operation{Type: "update", Path: []string{"foo"}, Val: "true"}},
setOps: []*Operation{{Type: "update", Path: []string{"foo"}, Val: "true"}},
exp: pb.SetRequest{
Update: []*pb.Update{&pb.Update{
Update: []*pb.Update{{
Path: pathFoo,
Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}},
@ -40,9 +62,9 @@ func TestNewSetRequest(t *testing.T) {
},
},
"replace": {
setOps: []*Operation{&Operation{Type: "replace", Path: []string{"foo"}, Val: "true"}},
setOps: []*Operation{{Type: "replace", Path: []string{"foo"}, Val: "true"}},
exp: pb.SetRequest{
Replace: []*pb.Update{&pb.Update{
Replace: []*pb.Update{{
Path: pathFoo,
Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}},
@ -50,16 +72,27 @@ func TestNewSetRequest(t *testing.T) {
},
},
"cli-replace": {
setOps: []*Operation{&Operation{Type: "replace", Path: []string{"cli"},
setOps: []*Operation{{Type: "replace", Origin: "cli",
Val: "hostname foo\nip routing"}},
exp: pb.SetRequest{
Replace: []*pb.Update{&pb.Update{
Replace: []*pb.Update{{
Path: pathCli,
Val: &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: "hostname foo\nip routing"}},
}},
},
},
"p4_config": {
setOps: []*Operation{{Type: "replace", Origin: "p4_config",
Val: p4Filename}},
exp: pb.SetRequest{
Replace: []*pb.Update{{
Path: pathP4,
Val: &pb.TypedValue{
Value: &pb.TypedValue_ProtoBytes{ProtoBytes: []byte(p4FileContent)}},
}},
},
},
}
for name, tc := range testCases {
@ -74,3 +107,183 @@ func TestNewSetRequest(t *testing.T) {
})
}
}
func TestStrUpdateVal(t *testing.T) {
anyBytes, err := proto.Marshal(&pb.ModelData{Name: "foobar"})
if err != nil {
t.Fatal(err)
}
anyMessage := &any.Any{TypeUrl: "gnmi/ModelData", Value: anyBytes}
anyString := proto.CompactTextString(anyMessage)
for name, tc := range map[string]struct {
update *pb.Update
exp string
}{
"JSON Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte(`{"foo":"bar"}`),
Type: pb.Encoding_JSON}},
exp: `{
"foo": "bar"
}`,
},
"JSON_IETF Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte(`{"foo":"bar"}`),
Type: pb.Encoding_JSON_IETF}},
exp: `{
"foo": "bar"
}`,
},
"BYTES Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte{0xde, 0xad},
Type: pb.Encoding_BYTES}},
exp: "3q0=",
},
"PROTO Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte{0xde, 0xad},
Type: pb.Encoding_PROTO}},
exp: "3q0=",
},
"ASCII Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte("foobar"),
Type: pb.Encoding_ASCII}},
exp: "foobar",
},
"INVALID Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte("foobar"),
Type: pb.Encoding(42)}},
exp: "foobar",
},
"StringVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_StringVal{StringVal: "foobar"}}},
exp: "foobar",
},
"IntVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_IntVal{IntVal: -42}}},
exp: "-42",
},
"UintVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_UintVal{UintVal: 42}}},
exp: "42",
},
"BoolVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_BoolVal{BoolVal: true}}},
exp: "true",
},
"BytesVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_BytesVal{BytesVal: []byte{0xde, 0xad}}}},
exp: "3q0=",
},
"FloatVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_FloatVal{FloatVal: 3.14}}},
exp: "3.14",
},
"DecimalVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_DecimalVal{
DecimalVal: &pb.Decimal64{Digits: 314, Precision: 2},
}}},
exp: "3.14",
},
"LeafListVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_LeaflistVal{
LeaflistVal: &pb.ScalarArray{Element: []*pb.TypedValue{
&pb.TypedValue{Value: &pb.TypedValue_BoolVal{BoolVal: true}},
&pb.TypedValue{Value: &pb.TypedValue_AsciiVal{AsciiVal: "foobar"}},
}},
}}},
exp: "[true, foobar]",
},
"AnyVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_AnyVal{AnyVal: anyMessage}}},
exp: anyString,
},
"JsonVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonVal{JsonVal: []byte(`{"foo":"bar"}`)}}},
exp: `{
"foo": "bar"
}`,
},
"JsonIetfVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte(`{"foo":"bar"}`)}}},
exp: `{
"foo": "bar"
}`,
},
"AsciiVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: "foobar"}}},
exp: "foobar",
},
} {
t.Run(name, func(t *testing.T) {
got := StrUpdateVal(tc.update)
if got != tc.exp {
t.Errorf("Expected: %q Got: %q", tc.exp, got)
}
})
}
}
func TestExtractJSON(t *testing.T) {
jsonFile, err := ioutil.TempFile("", "extractJSON")
if err != nil {
t.Fatal(err)
}
defer os.Remove(jsonFile.Name())
if _, err := jsonFile.Write([]byte(`"jsonFile"`)); err != nil {
jsonFile.Close()
t.Fatal(err)
}
if err := jsonFile.Close(); err != nil {
t.Fatal(err)
}
for val, exp := range map[string][]byte{
jsonFile.Name(): []byte(`"jsonFile"`),
"foobar": []byte(`"foobar"`),
`"foobar"`: []byte(`"foobar"`),
"Val: true": []byte(`"Val: true"`),
"host42": []byte(`"host42"`),
"42": []byte("42"),
"-123.43": []byte("-123.43"),
"0xFFFF": []byte("0xFFFF"),
// Int larger than can fit in 32 bits should be quoted
"0x8000000000": []byte(`"0x8000000000"`),
"-0x8000000000": []byte(`"-0x8000000000"`),
"true": []byte("true"),
"false": []byte("false"),
"null": []byte("null"),
"{true: 42}": []byte("{true: 42}"),
"[]": []byte("[]"),
} {
t.Run(val, func(t *testing.T) {
got := extractJSON(val)
if !bytes.Equal(exp, got) {
t.Errorf("Unexpected diff. Expected: %q Got: %q", exp, got)
}
})
}
}

View File

@ -126,11 +126,6 @@ func writeSafeString(buf *bytes.Buffer, s string, esc rune) {
// ParseGNMIElements builds up a gnmi path, from user-supplied text
func ParseGNMIElements(elms []string) (*pb.Path, error) {
if len(elms) == 1 && elms[0] == "cli" {
return &pb.Path{
Origin: "cli",
}, nil
}
var parsed []*pb.PathElem
for _, e := range elms {
n, keys, err := parseElement(e)

View File

@ -85,19 +85,6 @@ func TestStrPath(t *testing.T) {
}
}
func TestOriginCLIPath(t *testing.T) {
path := "cli"
sElms := SplitPath(path)
pbPath, err := ParseGNMIElements(sElms)
if err != nil {
t.Fatal(err)
}
expected := pb.Path{Origin: "cli"}
if !test.DeepEqual(expected, *pbPath) {
t.Errorf("want %v, got %v", expected, *pbPath)
}
}
func TestStrPathBackwardsCompat(t *testing.T) {
for i, tc := range []struct {
path *pb.Path

View File

@ -0,0 +1,20 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
//Connection type.
const (
HTTP = "HTTP"
UDP = "UDP"
)
//InfluxConfig is a configuration struct for influxlib.
type InfluxConfig struct {
Hostname string
Port uint16
Protocol string
Database string
RetentionPolicy string
}

View File

@ -0,0 +1,37 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
/*
Package: influxlib
Title: Influx DB Library
Authors: ssdaily, manojm321, senkrish, kthommandra
Email: influxdb-dev@arista.com
Description: The main purpose of influxlib is to provide users with a simple
and easy interface through which to connect to influxdb. It removed a lot of
the need to run the same setup and tear down code to connect the the service.
Example Code:
connection, err := influxlib.Connect(&influxlib.InfluxConfig {
Hostname: conf.Host,
Port: conf.Port,
Protocol: influxlib.UDP,
Database, conf.AlertDB,
})
tags := map[string]string {
"tag1": someStruct.Tag["host"],
"tag2": someStruct.Tag["tag2"],
}
fields := map[string]interface{} {
"field1": someStruct.Somefield,
"field2": someStruct.Somefield2,
}
connection.WritePoint("measurement", tags, fields)
*/
package influxlib

View File

@ -0,0 +1,173 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"errors"
"fmt"
"time"
influxdb "github.com/influxdata/influxdb/client/v2"
)
// Row is defined as a map where the key (string) is the name of the
// column (field name) and the value is left as an interface to
// accept any value.
type Row map[string]interface{}
// InfluxDBConnection is an object that the wrapper uses.
// Holds a client of the type v2.Client and the configuration
type InfluxDBConnection struct {
Client influxdb.Client
Config *InfluxConfig
}
// Point represents a datapoint to be written.
// Measurement:
// The measurement to write to
// Tags:
// A dictionary of tags in the form string=string
// Fields:
// A dictionary of fields(keys) with their associated values
type Point struct {
Measurement string
Tags map[string]string
Fields map[string]interface{}
Timestamp time.Time
}
// Connect takes an InfluxConfig and establishes a connection
// to InfluxDB. It returns an InfluxDBConnection structure.
// InfluxConfig may be nil for a default connection.
func Connect(config *InfluxConfig) (*InfluxDBConnection, error) {
var con influxdb.Client
var err error
switch config.Protocol {
case HTTP:
addr := fmt.Sprintf("http://%s:%v", config.Hostname, config.Port)
con, err = influxdb.NewHTTPClient(influxdb.HTTPConfig{
Addr: addr,
Timeout: 1 * time.Second,
})
case UDP:
addr := fmt.Sprintf("%s:%v", config.Hostname, config.Port)
con, err = influxdb.NewUDPClient(influxdb.UDPConfig{
Addr: addr,
})
default:
return nil, errors.New("Invalid Protocol")
}
if err != nil {
return nil, err
}
return &InfluxDBConnection{Client: con, Config: config}, nil
}
// RecordBatchPoints takes in a slice of influxlib.Point and writes them to the
// database.
func (conn *InfluxDBConnection) RecordBatchPoints(points []Point) error {
var err error
bp, err := influxdb.NewBatchPoints(influxdb.BatchPointsConfig{
Database: conn.Config.Database,
Precision: "ns",
RetentionPolicy: conn.Config.RetentionPolicy,
})
if err != nil {
return err
}
var influxPoints []*influxdb.Point
for _, p := range points {
if p.Timestamp.IsZero() {
p.Timestamp = time.Now()
}
point, err := influxdb.NewPoint(p.Measurement, p.Tags, p.Fields,
p.Timestamp)
if err != nil {
return err
}
influxPoints = append(influxPoints, point)
}
bp.AddPoints(influxPoints)
if err = conn.Client.Write(bp); err != nil {
return err
}
return nil
}
// WritePoint stores a datapoint to the database.
// Measurement:
// The measurement to write to
// Tags:
// A dictionary of tags in the form string=string
// Fields:
// A dictionary of fields(keys) with their associated values
func (conn *InfluxDBConnection) WritePoint(measurement string,
tags map[string]string, fields map[string]interface{}) error {
return conn.RecordPoint(Point{
Measurement: measurement,
Tags: tags,
Fields: fields,
Timestamp: time.Now(),
})
}
// RecordPoint implements the same as WritePoint but used a point struct
// as the argument instead.
func (conn *InfluxDBConnection) RecordPoint(p Point) error {
return conn.RecordBatchPoints([]Point{p})
}
// Query sends a query to the influxCli and returns a slice of
// rows. Rows are of type map[string]interface{}
func (conn *InfluxDBConnection) Query(query string) ([]Row, error) {
q := influxdb.NewQuery(query, conn.Config.Database, "ns")
var rows []Row
var index = 0
response, err := conn.Client.Query(q)
if err != nil {
return nil, err
}
if response.Error() != nil {
return nil, response.Error()
}
// The intent here is to combine the separate client v2
// series into a single array. As a result queries that
// utilize "group by" will be combined into a single
// array. And the tag value will be added to the query.
// Similar to what you would expect from a SQL query
for _, result := range response.Results {
for _, series := range result.Series {
columnNames := series.Columns
for _, row := range series.Values {
rows = append(rows, make(Row))
for columnIdx, value := range row {
rows[index][columnNames[columnIdx]] = value
}
for tagKey, tagValue := range series.Tags {
rows[index][tagKey] = tagValue
}
index++
}
}
}
return rows, nil
}
// Close closes the connection opened by Connect()
func (conn *InfluxDBConnection) Close() {
conn.Client.Close()
}

View File

@ -0,0 +1,157 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func testFields(line string, fields map[string]interface{},
t *testing.T) {
for k, v := range fields {
formatString := "%s=%v"
if _, ok := v.(string); ok {
formatString = "%s=%q"
}
assert.Contains(t, line, fmt.Sprintf(formatString, k, v),
fmt.Sprintf(formatString+" expected in %s", k, v, line))
}
}
func testTags(line string, tags map[string]string,
t *testing.T) {
for k, v := range tags {
assert.Contains(t, line, fmt.Sprintf("%s=%s", k, v),
fmt.Sprintf("%s=%s expected in %s", k, v, line))
}
}
func TestBasicWrite(t *testing.T) {
testConn, _ := NewMockConnection()
measurement := "TestData"
tags := map[string]string{
"tag1": "Happy",
"tag2": "Valentines",
"tag3": "Day",
}
fields := map[string]interface{}{
"Data1": 1234,
"Data2": "apples",
"Data3": 5.34,
}
err := testConn.WritePoint(measurement, tags, fields)
assert.NoError(t, err)
line, err := GetTestBuffer(testConn)
assert.NoError(t, err)
assert.Contains(t, line, measurement,
fmt.Sprintf("%s does not appear in %s", measurement, line))
testTags(line, tags, t)
testFields(line, fields, t)
}
func TestConnectionToHostFailure(t *testing.T) {
assert := assert.New(t)
var err error
config := &InfluxConfig{
Port: 8086,
Protocol: HTTP,
Database: "test",
}
config.Hostname = "this is fake.com"
_, err = Connect(config)
assert.Error(err)
config.Hostname = "\\-Fake.Url.Com"
_, err = Connect(config)
assert.Error(err)
}
func TestWriteFailure(t *testing.T) {
con, _ := NewMockConnection()
measurement := "TestData"
tags := map[string]string{
"tag1": "hi",
}
data := map[string]interface{}{
"Data1": "cats",
}
err := con.WritePoint(measurement, tags, data)
assert.NoError(t, err)
fc, _ := con.Client.(*fakeClient)
fc.failAll = true
err = con.WritePoint(measurement, tags, data)
assert.Error(t, err)
}
func TestQuery(t *testing.T) {
query := "SELECT * FROM 'system' LIMIT 50;"
con, _ := NewMockConnection()
_, err := con.Query(query)
assert.NoError(t, err)
}
func TestAddAndWriteBatchPoints(t *testing.T) {
testConn, _ := NewMockConnection()
measurement := "TestData"
points := []Point{
Point{
Measurement: measurement,
Tags: map[string]string{
"tag1": "Happy",
"tag2": "Valentines",
"tag3": "Day",
},
Fields: map[string]interface{}{
"Data1": 1234,
"Data2": "apples",
"Data3": 5.34,
},
Timestamp: time.Now(),
},
Point{
Measurement: measurement,
Tags: map[string]string{
"tag1": "Happy",
"tag2": "New",
"tag3": "Year",
},
Fields: map[string]interface{}{
"Data1": 5678,
"Data2": "bananas",
"Data3": 3.14,
},
Timestamp: time.Now(),
},
}
err := testConn.RecordBatchPoints(points)
assert.NoError(t, err)
line, err := GetTestBuffer(testConn)
assert.NoError(t, err)
assert.Contains(t, line, measurement,
fmt.Sprintf("%s does not appear in %s", measurement, line))
for _, p := range points {
testTags(line, p.Tags, t)
testFields(line, p.Fields, t)
}
}

View File

@ -0,0 +1,79 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"bytes"
"errors"
"fmt"
"time"
influx "github.com/influxdata/influxdb/client/v2"
)
// This will serve as a fake client object to test off of.
// The idea is to provide a way to test the Influx Wrapper
// without having it connected to the database.
type fakeClient struct {
writer bytes.Buffer
failAll bool
}
func (w *fakeClient) Ping(timeout time.Duration) (time.Duration,
string, error) {
return 0, "", nil
}
func (w *fakeClient) Query(q influx.Query) (*influx.Response, error) {
if w.failAll {
return nil, errors.New("quering points failed")
}
return &influx.Response{Results: nil, Err: ""}, nil
}
func (w *fakeClient) Close() error {
return nil
}
func (w *fakeClient) Write(bp influx.BatchPoints) error {
if w.failAll {
return errors.New("writing point failed")
}
w.writer.Reset()
for _, p := range bp.Points() {
fmt.Fprintf(&w.writer, p.String()+"\n")
}
return nil
}
func (w *fakeClient) qString() string {
return w.writer.String()
}
/***************************************************/
// NewMockConnection returns an influxDBConnection with
// a "fake" client for offline testing.
func NewMockConnection() (*InfluxDBConnection, error) {
client := new(fakeClient)
config := &InfluxConfig{
Hostname: "localhost",
Port: 8086,
Protocol: HTTP,
Database: "Test",
}
return &InfluxDBConnection{client, config}, nil
}
// GetTestBuffer returns the string that would normally
// be written to influx DB
func GetTestBuffer(con *InfluxDBConnection) (string, error) {
fc, ok := con.Client.(*fakeClient)
if !ok {
return "", errors.New("Expected a fake client but recieved a real one")
}
return fc.qString(), nil
}

View File

@ -1,26 +0,0 @@
#!/bin/sh
# Copyright (c) 2016 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the COPYING file.
DEFAULT_PORT=6042
set -e
if [ "$#" -lt 1 ]
then
echo "usage: $0 <host> [<gNMI port>]"
exit 1
fi
echo "WARNING: if you're not using EOS-INT, EOS-REV-0-1 or EOS 4.18 or earlier please use -allowed_ips on the server instead."
host=$1
port=$DEFAULT_PORT
if [ "$#" -gt 1 ]
then
port=$2
fi
iptables="bash sudo iptables -A INPUT -p tcp --dport $port -j ACCEPT"
ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no $host "$iptables"
echo "opened TCP port $port on $host"

View File

@ -79,7 +79,7 @@ func (e *elasticsearchMessageEncoder) Encode(message proto.Message) ([]*sarama.P
}
glog.V(9).Infof("kafka: %s", updateJSON)
return []*sarama.ProducerMessage{
&sarama.ProducerMessage{
{
Topic: e.topic,
Key: e.key,
Value: sarama.ByteEncoder(updateJSON),

View File

@ -116,7 +116,7 @@ func (p *producer) produceNotifications(protoMessage proto.Message) error {
case <-p.done:
return nil
case p.kafkaProducer.Input() <- m:
glog.V(9).Infof("Message produced to Kafka: %s", m)
glog.V(9).Infof("Message produced to Kafka: %v", m)
}
}
return nil

View File

@ -5,59 +5,30 @@
package key
import (
"encoding/json"
"fmt"
"reflect"
"unsafe"
"github.com/aristanetworks/goarista/areflect"
)
// composite allows storing a map[string]interface{} as a key in a Go map.
// This is useful when the key isn't a fixed data structure known at compile
// time but rather something generic, like a bag of key-value pairs.
// Go does not allow storing a map inside the key of a map, because maps are
// not comparable or hashable, and keys in maps must be both. This file is
// a hack specific to the 'gc' implementation of Go (which is the one most
// people use when they use Go), to bypass this check, by abusing reflection
// to override how Go compares composite for equality or how it's hashed.
// The values allowed in this map are only the types whitelisted in New() as
// well as map[Key]interface{}.
//
// See also https://github.com/golang/go/issues/283
type composite struct {
// This value must always be set to the sentinel constant above.
sentinel uintptr
m map[string]interface{}
}
func (k composite) Key() interface{} {
return k.m
}
func (k composite) String() string {
return stringify(k.Key())
}
func (k composite) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key())
}
func (k composite) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key())
}
func (k composite) Equal(other interface{}) bool {
o, ok := other.(composite)
return ok && mapStringEqual(k.m, o.m)
}
func hashInterface(v interface{}) uintptr {
switch v := v.(type) {
case map[string]interface{}:
return hashMapString(v)
case map[Key]interface{}:
return hashMapKey(v)
case []interface{}:
return hashSlice(v)
case Pointer:
// This case applies to pointers used
// as values in maps or slices (i.e.
// not wrapped in a key).
return hashSlice(pointerToSlice(v))
case Path:
// This case applies to paths used
// as values in maps or slices (i.e
// not wrapped in a kay).
return hashSlice(pathToSlice(v))
default:
return _nilinterhash(v)
}
@ -78,9 +49,9 @@ func hashMapKey(m map[Key]interface{}) uintptr {
for k, v := range m {
// Use addition so that the order of iteration doesn't matter.
switch k := k.(type) {
case keyImpl:
case interfaceKey:
h += _nilinterhash(k.key)
case composite:
case compositeKey:
h += hashMapString(k.m)
}
h += hashInterface(v)
@ -88,28 +59,42 @@ func hashMapKey(m map[Key]interface{}) uintptr {
return h
}
func hashSlice(s []interface{}) uintptr {
h := uintptr(31 * (len(s) + 1))
for _, v := range s {
h += hashInterface(v)
}
return h
}
func hash(p unsafe.Pointer, seed uintptr) uintptr {
ck := *(*composite)(p)
ck := *(*compositeKey)(p)
if ck.sentinel != sentinel {
panic("use of unhashable type in a map")
}
return seed ^ hashMapString(ck.m)
if ck.m != nil {
return seed ^ hashMapString(ck.m)
}
return seed ^ hashSlice(ck.s)
}
func equal(a unsafe.Pointer, b unsafe.Pointer) bool {
ca := (*composite)(a)
cb := (*composite)(b)
ca := (*compositeKey)(a)
cb := (*compositeKey)(b)
if ca.sentinel != sentinel {
panic("use of uncomparable type on the lhs of ==")
}
if cb.sentinel != sentinel {
panic("use of uncomparable type on the rhs of ==")
}
return mapStringEqual(ca.m, cb.m)
if ca.m != nil {
return mapStringEqual(ca.m, cb.m)
}
return sliceEqual(ca.s, cb.s)
}
func init() {
typ := reflect.TypeOf(composite{})
typ := reflect.TypeOf(compositeKey{})
alg := reflect.ValueOf(typ).Elem().FieldByName("alg").Elem()
// Pretty certain that doing this voids your warranty.
// This overwrites the typeAlg of either alg_NOEQ64 (on 32-bit platforms)

View File

@ -19,7 +19,7 @@ type unhashable struct {
func TestBadComposite(t *testing.T) {
test.ShouldPanicWith(t, "use of unhashable type in a map", func() {
m := map[interface{}]struct{}{
unhashable{func() {}, 0x42}: struct{}{},
unhashable{func() {}, 0x42}: {},
}
// Use Key here to make sure init() is called.
if _, ok := m[New("foo")]; ok {

View File

@ -16,14 +16,35 @@ import (
// Key represents the Key in the updates and deletes of the Notification
// objects. The only reason this exists is that Go won't let us define
// our own hash function for non-hashable types, and unfortunately we
// need to be able to index maps by map[string]interface{} objects.
// need to be able to index maps by map[string]interface{} objects
// and slices by []interface{} objects.
type Key interface {
Key() interface{}
String() string
Equal(other interface{}) bool
}
type keyImpl struct {
// compositeKey allows storing a map[string]interface{} or []interface{} as a key
// in a Go map. This is useful when the key isn't a fixed data structure known
// at compile time but rather something generic, like a bag of key-value pairs
// or a list of elements. Go does not allow storing a map or slice inside the
// key of a map, because maps and slices are not comparable or hashable, and
// keys in maps and slice elements must be both. This file is a hack specific
// to the 'gc' implementation of Go (which is the one most people use when they
// use Go), to bypass this check, by abusing reflection to override how Go
// compares compositeKey for equality or how it's hashed. The values allowed in
// this map are only the types whitelisted in New() as well as map[Key]interface{}
// and []interface{}.
//
// See also https://github.com/golang/go/issues/283
type compositeKey struct {
// This value must always be set to the sentinel constant above.
sentinel uintptr
m map[string]interface{}
s []interface{}
}
type interfaceKey struct {
key interface{}
}
@ -44,13 +65,43 @@ type float64Key float64
type boolKey bool
type pointerKey compositeKey
type pathKey compositeKey
func pathToSlice(path Path) []interface{} {
s := make([]interface{}, len(path))
for i, element := range path {
s[i] = element.Key()
}
return s
}
func sliceToPath(s []interface{}) Path {
path := make(Path, len(s))
for i, intf := range s {
path[i] = New(intf)
}
return path
}
func pointerToSlice(ptr Pointer) []interface{} {
return pathToSlice(ptr.Pointer())
}
func sliceToPointer(s []interface{}) pointer {
return pointer(sliceToPath(s))
}
// New wraps the given value in a Key.
// This function panics if the value passed in isn't allowed in a Key or
// doesn't implement value.Value.
func New(intf interface{}) Key {
switch t := intf.(type) {
case map[string]interface{}:
return composite{sentinel, t}
return compositeKey{sentinel: sentinel, m: t}
case []interface{}:
return compositeKey{sentinel: sentinel, s: t}
case string:
return strKey(t)
case int8:
@ -76,31 +127,35 @@ func New(intf interface{}) Key {
case bool:
return boolKey(t)
case value.Value:
return keyImpl{key: intf}
return interfaceKey{key: intf}
case Pointer:
return pointerKey{sentinel: sentinel, s: pointerToSlice(t)}
case Path:
return pathKey{sentinel: sentinel, s: pathToSlice(t)}
default:
panic(fmt.Sprintf("Invalid type for key: %T", intf))
}
}
func (k keyImpl) Key() interface{} {
func (k interfaceKey) Key() interface{} {
return k.key
}
func (k keyImpl) String() string {
func (k interfaceKey) String() string {
return stringify(k.key)
}
func (k keyImpl) GoString() string {
func (k interfaceKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key())
}
func (k keyImpl) MarshalJSON() ([]byte, error) {
func (k interfaceKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key())
}
func (k keyImpl) Equal(other interface{}) bool {
o, ok := other.(keyImpl)
return ok && keyEqual(k.key, o.key)
func (k interfaceKey) Equal(other interface{}) bool {
o, ok := other.(Key)
return ok && keyEqual(k.key, o.Key())
}
// Comparable types have an equality-testing method.
@ -121,6 +176,18 @@ func mapStringEqual(a, b map[string]interface{}) bool {
return true
}
func sliceEqual(a, b []interface{}) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !keyEqual(v, b[i]) {
return false
}
}
return true
}
func keyEqual(a, b interface{}) bool {
switch a := a.(type) {
case map[string]interface{}:
@ -137,19 +204,56 @@ func keyEqual(a, b interface{}) bool {
}
}
return true
case []interface{}:
b, ok := b.([]interface{})
return ok && sliceEqual(a, b)
case Comparable:
return a.Equal(b)
case Pointer:
b, ok := b.(Pointer)
return ok && pointerEqual(a, b)
case Path:
b, ok := b.(Path)
return ok && pathEqual(a, b)
}
return a == b
}
// Key interface implementation for map[string]interface{} and []interface{}
func (k compositeKey) Key() interface{} {
if k.m != nil {
return k.m
}
return k.s
}
func (k compositeKey) String() string {
return stringify(k.Key())
}
func (k compositeKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key())
}
func (k compositeKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key())
}
func (k compositeKey) Equal(other interface{}) bool {
o, ok := other.(compositeKey)
if k.m != nil {
return ok && mapStringEqual(k.m, o.m)
}
return ok && sliceEqual(k.s, o.s)
}
func (k strKey) Key() interface{} {
return string(k)
}
func (k strKey) String() string {
return escape(string(k))
return string(k)
}
func (k strKey) GoString() string {
@ -157,7 +261,7 @@ func (k strKey) GoString() string {
}
func (k strKey) MarshalJSON() ([]byte, error) {
return json.Marshal(string(k))
return json.Marshal(escape(string(k)))
}
func (k strKey) Equal(other interface{}) bool {
@ -175,7 +279,7 @@ func (k int8Key) String() string {
}
func (k int8Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int8(k))
return fmt.Sprintf("key.New(int8(%d))", int8(k))
}
func (k int8Key) MarshalJSON() ([]byte, error) {
@ -197,7 +301,7 @@ func (k int16Key) String() string {
}
func (k int16Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int16(k))
return fmt.Sprintf("key.New(int16(%d))", int16(k))
}
func (k int16Key) MarshalJSON() ([]byte, error) {
@ -219,7 +323,7 @@ func (k int32Key) String() string {
}
func (k int32Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int32(k))
return fmt.Sprintf("key.New(int32(%d))", int32(k))
}
func (k int32Key) MarshalJSON() ([]byte, error) {
@ -241,7 +345,7 @@ func (k int64Key) String() string {
}
func (k int64Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int64(k))
return fmt.Sprintf("key.New(int64(%d))", int64(k))
}
func (k int64Key) MarshalJSON() ([]byte, error) {
@ -263,7 +367,7 @@ func (k uint8Key) String() string {
}
func (k uint8Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint8(k))
return fmt.Sprintf("key.New(uint8(%d))", uint8(k))
}
func (k uint8Key) MarshalJSON() ([]byte, error) {
@ -285,7 +389,7 @@ func (k uint16Key) String() string {
}
func (k uint16Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint16(k))
return fmt.Sprintf("key.New(uint16(%d))", uint16(k))
}
func (k uint16Key) MarshalJSON() ([]byte, error) {
@ -307,7 +411,7 @@ func (k uint32Key) String() string {
}
func (k uint32Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint32(k))
return fmt.Sprintf("key.New(uint32(%d))", uint32(k))
}
func (k uint32Key) MarshalJSON() ([]byte, error) {
@ -329,7 +433,7 @@ func (k uint64Key) String() string {
}
func (k uint64Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint64(k))
return fmt.Sprintf("key.New(uint64(%d))", uint64(k))
}
func (k uint64Key) MarshalJSON() ([]byte, error) {
@ -351,7 +455,7 @@ func (k float32Key) String() string {
}
func (k float32Key) GoString() string {
return fmt.Sprintf("key.New(%v)", float32(k))
return fmt.Sprintf("key.New(float32(%v))", float32(k))
}
func (k float32Key) MarshalJSON() ([]byte, error) {
@ -373,7 +477,7 @@ func (k float64Key) String() string {
}
func (k float64Key) GoString() string {
return fmt.Sprintf("key.New(%v)", float64(k))
return fmt.Sprintf("key.New(float64(%v))", float64(k))
}
func (k float64Key) MarshalJSON() ([]byte, error) {
@ -406,3 +510,59 @@ func (k boolKey) Equal(other interface{}) bool {
o, ok := other.(boolKey)
return ok && k == o
}
// Key interface implementation for Pointer
func (k pointerKey) Key() interface{} {
return sliceToPointer(k.s)
}
func (k pointerKey) String() string {
return sliceToPointer(k.s).String()
}
func (k pointerKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.s)
}
func (k pointerKey) MarshalJSON() ([]byte, error) {
return sliceToPointer(k.s).MarshalJSON()
}
func (k pointerKey) Equal(other interface{}) bool {
if o, ok := other.(pointerKey); ok {
return sliceEqual(k.s, o.s)
}
key, ok := other.(Key)
if !ok {
return false
}
return ok && sliceToPointer(k.s).Equal(key.Key())
}
// Key interface implementation for Path
func (k pathKey) Key() interface{} {
return sliceToPath(k.s)
}
func (k pathKey) String() string {
return sliceToPath(k.s).String()
}
func (k pathKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.s)
}
func (k pathKey) MarshalJSON() ([]byte, error) {
return sliceToPath(k.s).MarshalJSON()
}
func (k pathKey) Equal(other interface{}) bool {
if o, ok := other.(pathKey); ok {
return sliceEqual(k.s, o.s)
}
key, ok := other.(Key)
if !ok {
return false
}
return ok && sliceToPath(k.s).Equal(key.Key())
}

View File

@ -55,6 +55,34 @@ func TestKeyEqual(t *testing.T) {
a: New("foo"),
b: New("bar"),
result: false,
}, {
a: New([]interface{}{}),
b: New("bar"),
result: false,
}, {
a: New([]interface{}{}),
b: New([]interface{}{}),
result: true,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"a"}),
result: false,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"b", "a"}),
result: false,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"a", "b"}),
result: true,
}, {
a: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
b: New([]interface{}{"a", map[string]interface{}{"c": "b"}}),
result: false,
}, {
a: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
b: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
result: true,
}, {
a: New(map[string]interface{}{}),
b: New("bar"),
@ -151,6 +179,32 @@ func TestGetFromMap(t *testing.T) {
k: New(uint32(37)),
m: map[Key]interface{}{},
found: false,
}, {
k: New([]interface{}{"a", "b"}),
m: map[Key]interface{}{
New([]interface{}{"a", "b"}): "foo",
},
v: "foo",
found: true,
}, {
k: New([]interface{}{"a", "b"}),
m: map[Key]interface{}{
New([]interface{}{"a", "b", "c"}): "foo",
},
found: false,
}, {
k: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
m: map[Key]interface{}{
New([]interface{}{"a", map[string]interface{}{"b": "c"}}): "foo",
},
v: "foo",
found: true,
}, {
k: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
m: map[Key]interface{}{
New([]interface{}{"a", map[string]interface{}{"c": "b"}}): "foo",
},
found: false,
}, {
k: New(map[string]interface{}{"a": "b", "c": uint64(4)}),
m: map[Key]interface{}{
@ -392,6 +446,53 @@ func TestSetToMap(t *testing.T) {
}
}
func TestGoString(t *testing.T) {
tcases := []struct {
in Key
out string
}{{
in: New(uint8(1)),
out: "key.New(uint8(1))",
}, {
in: New(uint16(1)),
out: "key.New(uint16(1))",
}, {
in: New(uint32(1)),
out: "key.New(uint32(1))",
}, {
in: New(uint64(1)),
out: "key.New(uint64(1))",
}, {
in: New(int8(1)),
out: "key.New(int8(1))",
}, {
in: New(int16(1)),
out: "key.New(int16(1))",
}, {
in: New(int32(1)),
out: "key.New(int32(1))",
}, {
in: New(int64(1)),
out: "key.New(int64(1))",
}, {
in: New(float32(1)),
out: "key.New(float32(1))",
}, {
in: New(float64(1)),
out: "key.New(float64(1))",
}, {
in: New(map[string]interface{}{"foo": true}),
out: `key.New(map[string]interface {}{"foo":true})`,
}}
for i, tcase := range tcases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
if out := fmt.Sprintf("%#v", tcase.in); out != tcase.out {
t.Errorf("Wanted Go representation %q but got %q", tcase.out, out)
}
})
}
}
func TestMisc(t *testing.T) {
k := New(map[string]interface{}{"foo": true})
js, err := json.Marshal(k)
@ -400,11 +501,6 @@ func TestMisc(t *testing.T) {
} else if expected := `{"foo":true}`; string(js) != expected {
t.Errorf("Wanted JSON %q but got %q", expected, js)
}
expected := `key.New(map[string]interface {}{"foo":true})`
gostr := fmt.Sprintf("%#v", k)
if expected != gostr {
t.Errorf("Wanted Go representation %q but got %q", expected, gostr)
}
test.ShouldPanic(t, func() { New(42) })

51
vendor/github.com/aristanetworks/goarista/key/path.go generated vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key
import (
"bytes"
"fmt"
)
// Path represents a path decomposed into elements where each
// element is a Key. A Path can be interpreted as either
// absolute or relative depending on how it is used.
type Path []Key
// String returns the Path as an absolute path string.
func (p Path) String() string {
if len(p) == 0 {
return "/"
}
var buf bytes.Buffer
for _, element := range p {
buf.WriteByte('/')
buf.WriteString(element.String())
}
return buf.String()
}
// MarshalJSON marshals a Path to JSON.
func (p Path) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"_path":%q}`, p)), nil
}
// Equal returns whether a Path is equal to @other.
func (p Path) Equal(other interface{}) bool {
o, ok := other.(Path)
return ok && pathEqual(p, o)
}
func pathEqual(a, b Path) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}

View File

@ -0,0 +1,141 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/path"
)
func TestPath(t *testing.T) {
expected := path.New("leet")
k := key.New(expected)
keyPath, ok := k.Key().(key.Path)
if !ok {
t.Fatalf("key.Key() did not return a key.Path")
}
if !path.Equal(keyPath, expected) {
t.Errorf("expected path from key is %#v, got %#v", expected, keyPath)
}
if expected, actual := "/leet", fmt.Sprint(k); actual != expected {
t.Errorf("expected string from path key %q, got %q", expected, actual)
}
if js, err := json.Marshal(k); err != nil {
t.Errorf("JSON marshalling error: %v", err)
} else if expected, actual := `{"_path":"/leet"}`, string(js); actual != expected {
t.Errorf("expected json output %q, got %q", expected, actual)
}
}
func newPathKey(e ...interface{}) key.Key {
return key.New(path.New(e...))
}
type customPath string
func (p customPath) Key() interface{} {
return path.FromString(string(p))
}
func (p customPath) Equal(other interface{}) bool {
o, ok := other.(key.Key)
return ok && o.Equal(p)
}
func (p customPath) String() string { return string(p) }
func (p customPath) ToBuiltin() interface{} { panic("not impl") }
func (p customPath) MarshalJSON() ([]byte, error) { panic("not impl") }
func TestPathEqual(t *testing.T) {
tests := []struct {
a key.Key
b key.Key
result bool
}{{
a: newPathKey(),
b: nil,
result: false,
}, {
a: newPathKey(),
b: newPathKey(),
result: true,
}, {
a: newPathKey(),
b: newPathKey("foo"),
result: false,
}, {
a: newPathKey("foo"),
b: newPathKey(),
result: false,
}, {
a: newPathKey(int16(1337)),
b: newPathKey(int64(1337)),
result: false,
}, {
a: newPathKey(path.Wildcard, "bar"),
b: newPathKey("foo", path.Wildcard),
result: false,
}, {
a: newPathKey(map[string]interface{}{"a": "x", "b": "y"}),
b: newPathKey(map[string]interface{}{"b": "y", "a": "x"}),
result: true,
}, {
a: newPathKey(map[string]interface{}{"a": "x", "b": "y"}),
b: newPathKey(map[string]interface{}{"x": "x", "y": "y"}),
result: false,
}, {
a: newPathKey("foo", "bar"),
b: customPath("/foo/bar"),
result: true,
}, {
a: customPath("/foo/bar"),
b: newPathKey("foo", "bar"),
result: true,
}, {
a: newPathKey("foo"),
b: key.New(customPath("/bar")),
result: false,
}, {
a: key.New(customPath("/foo")),
b: newPathKey("foo", "bar"),
result: false,
}}
for i, tc := range tests {
if a, b := tc.a, tc.b; a.Equal(b) != tc.result {
t.Errorf("result not as expected for test case %d", i)
}
}
}
func TestPathAsKey(t *testing.T) {
a := newPathKey("foo", path.Wildcard, map[string]interface{}{
"bar": map[key.Key]interface{}{
// Should be able to embed a path key and value
newPathKey("path", "to", "something"): path.New("else"),
},
})
m := map[key.Key]string{
a: "thats a complex key!",
}
if s, ok := m[a]; !ok {
t.Error("complex key not found in map")
} else if s != "thats a complex key!" {
t.Errorf("incorrect value in map: %s", s)
}
// preserve custom path implementations
b := key.New(customPath("/foo/bar"))
if _, ok := b.Key().(customPath); !ok {
t.Errorf("customPath implementation not preserved: %T", b.Key())
}
}

View File

@ -0,0 +1,46 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key
import (
"fmt"
)
// Pointer is a pointer to a path.
type Pointer interface {
Pointer() Path
}
// NewPointer creates a new pointer to a path.
func NewPointer(path Path) Pointer {
return pointer(path)
}
// This is the type returned by pointerKey.Key. Returning this is a
// lot faster than having pointerKey implement Pointer, since it is
// a compositeKey and thus would require reconstructing a Path from
// []interface{} any time the Pointer method is called.
type pointer Path
func (ptr pointer) Pointer() Path {
return Path(ptr)
}
func (ptr pointer) String() string {
return "{" + ptr.Pointer().String() + "}"
}
func (ptr pointer) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"_ptr":%q}`, ptr.Pointer().String())), nil
}
func (ptr pointer) Equal(other interface{}) bool {
o, ok := other.(Pointer)
return ok && pointerEqual(ptr, o)
}
func pointerEqual(a, b Pointer) bool {
return pathEqual(a.Pointer(), b.Pointer())
}

View File

@ -0,0 +1,189 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/path"
)
func TestPointer(t *testing.T) {
p := key.NewPointer(path.New("foo"))
if expected, actual := path.New("foo"), p.Pointer(); !path.Equal(expected, actual) {
t.Errorf("Expected %#v but got %#v", expected, actual)
}
if expected, actual := "{/foo}", fmt.Sprintf("%s", p); actual != expected {
t.Errorf("Expected %q but got %q", expected, actual)
}
if js, err := json.Marshal(p); err != nil {
t.Errorf("JSON marshaling failed: %s", err)
} else if expected, actual := `{"_ptr":"/foo"}`, string(js); actual != expected {
t.Errorf("Expected %q but got %q", expected, actual)
}
}
type pointer string
func (ptr pointer) Pointer() key.Path {
return path.FromString(string(ptr))
}
func (ptr pointer) ToBuiltin() interface{} {
panic("NOT IMPLEMENTED")
}
func (ptr pointer) String() string {
panic("NOT IMPLEMENTED")
}
func (ptr pointer) MarshalJSON() ([]byte, error) {
panic("NOT IMPLEMENTED")
}
func TestPointerEqual(t *testing.T) {
tests := []struct {
a key.Pointer
b key.Pointer
result bool
}{{
a: key.NewPointer(nil),
b: key.NewPointer(path.New()),
result: true,
}, {
a: key.NewPointer(path.New()),
b: key.NewPointer(nil),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: key.NewPointer(path.New("foo")),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: key.NewPointer(path.New("bar")),
result: false,
}, {
a: key.NewPointer(path.New(true)),
b: key.NewPointer(path.New("true")),
result: false,
}, {
a: key.NewPointer(path.New(int8(0))),
b: key.NewPointer(path.New(int16(0))),
result: false,
}, {
a: key.NewPointer(path.New(path.Wildcard, "bar")),
b: key.NewPointer(path.New("foo", path.Wildcard)),
result: false,
}, {
a: key.NewPointer(path.New(map[string]interface{}{"a": "x", "b": "y"})),
b: key.NewPointer(path.New(map[string]interface{}{"b": "y", "a": "x"})),
result: true,
}, {
a: key.NewPointer(path.New(map[string]interface{}{"a": "x", "b": "y"})),
b: key.NewPointer(path.New(map[string]interface{}{"x": "x", "y": "y"})),
result: false,
}, {
a: key.NewPointer(path.New("foo")),
b: pointer("/foo"),
result: true,
}, {
a: pointer("/foo"),
b: key.NewPointer(path.New("foo")),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: pointer("/foo/bar"),
result: false,
}, {
a: pointer("/foo/bar"),
b: key.NewPointer(path.New("foo")),
result: false,
}}
for i, tcase := range tests {
if key.New(tcase.a).Equal(key.New(tcase.b)) != tcase.result {
t.Errorf("Error in pointer comparison for test %d", i)
}
}
}
func TestPointerAsKey(t *testing.T) {
a := key.NewPointer(path.New("foo", path.Wildcard, map[string]interface{}{
"bar": map[key.Key]interface{}{
// Should be able to embed pointer key.
key.New(key.NewPointer(path.New("baz"))):
// Should be able to embed pointer value.
key.NewPointer(path.New("baz")),
},
}))
m := map[key.Key]string{
key.New(a): "a",
}
if s, ok := m[key.New(a)]; !ok {
t.Error("pointer to path not keyed in map")
} else if s != "a" {
t.Errorf("pointer to path not mapped to correct value in map: %s", s)
}
// Ensure that we preserve custom pointer implementations.
b := key.New(pointer("/foo/bar"))
if _, ok := b.Key().(pointer); !ok {
t.Errorf("pointer implementation not preserved: %T", b.Key())
}
}
func BenchmarkPointer(b *testing.B) {
benchmarks := []key.Path{
path.New(),
path.New("foo"),
path.New("foo", "bar"),
path.New("foo", "bar", "baz"),
path.New("foo", "bar", "baz", "qux"),
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.NewPointer(benchmark)
}
})
}
}
func BenchmarkPointerAsKey(b *testing.B) {
benchmarks := []key.Pointer{
key.NewPointer(path.New()),
key.NewPointer(path.New("foo")),
key.NewPointer(path.New("foo", "bar")),
key.NewPointer(path.New("foo", "bar", "baz")),
key.NewPointer(path.New("foo", "bar", "baz", "qux")),
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.New(benchmark)
}
})
}
}
func BenchmarkEmbeddedPointerAsKey(b *testing.B) {
benchmarks := [][]interface{}{
[]interface{}{key.NewPointer(path.New())},
[]interface{}{key.NewPointer(path.New("foo"))},
[]interface{}{key.NewPointer(path.New("foo", "bar"))},
[]interface{}{key.NewPointer(path.New("foo", "bar", "baz"))},
[]interface{}{key.NewPointer(path.New("foo", "bar", "baz", "qux"))},
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.New(benchmark)
}
})
}
}

View File

@ -5,11 +5,13 @@
package key
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"strings"
"unicode/utf8"
"github.com/aristanetworks/goarista/value"
)
@ -67,7 +69,16 @@ func StringifyInterface(key interface{}) (string, error) {
keys[i] = stringify(k) + "=" + stringify(m[k])
}
str = strings.Join(keys, "_")
case []interface{}:
elements := make([]string, len(key))
for i, element := range key {
elements[i] = stringify(element)
}
str = strings.Join(elements, ",")
case Pointer:
return "{" + key.Pointer().String() + "}", nil
case Path:
return "[" + key.String() + "]", nil
case value.Value:
return key.String(), nil
@ -78,15 +89,14 @@ func StringifyInterface(key interface{}) (string, error) {
return str, nil
}
// escape checks if the string is a valid utf-8 string.
// If it is, it will return the string as is.
// If it is not, it will return the base64 representation of the byte array string
func escape(str string) string {
for i := 0; i < len(str); i++ {
if chr := str[i]; chr < 0x20 || chr > 0x7E {
str = strconv.QuoteToASCII(str)
str = str[1 : len(str)-1] // Drop the leading and trailing quotes.
break
}
if utf8.ValidString(str) {
return str
}
return str
return base64.StdEncoding.EncodeToString([]byte(str))
}
func stringify(key interface{}) string {

View File

@ -27,9 +27,21 @@ func TestStringify(t *testing.T) {
input: "foobar",
output: "foobar",
}, {
name: "non-ASCII string",
name: "valid non-ASCII UTF-8 string",
input: "日本語",
output: `\u65e5\u672c\u8a9e`,
output: "日本語",
}, {
name: "invalid UTF-8 string 1",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe}),
output: "77++vr6+vg==",
}, {
name: "invalid UTF-8 string 2",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe, 0x23}),
output: "77++vr6+viM=",
}, {
name: "invalid UTF-8 string 3",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe, 0x23, 0x24}),
output: "77++vr6+viMk",
}, {
name: "uint8",
input: uint8(43),
@ -107,6 +119,22 @@ func TestStringify(t *testing.T) {
"n": nil,
},
output: "Unable to stringify nil",
}, {
name: "[]interface{}",
input: []interface{}{
uint32(42),
true,
"foo",
map[Key]interface{}{
New("a"): "b",
New("b"): "c",
},
},
output: "42,true,foo,a=b_b=c",
}, {
name: "pointer",
input: NewPointer(Path{New("foo"), New("bar")}),
output: "{/foo/bar}",
}}
for _, tcase := range testcases {

View File

@ -8,8 +8,6 @@ package netns
import (
"fmt"
"os"
"runtime"
)
const (
@ -45,54 +43,3 @@ func setNsByName(nsName string) error {
}
return nil
}
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, and you should not kick off any
// complex logic from the callback or call any complicated code or create any
// new goroutine from the callback. The callback should not panic or use defer.
// The behavior of this function is undefined if the callback doesn't conform
// these demands.
//go:nosplit
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
runtime.LockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
runtime.UnlockOSThread()
currNsFd.close()
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
cbErr = fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
runtime.UnlockOSThread()
currNsFd.close()
return cbErr
}

View File

@ -0,0 +1,60 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// +build go1.10
package netns
import (
"fmt"
"os"
"runtime"
)
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, as it's not desirable to start
// complex logic in a goroutine that is pinned to the current OS thread.
// Also any goroutine started from the callback function may or may not
// execute in the desired namespace.
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
defer currNsFd.close()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
return fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
return cbErr
}

View File

@ -0,0 +1,64 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// +build !go1.10
package netns
import (
"fmt"
"os"
"runtime"
)
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, and you should not kick off any
// complex logic from the callback or call any complicated code or create any
// new goroutine from the callback. The callback should not panic or use defer.
// The behavior of this function is undefined if the callback doesn't conform
// these demands.
//go:nosplit
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
runtime.LockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
runtime.UnlockOSThread()
currNsFd.close()
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
cbErr = fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
runtime.UnlockOSThread()
currNsFd.close()
return cbErr
}

View File

@ -18,7 +18,7 @@ import (
"google.golang.org/grpc/metadata"
)
const defaultPort = "6042"
const defaultPort = "6030"
// PublishFunc is the method to publish responses
type PublishFunc func(addr string, message proto.Message)
@ -36,6 +36,9 @@ func New(username, password, addr string, opts []grpc.DialOption) *Client {
if !strings.ContainsRune(addr, ':') {
addr += ":" + defaultPort
}
// Make sure we don't move past the grpc.Dial() call until we actually
// established an HTTP/2 connection successfully.
opts = append(opts, grpc.WithBlock(), grpc.WithWaitForHandshake())
conn, err := grpc.Dial(addr, opts...)
if err != nil {
glog.Fatalf("Failed to dial: %s", err)
@ -93,7 +96,7 @@ func (c *Client) Subscribe(wg *sync.WaitGroup, subscriptions []string,
Request: &openconfig.SubscribeRequest_Subscribe{
Subscribe: &openconfig.SubscriptionList{
Subscription: []*openconfig.Subscription{
&openconfig.Subscription{
{
Path: &openconfig.Path{Element: strings.Split(path, "/")},
},
},

View File

@ -23,7 +23,7 @@ func ParseFlags() (username string, password string, subscriptions, addrs []stri
opts []grpc.DialOption) {
var (
addrsFlag = flag.String("addrs", "localhost:6042",
addrsFlag = flag.String("addrs", "localhost:6030",
"Comma-separated list of addresses of OpenConfig gRPC servers")
caFileFlag = flag.String("cafile", "",

View File

@ -80,11 +80,11 @@ func TestNotificationToMap(t *testing.T) {
},
},
Delete: []*openconfig.Path{
&openconfig.Path{
{
Element: []string{
"route", "237.255.255.250_0.0.0.0",
}},
&openconfig.Path{
{
Element: []string{
"route", "238.255.255.250_0.0.0.0",
},

View File

@ -10,221 +10,288 @@ import (
"sort"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/pathmap"
)
// Map associates Paths to values. It allows wildcards. The
// primary use of Map is to be able to register handlers to paths
// that can be efficiently looked up every time a path is updated.
//
// For example:
//
// m.Set({key.New("interfaces"), key.New("*"), key.New("adminStatus")}, AdminStatusHandler)
// m.Set({key.New("interface"), key.New("Management1"), key.New("adminStatus")},
// Management1AdminStatusHandler)
//
// m.Visit(Path{key.New("interfaces"), key.New("Ethernet3/32/1"), key.New("adminStatus")},
// HandlerExecutor)
// >> AdminStatusHandler gets passed to HandlerExecutor
// m.Visit(Path{key.New("interfaces"), key.New("Management1"), key.New("adminStatus")},
// HandlerExecutor)
// >> AdminStatusHandler and Management1AdminStatusHandler gets passed to HandlerExecutor
//
// Note, Visit performance is typically linearly with the length of
// the path. But, it can be as bad as O(2^len(Path)) when TreeMap
// nodes have children and a wildcard associated with it. For example,
// if these paths were registered:
//
// m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1)
// m.Set(Path{key.New("*"), key.New("bar"), key.New("baz")}, 2)
// m.Set(Path{key.New("*"), key.New("*"), key.New("baz")}, 3)
// m.Set(Path{key.New("*"), key.New("*"), key.New("*")}, 4)
// m.Set(Path{key.New("foo"), key.New("*"), key.New("*")}, 5)
// m.Set(Path{key.New("foo"), key.New("bar"), key.New("*")}, 6)
// m.Set(Path{key.New("foo"), key.New("*"), key.New("baz")}, 7)
// m.Set(Path{key.New("*"), key.New("bar"), key.New("*")}, 8)
//
// m.Visit(Path{key.New("foo"),key.New("bar"),key.New("baz")}, Foo) // 2^3 nodes traversed
//
// This shouldn't be a concern with our paths because it is likely
// that a TreeMap node will either have a wildcard or children, not
// both. A TreeMap node that corresponds to a collection will often be a
// wildcard, otherwise it will have specific children.
type Map interface {
// Visit calls f for every registration in the Map that
// matches path. For example,
//
// m.Set(Path{key.New("foo"), key.New("bar")}, 1)
// m.Set(Path{key.New("*"), key.New("bar")}, 2)
//
// m.Visit(Path{key.New("foo"), key.New("bar")}, Printer)
// >> Calls Printer(1) and Printer(2)
Visit(p Path, f pathmap.VisitorFunc) error
// VisitPrefix calls f for every registration in the Map that
// is a prefix of path. For example,
//
// m.Set(Path{}, 0)
// m.Set(Path{key.New("foo")}, 1)
// m.Set(Path{key.New("foo"), key.New("bar")}, 2)
// m.Set(Path{key.New("foo"), key.New("quux")}, 3)
// m.Set(Path{key.New("*"), key.New("bar")}, 4)
//
// m.VisitPrefix(Path{key.New("foo"), key.New("bar"), key.New("baz")}, Printer)
// >> Calls Printer on values 0, 1, 2, and 4
VisitPrefix(p Path, f pathmap.VisitorFunc) error
// Get returns the mapping for path. This returns the exact
// mapping for path. For example, if you register two paths
//
// m.Set(Path{key.New("foo"), key.New("bar")}, 1)
// m.Set(Path{key.New("*"), key.New("bar")}, 2)
//
// m.Get(Path{key.New("foo"), key.New("bar")}) => 1
// m.Get(Path{key.New("*"), key.New("bar")}) => 2
Get(p Path) interface{}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
Set(p Path, v interface{})
// Delete removes the mapping for path
Delete(p Path) bool
}
// Wildcard is a special key representing any possible path
var Wildcard key.Key = key.New("*")
type node struct {
// Map associates paths to values. It allows wildcards. A Map
// is primarily used to register handlers with paths that can
// be easily looked up each time a path is updated.
type Map struct {
val interface{}
wildcard *node
children map[key.Key]*node
ok bool
wildcard *Map
children map[key.Key]*Map
}
// NewMap creates a new Map
func NewMap() Map {
return &node{}
}
// VisitorFunc is a function that handles the value associated
// with a path in a Map. Note that only the value is passed in
// as an argument since the path can be stored inside the value
// if needed.
type VisitorFunc func(v interface{}) error
// Visit calls f for every matching registration in the Map
func (n *node) Visit(p Path, f pathmap.VisitorFunc) error {
// Visit calls a function fn for every value in the Map
// that is registered with a match of a path p. In the
// general case, time complexity is linear with respect
// to the length of p but it can be as bad as O(2^len(p))
// if there are a lot of paths with wildcards registered.
//
// Example:
//
// a := path.New("foo", "bar", "baz")
// b := path.New("foo", path.Wildcard, "baz")
// c := path.New(path.Wildcard, "bar", "baz")
// d := path.New("foo", "bar", path.Wildcard)
// e := path.New(path.Wildcard, path.Wildcard, "baz")
// f := path.New(path.Wildcard, "bar", path.Wildcard)
// g := path.New("foo", path.Wildcard, path.Wildcard)
// h := path.New(path.Wildcard, path.Wildcard, path.Wildcard)
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
// m.Set(e, 5)
// m.Set(f, 6)
// m.Set(g, 7)
// m.Set(h, 8)
//
// p := path.New("foo", "bar", "baz")
//
// m.Visit(p, fn)
//
// Result: fn(1), fn(2), fn(3), fn(4), fn(5), fn(6), fn(7) and fn(8)
func (m *Map) Visit(p key.Path, fn VisitorFunc) error {
for i, element := range p {
if n.wildcard != nil {
if err := n.wildcard.Visit(p[i+1:], f); err != nil {
if m.wildcard != nil {
if err := m.wildcard.Visit(p[i+1:], fn); err != nil {
return err
}
}
next, ok := n.children[element]
next, ok := m.children[element]
if !ok {
return nil
}
n = next
m = next
}
if n.val == nil {
if !m.ok {
return nil
}
return f(n.val)
return fn(m.val)
}
// VisitPrefix calls f for every registered path that is a prefix of
// the path
func (n *node) VisitPrefix(p Path, f pathmap.VisitorFunc) error {
// VisitPrefixes calls a function fn for every value in the
// Map that is registered with a prefix of a path p.
//
// Example:
//
// a := path.New()
// b := path.New("foo")
// c := path.New("foo", "bar")
// d := path.New("foo", "baz")
// e := path.New(path.Wildcard, "bar")
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
// m.Set(e, 5)
//
// p := path.New("foo", "bar", "baz")
//
// m.VisitPrefixes(p, fn)
//
// Result: fn(1), fn(2), fn(3), fn(5)
func (m *Map) VisitPrefixes(p key.Path, fn VisitorFunc) error {
for i, element := range p {
// Call f on each node we visit
if n.val != nil {
if err := f(n.val); err != nil {
if m.ok {
if err := fn(m.val); err != nil {
return err
}
}
if n.wildcard != nil {
if err := n.wildcard.VisitPrefix(p[i+1:], f); err != nil {
if m.wildcard != nil {
if err := m.wildcard.VisitPrefixes(p[i+1:], fn); err != nil {
return err
}
}
next, ok := n.children[element]
next, ok := m.children[element]
if !ok {
return nil
}
n = next
m = next
}
if n.val == nil {
if !m.ok {
return nil
}
// Call f on the final node
return f(n.val)
return fn(m.val)
}
// Get returns the mapping for path
func (n *node) Get(p Path) interface{} {
for _, element := range p {
if element.Equal(Wildcard) {
if n.wildcard == nil {
return nil
// VisitPrefixed calls fn for every value in the map that is
// registerd with a path that is prefixed by p. This method
// can be used to visit every registered path if p is the
// empty path (or root path) which prefixes all paths.
//
// Example:
//
// a := path.New("foo")
// b := path.New("foo", "bar")
// c := path.New("foo", "bar", "baz")
// d := path.New("foo", path.Wildcard)
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
//
// p := path.New("foo", "bar")
//
// m.VisitPrefixed(p, fn)
//
// Result: fn(2), fn(3), fn(4)
func (m *Map) VisitPrefixed(p key.Path, fn VisitorFunc) error {
for i, element := range p {
if m.wildcard != nil {
if err := m.wildcard.VisitPrefixed(p[i+1:], fn); err != nil {
return err
}
n = n.wildcard
continue
}
next, ok := n.children[element]
next, ok := m.children[element]
if !ok {
return nil
}
n = next
m = next
}
return n.val
return m.visitSubtree(fn)
}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
func (n *node) Set(p Path, v interface{}) {
func (m *Map) visitSubtree(fn VisitorFunc) error {
if m.ok {
if err := fn(m.val); err != nil {
return err
}
}
if m.wildcard != nil {
if err := m.wildcard.visitSubtree(fn); err != nil {
return err
}
}
for _, next := range m.children {
if err := next.visitSubtree(fn); err != nil {
return err
}
}
return nil
}
// Get returns the value registered with an exact match of a
// path p. If there is no exact match for p, Get returns nil
// and false. If p has an exact match and it is set to true,
// Get returns nil and true.
//
// Example:
//
// m.Set(path.New("foo", "bar"), 1)
// m.Set(path.New("baz", "qux"), nil)
//
// a := m.Get(path.New("foo", "bar"))
// b := m.Get(path.New("foo", path.Wildcard))
// c, ok := m.Get(path.New("baz", "qux"))
//
// Result: a == 1, b == nil, c == nil and ok == true
func (m *Map) Get(p key.Path) (interface{}, bool) {
for _, element := range p {
if element.Equal(Wildcard) {
if n.wildcard == nil {
n.wildcard = &node{}
if m.wildcard == nil {
return nil, false
}
n = n.wildcard
m = m.wildcard
continue
}
if n.children == nil {
n.children = map[key.Key]*node{}
}
next, ok := n.children[element]
next, ok := m.children[element]
if !ok {
next = &node{}
n.children[element] = next
return nil, false
}
n = next
m = next
}
n.val = v
return m.val, m.ok
}
// Delete removes the mapping for path
func (n *node) Delete(p Path) bool {
nodes := make([]*node, len(p)+1)
for i, element := range p {
nodes[i] = n
// Set registers a path p with a value. If the path was already
// registered with a value it returns true and false otherwise.
//
// Example:
//
// p := path.New("foo", "bar")
//
// a := m.Set(p, 0)
// b := m.Set(p, 1)
//
// v := m.Get(p)
//
// Result: a == false, b == true and v == 1
func (m *Map) Set(p key.Path, v interface{}) bool {
for _, element := range p {
if element.Equal(Wildcard) {
if n.wildcard == nil {
if m.wildcard == nil {
m.wildcard = &Map{}
}
m = m.wildcard
continue
}
if m.children == nil {
m.children = map[key.Key]*Map{}
}
next, ok := m.children[element]
if !ok {
next = &Map{}
m.children[element] = next
}
m = next
}
set := !m.ok
m.val, m.ok = v, true
return set
}
// Delete unregisters the value registered with a path. It
// returns true if a value was deleted and false otherwise.
//
// Example:
//
// p := path.New("foo", "bar")
//
// m.Set(p, 0)
//
// a := m.Delete(p)
// b := m.Delete(p)
//
// Result: a == true and b == false
func (m *Map) Delete(p key.Path) bool {
maps := make([]*Map, len(p)+1)
for i, element := range p {
maps[i] = m
if element.Equal(Wildcard) {
if m.wildcard == nil {
return false
}
n = n.wildcard
m = m.wildcard
continue
}
next, ok := n.children[element]
next, ok := m.children[element]
if !ok {
return false
}
n = next
m = next
}
n.val = nil
nodes[len(p)] = n
deleted := m.ok
m.val, m.ok = nil, false
maps[len(p)] = m
// See if we can delete any node objects
// Remove any empty maps.
for i := len(p); i > 0; i-- {
n = nodes[i]
if n.val != nil || n.wildcard != nil || len(n.children) > 0 {
m = maps[i]
if m.ok || m.wildcard != nil || len(m.children) > 0 {
break
}
parent := nodes[i-1]
parent := maps[i-1]
element := p[i-1]
if element.Equal(Wildcard) {
parent.wildcard = nil
@ -232,28 +299,28 @@ func (n *node) Delete(p Path) bool {
delete(parent.children, element)
}
}
return true
return deleted
}
func (n *node) String() string {
func (m *Map) String() string {
var b bytes.Buffer
n.write(&b, "")
m.write(&b, "")
return b.String()
}
func (n *node) write(b *bytes.Buffer, indent string) {
if n.val != nil {
func (m *Map) write(b *bytes.Buffer, indent string) {
if m.ok {
b.WriteString(indent)
fmt.Fprintf(b, "Val: %v", n.val)
fmt.Fprintf(b, "Val: %v", m.val)
b.WriteString("\n")
}
if n.wildcard != nil {
if m.wildcard != nil {
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", Wildcard)
n.wildcard.write(b, indent+" ")
m.wildcard.write(b, indent+" ")
}
children := make([]key.Key, 0, len(n.children))
for key := range n.children {
children := make([]key.Key, 0, len(m.children))
for key := range m.children {
children = append(children, key)
}
sort.Slice(children, func(i, j int) bool {
@ -261,7 +328,7 @@ func (n *node) write(b *bytes.Buffer, indent string) {
})
for _, key := range children {
child := n.children[key]
child := m.children[key]
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", key.String())
child.write(b, indent+" ")

View File

@ -10,68 +10,76 @@ import (
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/pathmap"
"github.com/aristanetworks/goarista/test"
)
func accumulator(counter map[int]int) pathmap.VisitorFunc {
func accumulator(counter map[int]int) VisitorFunc {
return func(val interface{}) error {
counter[val.(int)]++
return nil
}
}
func TestVisit(t *testing.T) {
m := NewMap()
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1)
m.Set(Path{key.New("*"), key.New("bar"), key.New("baz")}, 2)
m.Set(Path{key.New("*"), key.New("*"), key.New("baz")}, 3)
m.Set(Path{key.New("*"), key.New("*"), key.New("*")}, 4)
m.Set(Path{key.New("foo"), key.New("*"), key.New("*")}, 5)
m.Set(Path{key.New("foo"), key.New("bar"), key.New("*")}, 6)
m.Set(Path{key.New("foo"), key.New("*"), key.New("baz")}, 7)
m.Set(Path{key.New("*"), key.New("bar"), key.New("*")}, 8)
func TestMapSet(t *testing.T) {
m := Map{}
a := m.Set(key.Path{key.New("foo")}, 0)
b := m.Set(key.Path{key.New("foo")}, 1)
if !a || b {
t.Fatal("Map.Set not working properly")
}
}
m.Set(Path{}, 10)
func TestMapVisit(t *testing.T) {
m := Map{}
m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1)
m.Set(key.Path{Wildcard, key.New("bar"), key.New("baz")}, 2)
m.Set(key.Path{Wildcard, Wildcard, key.New("baz")}, 3)
m.Set(key.Path{Wildcard, Wildcard, Wildcard}, 4)
m.Set(key.Path{key.New("foo"), Wildcard, Wildcard}, 5)
m.Set(key.Path{key.New("foo"), key.New("bar"), Wildcard}, 6)
m.Set(key.Path{key.New("foo"), Wildcard, key.New("baz")}, 7)
m.Set(key.Path{Wildcard, key.New("bar"), Wildcard}, 8)
m.Set(Path{key.New("*")}, 20)
m.Set(Path{key.New("foo")}, 21)
m.Set(key.Path{}, 10)
m.Set(Path{key.New("zap"), key.New("zip")}, 30)
m.Set(Path{key.New("zap"), key.New("zip")}, 31)
m.Set(key.Path{Wildcard}, 20)
m.Set(key.Path{key.New("foo")}, 21)
m.Set(Path{key.New("zip"), key.New("*")}, 40)
m.Set(Path{key.New("zip"), key.New("*")}, 41)
m.Set(key.Path{key.New("zap"), key.New("zip")}, 30)
m.Set(key.Path{key.New("zap"), key.New("zip")}, 31)
m.Set(key.Path{key.New("zip"), Wildcard}, 40)
m.Set(key.Path{key.New("zip"), Wildcard}, 41)
testCases := []struct {
path Path
path key.Path
expected map[int]int
}{{
path: Path{key.New("foo"), key.New("bar"), key.New("baz")},
path: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: map[int]int{1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1},
}, {
path: Path{key.New("qux"), key.New("bar"), key.New("baz")},
path: key.Path{key.New("qux"), key.New("bar"), key.New("baz")},
expected: map[int]int{2: 1, 3: 1, 4: 1, 8: 1},
}, {
path: Path{key.New("foo"), key.New("qux"), key.New("baz")},
path: key.Path{key.New("foo"), key.New("qux"), key.New("baz")},
expected: map[int]int{3: 1, 4: 1, 5: 1, 7: 1},
}, {
path: Path{key.New("foo"), key.New("bar"), key.New("qux")},
path: key.Path{key.New("foo"), key.New("bar"), key.New("qux")},
expected: map[int]int{4: 1, 5: 1, 6: 1, 8: 1},
}, {
path: Path{},
path: key.Path{},
expected: map[int]int{10: 1},
}, {
path: Path{key.New("foo")},
path: key.Path{key.New("foo")},
expected: map[int]int{20: 1, 21: 1},
}, {
path: Path{key.New("foo"), key.New("bar")},
path: key.Path{key.New("foo"), key.New("bar")},
expected: map[int]int{},
}, {
path: Path{key.New("zap"), key.New("zip")},
path: key.Path{key.New("zap"), key.New("zip")},
expected: map[int]int{31: 1},
}, {
path: Path{key.New("zip"), key.New("zap")},
path: key.Path{key.New("zip"), key.New("zap")},
expected: map[int]int{41: 1},
}}
@ -84,135 +92,160 @@ func TestVisit(t *testing.T) {
}
}
func TestVisitError(t *testing.T) {
m := NewMap()
m.Set(Path{key.New("foo"), key.New("bar")}, 1)
m.Set(Path{key.New("*"), key.New("bar")}, 2)
func TestMapVisitError(t *testing.T) {
m := Map{}
m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(key.Path{Wildcard, key.New("bar")}, 2)
errTest := errors.New("Test")
err := m.Visit(Path{key.New("foo"), key.New("bar")},
err := m.Visit(key.Path{key.New("foo"), key.New("bar")},
func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
err = m.VisitPrefix(Path{key.New("foo"), key.New("bar"), key.New("baz")},
err = m.VisitPrefixes(key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
}
func TestGet(t *testing.T) {
m := NewMap()
m.Set(Path{}, 0)
m.Set(Path{key.New("foo"), key.New("bar")}, 1)
m.Set(Path{key.New("foo"), key.New("*")}, 2)
m.Set(Path{key.New("*"), key.New("bar")}, 3)
m.Set(Path{key.New("zap"), key.New("zip")}, 4)
func TestMapGet(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(key.Path{key.New("foo"), Wildcard}, 2)
m.Set(key.Path{Wildcard, key.New("bar")}, 3)
m.Set(key.Path{key.New("zap"), key.New("zip")}, 4)
m.Set(key.Path{key.New("baz"), key.New("qux")}, nil)
testCases := []struct {
path Path
expected interface{}
path key.Path
v interface{}
ok bool
}{{
path: Path{},
expected: 0,
path: key.Path{},
v: 0,
ok: true,
}, {
path: Path{key.New("foo"), key.New("bar")},
expected: 1,
path: key.Path{key.New("foo"), key.New("bar")},
v: 1,
ok: true,
}, {
path: Path{key.New("foo"), key.New("*")},
expected: 2,
path: key.Path{key.New("foo"), Wildcard},
v: 2,
ok: true,
}, {
path: Path{key.New("*"), key.New("bar")},
expected: 3,
path: key.Path{Wildcard, key.New("bar")},
v: 3,
ok: true,
}, {
path: Path{key.New("bar"), key.New("foo")},
expected: nil,
path: key.Path{key.New("baz"), key.New("qux")},
v: nil,
ok: true,
}, {
path: Path{key.New("zap"), key.New("*")},
expected: nil,
path: key.Path{key.New("bar"), key.New("foo")},
v: nil,
}, {
path: key.Path{key.New("zap"), Wildcard},
v: nil,
}}
for _, tc := range testCases {
got := m.Get(tc.path)
if got != tc.expected {
t.Errorf("Test case %v: Expected %v, Got %v",
tc.path, tc.expected, got)
v, ok := m.Get(tc.path)
if v != tc.v || ok != tc.ok {
t.Errorf("Test case %v: Expected (v: %v, ok: %t), Got (v: %v, ok: %t)",
tc.path, tc.v, tc.ok, v, ok)
}
}
}
func countNodes(n *node) int {
if n == nil {
func countNodes(m *Map) int {
if m == nil {
return 0
}
count := 1
count += countNodes(n.wildcard)
for _, child := range n.children {
count += countNodes(m.wildcard)
for _, child := range m.children {
count += countNodes(child)
}
return count
}
func TestDelete(t *testing.T) {
m := NewMap()
m.Set(Path{}, 0)
m.Set(Path{key.New("*")}, 1)
m.Set(Path{key.New("foo"), key.New("bar")}, 2)
m.Set(Path{key.New("foo"), key.New("*")}, 3)
func TestMapDelete(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{Wildcard}, 1)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 2)
m.Set(key.Path{key.New("foo"), Wildcard}, 3)
m.Set(key.Path{key.New("foo")}, 4)
n := countNodes(m.(*node))
n := countNodes(&m)
if n != 5 {
t.Errorf("Initial count wrong. Expected: 5, Got: %d", n)
}
testCases := []struct {
del Path // Path to delete
del key.Path // key.Path to delete
expected bool // expected return value of Delete
visit Path // Path to Visit
visit key.Path // key.Path to Visit
before map[int]int // Expected to find items before deletion
after map[int]int // Expected to find items after deletion
count int // Count of nodes
}{{
del: Path{key.New("zap")}, // A no-op Delete
del: key.Path{key.New("zap")}, // A no-op Delete
expected: false,
visit: Path{key.New("foo"), key.New("bar")},
visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{2: 1, 3: 1},
count: 5,
}, {
del: Path{key.New("foo"), key.New("bar")},
del: key.Path{key.New("foo"), key.New("bar")},
expected: true,
visit: Path{key.New("foo"), key.New("bar")},
visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{3: 1},
count: 4,
}, {
del: Path{key.New("*")},
del: key.Path{key.New("foo")},
expected: true,
visit: Path{key.New("foo")},
visit: key.Path{key.New("foo")},
before: map[int]int{1: 1, 4: 1},
after: map[int]int{1: 1},
count: 4,
}, {
del: key.Path{key.New("foo")},
expected: false,
visit: key.Path{key.New("foo")},
before: map[int]int{1: 1},
after: map[int]int{1: 1},
count: 4,
}, {
del: key.Path{Wildcard},
expected: true,
visit: key.Path{key.New("foo")},
before: map[int]int{1: 1},
after: map[int]int{},
count: 3,
}, {
del: Path{key.New("*")},
del: key.Path{Wildcard},
expected: false,
visit: Path{key.New("foo")},
visit: key.Path{key.New("foo")},
before: map[int]int{},
after: map[int]int{},
count: 3,
}, {
del: Path{key.New("foo"), key.New("*")},
del: key.Path{key.New("foo"), Wildcard},
expected: true,
visit: Path{key.New("foo"), key.New("bar")},
visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{3: 1},
after: map[int]int{},
count: 1, // Should have deleted "foo" and "bar" nodes
}, {
del: Path{},
del: key.Path{},
expected: true,
visit: Path{},
visit: key.Path{},
before: map[int]int{0: 1},
after: map[int]int{},
count: 1, // Root node can't be deleted
@ -238,53 +271,102 @@ func TestDelete(t *testing.T) {
}
}
func TestVisitPrefix(t *testing.T) {
m := NewMap()
m.Set(Path{}, 0)
m.Set(Path{key.New("foo")}, 1)
m.Set(Path{key.New("foo"), key.New("bar")}, 2)
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 3)
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz"), key.New("quux")}, 4)
m.Set(Path{key.New("quux"), key.New("bar")}, 5)
m.Set(Path{key.New("foo"), key.New("quux")}, 6)
m.Set(Path{key.New("*")}, 7)
m.Set(Path{key.New("foo"), key.New("*")}, 8)
m.Set(Path{key.New("*"), key.New("bar")}, 9)
m.Set(Path{key.New("*"), key.New("quux")}, 10)
m.Set(Path{key.New("quux"), key.New("quux"), key.New("quux"), key.New("quux")}, 11)
func TestMapVisitPrefixes(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{key.New("foo")}, 1)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 2)
m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz")}, 3)
m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz"), key.New("quux")}, 4)
m.Set(key.Path{key.New("quux"), key.New("bar")}, 5)
m.Set(key.Path{key.New("foo"), key.New("quux")}, 6)
m.Set(key.Path{Wildcard}, 7)
m.Set(key.Path{key.New("foo"), Wildcard}, 8)
m.Set(key.Path{Wildcard, key.New("bar")}, 9)
m.Set(key.Path{Wildcard, key.New("quux")}, 10)
m.Set(key.Path{key.New("quux"), key.New("quux"), key.New("quux"), key.New("quux")}, 11)
testCases := []struct {
path Path
path key.Path
expected map[int]int
}{{
path: Path{key.New("foo"), key.New("bar"), key.New("baz")},
path: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 7: 1, 8: 1, 9: 1},
}, {
path: Path{key.New("zip"), key.New("zap")},
path: key.Path{key.New("zip"), key.New("zap")},
expected: map[int]int{0: 1, 7: 1},
}, {
path: Path{key.New("foo"), key.New("zap")},
path: key.Path{key.New("foo"), key.New("zap")},
expected: map[int]int{0: 1, 1: 1, 8: 1, 7: 1},
}, {
path: Path{key.New("quux"), key.New("quux"), key.New("quux")},
path: key.Path{key.New("quux"), key.New("quux"), key.New("quux")},
expected: map[int]int{0: 1, 7: 1, 10: 1},
}}
for _, tc := range testCases {
result := make(map[int]int, len(tc.expected))
m.VisitPrefix(tc.path, accumulator(result))
m.VisitPrefixes(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff)
}
}
}
func TestString(t *testing.T) {
m := NewMap()
m.Set(Path{}, 0)
m.Set(Path{key.New("foo"), key.New("bar")}, 1)
m.Set(Path{key.New("foo"), key.New("quux")}, 2)
m.Set(Path{key.New("foo"), key.New("*")}, 3)
func TestMapVisitPrefixed(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{key.New("qux")}, 1)
m.Set(key.Path{key.New("foo")}, 2)
m.Set(key.Path{key.New("foo"), key.New("qux")}, 3)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 4)
m.Set(key.Path{Wildcard, key.New("bar")}, 5)
m.Set(key.Path{key.New("foo"), Wildcard}, 6)
m.Set(key.Path{key.New("qux"), key.New("foo"), key.New("bar")}, 7)
testCases := []struct {
in key.Path
out map[int]int
}{{
in: key.Path{},
out: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1},
}, {
in: key.Path{key.New("qux")},
out: map[int]int{1: 1, 5: 1, 7: 1},
}, {
in: key.Path{key.New("foo")},
out: map[int]int{2: 1, 3: 1, 4: 1, 5: 1, 6: 1},
}, {
in: key.Path{key.New("foo"), key.New("qux")},
out: map[int]int{3: 1, 6: 1},
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: map[int]int{4: 1, 5: 1, 6: 1},
}, {
in: key.Path{key.New(int64(0))},
out: map[int]int{5: 1},
}, {
in: key.Path{Wildcard},
out: map[int]int{5: 1},
}, {
in: key.Path{Wildcard, Wildcard},
out: map[int]int{},
}}
for _, tc := range testCases {
out := make(map[int]int, len(tc.out))
m.VisitPrefixed(tc.in, accumulator(out))
if diff := test.Diff(tc.out, out); diff != "" {
t.Errorf("Test case %v: %s", tc.out, diff)
}
}
}
func TestMapString(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(key.Path{key.New("foo"), key.New("quux")}, 2)
m.Set(key.Path{key.New("foo"), Wildcard}, 3)
expected := `Val: 0
Child "foo":
@ -295,19 +377,19 @@ Child "foo":
Child "quux":
Val: 2
`
got := fmt.Sprint(m)
got := fmt.Sprint(&m)
if expected != got {
t.Errorf("Unexpected string. Expected:\n\n%s\n\nGot:\n\n%s", expected, got)
}
}
func genWords(count, wordLength int) Path {
func genWords(count, wordLength int) key.Path {
chars := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
if count+wordLength > len(chars) {
panic("need more chars")
}
result := make(Path, count)
result := make(key.Path, count)
for i := 0; i < count; i++ {
result[i] = key.New(string(chars[i : i+wordLength]))
}
@ -315,18 +397,16 @@ func genWords(count, wordLength int) Path {
}
func benchmarkPathMap(pathLength, pathDepth int, b *testing.B) {
m := NewMap()
// Push pathDepth paths, each of length pathLength
path := genWords(pathLength, 10)
words := genWords(pathDepth, 10)
n := m.(*node)
m := &Map{}
for _, element := range path {
n.children = map[key.Key]*node{}
m.children = map[key.Key]*Map{}
for _, word := range words {
n.children[word] = &node{}
m.children[word] = &Map{}
}
n = n.children[element]
m = m.children[element]
}
b.ResetTimer()
for i := 0; i < b.N; i++ {

View File

@ -2,109 +2,167 @@
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// Package path provides functionality for dealing with absolute paths elementally.
// Package path contains methods for dealing with key.Paths.
package path
import (
"bytes"
"fmt"
"strings"
"github.com/aristanetworks/goarista/key"
)
// Path is an absolute path broken down into elements where each element is a key.Key.
type Path []key.Key
func copyElements(path Path, elements ...interface{}) {
for i, element := range elements {
switch val := element.(type) {
case string:
path[i] = key.New(val)
case key.Key:
path[i] = val
default:
panic(fmt.Errorf("unsupported type: %T", element))
}
}
// New constructs a path from a variable number of elements.
// Each element may either be a key.Key or a value that can
// be wrapped by a key.Key.
func New(elements ...interface{}) key.Path {
result := make(key.Path, len(elements))
copyElements(result, elements...)
return result
}
// New constructs a Path from a variable number of elements.
// Each element may either be a string or a key.Key.
func New(elements ...interface{}) Path {
path := make(Path, len(elements))
copyElements(path, elements...)
return path
}
// FromString constructs a Path from the elements resulting
// from a split of the input string by "/". The string MUST
// begin with a '/' character unless it is the empty string
// in which case an empty Path is returned.
func FromString(str string) Path {
if str == "" {
return Path{}
} else if str[0] != '/' {
panic(fmt.Errorf("not an absolute path: %q", str))
}
elements := strings.Split(str, "/")[1:]
path := make(Path, len(elements))
for i, element := range elements {
path[i] = key.New(element)
}
return path
}
// Append appends a variable number of elements to a Path.
// Each element may either be a string or a key.Key.
func Append(path Path, elements ...interface{}) Path {
// Append appends a variable number of elements to a path.
// Each element may either be a key.Key or a value that can
// be wrapped by a key.Key. Note that calling Append on a
// single path returns that same path, whereas in all other
// cases a new path is returned.
func Append(path key.Path, elements ...interface{}) key.Path {
if len(elements) == 0 {
return path
}
n := len(path)
p := make(Path, n+len(elements))
copy(p, path)
copyElements(p[n:], elements...)
return p
result := make(key.Path, n+len(elements))
copy(result, path)
copyElements(result[n:], elements...)
return result
}
// String returns the Path as a string.
func (p Path) String() string {
if len(p) == 0 {
return ""
// Join joins a variable number of paths together. Each path
// in the joining is treated as a subpath of its predecessor.
// Calling Join with no or only empty paths returns nil.
func Join(paths ...key.Path) key.Path {
n := 0
for _, path := range paths {
n += len(path)
}
var buf bytes.Buffer
for _, element := range p {
buf.WriteByte('/')
buf.WriteString(element.String())
if n == 0 {
return nil
}
return buf.String()
result, i := make(key.Path, n), 0
for _, path := range paths {
i += copy(result[i:], path)
}
return result
}
// Equal returns whether the Path contains the same elements as the other Path.
// This method implements key.Comparable.
func (p Path) Equal(other interface{}) bool {
o, ok := other.(Path)
if !ok {
return false
// Parent returns all but the last element of the path. If
// the path is empty, Parent returns nil.
func Parent(path key.Path) key.Path {
if len(path) > 0 {
return path[:len(path)-1]
}
if len(o) != len(p) {
return false
}
return o.hasPrefix(p)
return nil
}
// HasPrefix returns whether the Path is prefixed by the other Path.
func (p Path) HasPrefix(prefix Path) bool {
if len(prefix) > len(p) {
return false
// Base returns the last element of the path. If the path is
// empty, Base returns nil.
func Base(path key.Path) key.Key {
if len(path) > 0 {
return path[len(path)-1]
}
return p.hasPrefix(prefix)
return nil
}
func (p Path) hasPrefix(prefix Path) bool {
for i := range prefix {
if !prefix[i].Equal(p[i]) {
// Clone returns a new path with the same elements as in the
// provided path.
func Clone(path key.Path) key.Path {
result := make(key.Path, len(path))
copy(result, path)
return result
}
// Equal returns whether path a and path b are the same
// length and whether each element in b corresponds to the
// same element in a.
func Equal(a, b key.Path) bool {
return len(a) == len(b) && hasPrefix(a, b)
}
// HasElement returns whether element b exists in path a.
func HasElement(a key.Path, b key.Key) bool {
for _, element := range a {
if element.Equal(b) {
return true
}
}
return false
}
// HasPrefix returns whether path b is a prefix of path a.
// It checks that b is at most the length of path a and
// whether each element in b corresponds to the same element
// in a from the first element.
func HasPrefix(a, b key.Path) bool {
return len(a) >= len(b) && hasPrefix(a, b)
}
// Match returns whether path a and path b are the same
// length and whether each element in b corresponds to the
// same element or a wildcard in a.
func Match(a, b key.Path) bool {
return len(a) == len(b) && matchPrefix(a, b)
}
// MatchPrefix returns whether path b is a prefix of path a
// where path a may contain wildcards.
// It checks that b is at most the length of path a and
// whether each element in b corresponds to the same element
// or a wildcard in a from the first element.
func MatchPrefix(a, b key.Path) bool {
return len(a) >= len(b) && matchPrefix(a, b)
}
// FromString constructs a path from the elements resulting
// from a split of the input string by "/". Strings that do
// not lead with a '/' are accepted but not reconstructable
// with key.Path.String. Both "" and "/" are treated as a
// key.Path{}.
func FromString(str string) key.Path {
if str == "" || str == "/" {
return key.Path{}
} else if str[0] == '/' {
str = str[1:]
}
elements := strings.Split(str, "/")
result := make(key.Path, len(elements))
for i, element := range elements {
result[i] = key.New(element)
}
return result
}
func copyElements(dest key.Path, elements ...interface{}) {
for i, element := range elements {
switch val := element.(type) {
case key.Key:
dest[i] = val
default:
dest[i] = key.New(val)
}
}
}
func hasPrefix(a, b key.Path) bool {
for i := range b {
if !b[i].Equal(a[i]) {
return false
}
}
return true
}
func matchPrefix(a, b key.Path) bool {
for i := range b {
if !a[i].Equal(Wildcard) && !b[i].Equal(a[i]) {
return false
}
}

View File

@ -12,79 +12,179 @@ import (
"github.com/aristanetworks/goarista/value"
)
func TestNewPath(t *testing.T) {
func TestNew(t *testing.T) {
tcases := []struct {
in []interface{}
out Path
out key.Path
}{
{
in: nil,
out: nil,
out: key.Path{},
}, {
in: []interface{}{},
out: Path{},
out: key.Path{},
}, {
in: []interface{}{""},
out: Path{key.New("")},
in: []interface{}{"foo", key.New("bar"), true},
out: key.Path{key.New("foo"), key.New("bar"), key.New(true)},
}, {
in: []interface{}{key.New("")},
out: Path{key.New("")},
in: []interface{}{int8(5), int16(5), int32(5), int64(5)},
out: key.Path{key.New(int8(5)), key.New(int16(5)), key.New(int32(5)),
key.New(int64(5))},
}, {
in: []interface{}{"foo"},
out: Path{key.New("foo")},
in: []interface{}{uint8(5), uint16(5), uint32(5), uint64(5)},
out: key.Path{key.New(uint8(5)), key.New(uint16(5)), key.New(uint32(5)),
key.New(uint64(5))},
}, {
in: []interface{}{key.New("foo")},
out: Path{key.New("foo")},
in: []interface{}{float32(5), float64(5)},
out: key.Path{key.New(float32(5)), key.New(float64(5))},
}, {
in: []interface{}{"foo", key.New("bar")},
out: Path{key.New("foo"), key.New("bar")},
}, {
in: []interface{}{key.New("foo"), "bar", key.New("baz")},
out: Path{key.New("foo"), key.New("bar"), key.New("baz")},
in: []interface{}{customKey{i: &a}, map[string]interface{}{}},
out: key.Path{key.New(customKey{i: &a}), key.New(map[string]interface{}{})},
},
}
for i, tcase := range tcases {
if p := New(tcase.in...); !p.Equal(tcase.out) {
if p := New(tcase.in...); !Equal(p, tcase.out) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out)
}
}
}
func TestAppendPath(t *testing.T) {
func TestClone(t *testing.T) {
if !Equal(Clone(key.Path{}), key.Path{}) {
t.Error("Clone(key.Path{}) != key.Path{}")
}
a := key.Path{key.New("foo"), key.New("bar")}
b, c := Clone(a), Clone(a)
b[1] = key.New("baz")
if Equal(a, b) || !Equal(a, c) {
t.Error("Clone is not making a copied path")
}
}
func TestAppend(t *testing.T) {
tcases := []struct {
base Path
elements []interface{}
expected Path
a key.Path
b []interface{}
result key.Path
}{
{
base: Path{},
elements: []interface{}{},
expected: Path{},
a: key.Path{},
b: []interface{}{},
result: key.Path{},
}, {
base: Path{},
elements: []interface{}{""},
expected: Path{key.New("")},
a: key.Path{key.New("foo")},
b: []interface{}{},
result: key.Path{key.New("foo")},
}, {
base: Path{},
elements: []interface{}{key.New("")},
expected: Path{key.New("")},
a: key.Path{},
b: []interface{}{"foo", key.New("bar")},
result: key.Path{key.New("foo"), key.New("bar")},
}, {
base: Path{},
elements: []interface{}{"foo", key.New("bar")},
expected: Path{key.New("foo"), key.New("bar")},
}, {
base: Path{key.New("foo")},
elements: []interface{}{key.New("bar"), "baz"},
expected: Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, {
base: Path{key.New("foo"), key.New("bar")},
elements: []interface{}{key.New("baz")},
expected: Path{key.New("foo"), key.New("bar"), key.New("baz")},
a: key.Path{key.New("foo")},
b: []interface{}{int64(0), key.New("bar")},
result: key.Path{key.New("foo"), key.New(int64(0)), key.New("bar")},
},
}
for i, tcase := range tcases {
if p := Append(tcase.base, tcase.elements...); !p.Equal(tcase.expected) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.expected)
if p := Append(tcase.a, tcase.b...); !Equal(p, tcase.result) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.result)
}
}
}
func TestJoin(t *testing.T) {
tcases := []struct {
paths []key.Path
result key.Path
}{
{
paths: nil,
result: nil,
}, {
paths: []key.Path{},
result: nil,
}, {
paths: []key.Path{key.Path{}},
result: nil,
}, {
paths: []key.Path{key.Path{key.New(true)}, key.Path{}},
result: key.Path{key.New(true)},
}, {
paths: []key.Path{key.Path{}, key.Path{key.New(true)}},
result: key.Path{key.New(true)},
}, {
paths: []key.Path{key.Path{key.New("foo")}, key.Path{key.New("bar")}},
result: key.Path{key.New("foo"), key.New("bar")},
}, {
paths: []key.Path{key.Path{key.New("bar")}, key.Path{key.New("foo")}},
result: key.Path{key.New("bar"), key.New("foo")},
}, {
paths: []key.Path{
key.Path{key.New(uint32(0)), key.New(uint64(0))},
key.Path{key.New(int8(0))},
key.Path{key.New(int16(0)), key.New(int32(0))},
key.Path{key.New(int64(0)), key.New(uint8(0)), key.New(uint16(0))},
},
result: key.Path{
key.New(uint32(0)), key.New(uint64(0)),
key.New(int8(0)), key.New(int16(0)),
key.New(int32(0)), key.New(int64(0)),
key.New(uint8(0)), key.New(uint16(0)),
},
},
}
for i, tcase := range tcases {
if p := Join(tcase.paths...); !Equal(p, tcase.result) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.result)
}
}
}
func TestParent(t *testing.T) {
if Parent(key.Path{}) != nil {
t.Fatal("Parent of empty key.Path should be nil")
}
tcases := []struct {
in key.Path
out key.Path
}{
{
in: key.Path{key.New("foo")},
out: key.Path{},
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: key.Path{key.New("foo")},
}, {
in: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
out: key.Path{key.New("foo"), key.New("bar")},
},
}
for _, tcase := range tcases {
if !Equal(Parent(tcase.in), tcase.out) {
t.Fatalf("Parent of %#v != %#v", tcase.in, tcase.out)
}
}
}
func TestBase(t *testing.T) {
if Base(key.Path{}) != nil {
t.Fatal("Base of empty key.Path should be nil")
}
tcases := []struct {
in key.Path
out key.Key
}{
{
in: key.Path{key.New("foo")},
out: key.New("foo"),
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: key.New("bar"),
},
}
for _, tcase := range tcases {
if !Base(tcase.in).Equal(tcase.out) {
t.Fatalf("Base of %#v != %#v", tcase.in, tcase.out)
}
}
}
@ -117,176 +217,407 @@ var (
b = 1
)
func TestPathEquality(t *testing.T) {
func TestEqual(t *testing.T) {
tcases := []struct {
base Path
other Path
expected bool
a key.Path
b key.Path
result bool
}{
{
base: Path{},
other: Path{},
expected: true,
a: nil,
b: nil,
result: true,
}, {
base: Path{},
other: Path{key.New("")},
expected: false,
a: nil,
b: key.Path{},
result: true,
}, {
base: Path{key.New("foo")},
other: Path{key.New("foo")},
expected: true,
a: key.Path{},
b: nil,
result: true,
}, {
base: Path{key.New("foo")},
other: Path{key.New("bar")},
expected: false,
a: key.Path{},
b: key.Path{},
result: true,
}, {
base: Path{key.New("foo"), key.New("bar")},
other: Path{key.New("foo")},
expected: false,
a: key.Path{},
b: key.Path{key.New("")},
result: false,
}, {
base: Path{key.New("foo"), key.New("bar")},
other: Path{key.New("bar"), key.New("foo")},
expected: false,
a: key.Path{Wildcard},
b: key.Path{key.New("foo")},
result: false,
}, {
base: Path{key.New("foo"), key.New("bar"), key.New("baz")},
other: Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: true,
a: key.Path{Wildcard},
b: key.Path{Wildcard},
result: true,
}, {
a: key.Path{key.New("foo")},
b: key.Path{key.New("foo")},
result: true,
}, {
a: key.Path{key.New(true)},
b: key.Path{key.New(false)},
result: false,
}, {
a: key.Path{key.New(int32(5))},
b: key.Path{key.New(int64(5))},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New(uint8(0)), key.New(int8(0))},
b: key.Path{key.New(int8(0)), key.New(uint8(0))},
result: false,
},
// Ensure that we check deep equality.
{
base: Path{key.New(map[string]interface{}{})},
other: Path{key.New(map[string]interface{}{})},
expected: true,
a: key.Path{key.New(map[string]interface{}{})},
b: key.Path{key.New(map[string]interface{}{})},
result: true,
}, {
base: Path{key.New(customKey{i: &a})},
other: Path{key.New(customKey{i: &b})},
expected: true,
a: key.Path{key.New(customKey{i: &a})},
b: key.Path{key.New(customKey{i: &b})},
result: true,
},
}
for i, tcase := range tcases {
if result := tcase.base.Equal(tcase.other); result != tcase.expected {
t.Fatalf("Test %d failed: base: %#v; other: %#v, expected: %t",
i, tcase.base, tcase.other, tcase.expected)
if result := Equal(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestPathHasPrefix(t *testing.T) {
func TestMatch(t *testing.T) {
tcases := []struct {
base Path
prefix Path
expected bool
a key.Path
b key.Path
result bool
}{
{
base: Path{},
prefix: Path{},
expected: true,
a: nil,
b: nil,
result: true,
}, {
base: Path{key.New("foo")},
prefix: Path{},
expected: true,
a: nil,
b: key.Path{},
result: true,
}, {
base: Path{key.New("foo"), key.New("bar")},
prefix: Path{key.New("foo")},
expected: true,
a: key.Path{},
b: nil,
result: true,
}, {
base: Path{key.New("foo"), key.New("bar")},
prefix: Path{key.New("bar")},
expected: false,
a: key.Path{},
b: key.Path{},
result: true,
}, {
base: Path{key.New("foo"), key.New("bar")},
prefix: Path{key.New("bar"), key.New("foo")},
expected: false,
a: key.Path{},
b: key.Path{key.New("foo")},
result: false,
}, {
base: Path{key.New("foo"), key.New("bar")},
prefix: Path{key.New("foo"), key.New("bar")},
expected: true,
a: key.Path{Wildcard},
b: key.Path{key.New("foo")},
result: true,
}, {
base: Path{key.New("foo"), key.New("bar")},
prefix: Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: false,
a: key.Path{key.New("foo")},
b: key.Path{Wildcard},
result: false,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{Wildcard, Wildcard},
b: key.Path{key.New(int64(0))},
result: false,
}, {
a: key.Path{Wildcard, Wildcard},
b: key.Path{key.New(int64(0)), key.New(int32(0))},
result: true,
}, {
a: key.Path{Wildcard, key.New(false)},
b: key.Path{key.New(true), Wildcard},
result: false,
},
}
for i, tcase := range tcases {
if result := tcase.base.HasPrefix(tcase.prefix); result != tcase.expected {
t.Fatalf("Test %d failed: base: %#v; prefix: %#v, expected: %t",
i, tcase.base, tcase.prefix, tcase.expected)
if result := Match(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestPathFromString(t *testing.T) {
func TestHasElement(t *testing.T) {
tcases := []struct {
a key.Path
b key.Key
result bool
}{
{
a: nil,
b: nil,
result: false,
}, {
a: nil,
b: key.New("foo"),
result: false,
}, {
a: key.Path{},
b: nil,
result: false,
}, {
a: key.Path{key.New("foo")},
b: nil,
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.New("foo"),
result: true,
}, {
a: key.Path{key.New(true)},
b: key.New("true"),
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.New("bar"),
result: true,
}, {
a: key.Path{key.New(map[string]interface{}{})},
b: key.New(map[string]interface{}{}),
result: true,
}, {
a: key.Path{key.New(map[string]interface{}{"foo": "a"})},
b: key.New(map[string]interface{}{"bar": "a"}),
result: false,
},
}
for i, tcase := range tcases {
if result := HasElement(tcase.a, tcase.b); result != tcase.result {
t.Errorf("Test %d failed: a: %#v; b: %#v, result: %t, expected: %t",
i, tcase.a, tcase.b, result, tcase.result)
}
}
}
func TestHasPrefix(t *testing.T) {
tcases := []struct {
a key.Path
b key.Path
result bool
}{
{
a: nil,
b: nil,
result: true,
}, {
a: nil,
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: nil,
result: true,
}, {
a: key.Path{},
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{},
result: true,
}, {
a: key.Path{key.New(true)},
b: key.Path{key.New(false)},
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.Path{key.New("bar"), key.New("foo")},
result: false,
}, {
a: key.Path{key.New(int8(0)), key.New(uint8(0))},
b: key.Path{key.New(uint8(0)), key.New(uint8(0))},
result: false,
}, {
a: key.Path{key.New(true), key.New(true)},
b: key.Path{key.New(true), key.New(true), key.New(true)},
result: false,
}, {
a: key.Path{key.New(true), key.New(true), key.New(true)},
b: key.Path{key.New(true), key.New(true)},
result: true,
}, {
a: key.Path{Wildcard, key.New(int32(0)), Wildcard},
b: key.Path{key.New(int64(0)), Wildcard},
result: false,
},
}
for i, tcase := range tcases {
if result := HasPrefix(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestMatchPrefix(t *testing.T) {
tcases := []struct {
a key.Path
b key.Path
result bool
}{
{
a: nil,
b: nil,
result: true,
}, {
a: nil,
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: nil,
result: true,
}, {
a: key.Path{},
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{},
result: true,
}, {
a: key.Path{key.New("foo")},
b: key.Path{Wildcard},
result: false,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo")},
result: true,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{Wildcard, key.New(true)},
b: key.Path{key.New(false), Wildcard},
result: false,
}, {
a: key.Path{Wildcard, key.New(int32(0)), key.New(int16(0))},
b: key.Path{key.New(int64(0)), key.New(int32(0))},
result: true,
},
}
for i, tcase := range tcases {
if result := MatchPrefix(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestFromString(t *testing.T) {
tcases := []struct {
in string
out Path
out key.Path
}{
{
in: "",
out: Path{},
out: key.Path{},
}, {
in: "/",
out: Path{key.New("")},
out: key.Path{},
}, {
in: "//",
out: Path{key.New(""), key.New("")},
out: key.Path{key.New(""), key.New("")},
}, {
in: "foo",
out: key.Path{key.New("foo")},
}, {
in: "/foo",
out: Path{key.New("foo")},
out: key.Path{key.New("foo")},
}, {
in: "foo/bar",
out: key.Path{key.New("foo"), key.New("bar")},
}, {
in: "/foo/bar",
out: Path{key.New("foo"), key.New("bar")},
out: key.Path{key.New("foo"), key.New("bar")},
}, {
in: "foo/bar/baz",
out: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, {
in: "/foo/bar/baz",
out: Path{key.New("foo"), key.New("bar"), key.New("baz")},
out: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, {
in: "0/123/456/789",
out: key.Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")},
}, {
in: "/0/123/456/789",
out: Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")},
out: key.Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")},
}, {
in: "`~!@#$%^&*()_+{}\\/|[];':\"<>?,./",
out: key.Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")},
}, {
in: "/`~!@#$%^&*()_+{}\\/|[];':\"<>?,./",
out: Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")},
out: key.Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")},
},
}
for i, tcase := range tcases {
if p := FromString(tcase.in); !p.Equal(tcase.out) {
if p := FromString(tcase.in); !Equal(p, tcase.out) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out)
}
}
}
func TestPathToString(t *testing.T) {
func TestString(t *testing.T) {
tcases := []struct {
in Path
in key.Path
out string
}{
{
in: Path{},
out: "",
}, {
in: Path{key.New("")},
in: key.Path{},
out: "/",
}, {
in: Path{key.New("foo")},
in: key.Path{key.New("")},
out: "/",
}, {
in: key.Path{key.New("foo")},
out: "/foo",
}, {
in: Path{key.New("foo"), key.New("bar")},
in: key.Path{key.New("foo"), key.New("bar")},
out: "/foo/bar",
}, {
in: Path{key.New("/foo"), key.New("bar")},
in: key.Path{key.New("/foo"), key.New("bar")},
out: "//foo/bar",
}, {
in: Path{key.New("foo"), key.New("bar/")},
in: key.Path{key.New("foo"), key.New("bar/")},
out: "/foo/bar/",
}, {
in: Path{key.New(""), key.New("foo"), key.New("bar")},
in: key.Path{key.New(""), key.New("foo"), key.New("bar")},
out: "//foo/bar",
}, {
in: Path{key.New("foo"), key.New("bar"), key.New("")},
in: key.Path{key.New("foo"), key.New("bar"), key.New("")},
out: "/foo/bar/",
}, {
in: Path{key.New("/"), key.New("foo"), key.New("bar")},
in: key.Path{key.New("/"), key.New("foo"), key.New("bar")},
out: "///foo/bar",
}, {
in: Path{key.New("foo"), key.New("bar"), key.New("/")},
in: key.Path{key.New("foo"), key.New("bar"), key.New("/")},
out: "/foo/bar//",
},
}
@ -296,3 +627,59 @@ func TestPathToString(t *testing.T) {
}
}
}
func BenchmarkJoin(b *testing.B) {
generate := func(n int) []key.Path {
paths := make([]key.Path, 0, n)
for i := 0; i < n; i++ {
paths = append(paths, key.Path{key.New("foo")})
}
return paths
}
benchmarks := map[string][]key.Path{
"10 key.Paths": generate(10),
"100 key.Paths": generate(100),
"1000 key.Paths": generate(1000),
"10000 key.Paths": generate(10000),
}
for name, benchmark := range benchmarks {
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Join(benchmark...)
}
})
}
}
func BenchmarkHasElement(b *testing.B) {
element := key.New("waldo")
generate := func(n, loc int) key.Path {
path := make(key.Path, n)
for i := 0; i < n; i++ {
if i == loc {
path[i] = element
} else {
path[i] = key.New(int8(0))
}
}
return path
}
benchmarks := map[string]key.Path{
"10 Elements Index 0": generate(10, 0),
"10 Elements Index 4": generate(10, 4),
"10 Elements Index 9": generate(10, 9),
"100 Elements Index 0": generate(100, 0),
"100 Elements Index 49": generate(100, 49),
"100 Elements Index 99": generate(100, 99),
"1000 Elements Index 0": generate(1000, 0),
"1000 Elements Index 499": generate(1000, 499),
"1000 Elements Index 999": generate(1000, 999),
}
for name, benchmark := range benchmarks {
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
HasElement(benchmark, element)
}
})
}
}

View File

@ -0,0 +1,36 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package path
import "github.com/aristanetworks/goarista/key"
// Wildcard is a special element in a path that is used by Map
// and the Match* functions to match any other element.
var Wildcard = key.New(WildcardType{})
// WildcardType is the type used to construct a Wildcard. It
// implements the value.Value interface so it can be used as
// a key.Key.
type WildcardType struct{}
func (w WildcardType) String() string {
return "*"
}
// Equal implements the key.Comparable interface.
func (w WildcardType) Equal(other interface{}) bool {
_, ok := other.(WildcardType)
return ok
}
// ToBuiltin implements the value.Value interface.
func (w WildcardType) ToBuiltin() interface{} {
return WildcardType{}
}
// MarshalJSON implements the value.Value interface.
func (w WildcardType) MarshalJSON() ([]byte, error) {
return []byte(`{"_wildcard":{}}`), nil
}

View File

@ -0,0 +1,79 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package path
import (
"encoding/json"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/value"
)
type pseudoWildcard struct{}
func (w pseudoWildcard) Key() interface{} {
return struct{}{}
}
func (w pseudoWildcard) String() string {
return "*"
}
func (w pseudoWildcard) Equal(other interface{}) bool {
o, ok := other.(pseudoWildcard)
return ok && w == o
}
func TestWildcardUniqueness(t *testing.T) {
if Wildcard.Equal(pseudoWildcard{}) {
t.Fatal("Wildcard is not unique")
}
if Wildcard.Equal(struct{}{}) {
t.Fatal("Wildcard is not unique")
}
if Wildcard.Equal(key.New("*")) {
t.Fatal("Wildcard is not unique")
}
}
func TestWildcardTypeIsNotAKey(t *testing.T) {
var intf interface{} = WildcardType{}
_, ok := intf.(key.Key)
if ok {
t.Error("WildcardType should not implement key.Key")
}
}
func TestWildcardTypeEqual(t *testing.T) {
k1 := key.New(WildcardType{})
k2 := key.New(WildcardType{})
if !k1.Equal(k2) {
t.Error("They should be equal")
}
if !Wildcard.Equal(k1) {
t.Error("They should be equal")
}
}
func TestWildcardTypeAsValue(t *testing.T) {
var k value.Value = WildcardType{}
w := WildcardType{}
if k.ToBuiltin() != w {
t.Error("WildcardType.ToBuiltin is not correct")
}
}
func TestWildcardMarshalJSON(t *testing.T) {
b, err := json.Marshal(Wildcard)
if err != nil {
t.Fatal(err)
}
expected := `{"_wildcard":{}}`
if string(b) != expected {
t.Errorf("Invalid Wildcard json representation.\nExpected: %s\nReceived: %s",
expected, string(b))
}
}

View File

@ -1,265 +0,0 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package pathmap
import (
"bytes"
"fmt"
"sort"
)
// PathMap associates Paths to a values. It allows wildcards. The
// primary use of PathMap is to be able to register handlers to paths
// that can be efficiently looked up every time a path is updated.
//
// For example:
//
// m.Set({"interfaces", "*", "adminStatus"}, AdminStatusHandler)
// m.Set({"interface", "Management1", "adminStatus"}, Management1AdminStatusHandler)
//
// m.Visit({"interfaces", "Ethernet3/32/1", "adminStatus"}, HandlerExecutor)
// >> AdminStatusHandler gets passed to HandlerExecutor
// m.Visit({"interfaces", "Management1", "adminStatus"}, HandlerExecutor)
// >> AdminStatusHandler and Management1AdminStatusHandler gets passed to HandlerExecutor
//
// Note, Visit performance is typically linearly with the length of
// the path. But, it can be as bad as O(2^len(Path)) when TreeMap
// nodes have children and a wildcard associated with it. For example,
// if these paths were registered:
//
// m.Set([]string{"foo", "bar", "baz"}, 1)
// m.Set([]string{"*", "bar", "baz"}, 2)
// m.Set([]string{"*", "*", "baz"}, 3)
// m.Set([]string{"*", "*", "*"}, 4)
// m.Set([]string{"foo", "*", "*"}, 5)
// m.Set([]string{"foo", "bar", "*"}, 6)
// m.Set([]string{"foo", "*", "baz"}, 7)
// m.Set([]string{"*", "bar", "*"}, 8)
//
// m.Visit([]{"foo","bar","baz"}, Foo) // 2^3 nodes traversed
//
// This shouldn't be a concern with our paths because it is likely
// that a TreeMap node will either have a wildcard or children, not
// both. A TreeMap node that corresponds to a collection will often be a
// wildcard, otherwise it will have specific children.
type PathMap interface {
// Visit calls f for every registration in the PathMap that
// matches path. For example,
//
// m.Set({"foo", "bar"}, 1)
// m.Set({"*", "bar"}, 2)
//
// m.Visit({"foo", "bar"}, Printer)
// >> Calls Printer(1) and Printer(2)
Visit(path []string, f VisitorFunc) error
// VisitPrefix calls f for every registration in the PathMap that
// is a prefix of path. For example,
//
// m.Set({}, 0)
// m.Set({"foo"}, 1)
// m.Set({"foo", "bar"}, 2)
// m.Set({"foo", "quux"}, 3)
// m.Set({"*", "bar"}, 4)
//
// m.VisitPrefix({"foo", "bar", "baz"}, Printer)
// >> Calls Printer on values 0, 1, 2, and 4
VisitPrefix(path []string, f VisitorFunc) error
// Get returns the mapping for path. This returns the exact
// mapping for path. For example, if you register two paths
//
// m.Set({"foo", "bar"}, 1)
// m.Set({"*", "bar"}, 2)
//
// m.Get({"foo", "bar"}) => 1
// m.Get({"*", "bar"}) => 2
Get(path []string) interface{}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
Set(path []string, v interface{})
// Delete removes the mapping for path
Delete(path []string) bool
}
// Wildcard is a special string representing any possible path
const Wildcard string = "*"
type node struct {
val interface{}
wildcard *node
children map[string]*node
}
// New creates a new PathMap
func New() PathMap {
return &node{}
}
// VisitorFunc is the func type passed to Visit
type VisitorFunc func(v interface{}) error
// Visit calls f for every matching registration in the PathMap
func (n *node) Visit(path []string, f VisitorFunc) error {
for i, element := range path {
if n.wildcard != nil {
if err := n.wildcard.Visit(path[i+1:], f); err != nil {
return err
}
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
if n.val == nil {
return nil
}
return f(n.val)
}
// VisitPrefix calls f for every registered path that is a prefix of
// the path
func (n *node) VisitPrefix(path []string, f VisitorFunc) error {
for i, element := range path {
// Call f on each node we visit
if n.val != nil {
if err := f(n.val); err != nil {
return err
}
}
if n.wildcard != nil {
if err := n.wildcard.VisitPrefix(path[i+1:], f); err != nil {
return err
}
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
if n.val == nil {
return nil
}
// Call f on the final node
return f(n.val)
}
// Get returns the mapping for path
func (n *node) Get(path []string) interface{} {
for _, element := range path {
if element == Wildcard {
if n.wildcard == nil {
return nil
}
n = n.wildcard
continue
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
return n.val
}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
func (n *node) Set(path []string, v interface{}) {
for _, element := range path {
if element == Wildcard {
if n.wildcard == nil {
n.wildcard = &node{}
}
n = n.wildcard
continue
}
if n.children == nil {
n.children = map[string]*node{}
}
next, ok := n.children[element]
if !ok {
next = &node{}
n.children[element] = next
}
n = next
}
n.val = v
}
// Delete removes the mapping for path
func (n *node) Delete(path []string) bool {
nodes := make([]*node, len(path)+1)
for i, element := range path {
nodes[i] = n
if element == Wildcard {
if n.wildcard == nil {
return false
}
n = n.wildcard
continue
}
next, ok := n.children[element]
if !ok {
return false
}
n = next
}
n.val = nil
nodes[len(path)] = n
// See if we can delete any node objects
for i := len(path); i > 0; i-- {
n = nodes[i]
if n.val != nil || n.wildcard != nil || len(n.children) > 0 {
break
}
parent := nodes[i-1]
element := path[i-1]
if element == Wildcard {
parent.wildcard = nil
} else {
delete(parent.children, element)
}
}
return true
}
func (n *node) String() string {
var b bytes.Buffer
n.write(&b, "")
return b.String()
}
func (n *node) write(b *bytes.Buffer, indent string) {
if n.val != nil {
b.WriteString(indent)
fmt.Fprintf(b, "Val: %v", n.val)
b.WriteString("\n")
}
if n.wildcard != nil {
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", Wildcard)
n.wildcard.write(b, indent+" ")
}
children := make([]string, 0, len(n.children))
for name := range n.children {
children = append(children, name)
}
sort.Strings(children)
for _, name := range children {
child := n.children[name]
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", name)
child.write(b, indent+" ")
}
}

View File

@ -1,335 +0,0 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package pathmap
import (
"errors"
"fmt"
"testing"
"github.com/aristanetworks/goarista/test"
)
func accumulator(counter map[int]int) VisitorFunc {
return func(val interface{}) error {
counter[val.(int)]++
return nil
}
}
func TestVisit(t *testing.T) {
m := New()
m.Set([]string{"foo", "bar", "baz"}, 1)
m.Set([]string{"*", "bar", "baz"}, 2)
m.Set([]string{"*", "*", "baz"}, 3)
m.Set([]string{"*", "*", "*"}, 4)
m.Set([]string{"foo", "*", "*"}, 5)
m.Set([]string{"foo", "bar", "*"}, 6)
m.Set([]string{"foo", "*", "baz"}, 7)
m.Set([]string{"*", "bar", "*"}, 8)
m.Set([]string{}, 10)
m.Set([]string{"*"}, 20)
m.Set([]string{"foo"}, 21)
m.Set([]string{"zap", "zip"}, 30)
m.Set([]string{"zap", "zip"}, 31)
m.Set([]string{"zip", "*"}, 40)
m.Set([]string{"zip", "*"}, 41)
testCases := []struct {
path []string
expected map[int]int
}{{
path: []string{"foo", "bar", "baz"},
expected: map[int]int{1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1},
}, {
path: []string{"qux", "bar", "baz"},
expected: map[int]int{2: 1, 3: 1, 4: 1, 8: 1},
}, {
path: []string{"foo", "qux", "baz"},
expected: map[int]int{3: 1, 4: 1, 5: 1, 7: 1},
}, {
path: []string{"foo", "bar", "qux"},
expected: map[int]int{4: 1, 5: 1, 6: 1, 8: 1},
}, {
path: []string{},
expected: map[int]int{10: 1},
}, {
path: []string{"foo"},
expected: map[int]int{20: 1, 21: 1},
}, {
path: []string{"foo", "bar"},
expected: map[int]int{},
}, {
path: []string{"zap", "zip"},
expected: map[int]int{31: 1},
}, {
path: []string{"zip", "zap"},
expected: map[int]int{41: 1},
}}
for _, tc := range testCases {
result := make(map[int]int, len(tc.expected))
m.Visit(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff)
}
}
}
func TestVisitError(t *testing.T) {
m := New()
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"*", "bar"}, 2)
errTest := errors.New("Test")
err := m.Visit([]string{"foo", "bar"}, func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
err = m.VisitPrefix([]string{"foo", "bar", "baz"}, func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
}
func TestGet(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"foo", "*"}, 2)
m.Set([]string{"*", "bar"}, 3)
m.Set([]string{"zap", "zip"}, 4)
testCases := []struct {
path []string
expected interface{}
}{{
path: []string{},
expected: 0,
}, {
path: []string{"foo", "bar"},
expected: 1,
}, {
path: []string{"foo", "*"},
expected: 2,
}, {
path: []string{"*", "bar"},
expected: 3,
}, {
path: []string{"bar", "foo"},
expected: nil,
}, {
path: []string{"zap", "*"},
expected: nil,
}}
for _, tc := range testCases {
got := m.Get(tc.path)
if got != tc.expected {
t.Errorf("Test case %v: Expected %v, Got %v",
tc.path, tc.expected, got)
}
}
}
func countNodes(n *node) int {
if n == nil {
return 0
}
count := 1
count += countNodes(n.wildcard)
for _, child := range n.children {
count += countNodes(child)
}
return count
}
func TestDelete(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"*"}, 1)
m.Set([]string{"foo", "bar"}, 2)
m.Set([]string{"foo", "*"}, 3)
n := countNodes(m.(*node))
if n != 5 {
t.Errorf("Initial count wrong. Expected: 5, Got: %d", n)
}
testCases := []struct {
del []string // Path to delete
expected bool // expected return value of Delete
visit []string // Path to Visit
before map[int]int // Expected to find items before deletion
after map[int]int // Expected to find items after deletion
count int // Count of nodes
}{{
del: []string{"zap"}, // A no-op Delete
expected: false,
visit: []string{"foo", "bar"},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{2: 1, 3: 1},
count: 5,
}, {
del: []string{"foo", "bar"},
expected: true,
visit: []string{"foo", "bar"},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{3: 1},
count: 4,
}, {
del: []string{"*"},
expected: true,
visit: []string{"foo"},
before: map[int]int{1: 1},
after: map[int]int{},
count: 3,
}, {
del: []string{"*"},
expected: false,
visit: []string{"foo"},
before: map[int]int{},
after: map[int]int{},
count: 3,
}, {
del: []string{"foo", "*"},
expected: true,
visit: []string{"foo", "bar"},
before: map[int]int{3: 1},
after: map[int]int{},
count: 1, // Should have deleted "foo" and "bar" nodes
}, {
del: []string{},
expected: true,
visit: []string{},
before: map[int]int{0: 1},
after: map[int]int{},
count: 1, // Root node can't be deleted
}}
for i, tc := range testCases {
beforeResult := make(map[int]int, len(tc.before))
m.Visit(tc.visit, accumulator(beforeResult))
if diff := test.Diff(tc.before, beforeResult); diff != "" {
t.Errorf("Test case %d (%v): %s", i, tc.del, diff)
}
if got := m.Delete(tc.del); got != tc.expected {
t.Errorf("Test case %d (%v): Unexpected return. Expected %t, Got: %t",
i, tc.del, tc.expected, got)
}
afterResult := make(map[int]int, len(tc.after))
m.Visit(tc.visit, accumulator(afterResult))
if diff := test.Diff(tc.after, afterResult); diff != "" {
t.Errorf("Test case %d (%v): %s", i, tc.del, diff)
}
}
}
func TestVisitPrefix(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo"}, 1)
m.Set([]string{"foo", "bar"}, 2)
m.Set([]string{"foo", "bar", "baz"}, 3)
m.Set([]string{"foo", "bar", "baz", "quux"}, 4)
m.Set([]string{"quux", "bar"}, 5)
m.Set([]string{"foo", "quux"}, 6)
m.Set([]string{"*"}, 7)
m.Set([]string{"foo", "*"}, 8)
m.Set([]string{"*", "bar"}, 9)
m.Set([]string{"*", "quux"}, 10)
m.Set([]string{"quux", "quux", "quux", "quux"}, 11)
testCases := []struct {
path []string
expected map[int]int
}{{
path: []string{"foo", "bar", "baz"},
expected: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 7: 1, 8: 1, 9: 1},
}, {
path: []string{"zip", "zap"},
expected: map[int]int{0: 1, 7: 1},
}, {
path: []string{"foo", "zap"},
expected: map[int]int{0: 1, 1: 1, 8: 1, 7: 1},
}, {
path: []string{"quux", "quux", "quux"},
expected: map[int]int{0: 1, 7: 1, 10: 1},
}}
for _, tc := range testCases {
result := make(map[int]int, len(tc.expected))
m.VisitPrefix(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff)
}
}
}
func TestString(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"foo", "quux"}, 2)
m.Set([]string{"foo", "*"}, 3)
expected := `Val: 0
Child "foo":
Child "*":
Val: 3
Child "bar":
Val: 1
Child "quux":
Val: 2
`
got := fmt.Sprint(m)
if expected != got {
t.Errorf("Unexpected string. Expected:\n\n%s\n\nGot:\n\n%s", expected, got)
}
}
func genWords(count, wordLength int) []string {
chars := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
if count+wordLength > len(chars) {
panic("need more chars")
}
result := make([]string, count)
for i := 0; i < count; i++ {
result[i] = string(chars[i : i+wordLength])
}
return result
}
func benchmarkPathMap(pathLength, pathDepth int, b *testing.B) {
m := New()
// Push pathDepth paths, each of length pathLength
path := genWords(pathLength, 10)
words := genWords(pathDepth, 10)
n := m.(*node)
for _, element := range path {
n.children = map[string]*node{}
for _, word := range words {
n.children[word] = &node{}
}
n = n.children[element]
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Visit(path, func(v interface{}) error { return nil })
}
}
func BenchmarkPathMap1x25(b *testing.B) { benchmarkPathMap(1, 25, b) }
func BenchmarkPathMap10x50(b *testing.B) { benchmarkPathMap(10, 25, b) }
func BenchmarkPathMap20x50(b *testing.B) { benchmarkPathMap(20, 25, b) }

View File

@ -0,0 +1,387 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
package sizeof
import (
"errors"
"reflect"
"unsafe"
"github.com/aristanetworks/goarista/areflect"
)
// blocks are used to keep track of which memory areas were already
// been visited.
type block struct {
start uintptr
end uintptr
}
func (b block) size() uintptr {
return b.end - b.start
}
// DeepSizeof returns total memory occupied by each type for val.
// The value passed in argument must be a pointer.
func DeepSizeof(val interface{}) (map[string]uintptr, error) {
value := reflect.ValueOf(val)
// We want to force val to be a pointer to the original value, because if we get a copy, we
// can get some pointers that will point back to our original value.
if value.Kind() != reflect.Ptr {
return nil, errors.New("cannot get the deep size of a non-pointer value")
}
m := make(map[string]uintptr)
ptrsTypes := make(map[uintptr]map[string]struct{})
sizeof(value.Elem(), m, ptrsTypes, false, block{start: uintptr(value.Pointer())}, nil)
return m, nil
}
// Check if curBlock overlap tmpBlock
func isOverlapping(curBlock, tmpBlock block) bool {
return curBlock.start <= tmpBlock.end && tmpBlock.start <= curBlock.end
}
func getOverlappingBlocks(curBlock block, seen []block) ([]block, int) {
var tmp []block
for idx, a := range seen {
if a.start > curBlock.end {
return tmp, idx
}
if isOverlapping(curBlock, a) {
tmp = append(tmp, a)
}
}
return tmp, len(seen)
}
func insertBlock(curBlock block, idxToInsert int, seen []block) []block {
seen = append(seen, block{})
copy(seen[idxToInsert+1:], seen[idxToInsert:])
seen[idxToInsert] = curBlock
return seen
}
// get the size of our current block that is not overlapping other blocks.
func getUnseenSizeOfCurrentBlock(curBlock block, overlappingBlocks []block) uintptr {
var size uintptr
for idx, a := range overlappingBlocks {
if idx == 0 && curBlock.start < a.start {
size += a.start - curBlock.start
}
if idx == len(overlappingBlocks)-1 {
if curBlock.end > a.end {
size += curBlock.end - a.end
}
} else {
size += overlappingBlocks[idx+1].start - a.end
}
}
return size
}
func updateSeenBlocks(curBlock block, seen []block) ([]block, uintptr) {
if len(seen) == 0 {
return []block{curBlock}, curBlock.size()
}
overlappingBlocks, idx := getOverlappingBlocks(curBlock, seen)
if len(overlappingBlocks) == 0 {
// No overlap, so we will insert our new block in our list.
return insertBlock(curBlock, idx, seen), curBlock.size()
}
unseenSize := getUnseenSizeOfCurrentBlock(curBlock, overlappingBlocks)
idxFirstOverlappingBlock := idx - len(overlappingBlocks)
firstOverlappingBlock := &seen[idxFirstOverlappingBlock]
lastOverlappingBlock := seen[idx-1]
if firstOverlappingBlock.start > curBlock.start {
firstOverlappingBlock.start = curBlock.start
}
if lastOverlappingBlock.end < curBlock.end {
firstOverlappingBlock.end = curBlock.end
} else {
firstOverlappingBlock.end = lastOverlappingBlock.end
}
tailLen := len(seen[idx:])
copy(seen[idxFirstOverlappingBlock+1:], seen[idx:])
return seen[:idxFirstOverlappingBlock+1+tailLen], unseenSize
}
// Check if this current block is already fully contained in our list of seen blocks
func isKnownBlock(curBlock block, seen []block) bool {
for _, a := range seen {
if a.start <= curBlock.start &&
a.end >= curBlock.end {
// curBlock is fully contained in an other block
// that we already know
return true
}
if a.start > curBlock.start {
// Our slice of seens block is order by pointer address.
// That means, if curBlock was not contained in a previous known
// block, there is no need to continue.
return false
}
}
return false
}
func sizeof(v reflect.Value, m map[string]uintptr, ptrsTypes map[uintptr]map[string]struct{},
counted bool, curBlock block, seen []block) []block {
if !v.IsValid() {
return seen
}
vn := v.Type().String()
vs := v.Type().Size()
curBlock.end = vs + curBlock.start
if counted {
// already accounted for the size (field in struct, in array, etc)
vs = 0
}
if curBlock.start != 0 {
// A pointer can point to the same memory area than a previous pointer,
// but its type should be different (see tests struct_5 and struct_6).
if types, ok := ptrsTypes[curBlock.start]; ok {
if _, ok := types[vn]; ok {
return seen
}
types[vn] = struct{}{}
} else {
ptrsTypes[curBlock.start] = make(map[string]struct{})
}
if isKnownBlock(curBlock, seen) {
// we don't want to count this size if we have a known block
vs = 0
} else {
var tmpVs uintptr
seen, tmpVs = updateSeenBlocks(curBlock, seen)
if !counted {
vs = tmpVs
}
}
}
switch v.Kind() {
case reflect.Interface:
seen = sizeof(v.Elem(), m, ptrsTypes, false, block{}, seen)
case reflect.Ptr:
if v.IsNil() {
break
}
seen = sizeof(v.Elem(), m, ptrsTypes, false, block{start: uintptr(v.Pointer())}, seen)
case reflect.Array:
// get size of all elements in the array in case there are pointers
l := v.Len()
for i := 0; i < l; i++ {
seen = sizeof(v.Index(i), m, ptrsTypes, true, block{}, seen)
}
case reflect.Slice:
// get size of all elements in the slice in case there are pointers
// TODO: count elements that are not accessible after reslicing
l := v.Len()
vLen := v.Type().Elem().Size()
for i := 0; i < l; i++ {
e := v.Index(i)
eStart := uintptr(e.UnsafeAddr())
eBlock := block{
start: eStart,
end: eStart + vLen,
}
if !isKnownBlock(eBlock, seen) {
vs += vLen
seen = sizeof(e, m, ptrsTypes, true, eBlock, seen)
}
}
capStart := uintptr(v.Pointer()) + (v.Type().Elem().Size() * uintptr(v.Len()))
capEnd := uintptr(v.Pointer()) + (v.Type().Elem().Size() * uintptr(v.Cap()))
capBlock := block{start: capStart, end: capEnd}
if isKnownBlock(capBlock, seen) {
break
}
var tmpSize uintptr
seen, tmpSize = updateSeenBlocks(capBlock, seen)
vs += tmpSize
case reflect.Map:
if v.IsNil() {
break
}
var tmpSize uintptr
if tmpSize, seen = sizeofmap(v, seen); tmpSize == 0 {
// we saw this map
break
}
vs += tmpSize
for _, k := range v.MapKeys() {
kv := v.MapIndex(k)
seen = sizeof(k, m, ptrsTypes, true, block{}, seen)
seen = sizeof(kv, m, ptrsTypes, true, block{}, seen)
}
case reflect.Struct:
for i, n := 0, v.NumField(); i < n; i++ {
vf := areflect.ForceExport(v.Field(i))
seen = sizeof(vf, m, ptrsTypes, true, block{}, seen)
}
case reflect.String:
str := v.String()
strHdr := (*reflect.StringHeader)(unsafe.Pointer(&str))
tmpSize := uintptr(strHdr.Len)
strBlock := block{start: strHdr.Data, end: strHdr.Data + tmpSize}
if isKnownBlock(strBlock, seen) {
break
}
seen, tmpSize = updateSeenBlocks(strBlock, seen)
vs += tmpSize
case reflect.Chan:
var tmpSize uintptr
tmpSize, seen = sizeofChan(v, m, ptrsTypes, seen)
vs += tmpSize
}
if vs != 0 {
m[vn] += vs
}
return seen
}
//go:linkname typesByString reflect.typesByString
func typesByString(s string) []unsafe.Pointer
func sizeofmap(v reflect.Value, seen []block) (uintptr, []block) {
// get field typ *rtype of our Value v and store it in an interface
var ti interface{} = v.Type()
tp := (*unsafe.Pointer)(unsafe.Pointer(&ti))
// we know that this pointer rtype point at the begining of struct
// mapType defined in /go/src/reflect/type.go, so we can change the underlying
// type of the interface to be a pointer to runtime.maptype because it as the
// exact same definition as reflect.mapType.
*tp = typesByString("*runtime.maptype")[0]
maptypev := reflect.ValueOf(ti)
maptypev = reflect.Indirect(maptypev)
// now we can access field bucketsize in struct maptype
bucketsize := maptypev.FieldByName("bucketsize").Uint()
// get hmap
var m interface{} = v.Interface()
hmap := (*unsafe.Pointer)(unsafe.Pointer(&m))
*hmap = typesByString("*runtime.hmap")[0]
hmapv := reflect.ValueOf(m)
// account for the size of the hmap, buckets and oldbuckets
hmapv = reflect.Indirect(hmapv)
mapBlock := block{
start: hmapv.UnsafeAddr(),
end: hmapv.UnsafeAddr() + hmapv.Type().Size(),
}
// is it a map we already saw?
if isKnownBlock(mapBlock, seen) {
return 0, seen
}
seen, _ = updateSeenBlocks(mapBlock, seen)
B := hmapv.FieldByName("B").Uint()
oldbuckets := hmapv.FieldByName("oldbuckets").Pointer()
buckets := hmapv.FieldByName("buckets").Pointer()
noverflow := int16(hmapv.FieldByName("noverflow").Uint())
nb := 2
if B == 0 {
nb = 1
}
size := uint64((nb << B)) * bucketsize
if noverflow != 0 {
size += uint64(noverflow) * bucketsize
}
seen, _ = updateSeenBlocks(block{start: buckets, end: buckets + uintptr(size)},
seen)
// As defined in /go/src/runtime/hashmap.go in struct hmap, oldbuckets is the
// previous bucket array that is half the size of the current one. We need to
// also take that in consideration since there is still a pointer to this previous bucket.
if oldbuckets != 0 {
tmp := (2 << (B - 1)) * bucketsize
size += tmp
seen, _ = updateSeenBlocks(block{
start: oldbuckets,
end: oldbuckets + uintptr(tmp),
}, seen)
}
return hmapv.Type().Size() + uintptr(size), seen
}
func getSliceToChanBuffer(buff unsafe.Pointer, buffLen uint, dataSize uint) []byte {
var slice []byte
sliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
sliceHdr.Len = int(buffLen * dataSize)
sliceHdr.Cap = sliceHdr.Len
sliceHdr.Data = uintptr(buff)
return slice
}
func sizeofChan(v reflect.Value, m map[string]uintptr, ptrsTypes map[uintptr]map[string]struct{},
seen []block) (uintptr, []block) {
var c interface{} = v.Interface()
hchan := (*unsafe.Pointer)(unsafe.Pointer(&c))
*hchan = typesByString("*runtime.hchan")[0]
hchanv := reflect.ValueOf(c)
hchanv = reflect.Indirect(hchanv)
chanBlock := block{
start: hchanv.UnsafeAddr(),
end: hchanv.UnsafeAddr() + hchanv.Type().Size(),
}
// is it a chan we already saw?
if isKnownBlock(chanBlock, seen) {
return 0, seen
}
seen, _ = updateSeenBlocks(chanBlock, seen)
elemType := unsafe.Pointer(hchanv.FieldByName("elemtype").Pointer())
buff := unsafe.Pointer(hchanv.FieldByName("buf").Pointer())
buffLen := hchanv.FieldByName("dataqsiz").Uint()
elemSize := uint16(hchanv.FieldByName("elemsize").Uint())
seen, _ = updateSeenBlocks(block{
start: uintptr(buff),
end: uintptr(buff) + uintptr(buffLen*uint64(elemSize)),
}, seen)
buffSlice := getSliceToChanBuffer(buff, uint(buffLen), uint(elemSize))
recvx := hchanv.FieldByName("recvx").Uint()
qcount := hchanv.FieldByName("qcount").Uint()
var tmp interface{}
eface := (*struct {
typ unsafe.Pointer
ptr unsafe.Pointer
})(unsafe.Pointer(&tmp))
eface.typ = elemType
for i := uint64(0); buffLen > 0 && i < qcount; i++ {
idx := (recvx + i) % buffLen
// get the pointer to the data inside the chan buffer.
elem := unsafe.Pointer(&buffSlice[uint64(elemSize)*idx])
eface.ptr = elem
ev := reflect.ValueOf(tmp)
var blk block
k := ev.Kind()
if k == reflect.Ptr || k == reflect.Chan || k == reflect.Map || k == reflect.Func {
// let's say our chan is a chan *whatEver, or chan chan whatEver or
// chan map[whatEver]whatEver. In this case elemType will
// be either of type *whatEver, chan whatEver or map[whatEver]whatEver
// but what we set eface.ptr = elem above, we make it point to a pointer
// to where the data is sotred in the buffer of our channel.
// So the interface tmp would look like:
// chan *whatEver -> (type=*whatEver, ptr=**whatEver)
// chan chan whatEver -> (type= chan whatEver, ptr=*chan whatEver)
// chan map[whatEver]whatEver -> (type= map[whatEver]whatEver{},
// ptr=*map[whatEver]whatEver)
// So we need to take the ptr which is stored into the buffer and replace
// the ptr to the data of our interface tmp.
ptr := (*unsafe.Pointer)(elem)
eface.ptr = *ptr
ev = reflect.ValueOf(tmp)
ev = reflect.Indirect(ev)
blk.start = uintptr(*ptr)
}
// It seems that when the chan is of type chan *whatEver, the type in eface
// will be whatEver and not *whatEver, but ev.Kind() will be a reflect.ptr.
// So if k is a reflect.Ptr (i.e. a pointer) to a struct, then we want to take
// the size of the struct into account because
// vs := v.Type().Size() will return us the size of the struct and not the size
// of the pointer that is in the channel's buffer.
seen = sizeof(ev, m, ptrsTypes, true && k != reflect.Ptr, blk, seen)
}
return hchanv.Type().Size() + uintptr(uint64(elemSize)*buffLen), seen
}

View File

@ -0,0 +1,8 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
// This file is intentionally empty.
// It's a workaround for https://github.com/golang/go/issues/15006

View File

@ -0,0 +1,573 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
package sizeof
import (
"fmt"
"strconv"
"testing"
"unsafe"
"github.com/aristanetworks/goarista/test"
)
type yolo struct {
i int32
a [3]int8
p unsafe.Pointer
}
func (y yolo) String() string {
return "Yolo"
}
func TestDeepSizeof(t *testing.T) {
ptrSize := uintptr(unsafe.Sizeof(unsafe.Pointer(t)))
// hmapStructSize represent the size of struct hmap defined in
// file /go/src/runtime/hashmap.go
hmapStructSize := uintptr(unsafe.Sizeof(int(0)) + 2*1 + 2 + 4 +
2*ptrSize + ptrSize + ptrSize)
var alignement uintptr = 4
if ptrSize == 4 {
alignement = 0
}
strHdrSize := unsafe.Sizeof("") // int + ptr to data
sliceHdrSize := 3 * ptrSize // ptr to data + 2 * int
// struct hchan is defined in /go/src/runtime/chan.go
chanHdrSize := 2*ptrSize + ptrSize + 2 + 2 /* padding */ + 4 + ptrSize + 2*ptrSize +
2*(2*ptrSize) + ptrSize
yoloSize := unsafe.Sizeof(yolo{})
interfaceSize := 2 * ptrSize
topHashSize := uintptr(8)
tests := map[string]struct {
getStruct func() interface{}
expectedSize uintptr
}{
"bool": {
getStruct: func() interface{} {
var test bool
return &test
},
expectedSize: 1,
},
"int8": {
getStruct: func() interface{} {
test := int8(4)
return &test
},
expectedSize: 1,
},
"int16": {
getStruct: func() interface{} {
test := int16(4)
return &test
},
expectedSize: 2,
},
"int32": {
getStruct: func() interface{} {
test := int32(4)
return &test
},
expectedSize: 4,
},
"int64": {
getStruct: func() interface{} {
test := int64(4)
return &test
},
expectedSize: 8,
},
"uint": {
getStruct: func() interface{} {
test := uint(4)
return &test
},
expectedSize: ptrSize,
},
"uint8": {
getStruct: func() interface{} {
test := uint8(4)
return &test
},
expectedSize: 1,
},
"uint16": {
getStruct: func() interface{} {
test := uint16(4)
return &test
},
expectedSize: 2,
},
"uint32": {
getStruct: func() interface{} {
test := uint32(4)
return &test
},
expectedSize: 4,
},
"uint64": {
getStruct: func() interface{} {
test := uint64(4)
return &test
},
expectedSize: 8,
},
"uintptr": {
getStruct: func() interface{} {
test := uintptr(4)
return &test
},
expectedSize: ptrSize,
},
"float32": {
getStruct: func() interface{} {
test := float32(4)
return &test
},
expectedSize: 4,
},
"float64": {
getStruct: func() interface{} {
test := float64(4)
return &test
},
expectedSize: 8,
},
"complex64": {
getStruct: func() interface{} {
test := complex64(4 + 1i)
return &test
},
expectedSize: 8,
},
"complex128": {
getStruct: func() interface{} {
test := complex128(4 + 1i)
return &test
},
expectedSize: 16,
},
"string": {
getStruct: func() interface{} {
test := "Hello Dolly!"
return &test
},
expectedSize: strHdrSize + 12,
},
"unsafe_Pointer": {
getStruct: func() interface{} {
tmp := uint64(54)
var test unsafe.Pointer
test = unsafe.Pointer(&tmp)
return &test
},
expectedSize: ptrSize,
}, "rune": {
getStruct: func() interface{} {
test := rune('A')
return &test
},
expectedSize: 4,
}, "intPtr": {
getStruct: func() interface{} {
test := int(4)
return &test
},
expectedSize: ptrSize,
}, "FuncPtr": {
getStruct: func() interface{} {
test := TestDeepSizeof
return &test
},
expectedSize: ptrSize,
}, "struct_1": {
getStruct: func() interface{} {
v := struct {
a uint32
b *uint32
c struct {
e [8]byte
d string
}
f string
}{
a: 10,
c: struct {
e [8]byte
d string
}{
e: [8]byte{0, 1, 2, 3, 4, 5, 6, 7},
d: "Hello Test!",
},
f: "Hello Test!",
}
a := uint32(47)
v.b = &a
return &v
},
expectedSize: 4 + alignement + ptrSize + 8 + strHdrSize*2 +
11 /* "Hello Test!" */ + 4, /* uint32(47) */
}, "struct_2": {
getStruct: func() interface{} {
v := struct {
a []byte
b []byte
c []byte
}{
c: make([]byte, 32, 64),
}
v.a = v.c[20:32]
v.b = v.c[10:20]
return &v
},
expectedSize: 3*sliceHdrSize + 64, /*slice capacity*/
}, "struct_3": {
getStruct: func() interface{} {
type test struct {
a *byte
c []byte
}
tmp := make([]byte, 64, 128)
v := (*test)(unsafe.Pointer(&tmp[16]))
v.c = tmp
v.a = (*byte)(unsafe.Pointer(&tmp[5]))
return v
},
// we expect to see 128 bytes as struct test is part of the bytes slice
// and field c point to it.
expectedSize: 128,
}, "map_string_interface": {
getStruct: func() interface{} {
return &map[string]interface{}{}
},
expectedSize: ptrSize + topHashSize + hmapStructSize +
(8*(strHdrSize+interfaceSize) + ptrSize),
}, "map_interface_interface": {
getStruct: func() interface{} {
// All the map will use only one bucket because there is less than 8
// entries in each map. Also for amd64 and i386 the bucket size is
// computed like in function bucketOf in /go/src/reflect/type.go:
return &map[interface{}]interface{}{
// 2 + (8 + 4) 386
// 2 + (16 + 4) amd64
uint16(123): "yolo",
// (4 + 8) + (4 + 28 + 140) + (4 for SWAG) 386
// (4 + 16) + (8 + 48 + 272) + (4 for SWAG) amd64
"meow": map[string]string{"SWAG": "yolo"},
// (12) + (4 + 12) 386
// (16) + (8 + 16) amd64
yolo{i: 523}: &yolo{i: 126},
// (12) + (12) 386
// (16) + (16) amd64
fmt.Stringer(yolo{i: 123}): yolo{i: 234},
}
},
// Total
// 386: (4 + 28 + 140) + 2 + (8 + 4) + (4 + 8) + (4 + 28 + 140) + 4 + (12) +
// (4 + 12) + 12 + 12
// amd64: (8 + 48 + 272) + 2 + (16 + 4) + (4 + 16) + (8 + 48 + 272) + (4) +
// (16) + (8 + 16) + (16) + (16)
expectedSize: (ptrSize + topHashSize + hmapStructSize +
(8*(2*interfaceSize) + ptrSize)) +
(unsafe.Sizeof(uint16(123)) + strHdrSize + 4 /* "yolo" */) +
(strHdrSize + 4 /* "meow" */ +
(ptrSize + hmapStructSize + topHashSize +
(8*(2*strHdrSize) + ptrSize)) /*map[string]string*/ +
4 /* "SWAG" */) +
(yoloSize /* obj: */ + (ptrSize + yoloSize) /* &obj */) +
yoloSize*2,
}, "struct_4": {
getStruct: func() interface{} {
return &struct {
a map[interface{}]interface{}
c string
d []string
}{
a: map[interface{}]interface{}{
uint16(123): "yolo",
"meow": map[string]string{"SWAG": "yolo"},
yolo{i: 127}: &yolo{i: 124},
fmt.Stringer(yolo{i: 123}): yolo{i: 234},
}, // 4 (386) or 8 (amd64)
c: "Hello", // 8 (386) or 16 (amd64)
d: []string{"Bonjour", "Hello", "Hola"}, // 12 (386) or 24 (amd64)
}
},
// Total
// 386: sizeof(tmp map) + 8(test.c) + 12(test.d) +
// 3 * 8 (strSlice) + 16(len("Bonjour") + len("Hello")...)
// amd64: sizeof(tmp map) + 8 (test.b) + 16(test.c) + 24(test.d) +
// 3 * 16 (strSlice) + 16(len("Bonjour") + len("Hello")...)
expectedSize: (ptrSize + strHdrSize + sliceHdrSize) + (hmapStructSize +
topHashSize + (8*(2*2*ptrSize /* interface size */) + ptrSize) +
unsafe.Sizeof(uint16(123)) +
strHdrSize + 4 /* "yolo" */ + strHdrSize + 4 /* "meow" */ +
+(ptrSize + hmapStructSize + topHashSize + (8*(2*strHdrSize) + ptrSize) +
4 /* "SWAG" */) + yoloSize /* obj */ + (ptrSize + yoloSize) /* &obj */ +
yoloSize*2) + 5 /* "Hello" */ +
3*strHdrSize /*strings in strSlice*/ + 11, /* "Bonjour" + "Hola" */
}, "chan_int": {
getStruct: func() interface{} {
test := make(chan int)
return &test
},
// The expected size should be equal to the size of the struct hchan
// defined in /go/src/runtime/chan.go
expectedSize: ptrSize + chanHdrSize,
}, "chan_int_16": {
getStruct: func() interface{} {
test := make(chan int, 16)
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
}, "chan_yoloPtr_16": {
getStruct: func() interface{} {
test := make(chan *yolo, 16)
for i := 0; i < 16; i++ {
tmp := &yolo{
i: int32(i),
}
tmp.p = unsafe.Pointer(&tmp.i)
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*(ptrSize+yoloSize),
}, "struct_5": {
getStruct: func() interface{} {
tmp := make([]byte, 32)
test := struct {
a []byte
b **uint32
}{
a: tmp,
}
bob := uint32(42)
ptrInt := (*uintptr)(unsafe.Pointer(&tmp[0]))
*ptrInt = uintptr(unsafe.Pointer(&bob))
test.b = (**uint32)(unsafe.Pointer(&tmp[0]))
return &test
},
expectedSize: sliceHdrSize + ptrSize + 32 + 4,
}, "struct_6": {
getStruct: func() interface{} {
type A struct {
a uintptr
b *yolo
}
type B struct {
a *A
b uintptr
}
tmp := make([]byte, 32)
test := struct {
a []byte
b *B
}{
a: tmp,
}
y := yolo{i: 42}
test.b = (*B)(unsafe.Pointer(&tmp[0]))
test.b.a = (*A)(unsafe.Pointer(&tmp[0]))
test.b.a.b = &y
return &test
},
expectedSize: sliceHdrSize + ptrSize + 32 + yoloSize,
}, "chan_chan_int_16": {
getStruct: func() interface{} {
test := make(chan chan int, 16)
for i := 0; i < 16; i++ {
tmp := make(chan int)
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize*17 + 16*ptrSize,
}, "chan_yolo_16": {
getStruct: func() interface{} {
test := make(chan yolo, 16)
for i := 0; i < 16; i++ {
tmp := yolo{
i: int32(i),
}
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*yoloSize,
}, "chan_map_string_interface_16)": {
getStruct: func() interface{} {
test := make(chan map[string]interface{}, 16)
for i := 0; i < 16; i++ {
tmp := make(map[string]interface{})
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*(ptrSize+hmapStructSize+
(8*(1+strHdrSize+interfaceSize)+ptrSize)),
}, "chan_unsafe_Pointer_16": {
getStruct: func() interface{} {
test := make(chan unsafe.Pointer, 16)
for i := 0; i < 16; i++ {
var a int
ptrToA := (unsafe.Pointer)(unsafe.Pointer(&a))
test <- ptrToA
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
}, "chan_[]int_16": {
getStruct: func() interface{} {
test := make(chan []int, 16)
for i := 0; i < 8; i++ {
intSlice := make([]int, 16)
test <- intSlice
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*sliceHdrSize + 8*16*ptrSize,
}, "chan_func": {
getStruct: func() interface{} {
test := make(chan func(), 16)
f := func() {
fmt.Printf("Hello!")
}
for i := 0; i < 8; i++ {
test <- f
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
},
}
for key, tcase := range tests {
t.Run(key, func(t *testing.T) {
v := tcase.getStruct()
m, err := DeepSizeof(v)
if err != nil {
t.Fatal(err)
}
var totalSize uintptr
for _, size := range m {
totalSize += size
}
expectedSize := tcase.expectedSize
if totalSize != expectedSize {
t.Fatalf("Expected size: %v, but got %v", expectedSize, totalSize)
}
})
}
}
func TestUpdateSeenAreas(t *testing.T) {
tests := []struct {
seen []block
expectedSeen []block
expectedSize uintptr
update block
}{{
seen: []block{
{start: 0x100000, end: 0x100050},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100050},
{start: 0x100100, end: 0x100150},
},
expectedSize: 0x50,
update: block{start: 0x100100, end: 0x100150},
}, {
seen: []block{
{start: 0x100000, end: 0x100050},
},
expectedSeen: []block{
{start: 0x100, end: 0x150},
{start: 0x100000, end: 0x100050},
},
expectedSize: 0x50,
update: block{start: 0x100, end: 0x150},
}, {
seen: []block{
{start: 0x100000, end: 0x100500},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100750},
},
expectedSize: 0x250,
update: block{start: 0x100250, end: 0x100750},
}, {
seen: []block{
{start: 0x100250, end: 0x100750},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100750},
},
expectedSize: 0x250,
update: block{start: 0x100000, end: 0x100500},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
},
expectedSeen: []block{
{start: 0x1000, end: 0x1750},
},
expectedSize: 0x2B0,
update: block{start: 0x1200, end: 0x1700},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F50, end: 0x21A0},
},
expectedSeen: []block{
{start: 0xF00, end: 0x1F00},
{start: 0x1F50, end: 0x21A0},
},
expectedSize: 0xB60,
update: block{start: 0xF00, end: 0x1F00},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSeen: []block{
{start: 0xF00, end: 0x2150},
},
expectedSize: 0xB60,
update: block{start: 0xF00, end: 0x1F00},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSeen: []block{
{start: 0x1000, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSize: 0x2B0,
update: block{start: 0x1250, end: 0x1500},
}}
for i, tcase := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
seen, size := updateSeenBlocks(tcase.update, tcase.seen)
if !test.DeepEqual(seen, tcase.expectedSeen) {
t.Fatalf("seen blocks %x for iterration %v are different than the "+
"one expected:\n %x", seen, i, tcase.expectedSeen)
}
if size != tcase.expectedSize {
t.Fatalf("Size does not match, expected 0x%x got 0x%x",
tcase.expectedSize, size)
}
})
}
}

View File

@ -5,7 +5,6 @@
package test
import (
"regexp"
"testing"
"github.com/aristanetworks/goarista/key"
@ -357,21 +356,21 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
b: complexCompare{},
}, {
a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}},
m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}},
m: map[builtinCompare]int8{{1, "foo"}: 42}},
}, {
a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}},
m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 51}},
m: map[builtinCompare]int8{{1, "foo"}: 51}},
diff: `attributes "m" are different: for key test.builtinCompare{a:uint32(1),` +
` b:"foo"} in map, values are different: int8(42) != int8(51)`,
}, {
a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}},
m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "bar"}: 42}},
m: map[builtinCompare]int8{{1, "bar"}: 42}},
diff: `attributes "m" are different: key test.builtinCompare{a:uint32(1),` +
` b:"foo"} in map is missing in the actual map`,
}, {
@ -404,16 +403,16 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
a: partialCompare{a: 42, b: "foo"},
b: partialCompare{a: 42, b: "bar"},
}, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42},
a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{{1, "foo"}: 42},
}, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{2, "foo"}: 42},
a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{{2, "foo"}: 42},
diff: `complex key *test.builtinCompare{a:uint32(1), b:"foo"}` +
` in map is missing in the actual map`,
}, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 51},
a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{{1, "foo"}: 51},
diff: `for complex key *test.builtinCompare{a:uint32(1), b:"foo"}` +
` in map, values are different: uint32(42) != uint32(51)`,
}, {
@ -436,10 +435,11 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
b: key.New(map[string]interface{}{
"a": map[key.Key]interface{}{key.New(map[string]interface{}{"k": 51}): true}}),
diff: `Comparable types are different: ` +
`key.composite{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` +
`{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}} vs` +
` key.composite{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` +
`{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}}`,
`key.compositeKey{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` +
`{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}, s:[]interface {}{}}` +
` vs key.compositeKey{sentinel:uintptr(18379810577513696751), ` +
`m:map[string]interface {}{"a":map[key.Key]interface {}` +
`{<max_depth>:<max_depth>}}, s:[]interface {}{}}`,
}, {
a: code(42),
b: code(42),
@ -464,8 +464,5 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
}, {
a: embedder{builtinCompare: builtinCompare{}},
b: embedder{builtinCompare: builtinCompare{}},
}, {
a: regexp.MustCompile("foo.*bar"),
b: regexp.MustCompile("foo.*bar"),
}}
}

View File

@ -6,12 +6,14 @@ package test
import (
"io"
"io/ioutil"
"os"
"testing"
)
// CopyFile copies a file
func CopyFile(t *testing.T, srcPath, dstPath string) {
t.Helper()
src, err := os.Open(srcPath)
if err != nil {
t.Fatal(err)
@ -27,3 +29,14 @@ func CopyFile(t *testing.T, srcPath, dstPath string) {
t.Fatal(err)
}
}
// TempDir creates a temporary directory under the default directory for temporary files (see
// os.TempDir) and returns the path of the new directory or fails the test trying.
func TempDir(t *testing.T, dirName string) string {
t.Helper()
tempDir, err := ioutil.TempDir("", dirName)
if err != nil {
t.Fatal(err)
}
return tempDir
}

View File

@ -12,7 +12,9 @@ import (
// ShouldPanic will test is a function is panicking
func ShouldPanic(t *testing.T, fn func()) {
t.Helper()
defer func() {
t.Helper()
if r := recover(); r == nil {
t.Errorf("%sThe function %p should have panicked",
getCallerInfo(), fn)
@ -24,10 +26,12 @@ func ShouldPanic(t *testing.T, fn func()) {
// ShouldPanicWith will test is a function is panicking with a specific message
func ShouldPanicWith(t *testing.T, msg interface{}, fn func()) {
t.Helper()
defer func() {
t.Helper()
if r := recover(); r == nil {
t.Errorf("%sThe function %p should have panicked",
getCallerInfo(), fn)
t.Errorf("%sThe function %p should have panicked with %#v",
getCallerInfo(), fn, msg)
} else if d := Diff(msg, r); len(d) != 0 {
t.Errorf("%sThe function %p panicked with the wrong message.\n"+
"Expected: %#v\nReceived: %#v\nDiff:%s",

View File

@ -1,7 +1,7 @@
language: go
go:
- 1.8.x
- 1.9.x
- "1.9.5"
- "1.10.1"
sudo: false
install:
- GLIDE_TAG=v0.12.3
@ -10,8 +10,8 @@ install:
- export PATH=$PATH:$PWD/linux-amd64/
- glide install
- go install . ./cmd/...
- go get -v github.com/alecthomas/gometalinter
- gometalinter --install
- go get -u gopkg.in/alecthomas/gometalinter.v2
- gometalinter.v2 --install
script:
- export PATH=$PATH:$HOME/gopath/bin
- ./goclean.sh

View File

@ -570,8 +570,8 @@ Changes in 0.8.0-beta (Sun May 25 2014)
- btcctl utility changes:
- Add createencryptedwallet command
- Add getblockchaininfo command
- Add importwallet commmand
- Add addmultisigaddress commmand
- Add importwallet command
- Add addmultisigaddress command
- Add setgenerate command
- Accept --testnet and --wallet flags which automatically select
the appropriate port and TLS certificates needed to communicate

View File

@ -92,7 +92,7 @@ $ go install . ./cmd/...
## Getting Started
btcd has several configuration options avilable to tweak how it runs, but all
btcd has several configuration options available to tweak how it runs, but all
of the basic operations described in the intro section work with zero
configuration.

View File

@ -63,7 +63,7 @@ is by no means exhaustive:
* [ProcessBlock Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BlockChain-ProcessBlock)
Demonstrates how to create a new chain instance and use ProcessBlock to
attempt to attempt add a block to the chain. This example intentionally
attempt to add a block to the chain. This example intentionally
attempts to insert a duplicate genesis block to illustrate how an invalid
block is handled.
@ -73,7 +73,7 @@ is by no means exhaustive:
typical hex notation.
* [BigToCompact Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BigToCompact)
Demonstrates how to convert how to convert a target difficulty into the
Demonstrates how to convert a target difficulty into the
compact "bits" in a block header which represent that target difficulty.
## GPG Verification Key

View File

@ -54,23 +54,24 @@ func (b *BlockChain) maybeAcceptBlock(block *btcutil.Block, flags BehaviorFlags)
// such as making blocks that never become part of the main chain or
// blocks that fail to connect available for further analysis.
err = b.db.Update(func(dbTx database.Tx) error {
return dbMaybeStoreBlock(dbTx, block)
return dbStoreBlock(dbTx, block)
})
if err != nil {
return false, err
}
// Create a new block node for the block and add it to the in-memory
// block chain (could be either a side chain or the main chain).
// Create a new block node for the block and add it to the node index. Even
// if the block ultimately gets connected to the main chain, it starts out
// on a side chain.
blockHeader := &block.MsgBlock().Header
newNode := newBlockNode(blockHeader, blockHeight)
newNode := newBlockNode(blockHeader, prevNode)
newNode.status = statusDataStored
if prevNode != nil {
newNode.parent = prevNode
newNode.height = blockHeight
newNode.workSum.Add(prevNode.workSum, newNode.workSum)
}
b.index.AddNode(newNode)
err = b.index.flushToDB()
if err != nil {
return false, err
}
// Connect the passed block to the chain while respecting proper chain
// selection according to the chain with the most proof of work. This

View File

@ -101,33 +101,33 @@ type blockNode struct {
status blockStatus
}
// initBlockNode initializes a block node from the given header and height. The
// node is completely disconnected from the chain and the workSum value is just
// the work for the passed block. The work sum must be updated accordingly when
// the node is inserted into a chain.
//
// initBlockNode initializes a block node from the given header and parent node,
// calculating the height and workSum from the respective fields on the parent.
// This function is NOT safe for concurrent access. It must only be called when
// initially creating a node.
func initBlockNode(node *blockNode, blockHeader *wire.BlockHeader, height int32) {
func initBlockNode(node *blockNode, blockHeader *wire.BlockHeader, parent *blockNode) {
*node = blockNode{
hash: blockHeader.BlockHash(),
workSum: CalcWork(blockHeader.Bits),
height: height,
version: blockHeader.Version,
bits: blockHeader.Bits,
nonce: blockHeader.Nonce,
timestamp: blockHeader.Timestamp.Unix(),
merkleRoot: blockHeader.MerkleRoot,
}
if parent != nil {
node.parent = parent
node.height = parent.height + 1
node.workSum = node.workSum.Add(parent.workSum, node.workSum)
}
}
// newBlockNode returns a new block node for the given block header. It is
// completely disconnected from the chain and the workSum value is just the work
// for the passed block. The work sum must be updated accordingly when the node
// is inserted into a chain.
func newBlockNode(blockHeader *wire.BlockHeader, height int32) *blockNode {
// newBlockNode returns a new block node for the given block header and parent
// node, calculating the height and workSum from the respective fields on the
// parent. This function is NOT safe for concurrent access.
func newBlockNode(blockHeader *wire.BlockHeader, parent *blockNode) *blockNode {
var node blockNode
initBlockNode(&node, blockHeader, height)
initBlockNode(&node, blockHeader, parent)
return &node
}
@ -136,7 +136,7 @@ func newBlockNode(blockHeader *wire.BlockHeader, height int32) *blockNode {
// This function is safe for concurrent access.
func (node *blockNode) Header() wire.BlockHeader {
// No lock is needed because all accessed fields are immutable.
prevHash := zeroHash
prevHash := &zeroHash
if node.parent != nil {
prevHash = &node.parent.hash
}
@ -231,6 +231,7 @@ type blockIndex struct {
sync.RWMutex
index map[chainhash.Hash]*blockNode
dirty map[*blockNode]struct{}
}
// newBlockIndex returns a new empty instance of a block index. The index will
@ -241,6 +242,7 @@ func newBlockIndex(db database.DB, chainParams *chaincfg.Params) *blockIndex {
db: db,
chainParams: chainParams,
index: make(map[chainhash.Hash]*blockNode),
dirty: make(map[*blockNode]struct{}),
}
}
@ -265,16 +267,25 @@ func (bi *blockIndex) LookupNode(hash *chainhash.Hash) *blockNode {
return node
}
// AddNode adds the provided node to the block index. Duplicate entries are not
// checked so it is up to caller to avoid adding them.
// AddNode adds the provided node to the block index and marks it as dirty.
// Duplicate entries are not checked so it is up to caller to avoid adding them.
//
// This function is safe for concurrent access.
func (bi *blockIndex) AddNode(node *blockNode) {
bi.Lock()
bi.index[node.hash] = node
bi.addNode(node)
bi.dirty[node] = struct{}{}
bi.Unlock()
}
// addNode adds the provided node to the block index, but does not mark it as
// dirty. This can be used while initializing the block index.
//
// This function is NOT safe for concurrent access.
func (bi *blockIndex) addNode(node *blockNode) {
bi.index[node.hash] = node
}
// NodeStatus provides concurrent-safe access to the status field of a node.
//
// This function is safe for concurrent access.
@ -293,6 +304,7 @@ func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus {
func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock()
node.status |= flags
bi.dirty[node] = struct{}{}
bi.Unlock()
}
@ -303,5 +315,34 @@ func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock()
node.status &^= flags
bi.dirty[node] = struct{}{}
bi.Unlock()
}
// flushToDB writes all dirty block nodes to the database. If all writes
// succeed, this clears the dirty set.
func (bi *blockIndex) flushToDB() error {
bi.Lock()
if len(bi.dirty) == 0 {
bi.Unlock()
return nil
}
err := bi.db.Update(func(dbTx database.Tx) error {
for node := range bi.dirty {
err := dbStoreBlockNode(dbTx, node)
if err != nil {
return err
}
}
return nil
})
// If write was successful, clear the dirty set.
if err == nil {
bi.dirty = make(map[*blockNode]struct{})
}
bi.Unlock()
return err
}

View File

@ -1,4 +1,5 @@
// Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2013-2018 The btcsuite developers
// Copyright (c) 2015-2018 The Decred developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
@ -399,7 +400,7 @@ func (b *BlockChain) calcSequenceLock(node *blockNode, tx *btcutil.Tx, utxoView
nextHeight := node.height + 1
for txInIndex, txIn := range mTx.TxIn {
utxo := utxoView.LookupEntry(&txIn.PreviousOutPoint.Hash)
utxo := utxoView.LookupEntry(txIn.PreviousOutPoint)
if utxo == nil {
str := fmt.Sprintf("output %v referenced from "+
"transaction %s:%d either does not exist or "+
@ -495,6 +496,8 @@ func LockTimeToSequence(isSeconds bool, locktime uint32) uint32 {
// passed node is the new end of the main chain. The lists will be empty if the
// passed node is not on a side chain.
//
// This function may modify node statuses in the block index without flushing.
//
// This function MUST be called with the chain state lock held (for reads).
func (b *BlockChain) getReorganizeNodes(node *blockNode) (*list.List, *list.List) {
attachNodes := list.New()
@ -544,20 +547,6 @@ func (b *BlockChain) getReorganizeNodes(node *blockNode) (*list.List, *list.List
return detachNodes, attachNodes
}
// dbMaybeStoreBlock stores the provided block in the database if it's not
// already there.
func dbMaybeStoreBlock(dbTx database.Tx, block *btcutil.Block) error {
hasBlock, err := dbTx.HasBlock(block.Hash())
if err != nil {
return err
}
if hasBlock {
return nil
}
return dbTx.StoreBlock(block)
}
// connectBlock handles connecting the passed node/block to the end of the main
// (best) chain.
//
@ -569,7 +558,9 @@ func dbMaybeStoreBlock(dbTx database.Tx, block *btcutil.Block) error {
// it would be inefficient to repeat it.
//
// This function MUST be called with the chain state lock held (for writes).
func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *UtxoViewpoint, stxos []spentTxOut) error {
func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block,
view *UtxoViewpoint, stxos []SpentTxOut) error {
// Make sure it's extending the end of the best chain.
prevHash := &block.MsgBlock().Header.PrevBlock
if !prevHash.IsEqual(&b.bestChain.Tip().hash) {
@ -599,6 +590,12 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
}
}
// Write any block status changes to DB before updating best state.
err := b.index.flushToDB()
if err != nil {
return err
}
// Generate a new best state snapshot that will be used to update the
// database and later memory if all database updates are successful.
b.stateLock.RLock()
@ -611,7 +608,7 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
curTotalTxns+numTxns, node.CalcPastMedianTime())
// Atomically insert info into the database.
err := b.db.Update(func(dbTx database.Tx) error {
err = b.db.Update(func(dbTx database.Tx) error {
// Update best block state.
err := dbPutBestState(dbTx, state, node.workSum)
if err != nil {
@ -644,7 +641,7 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
// optional indexes with the block being connected so they can
// update themselves accordingly.
if b.indexManager != nil {
err := b.indexManager.ConnectBlock(dbTx, block, view)
err := b.indexManager.ConnectBlock(dbTx, block, stxos)
if err != nil {
return err
}
@ -705,6 +702,12 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
return err
}
// Write any block status changes to DB before updating best state.
err = b.index.flushToDB()
if err != nil {
return err
}
// Generate a new best state snapshot that will be used to update the
// database and later memory if all database updates are successful.
b.stateLock.RLock()
@ -739,8 +742,15 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
return err
}
// Before we delete the spend journal entry for this back,
// we'll fetch it as is so the indexers can utilize if needed.
stxos, err := dbFetchSpendJournalEntry(dbTx, block)
if err != nil {
return err
}
// Update the transaction spend journal by removing the record
// that contains all txos spent by the block .
// that contains all txos spent by the block.
err = dbRemoveSpendJournalEntry(dbTx, block.Hash())
if err != nil {
return err
@ -750,7 +760,7 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
// optional indexes with the block being disconnected so they
// can update themselves accordingly.
if b.indexManager != nil {
err := b.indexManager.DisconnectBlock(dbTx, block, view)
err := b.indexManager.DisconnectBlock(dbTx, block, stxos)
if err != nil {
return err
}
@ -806,15 +816,49 @@ func countSpentOutputs(block *btcutil.Block) int {
// the chain) and nodes the are being attached must be in forwards order
// (think pushing them onto the end of the chain).
//
// This function may modify node statuses in the block index without flushing.
//
// This function MUST be called with the chain state lock held (for writes).
func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error {
// Nothing to do if no reorganize nodes were provided.
if detachNodes.Len() == 0 && attachNodes.Len() == 0 {
return nil
}
// Ensure the provided nodes match the current best chain.
tip := b.bestChain.Tip()
if detachNodes.Len() != 0 {
firstDetachNode := detachNodes.Front().Value.(*blockNode)
if firstDetachNode.hash != tip.hash {
return AssertError(fmt.Sprintf("reorganize nodes to detach are "+
"not for the current best chain -- first detach node %v, "+
"current chain %v", &firstDetachNode.hash, &tip.hash))
}
}
// Ensure the provided nodes are for the same fork point.
if attachNodes.Len() != 0 && detachNodes.Len() != 0 {
firstAttachNode := attachNodes.Front().Value.(*blockNode)
lastDetachNode := detachNodes.Back().Value.(*blockNode)
if firstAttachNode.parent.hash != lastDetachNode.parent.hash {
return AssertError(fmt.Sprintf("reorganize nodes do not have the "+
"same fork point -- first attach parent %v, last detach "+
"parent %v", &firstAttachNode.parent.hash,
&lastDetachNode.parent.hash))
}
}
// Track the old and new best chains heads.
oldBest := tip
newBest := tip
// All of the blocks to detach and related spend journal entries needed
// to unspend transaction outputs in the blocks being disconnected must
// be loaded from the database during the reorg check phase below and
// then they are needed again when doing the actual database updates.
// Rather than doing two loads, cache the loaded data into these slices.
detachBlocks := make([]*btcutil.Block, 0, detachNodes.Len())
detachSpentTxOuts := make([][]spentTxOut, 0, detachNodes.Len())
detachSpentTxOuts := make([][]SpentTxOut, 0, detachNodes.Len())
attachBlocks := make([]*btcutil.Block, 0, attachNodes.Len())
// Disconnect all of the blocks back to the point of the fork. This
@ -822,7 +866,7 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// database and using that information to unspend all of the spent txos
// and remove the utxos created by the blocks.
view := NewUtxoViewpoint()
view.SetBestHash(&b.bestChain.Tip().hash)
view.SetBestHash(&oldBest.hash)
for e := detachNodes.Front(); e != nil; e = e.Next() {
n := e.Value.(*blockNode)
var block *btcutil.Block
@ -834,6 +878,11 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
if err != nil {
return err
}
if n.hash != *block.Hash() {
return AssertError(fmt.Sprintf("detach block node hash %v (height "+
"%v) does not match previous parent block hash %v", &n.hash,
n.height, block.Hash()))
}
// Load all of the utxos referenced by the block that aren't
// already in the view.
@ -844,9 +893,9 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Load all of the spent txos for the block from the spend
// journal.
var stxos []spentTxOut
var stxos []SpentTxOut
err = b.db.View(func(dbTx database.Tx) error {
stxos, err = dbFetchSpendJournalEntry(dbTx, block, view)
stxos, err = dbFetchSpendJournalEntry(dbTx, block)
return err
})
if err != nil {
@ -857,10 +906,19 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
detachBlocks = append(detachBlocks, block)
detachSpentTxOuts = append(detachSpentTxOuts, stxos)
err = view.disconnectTransactions(block, stxos)
err = view.disconnectTransactions(b.db, block, stxos)
if err != nil {
return err
}
newBest = n.parent
}
// Set the fork point only if there are nodes to attach since otherwise
// blocks are only being disconnected and thus there is no fork point.
var forkNode *blockNode
if attachNodes.Len() > 0 {
forkNode = newBest
}
// Perform several checks to verify each block that needs to be attached
@ -875,17 +933,9 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// at least a couple of ways accomplish that rollback, but both involve
// tweaking the chain and/or database. This approach catches these
// issues before ever modifying the chain.
var validationError error
for e := attachNodes.Front(); e != nil; e = e.Next() {
n := e.Value.(*blockNode)
// If any previous nodes in attachNodes failed validation,
// mark this one as having an invalid ancestor.
if validationError != nil {
b.index.SetStatusFlags(n, statusInvalidAncestor)
continue
}
var block *btcutil.Block
err := b.db.View(func(dbTx database.Tx) error {
var err error
@ -911,6 +961,8 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
if err != nil {
return err
}
newBest = n
continue
}
@ -918,23 +970,24 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// thus will not be generated. This is done because the state
// is not being immediately written to the database, so it is
// not needed.
//
// In the case the block is determined to be invalid due to a
// rule violation, mark it as invalid and mark all of its
// descendants as having an invalid ancestor.
err = b.checkConnectBlock(n, block, view, nil)
if err != nil {
// If the block failed validation mark it as invalid, then
// continue to loop through remaining nodes, marking them as
// having an invalid ancestor.
if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(n, statusValidateFailed)
validationError = err
continue
for de := e.Next(); de != nil; de = de.Next() {
dn := de.Value.(*blockNode)
b.index.SetStatusFlags(dn, statusInvalidAncestor)
}
}
return err
}
b.index.SetStatusFlags(n, statusValid)
}
if validationError != nil {
return validationError
newBest = n
}
// Reset the view for the actual connection code below. This is
@ -959,7 +1012,8 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Update the view to unspend all of the spent txos and remove
// the utxos created by the block.
err = view.disconnectTransactions(block, detachSpentTxOuts[i])
err = view.disconnectTransactions(b.db, block,
detachSpentTxOuts[i])
if err != nil {
return err
}
@ -987,7 +1041,7 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// as spent and add all transactions being created by this block
// to it. Also, provide an stxo slice so the spent txout
// details are generated.
stxos := make([]spentTxOut, 0, countSpentOutputs(block))
stxos := make([]SpentTxOut, 0, countSpentOutputs(block))
err = view.connectTransactions(block, &stxos)
if err != nil {
return err
@ -1002,12 +1056,14 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Log the point where the chain forked and old and new best chain
// heads.
firstAttachNode := attachNodes.Front().Value.(*blockNode)
firstDetachNode := detachNodes.Front().Value.(*blockNode)
lastAttachNode := attachNodes.Back().Value.(*blockNode)
log.Infof("REORGANIZE: Chain forks at %v", firstAttachNode.parent.hash)
log.Infof("REORGANIZE: Old best chain head was %v", firstDetachNode.hash)
log.Infof("REORGANIZE: New best chain head is %v", lastAttachNode.hash)
if forkNode != nil {
log.Infof("REORGANIZE: Chain forks at %v (height %v)", forkNode.hash,
forkNode.height)
}
log.Infof("REORGANIZE: Old best chain head was %v (height %v)",
&oldBest.hash, oldBest.height)
log.Infof("REORGANIZE: New best chain head is %v (height %v)",
newBest.hash, newBest.height)
return nil
}
@ -1029,6 +1085,17 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, flags BehaviorFlags) (bool, error) {
fastAdd := flags&BFFastAdd == BFFastAdd
flushIndexState := func() {
// Intentionally ignore errors writing updated node status to DB. If
// it fails to write, it's not the end of the world. If the block is
// valid, we flush in connectBlock and if the block is invalid, the
// worst that can happen is we revalidate the block after a restart.
if writeErr := b.index.flushToDB(); writeErr != nil {
log.Warnf("Error flushing block index changes to disk: %v",
writeErr)
}
}
// We are extending the main (best) chain with a new block. This is the
// most common case.
parentHash := &block.MsgBlock().Header.PrevBlock
@ -1041,16 +1108,22 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// actually connecting the block.
view := NewUtxoViewpoint()
view.SetBestHash(parentHash)
stxos := make([]spentTxOut, 0, countSpentOutputs(block))
stxos := make([]SpentTxOut, 0, countSpentOutputs(block))
if !fastAdd {
err := b.checkConnectBlock(node, block, view, &stxos)
if err != nil {
if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(node, statusValidateFailed)
}
if err == nil {
b.index.SetStatusFlags(node, statusValid)
} else if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(node, statusValidateFailed)
} else {
return false, err
}
flushIndexState()
if err != nil {
return false, err
}
b.index.SetStatusFlags(node, statusValid)
}
// In the fast add case the code to check the block connection
@ -1071,9 +1144,28 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// Connect the block to the main chain.
err := b.connectBlock(node, block, view, stxos)
if err != nil {
// If we got hit with a rule error, then we'll mark
// that status of the block as invalid and flush the
// index state to disk before returning with the error.
if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(
node, statusValidateFailed,
)
}
flushIndexState()
return false, err
}
// If this is fast add, or this block node isn't yet marked as
// valid, then we'll update its status and flush the state to
// disk again.
if fastAdd || !b.index.NodeStatus(node).KnownValid() {
b.index.SetStatusFlags(node, statusValid)
flushIndexState()
}
return true, nil
}
if fastAdd {
@ -1111,11 +1203,16 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// Reorganize the chain.
log.Infof("REORGANIZE: Block %v is causing a reorganize.", node.hash)
err := b.reorganizeChain(detachNodes, attachNodes)
if err != nil {
return false, err
// Either getReorganizeNodes or reorganizeChain could have made unsaved
// changes to the block index, so flush regardless of whether there was an
// error. The index would only be dirty if the block failed to connect, so
// we can ignore any errors writing.
if writeErr := b.index.flushToDB(); writeErr != nil {
log.Warnf("Error flushing block index changes to disk: %v", writeErr)
}
return true, nil
return err == nil, err
}
// isCurrent returns whether or not the chain believes it is current. Several
@ -1168,25 +1265,17 @@ func (b *BlockChain) BestSnapshot() *BestState {
return snapshot
}
// FetchHeader returns the block header identified by the given hash or an error
// if it doesn't exist.
func (b *BlockChain) FetchHeader(hash *chainhash.Hash) (wire.BlockHeader, error) {
// Reconstruct the header from the block index if possible.
if node := b.index.LookupNode(hash); node != nil {
return node.Header(), nil
}
// Fall back to loading it from the database.
var header *wire.BlockHeader
err := b.db.View(func(dbTx database.Tx) error {
var err error
header, err = dbFetchHeaderByHash(dbTx, hash)
return err
})
if err != nil {
// HeaderByHash returns the block header identified by the given hash or an
// error if it doesn't exist. Note that this will return headers from both the
// main and side chains.
func (b *BlockChain) HeaderByHash(hash *chainhash.Hash) (wire.BlockHeader, error) {
node := b.index.LookupNode(hash)
if node == nil {
err := fmt.Errorf("block %s is not known", hash)
return wire.BlockHeader{}, err
}
return *header, nil
return node.Header(), nil
}
// MainChainHasBlock returns whether or not the block with the given hash is in
@ -1302,6 +1391,87 @@ func (b *BlockChain) HeightRange(startHeight, endHeight int32) ([]chainhash.Hash
return hashes, nil
}
// HeightToHashRange returns a range of block hashes for the given start height
// and end hash, inclusive on both ends. The hashes are for all blocks that are
// ancestors of endHash with height greater than or equal to startHeight. The
// end hash must belong to a block that is known to be valid.
//
// This function is safe for concurrent access.
func (b *BlockChain) HeightToHashRange(startHeight int32,
endHash *chainhash.Hash, maxResults int) ([]chainhash.Hash, error) {
endNode := b.index.LookupNode(endHash)
if endNode == nil {
return nil, fmt.Errorf("no known block header with hash %v", endHash)
}
if !b.index.NodeStatus(endNode).KnownValid() {
return nil, fmt.Errorf("block %v is not yet validated", endHash)
}
endHeight := endNode.height
if startHeight < 0 {
return nil, fmt.Errorf("start height (%d) is below 0", startHeight)
}
if startHeight > endHeight {
return nil, fmt.Errorf("start height (%d) is past end height (%d)",
startHeight, endHeight)
}
resultsLength := int(endHeight - startHeight + 1)
if resultsLength > maxResults {
return nil, fmt.Errorf("number of results (%d) would exceed max (%d)",
resultsLength, maxResults)
}
// Walk backwards from endHeight to startHeight, collecting block hashes.
node := endNode
hashes := make([]chainhash.Hash, resultsLength)
for i := resultsLength - 1; i >= 0; i-- {
hashes[i] = node.hash
node = node.parent
}
return hashes, nil
}
// IntervalBlockHashes returns hashes for all blocks that are ancestors of
// endHash where the block height is a positive multiple of interval.
//
// This function is safe for concurrent access.
func (b *BlockChain) IntervalBlockHashes(endHash *chainhash.Hash, interval int,
) ([]chainhash.Hash, error) {
endNode := b.index.LookupNode(endHash)
if endNode == nil {
return nil, fmt.Errorf("no known block header with hash %v", endHash)
}
if !b.index.NodeStatus(endNode).KnownValid() {
return nil, fmt.Errorf("block %v is not yet validated", endHash)
}
endHeight := endNode.height
resultsLength := int(endHeight) / interval
hashes := make([]chainhash.Hash, resultsLength)
b.bestChain.mtx.Lock()
defer b.bestChain.mtx.Unlock()
blockNode := endNode
for index := int(endHeight) / interval; index > 0; index-- {
// Use the bestChain chainView for faster lookups once lookup intersects
// the best chain.
blockHeight := int32(index * interval)
if b.bestChain.contains(blockNode) {
blockNode = b.bestChain.nodeByHeight(blockHeight)
} else {
blockNode = blockNode.Ancestor(blockHeight)
}
hashes[index-1] = blockNode.hash
}
return hashes, nil
}
// locateInventory returns the node of the block after the first known block in
// the locator along with the number of subsequent nodes needed to either reach
// the provided stop hash or the provided max number of entries.
@ -1467,12 +1637,16 @@ type IndexManager interface {
Init(*BlockChain, <-chan struct{}) error
// ConnectBlock is invoked when a new block has been connected to the
// main chain.
ConnectBlock(database.Tx, *btcutil.Block, *UtxoViewpoint) error
// main chain. The set of output spent within a block is also passed in
// so indexers can access the previous output scripts input spent if
// required.
ConnectBlock(database.Tx, *btcutil.Block, []SpentTxOut) error
// DisconnectBlock is invoked when a block has been disconnected from
// the main chain.
DisconnectBlock(database.Tx, *btcutil.Block, *UtxoViewpoint) error
// the main chain. The set of outputs scripts that were spent within
// this block is also returned so indexers can clean up the prior index
// state for this block.
DisconnectBlock(database.Tx, *btcutil.Block, []SpentTxOut) error
}
// Config is a descriptor which specifies the blockchain instance configuration.
@ -1601,6 +1775,11 @@ func New(config *Config) (*BlockChain, error) {
return nil, err
}
// Perform any upgrades to the various chain-specific buckets as needed.
if err := b.maybeUpgradeDbBuckets(config.Interrupt); err != nil {
return nil, err
}
// Initialize and catch up all of the currently active optional indexes
// as needed.
if config.IndexManager != nil {

View File

@ -800,3 +800,167 @@ func TestLocateInventory(t *testing.T) {
}
}
}
// TestHeightToHashRange ensures that fetching a range of block hashes by start
// height and end hash works as expected.
func TestHeightToHashRange(t *testing.T) {
// Construct a synthetic block chain with a block index consisting of
// the following structure.
// genesis -> 1 -> 2 -> ... -> 15 -> 16 -> 17 -> 18
// \-> 16a -> 17a -> 18a (unvalidated)
tip := tstTip
chain := newFakeChain(&chaincfg.MainNetParams)
branch0Nodes := chainedNodes(chain.bestChain.Genesis(), 18)
branch1Nodes := chainedNodes(branch0Nodes[14], 3)
for _, node := range branch0Nodes {
chain.index.SetStatusFlags(node, statusValid)
chain.index.AddNode(node)
}
for _, node := range branch1Nodes {
if node.height < 18 {
chain.index.SetStatusFlags(node, statusValid)
}
chain.index.AddNode(node)
}
chain.bestChain.SetTip(tip(branch0Nodes))
tests := []struct {
name string
startHeight int32 // locator for requested inventory
endHash chainhash.Hash // stop hash for locator
maxResults int // max to locate, 0 = wire const
hashes []chainhash.Hash // expected located hashes
expectError bool
}{
{
name: "blocks below tip",
startHeight: 11,
endHash: branch0Nodes[14].hash,
maxResults: 10,
hashes: nodeHashes(branch0Nodes, 10, 11, 12, 13, 14),
},
{
name: "blocks on main chain",
startHeight: 15,
endHash: branch0Nodes[17].hash,
maxResults: 10,
hashes: nodeHashes(branch0Nodes, 14, 15, 16, 17),
},
{
name: "blocks on stale chain",
startHeight: 15,
endHash: branch1Nodes[1].hash,
maxResults: 10,
hashes: append(nodeHashes(branch0Nodes, 14),
nodeHashes(branch1Nodes, 0, 1)...),
},
{
name: "invalid start height",
startHeight: 19,
endHash: branch0Nodes[17].hash,
maxResults: 10,
expectError: true,
},
{
name: "too many results",
startHeight: 1,
endHash: branch0Nodes[17].hash,
maxResults: 10,
expectError: true,
},
{
name: "unvalidated block",
startHeight: 15,
endHash: branch1Nodes[2].hash,
maxResults: 10,
expectError: true,
},
}
for _, test := range tests {
hashes, err := chain.HeightToHashRange(test.startHeight, &test.endHash,
test.maxResults)
if err != nil {
if !test.expectError {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
continue
}
if !reflect.DeepEqual(hashes, test.hashes) {
t.Errorf("%s: unxpected hashes -- got %v, want %v",
test.name, hashes, test.hashes)
}
}
}
// TestIntervalBlockHashes ensures that fetching block hashes at specified
// intervals by end hash works as expected.
func TestIntervalBlockHashes(t *testing.T) {
// Construct a synthetic block chain with a block index consisting of
// the following structure.
// genesis -> 1 -> 2 -> ... -> 15 -> 16 -> 17 -> 18
// \-> 16a -> 17a -> 18a (unvalidated)
tip := tstTip
chain := newFakeChain(&chaincfg.MainNetParams)
branch0Nodes := chainedNodes(chain.bestChain.Genesis(), 18)
branch1Nodes := chainedNodes(branch0Nodes[14], 3)
for _, node := range branch0Nodes {
chain.index.SetStatusFlags(node, statusValid)
chain.index.AddNode(node)
}
for _, node := range branch1Nodes {
if node.height < 18 {
chain.index.SetStatusFlags(node, statusValid)
}
chain.index.AddNode(node)
}
chain.bestChain.SetTip(tip(branch0Nodes))
tests := []struct {
name string
endHash chainhash.Hash
interval int
hashes []chainhash.Hash
expectError bool
}{
{
name: "blocks on main chain",
endHash: branch0Nodes[17].hash,
interval: 8,
hashes: nodeHashes(branch0Nodes, 7, 15),
},
{
name: "blocks on stale chain",
endHash: branch1Nodes[1].hash,
interval: 8,
hashes: append(nodeHashes(branch0Nodes, 7),
nodeHashes(branch1Nodes, 0)...),
},
{
name: "no results",
endHash: branch0Nodes[17].hash,
interval: 20,
hashes: []chainhash.Hash{},
},
{
name: "unvalidated block",
endHash: branch1Nodes[2].hash,
interval: 8,
expectError: true,
},
}
for _, test := range tests {
hashes, err := chain.IntervalBlockHashes(&test.endHash, test.interval)
if err != nil {
if !test.expectError {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
continue
}
if !reflect.DeepEqual(hashes, test.hashes) {
t.Errorf("%s: unxpected hashes -- got %v, want %v",
test.name, hashes, test.hashes)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,6 @@ import (
"reflect"
"testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/database"
"github.com/btcsuite/btcd/wire"
)
@ -38,19 +37,6 @@ func TestErrNotInMainChain(t *testing.T) {
}
}
// maybeDecompress decompresses the amount and public key script fields of the
// stxo and marks it decompressed if needed.
func (o *spentTxOut) maybeDecompress(version int32) {
// Nothing to do if it's not compressed.
if !o.compressed {
return
}
o.amount = int64(decompressTxOutAmount(uint64(o.amount)))
o.pkScript = decompressScript(o.pkScript, version)
o.compressed = false
}
// TestStxoSerialization ensures serializing and deserializing spent transaction
// output entries works as expected.
func TestStxoSerialization(t *testing.T) {
@ -58,43 +44,38 @@ func TestStxoSerialization(t *testing.T) {
tests := []struct {
name string
stxo spentTxOut
txVersion int32 // When the txout is not fully spent.
stxo SpentTxOut
serialized []byte
}{
// From block 170 in main blockchain.
{
name: "Spends last output of coinbase",
stxo: spentTxOut{
amount: 5000000000,
pkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
isCoinBase: true,
height: 9,
version: 1,
stxo: SpentTxOut{
Amount: 5000000000,
PkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
IsCoinBase: true,
Height: 9,
},
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
serialized: hexToBytes("1300320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
},
// Adapted from block 100025 in main blockchain.
{
name: "Spends last output of non coinbase",
stxo: spentTxOut{
amount: 13761000000,
pkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
isCoinBase: false,
height: 100024,
version: 1,
stxo: SpentTxOut{
Amount: 13761000000,
PkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
IsCoinBase: false,
Height: 100024,
},
serialized: hexToBytes("8b99700186c64700b2fb57eadf61e106a100a7445a8c3f67898841ec"),
serialized: hexToBytes("8b99700086c64700b2fb57eadf61e106a100a7445a8c3f67898841ec"),
},
// Adapted from block 100025 in main blockchain.
{
name: "Does not spend last output",
stxo: spentTxOut{
amount: 34405000000,
pkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
version: 1,
name: "Does not spend last output, legacy format",
stxo: SpentTxOut{
Amount: 34405000000,
PkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
},
txVersion: 1,
serialized: hexToBytes("0091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
},
}
@ -104,7 +85,7 @@ func TestStxoSerialization(t *testing.T) {
// actually serializing it is calculated properly.
gotSize := spentTxOutSerializeSize(&test.stxo)
if gotSize != len(test.serialized) {
t.Errorf("spentTxOutSerializeSize (%s): did not get "+
t.Errorf("SpentTxOutSerializeSize (%s): did not get "+
"expected size - got %d, want %d", test.name,
gotSize, len(test.serialized))
continue
@ -129,15 +110,13 @@ func TestStxoSerialization(t *testing.T) {
// Ensure the serialized bytes are decoded back to the expected
// stxo.
var gotStxo spentTxOut
gotBytesRead, err := decodeSpentTxOut(test.serialized, &gotStxo,
test.txVersion)
var gotStxo SpentTxOut
gotBytesRead, err := decodeSpentTxOut(test.serialized, &gotStxo)
if err != nil {
t.Errorf("decodeSpentTxOut (%s): unexpected error: %v",
test.name, err)
continue
}
gotStxo.maybeDecompress(test.stxo.version)
if !reflect.DeepEqual(gotStxo, test.stxo) {
t.Errorf("decodeSpentTxOut (%s) mismatched entries - "+
"got %v, want %v", test.name, gotStxo, test.stxo)
@ -159,53 +138,43 @@ func TestStxoDecodeErrors(t *testing.T) {
tests := []struct {
name string
stxo spentTxOut
txVersion int32 // When the txout is not fully spent.
stxo SpentTxOut
serialized []byte
bytesRead int // Expected number of bytes read.
errType error
}{
{
name: "nothing serialized",
stxo: spentTxOut{},
stxo: SpentTxOut{},
serialized: hexToBytes(""),
errType: errDeserialize(""),
bytesRead: 0,
},
{
name: "no data after header code w/o version",
stxo: spentTxOut{},
name: "no data after header code w/o reserved",
stxo: SpentTxOut{},
serialized: hexToBytes("00"),
errType: errDeserialize(""),
bytesRead: 1,
},
{
name: "no data after header code with version",
stxo: spentTxOut{},
name: "no data after header code with reserved",
stxo: SpentTxOut{},
serialized: hexToBytes("13"),
errType: errDeserialize(""),
bytesRead: 1,
},
{
name: "no data after version",
stxo: spentTxOut{},
serialized: hexToBytes("1301"),
name: "no data after reserved",
stxo: SpentTxOut{},
serialized: hexToBytes("1300"),
errType: errDeserialize(""),
bytesRead: 2,
},
{
name: "no serialized tx version and passed -1",
stxo: spentTxOut{},
txVersion: -1,
serialized: hexToBytes("003205"),
errType: AssertError(""),
bytesRead: 1,
},
{
name: "incomplete compressed txout",
stxo: spentTxOut{},
txVersion: 1,
serialized: hexToBytes("0032"),
stxo: SpentTxOut{},
serialized: hexToBytes("1332"),
errType: errDeserialize(""),
bytesRead: 2,
},
@ -214,7 +183,7 @@ func TestStxoDecodeErrors(t *testing.T) {
for _, test := range tests {
// Ensure the expected error type is returned.
gotBytesRead, err := decodeSpentTxOut(test.serialized,
&test.stxo, test.txVersion)
&test.stxo)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("decodeSpentTxOut (%s): expected error type "+
"does not match - got %T, want %T", test.name,
@ -239,9 +208,8 @@ func TestSpendJournalSerialization(t *testing.T) {
tests := []struct {
name string
entry []spentTxOut
entry []SpentTxOut
blockTxns []*wire.MsgTx
utxoView *UtxoViewpoint
serialized []byte
}{
// From block 2 in main blockchain.
@ -249,18 +217,16 @@ func TestSpendJournalSerialization(t *testing.T) {
name: "No spends",
entry: nil,
blockTxns: nil,
utxoView: NewUtxoViewpoint(),
serialized: nil,
},
// From block 170 in main blockchain.
{
name: "One tx with one input spends last output of coinbase",
entry: []spentTxOut{{
amount: 5000000000,
pkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
isCoinBase: true,
height: 9,
version: 1,
entry: []SpentTxOut{{
Amount: 5000000000,
PkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
IsCoinBase: true,
Height: 9,
}},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1,
@ -281,22 +247,21 @@ func TestSpendJournalSerialization(t *testing.T) {
}},
LockTime: 0,
}},
utxoView: NewUtxoViewpoint(),
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
serialized: hexToBytes("1300320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
},
// Adapted from block 100025 in main blockchain.
{
name: "Two txns when one spends last output, one doesn't",
entry: []spentTxOut{{
amount: 34405000000,
pkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
version: 1,
entry: []SpentTxOut{{
Amount: 34405000000,
PkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
IsCoinBase: false,
Height: 100024,
}, {
amount: 13761000000,
pkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
isCoinBase: false,
height: 100024,
version: 1,
Amount: 13761000000,
PkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
IsCoinBase: false,
Height: 100024,
}},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1,
@ -335,73 +300,7 @@ func TestSpendJournalSerialization(t *testing.T) {
}},
LockTime: 0,
}},
utxoView: &UtxoViewpoint{entries: map[chainhash.Hash]*UtxoEntry{
*newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"): {
version: 1,
isCoinBase: false,
blockHeight: 100024,
sparseOutputs: map[uint32]*utxoOutput{
1: {
amount: 34405000000,
pkScript: hexToBytes("76a9142084541c3931677527a7eafe56fd90207c344eb088ac"),
},
},
},
}},
serialized: hexToBytes("8b99700186c64700b2fb57eadf61e106a100a7445a8c3f67898841ec0091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
},
// Hand crafted.
{
name: "One tx, two inputs from same tx, neither spend last output",
entry: []spentTxOut{{
amount: 165125632,
pkScript: hexToBytes("51"),
version: 1,
}, {
amount: 154370000,
pkScript: hexToBytes("51"),
version: 1,
}},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1,
TxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: *newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"),
Index: 1,
},
SignatureScript: hexToBytes(""),
Sequence: 0xffffffff,
}, {
PreviousOutPoint: wire.OutPoint{
Hash: *newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"),
Index: 2,
},
SignatureScript: hexToBytes(""),
Sequence: 0xffffffff,
}},
TxOut: []*wire.TxOut{{
Value: 165125632,
PkScript: hexToBytes("51"),
}, {
Value: 154370000,
PkScript: hexToBytes("51"),
}},
LockTime: 0,
}},
utxoView: &UtxoViewpoint{entries: map[chainhash.Hash]*UtxoEntry{
*newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"): {
version: 1,
isCoinBase: false,
blockHeight: 100000,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 165712179,
pkScript: hexToBytes("51"),
},
},
},
}},
serialized: hexToBytes("0087bc3707510084c3d19a790751"),
serialized: hexToBytes("8b99700086c64700b2fb57eadf61e106a100a7445a8c3f67898841ec8b99700091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
},
}
@ -417,16 +316,12 @@ func TestSpendJournalSerialization(t *testing.T) {
// Deserialize to a spend journal entry.
gotEntry, err := deserializeSpendJournalEntry(test.serialized,
test.blockTxns, test.utxoView)
test.blockTxns)
if err != nil {
t.Errorf("deserializeSpendJournalEntry #%d (%s) "+
"unexpected error: %v", i, test.name, err)
continue
}
for stxoIdx := range gotEntry {
stxo := &gotEntry[stxoIdx]
stxo.maybeDecompress(test.entry[stxoIdx].version)
}
// Ensure that the deserialized spend journal entry has the
// correct properties.
@ -447,7 +342,6 @@ func TestSpendJournalErrors(t *testing.T) {
tests := []struct {
name string
blockTxns []*wire.MsgTx
utxoView *UtxoViewpoint
serialized []byte
errType error
}{
@ -466,7 +360,6 @@ func TestSpendJournalErrors(t *testing.T) {
}},
LockTime: 0,
}},
utxoView: NewUtxoViewpoint(),
serialized: hexToBytes(""),
errType: AssertError(""),
},
@ -484,7 +377,6 @@ func TestSpendJournalErrors(t *testing.T) {
}},
LockTime: 0,
}},
utxoView: NewUtxoViewpoint(),
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a"),
errType: errDeserialize(""),
},
@ -494,7 +386,7 @@ func TestSpendJournalErrors(t *testing.T) {
// Ensure the expected error type is returned and the returned
// slice is nil.
stxos, err := deserializeSpendJournalEntry(test.serialized,
test.blockTxns, test.utxoView)
test.blockTxns)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("deserializeSpendJournalEntry (%s): expected "+
"error type does not match - got %T, want %T",
@ -521,186 +413,52 @@ func TestUtxoSerialization(t *testing.T) {
serialized []byte
}{
// From tx in main blockchain:
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098:0
{
name: "Only output 0, coinbase",
name: "height 1, coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: true,
amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
blockHeight: 1,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
},
},
packedFlags: tfCoinBase,
},
serialized: hexToBytes("010103320496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52"),
serialized: hexToBytes("03320496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52"),
},
// From tx in main blockchain:
// 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098:0
{
name: "Only output 1, not coinbase",
name: "height 1, coinbase, spent",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 100001,
sparseOutputs: map[uint32]*utxoOutput{
1: {
amount: 1000000,
pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
},
},
},
serialized: hexToBytes("01858c21040700ee8bd501094a7d5ca318da2506de35e1cb025ddc"),
},
// Adapted from tx in main blockchain:
// df3f3f442d9699857f7f49de4ff0b5d0f3448bec31cdc7b5bf6d25f2abd637d5
{
name: "Only output 2, coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: true,
blockHeight: 99004,
sparseOutputs: map[uint32]*utxoOutput{
2: {
amount: 100937281,
pkScript: hexToBytes("76a914da33f77cee27c2a975ed5124d7e4f7f97513510188ac"),
},
},
},
serialized: hexToBytes("0185843c010182b095bf4100da33f77cee27c2a975ed5124d7e4f7f975135101"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2 not coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
amount: 15000000,
pkScript: hexToBytes("76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, 1 marked spent",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
1: { // This won't be serialized.
spent: true,
amount: 1000000,
pkScript: hexToBytes("76a914e43031c3e46f20bf1ccee9553ce815de5a48467588ac"),
},
2: {
amount: 15000000,
pkScript: hexToBytes("76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, output 2 compressed",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
// Uncompressed Amount: 15000000
// Uncompressed PkScript: 76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac
compressed: true,
amount: 137,
pkScript: hexToBytes("00b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, output 2 compressed, packed indexes reversed",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
// Uncompressed Amount: 15000000
// Uncompressed PkScript: 76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac
compressed: true,
amount: 137,
pkScript: hexToBytes("00b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// From tx in main blockchain:
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098
{
name: "Only output 0, coinbase, fully spent",
entry: &UtxoEntry{
version: 1,
isCoinBase: true,
amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
blockHeight: 1,
sparseOutputs: map[uint32]*utxoOutput{
0: {
spent: true,
amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
},
},
packedFlags: tfCoinBase | tfSpent,
},
serialized: nil,
},
// Adapted from tx in main blockchain:
// 1b02d1c8cfef60a189017b9a420c682cf4a0028175f2f563209e4ff61c8c3620
// From tx in main blockchain:
// 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb:1
{
name: "Only output 22, not coinbase",
name: "height 100001, not coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 338156,
sparseOutputs: map[uint32]*utxoOutput{
22: {
spent: false,
amount: 366875659,
pkScript: hexToBytes("a9141dd46a006572d820e448e12d2bbb38640bc718e687"),
},
},
amount: 1000000,
pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
blockHeight: 100001,
packedFlags: 0,
},
serialized: hexToBytes("0193d06c100000108ba5b9e763011dd46a006572d820e448e12d2bbb38640bc718e6"),
serialized: hexToBytes("8b99420700ee8bd501094a7d5ca318da2506de35e1cb025ddc"),
},
// From tx in main blockchain:
// 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb:1
{
name: "height 100001, not coinbase, spent",
entry: &UtxoEntry{
amount: 1000000,
pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
blockHeight: 100001,
packedFlags: tfSpent,
},
serialized: nil,
},
}
@ -719,9 +477,9 @@ func TestUtxoSerialization(t *testing.T) {
continue
}
// Don't try to deserialize if the test entry was fully spent
// since it will have a nil serialization.
if test.entry.IsFullySpent() {
// Don't try to deserialize if the test entry was spent since it
// will have a nil serialization.
if test.entry.IsSpent() {
continue
}
@ -733,12 +491,33 @@ func TestUtxoSerialization(t *testing.T) {
continue
}
// Ensure that the deserialized utxo entry has the same
// properties for the containing transaction and block height.
if utxoEntry.Version() != test.entry.Version() {
// The deserialized entry must not be marked spent since unspent
// entries are not serialized.
if utxoEntry.IsSpent() {
t.Errorf("deserializeUtxoEntry #%d (%s) output should "+
"not be marked spent", i, test.name)
continue
}
// Ensure the deserialized entry has the same properties as the
// ones in the test entry.
if utxoEntry.Amount() != test.entry.Amount() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"version: got %d, want %d", i, test.name,
utxoEntry.Version(), test.entry.Version())
"amounts: got %d, want %d", i, test.name,
utxoEntry.Amount(), test.entry.Amount())
continue
}
if !bytes.Equal(utxoEntry.PkScript(), test.entry.PkScript()) {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"scripts: got %x, want %x", i, test.name,
utxoEntry.PkScript(), test.entry.PkScript())
continue
}
if utxoEntry.BlockHeight() != test.entry.BlockHeight() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"block height: got %d, want %d", i, test.name,
utxoEntry.BlockHeight(), test.entry.BlockHeight())
continue
}
if utxoEntry.IsCoinBase() != test.entry.IsCoinBase() {
@ -747,71 +526,6 @@ func TestUtxoSerialization(t *testing.T) {
utxoEntry.IsCoinBase(), test.entry.IsCoinBase())
continue
}
if utxoEntry.BlockHeight() != test.entry.BlockHeight() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"block height: got %d, want %d", i, test.name,
utxoEntry.BlockHeight(),
test.entry.BlockHeight())
continue
}
if utxoEntry.IsFullySpent() != test.entry.IsFullySpent() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"fully spent: got %v, want %v", i, test.name,
utxoEntry.IsFullySpent(),
test.entry.IsFullySpent())
continue
}
// Ensure all of the outputs in the test entry match the
// spentness of the output in the deserialized entry and the
// deserialized entry does not contain any additional utxos.
var numUnspent int
for outputIndex := range test.entry.sparseOutputs {
gotSpent := utxoEntry.IsOutputSpent(outputIndex)
wantSpent := test.entry.IsOutputSpent(outputIndex)
if !wantSpent {
numUnspent++
}
if gotSpent != wantSpent {
t.Errorf("deserializeUtxoEntry #%d (%s) output "+
"#%d: mismatched spent: got %v, want "+
"%v", i, test.name, outputIndex,
gotSpent, wantSpent)
continue
}
}
if len(utxoEntry.sparseOutputs) != numUnspent {
t.Errorf("deserializeUtxoEntry #%d (%s): mismatched "+
"number of unspent outputs: got %d, want %d", i,
test.name, len(utxoEntry.sparseOutputs),
numUnspent)
continue
}
// Ensure all of the amounts and scripts of the utxos in the
// deserialized entry match the ones in the test entry.
for outputIndex := range utxoEntry.sparseOutputs {
gotAmount := utxoEntry.AmountByIndex(outputIndex)
wantAmount := test.entry.AmountByIndex(outputIndex)
if gotAmount != wantAmount {
t.Errorf("deserializeUtxoEntry #%d (%s) "+
"output #%d: mismatched amounts: got "+
"%d, want %d", i, test.name,
outputIndex, gotAmount, wantAmount)
continue
}
gotPkScript := utxoEntry.PkScriptByIndex(outputIndex)
wantPkScript := test.entry.PkScriptByIndex(outputIndex)
if !bytes.Equal(gotPkScript, wantPkScript) {
t.Errorf("deserializeUtxoEntry #%d (%s) "+
"output #%d mismatched scripts: got "+
"%x, want %x", i, test.name,
outputIndex, gotPkScript, wantPkScript)
continue
}
}
}
}
@ -821,23 +535,21 @@ func TestUtxoEntryHeaderCodeErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
entry *UtxoEntry
code uint64
bytesRead int // Expected number of bytes read.
errType error
name string
entry *UtxoEntry
code uint64
errType error
}{
{
name: "Force assertion due to fully spent tx",
entry: &UtxoEntry{},
errType: AssertError(""),
bytesRead: 0,
name: "Force assertion due to spent output",
entry: &UtxoEntry{packedFlags: tfSpent},
errType: AssertError(""),
},
}
for _, test := range tests {
// Ensure the expected error type is returned and the code is 0.
code, gotBytesRead, err := utxoEntryHeaderCode(test.entry, 0)
code, err := utxoEntryHeaderCode(test.entry)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("utxoEntryHeaderCode (%s): expected error "+
"type does not match - got %T, want %T",
@ -849,14 +561,6 @@ func TestUtxoEntryHeaderCodeErrors(t *testing.T) {
"on error - got %d, want 0", test.name, code)
continue
}
// Ensure the expected number of bytes read is returned.
if gotBytesRead != test.bytesRead {
t.Errorf("utxoEntryHeaderCode (%s): unexpected number "+
"of bytes read - got %d, want %d", test.name,
gotBytesRead, test.bytesRead)
continue
}
}
}
@ -870,29 +574,14 @@ func TestUtxoEntryDeserializeErrors(t *testing.T) {
serialized []byte
errType error
}{
{
name: "no data after version",
serialized: hexToBytes("01"),
errType: errDeserialize(""),
},
{
name: "no data after block height",
serialized: hexToBytes("0101"),
errType: errDeserialize(""),
},
{
name: "no data after header code",
serialized: hexToBytes("010102"),
errType: errDeserialize(""),
},
{
name: "not enough bytes for unspentness bitmap",
serialized: hexToBytes("01017800"),
serialized: hexToBytes("02"),
errType: errDeserialize(""),
},
{
name: "incomplete compressed txout",
serialized: hexToBytes("01010232"),
serialized: hexToBytes("0232"),
errType: errDeserialize(""),
},
}

View File

@ -27,16 +27,11 @@ func chainedNodes(parent *blockNode, numNodes int) []*blockNode {
// This is invalid, but all that is needed is enough to get the
// synthetic tests to work.
header := wire.BlockHeader{Nonce: testNoncePrng.Uint32()}
height := int32(0)
if tip != nil {
header.PrevBlock = tip.hash
height = tip.height + 1
}
node := newBlockNode(&header, height)
node.parent = tip
tip = node
nodes[i] = node
nodes[i] = newBlockNode(&header, tip)
tip = nodes[i]
}
return nodes
}
@ -74,7 +69,7 @@ func zipLocators(locators ...BlockLocator) BlockLocator {
}
// TestChainView ensures all of the exported functionality of chain views works
// as intended with the expection of some special cases which are handled in
// as intended with the exception of some special cases which are handled in
// other tests.
func TestChainView(t *testing.T) {
// Construct a synthetic block index consisting of the following

View File

@ -190,10 +190,10 @@ func chainSetup(dbName string, params *chaincfg.Params) (*BlockChain, func(), er
// loadUtxoView returns a utxo view loaded from a file.
func loadUtxoView(filename string) (*UtxoViewpoint, error) {
// The utxostore file format is:
// <tx hash><serialized utxo len><serialized utxo>
// <tx hash><output index><serialized utxo len><serialized utxo>
//
// The serialized utxo len is a little endian uint32 and the serialized
// utxo uses the format described in chainio.go.
// The output index and serialized utxo len are little endian uint32s
// and the serialized utxo uses the format described in chainio.go.
filename = filepath.Join("testdata", filename)
fi, err := os.Open(filename)
@ -223,7 +223,14 @@ func loadUtxoView(filename string) (*UtxoViewpoint, error) {
return nil, err
}
// Num of serialize utxo entry bytes.
// Output index of the utxo entry.
var index uint32
err = binary.Read(r, binary.LittleEndian, &index)
if err != nil {
return nil, err
}
// Num of serialized utxo entry bytes.
var numBytes uint32
err = binary.Read(r, binary.LittleEndian, &numBytes)
if err != nil {
@ -238,16 +245,98 @@ func loadUtxoView(filename string) (*UtxoViewpoint, error) {
}
// Deserialize it and add it to the view.
utxoEntry, err := deserializeUtxoEntry(serialized)
entry, err := deserializeUtxoEntry(serialized)
if err != nil {
return nil, err
}
view.Entries()[hash] = utxoEntry
view.Entries()[wire.OutPoint{Hash: hash, Index: index}] = entry
}
return view, nil
}
// convertUtxoStore reads a utxostore from the legacy format and writes it back
// out using the latest format. It is only useful for converting utxostore data
// used in the tests, which has already been done. However, the code is left
// available for future reference.
func convertUtxoStore(r io.Reader, w io.Writer) error {
// The old utxostore file format was:
// <tx hash><serialized utxo len><serialized utxo>
//
// The serialized utxo len was a little endian uint32 and the serialized
// utxo uses the format described in upgrade.go.
littleEndian := binary.LittleEndian
for {
// Hash of the utxo entry.
var hash chainhash.Hash
_, err := io.ReadAtLeast(r, hash[:], len(hash[:]))
if err != nil {
// Expected EOF at the right offset.
if err == io.EOF {
break
}
return err
}
// Num of serialized utxo entry bytes.
var numBytes uint32
err = binary.Read(r, littleEndian, &numBytes)
if err != nil {
return err
}
// Serialized utxo entry.
serialized := make([]byte, numBytes)
_, err = io.ReadAtLeast(r, serialized, int(numBytes))
if err != nil {
return err
}
// Deserialize the entry.
entries, err := deserializeUtxoEntryV0(serialized)
if err != nil {
return err
}
// Loop through all of the utxos and write them out in the new
// format.
for outputIdx, entry := range entries {
// Reserialize the entries using the new format.
serialized, err := serializeUtxoEntry(entry)
if err != nil {
return err
}
// Write the hash of the utxo entry.
_, err = w.Write(hash[:])
if err != nil {
return err
}
// Write the output index of the utxo entry.
err = binary.Write(w, littleEndian, outputIdx)
if err != nil {
return err
}
// Write num of serialized utxo entry bytes.
err = binary.Write(w, littleEndian, uint32(len(serialized)))
if err != nil {
return err
}
// Write the serialized utxo.
_, err = w.Write(serialized)
if err != nil {
return err
}
}
}
return nil
}
// TstSetCoinbaseMaturity makes the ability to set the coinbase maturity
// available when running tests.
func (b *BlockChain) TstSetCoinbaseMaturity(maturity uint16) {
@ -261,7 +350,7 @@ func (b *BlockChain) TstSetCoinbaseMaturity(maturity uint16) {
func newFakeChain(params *chaincfg.Params) *BlockChain {
// Create a genesis block node and block index index populated with it
// for use when creating the fake chain below.
node := newBlockNode(&params.GenesisBlock.Header, 0)
node := newBlockNode(&params.GenesisBlock.Header, nil)
index := newBlockIndex(nil, params)
index.AddNode(node)
@ -291,8 +380,5 @@ func newFakeNode(parent *blockNode, blockVersion int32, bits uint32, timestamp t
Bits: bits,
Timestamp: timestamp,
}
node := newBlockNode(header, parent.height+1)
node.parent = parent
node.workSum.Add(parent.workSum, node.workSum)
return node
return newBlockNode(header, parent)
}

View File

@ -241,7 +241,7 @@ func isPubKey(script []byte) (bool, []byte) {
// compressedScriptSize returns the number of bytes the passed script would take
// when encoded with the domain specific compression algorithm described above.
func compressedScriptSize(pkScript []byte, version int32) int {
func compressedScriptSize(pkScript []byte) int {
// Pay-to-pubkey-hash script.
if valid, _ := isPubKeyHash(pkScript); valid {
return 21
@ -268,7 +268,7 @@ func compressedScriptSize(pkScript []byte, version int32) int {
// script, possibly followed by other data, and returns the number of bytes it
// occupies taking into account the special encoding of the script size by the
// domain specific compression algorithm described above.
func decodeCompressedScriptSize(serialized []byte, version int32) int {
func decodeCompressedScriptSize(serialized []byte) int {
scriptSize, bytesRead := deserializeVLQ(serialized)
if bytesRead == 0 {
return 0
@ -296,7 +296,7 @@ func decodeCompressedScriptSize(serialized []byte, version int32) int {
// target byte slice. The target byte slice must be at least large enough to
// handle the number of bytes returned by the compressedScriptSize function or
// it will panic.
func putCompressedScript(target, pkScript []byte, version int32) int {
func putCompressedScript(target, pkScript []byte) int {
// Pay-to-pubkey-hash script.
if valid, hash := isPubKeyHash(pkScript); valid {
target[0] = cstPayToPubKeyHash
@ -344,7 +344,7 @@ func putCompressedScript(target, pkScript []byte, version int32) int {
// NOTE: The script parameter must already have been proven to be long enough
// to contain the number of bytes returned by decodeCompressedScriptSize or it
// will panic. This is acceptable since it is only an internal function.
func decompressScript(compressedPkScript []byte, version int32) []byte {
func decompressScript(compressedPkScript []byte) []byte {
// In practice this function will not be called with a zero-length or
// nil script since the nil script encoding includes the length, however
// the code below assumes the length exists, so just return nil now if
@ -542,43 +542,27 @@ func decompressTxOutAmount(amount uint64) uint64 {
// -----------------------------------------------------------------------------
// compressedTxOutSize returns the number of bytes the passed transaction output
// fields would take when encoded with the format described above. The
// preCompressed flag indicates the provided amount and script are already
// compressed. This is useful since loaded utxo entries are not decompressed
// until the output is accessed.
func compressedTxOutSize(amount uint64, pkScript []byte, version int32, preCompressed bool) int {
if preCompressed {
return serializeSizeVLQ(amount) + len(pkScript)
}
// fields would take when encoded with the format described above.
func compressedTxOutSize(amount uint64, pkScript []byte) int {
return serializeSizeVLQ(compressTxOutAmount(amount)) +
compressedScriptSize(pkScript, version)
compressedScriptSize(pkScript)
}
// putCompressedTxOut potentially compresses the passed amount and script
// according to their domain specific compression algorithms and encodes them
// directly into the passed target byte slice with the format described above.
// The preCompressed flag indicates the provided amount and script are already
// compressed in which case the values are not modified. This is useful since
// loaded utxo entries are not decompressed until the output is accessed. The
// target byte slice must be at least large enough to handle the number of bytes
// returned by the compressedTxOutSize function or it will panic.
func putCompressedTxOut(target []byte, amount uint64, pkScript []byte, version int32, preCompressed bool) int {
if preCompressed {
offset := putVLQ(target, amount)
copy(target[offset:], pkScript)
return offset + len(pkScript)
}
// putCompressedTxOut compresses the passed amount and script according to their
// domain specific compression algorithms and encodes them directly into the
// passed target byte slice with the format described above. The target byte
// slice must be at least large enough to handle the number of bytes returned by
// the compressedTxOutSize function or it will panic.
func putCompressedTxOut(target []byte, amount uint64, pkScript []byte) int {
offset := putVLQ(target, compressTxOutAmount(amount))
offset += putCompressedScript(target[offset:], pkScript, version)
offset += putCompressedScript(target[offset:], pkScript)
return offset
}
// decodeCompressedTxOut decodes the passed compressed txout, possibly followed
// by other data, into its compressed amount and compressed script and returns
// them along with the number of bytes they occupied.
func decodeCompressedTxOut(serialized []byte, version int32) (uint64, []byte, int, error) {
// by other data, into its uncompressed amount and script and returns them along
// with the number of bytes they occupied prior to decompression.
func decodeCompressedTxOut(serialized []byte) (uint64, []byte, int, error) {
// Deserialize the compressed amount and ensure there are bytes
// remaining for the compressed script.
compressedAmount, bytesRead := deserializeVLQ(serialized)
@ -589,15 +573,14 @@ func decodeCompressedTxOut(serialized []byte, version int32) (uint64, []byte, in
// Decode the compressed script size and ensure there are enough bytes
// left in the slice for it.
scriptSize := decodeCompressedScriptSize(serialized[bytesRead:], version)
scriptSize := decodeCompressedScriptSize(serialized[bytesRead:])
if len(serialized[bytesRead:]) < scriptSize {
return 0, nil, bytesRead, errDeserialize("unexpected end of " +
"data after script size")
}
// Make a copy of the compressed script so the original serialized data
// can be released as soon as possible.
compressedScript := make([]byte, scriptSize)
copy(compressedScript, serialized[bytesRead:bytesRead+scriptSize])
return compressedAmount, compressedScript, bytesRead + scriptSize, nil
// Decompress and return the amount and script.
amount := decompressTxOutAmount(compressedAmount)
script := decompressScript(serialized[bytesRead : bytesRead+scriptSize])
return amount, script, bytesRead + scriptSize, nil
}

Some files were not shown because too many files have changed in this diff Show More