Update dependencies

- uses newer version of go-ethereum required for go1.11
This commit is contained in:
Rob Mulholand 2018-09-05 10:36:14 -05:00
parent 939ead0c82
commit 560305f601
2356 changed files with 331681 additions and 128329 deletions

168
Gopkg.lock generated
View File

@ -5,13 +5,19 @@
branch = "master" branch = "master"
name = "github.com/aristanetworks/goarista" name = "github.com/aristanetworks/goarista"
packages = ["monotime"] packages = ["monotime"]
revision = "8d0e8f607a4080e7df3532e645440ed0900c64a4" revision = "ff33da284e760fcdb03c33d37a719e5ed30ba844"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/btcsuite/btcd" name = "github.com/btcsuite/btcd"
packages = ["btcec"] packages = ["btcec"]
revision = "2e60448ffcc6bf78332d1fe590260095f554dd78" revision = "cff30e1d23fc9e800b2b5b4b41ef1817dda07e9f"
[[projects]]
name = "github.com/deckarep/golang-set"
packages = ["."]
revision = "1d4478f51bed434f1dadf96dcd9b43aabac66795"
version = "v1.7"
[[projects]] [[projects]]
name = "github.com/ethereum/go-ethereum" name = "github.com/ethereum/go-ethereum"
@ -25,18 +31,10 @@
"common/hexutil", "common/hexutil",
"common/math", "common/math",
"common/mclock", "common/mclock",
"consensus", "core/rawdb",
"consensus/misc",
"core",
"core/state",
"core/types", "core/types",
"core/vm",
"crypto", "crypto",
"crypto/bn256",
"crypto/bn256/cloudflare",
"crypto/bn256/google",
"crypto/ecies", "crypto/ecies",
"crypto/randentropy",
"crypto/secp256k1", "crypto/secp256k1",
"crypto/sha3", "crypto/sha3",
"ethclient", "ethclient",
@ -54,8 +52,8 @@
"rpc", "rpc",
"trie" "trie"
] ]
revision = "b8b9f7f4476a30a0aaf6077daade6ae77f969960" revision = "89451f7c382ad2185987ee369f16416f89c28a7d"
version = "v1.8.2" version = "v1.8.15"
[[projects]] [[projects]]
name = "github.com/fsnotify/fsnotify" name = "github.com/fsnotify/fsnotify"
@ -66,37 +64,34 @@
[[projects]] [[projects]]
name = "github.com/go-stack/stack" name = "github.com/go-stack/stack"
packages = ["."] packages = ["."]
revision = "259ab82a6cad3992b4e21ff5cac294ccb06474bc" revision = "2fee6af1a9795aafbe0253a0cfbdf668e1fb8a9a"
version = "v1.7.0" version = "v1.8.0"
[[projects]] [[projects]]
branch = "master"
name = "github.com/golang/protobuf" name = "github.com/golang/protobuf"
packages = ["proto"] packages = ["proto"]
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5"
version = "v1.2.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/golang/snappy" name = "github.com/golang/snappy"
packages = ["."] packages = ["."]
revision = "553a641470496b2327abcac10b36396bd98e45c9" revision = "2e65f85255dbc3072edf28d6b5b8efc472979f5a"
[[projects]] [[projects]]
branch = "master" name = "github.com/google/uuid"
name = "github.com/hashicorp/golang-lru" packages = ["."]
packages = [ revision = "d460ce9f8df2e77fb1ba55ca87fafed96c607494"
".", version = "v1.0.0"
"simplelru"
]
revision = "0fb14efe8c47ae851c0034ed7a448854d3d34cf3"
[[projects]] [[projects]]
branch = "master"
name = "github.com/hashicorp/hcl" name = "github.com/hashicorp/hcl"
packages = [ packages = [
".", ".",
"hcl/ast", "hcl/ast",
"hcl/parser", "hcl/parser",
"hcl/printer",
"hcl/scanner", "hcl/scanner",
"hcl/strconv", "hcl/strconv",
"hcl/token", "hcl/token",
@ -104,7 +99,20 @@
"json/scanner", "json/scanner",
"json/token" "json/token"
] ]
revision = "23c074d0eceb2b8a5bfdbb271ab780cde70f05a8" revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241"
version = "v1.0.0"
[[projects]]
name = "github.com/hpcloud/tail"
packages = [
".",
"ratelimiter",
"util",
"watch",
"winfile"
]
revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5"
version = "v1.0.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -118,7 +126,7 @@
"soap", "soap",
"ssdp" "ssdp"
] ]
revision = "dceda08e705b2acee36aab47d765ed801f64cfc7" revision = "1395d1447324cbea88d249fbfcfd70ea878fdfca"
[[projects]] [[projects]]
name = "github.com/inconshreveable/mousetrap" name = "github.com/inconshreveable/mousetrap"
@ -139,34 +147,34 @@
".", ".",
"reflectx" "reflectx"
] ]
revision = "99f3ad6d85ae53d0fecf788ab62d0e9734b3c117" revision = "0dae4fefe7c0e190f7b5a78dac28a1c82cc8d849"
[[projects]] [[projects]]
branch = "master"
name = "github.com/lib/pq" name = "github.com/lib/pq"
packages = [ packages = [
".", ".",
"oid" "oid"
] ]
revision = "83612a56d3dd153a94a629cd64925371c9adad78" revision = "4ded0e9383f75c197b3a2aaa6d590ac52df6fd79"
version = "v1.0.0"
[[projects]] [[projects]]
name = "github.com/magiconair/properties" name = "github.com/magiconair/properties"
packages = ["."] packages = ["."]
revision = "d419a98cdbed11a922bf76f257b7c4be79b50e73" revision = "c2353362d570a7bfa228149c62842019201cfb71"
version = "v1.7.4" version = "v1.8.0"
[[projects]] [[projects]]
branch = "master"
name = "github.com/mitchellh/go-homedir" name = "github.com/mitchellh/go-homedir"
packages = ["."] packages = ["."]
revision = "b8bc1bf767474819792c23f32d8286a45736f1c6" revision = "ae18d6b8b3205b561c79e8e5f69bff09736185f4"
version = "v1.0.0"
[[projects]] [[projects]]
branch = "master"
name = "github.com/mitchellh/mapstructure" name = "github.com/mitchellh/mapstructure"
packages = ["."] packages = ["."]
revision = "b4575eea38cca1123ec2dc90c26529b5c5acfcff" revision = "fa473d140ef3c6adf42d6b391fe76707f1f243c8"
version = "v1.0.0"
[[projects]] [[projects]]
name = "github.com/onsi/ginkgo" name = "github.com/onsi/ginkgo"
@ -190,8 +198,8 @@
"reporters/stenographer/support/go-isatty", "reporters/stenographer/support/go-isatty",
"types" "types"
] ]
revision = "9eda700730cba42af70d53180f9dcce9266bc2bc" revision = "3774a09d95489ccaa16032e0770d08ea77ba6184"
version = "v1.4.0" version = "v1.6.0"
[[projects]] [[projects]]
name = "github.com/onsi/gomega" name = "github.com/onsi/gomega"
@ -210,20 +218,20 @@
"matchers/support/goraph/util", "matchers/support/goraph/util",
"types" "types"
] ]
revision = "c893efa28eb45626cdaa76c9f653b62488858837" revision = "7615b9433f86a8bdf29709bf288bc4fd0636a369"
version = "v1.2.0" version = "v1.4.2"
[[projects]] [[projects]]
name = "github.com/pborman/uuid" name = "github.com/pborman/uuid"
packages = ["."] packages = ["."]
revision = "e790cca94e6cc75c7064b1332e63811d4aae1a53" revision = "adf5a7427709b9deb95d29d3fa8a2bf9cfd388f1"
version = "v1.1" version = "v1.2"
[[projects]] [[projects]]
name = "github.com/pelletier/go-toml" name = "github.com/pelletier/go-toml"
packages = ["."] packages = ["."]
revision = "acdc4509485b587f5e675510c4f2c63e90ff68a8" revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194"
version = "v1.1.0" version = "v1.2.0"
[[projects]] [[projects]]
name = "github.com/philhofer/fwd" name = "github.com/philhofer/fwd"
@ -234,14 +242,14 @@
[[projects]] [[projects]]
name = "github.com/rjeczalik/notify" name = "github.com/rjeczalik/notify"
packages = ["."] packages = ["."]
revision = "52ae50d8490436622a8941bd70c3dbe0acdd4bbf" revision = "0f065fa99b48b842c3fd3e2c8b194c6f2b69f6b8"
version = "v0.9.0" version = "v0.9.1"
[[projects]] [[projects]]
name = "github.com/rs/cors" name = "github.com/rs/cors"
packages = ["."] packages = ["."]
revision = "7af7a1e09ba336d2ea14b1ce73bf693c6837dbf6" revision = "3fb1b69b103a84de38a19c3c6ec073dd6caa4d3f"
version = "v1.2" version = "v1.5.0"
[[projects]] [[projects]]
name = "github.com/spf13/afero" name = "github.com/spf13/afero"
@ -249,38 +257,38 @@
".", ".",
"mem" "mem"
] ]
revision = "bb8f1927f2a9d3ab41c9340aa034f6b803f4359c" revision = "d40851caa0d747393da1ffb28f7f9d8b4eeffebd"
version = "v1.0.2" version = "v1.1.2"
[[projects]] [[projects]]
name = "github.com/spf13/cast" name = "github.com/spf13/cast"
packages = ["."] packages = ["."]
revision = "acbeb36b902d72a7a4c18e8f3241075e7ab763e4" revision = "8965335b8c7107321228e3e3702cab9832751bac"
version = "v1.1.0" version = "v1.2.0"
[[projects]] [[projects]]
name = "github.com/spf13/cobra" name = "github.com/spf13/cobra"
packages = ["."] packages = ["."]
revision = "7b2c5ac9fc04fc5efafb60700713d4fa609b777b" revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385"
version = "v0.0.1" version = "v0.0.3"
[[projects]] [[projects]]
branch = "master"
name = "github.com/spf13/jwalterweatherman" name = "github.com/spf13/jwalterweatherman"
packages = ["."] packages = ["."]
revision = "7c0cea34c8ece3fbeb2b27ab9b59511d360fb394" revision = "4a4406e478ca629068e7768fc33f3f044173c0a6"
version = "v1.0.0"
[[projects]] [[projects]]
name = "github.com/spf13/pflag" name = "github.com/spf13/pflag"
packages = ["."] packages = ["."]
revision = "e57e3eeb33f795204c1ca35f56c44f83227c6e66" revision = "9a97c102cda95a86cec2345a6f09f55a939babf5"
version = "v1.0.0" version = "v1.0.2"
[[projects]] [[projects]]
name = "github.com/spf13/viper" name = "github.com/spf13/viper"
packages = ["."] packages = ["."]
revision = "25b30aa063fc18e48662b86996252eabdcf2f0c7" revision = "8fb642006536c8d3760c99d4fa2389f5e2205631"
version = "v1.0.0" version = "v1.2.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -299,7 +307,7 @@
"leveldb/table", "leveldb/table",
"leveldb/util" "leveldb/util"
] ]
revision = "adf24ef3f94bd13ec4163060b21a5678f22b429b" revision = "ae2bd5eed72d46b28834ec3f60db3a3ebedd8dbd"
[[projects]] [[projects]]
name = "github.com/tinylib/msgp" name = "github.com/tinylib/msgp"
@ -312,10 +320,9 @@
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = [ packages = [
"pbkdf2", "pbkdf2",
"ripemd160",
"scrypt" "scrypt"
] ]
revision = "613d6eafa307c6881a737a3c35c0e312e8d3a8c5" revision = "0e37d006457bf46f9e6692014ba72ef82c33022c"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -327,7 +334,7 @@
"html/charset", "html/charset",
"websocket" "websocket"
] ]
revision = "faacc1b5e36e3ff02cbec9661c69ac63dd5a83ad" revision = "26e67e76b6c3f6ce91f7c52def5af501b4e0f3a2"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -342,10 +349,9 @@
"unix", "unix",
"windows" "windows"
] ]
revision = "a0f4589a76f1f83070cb9e5613809e1d07b97c13" revision = "d0be0721c37eeb5299f245a996a483160fc36940"
[[projects]] [[projects]]
branch = "master"
name = "golang.org/x/text" name = "golang.org/x/text"
packages = [ packages = [
"encoding", "encoding",
@ -369,7 +375,8 @@
"unicode/cldr", "unicode/cldr",
"unicode/norm" "unicode/norm"
] ]
revision = "be25de41fadfae372d6470bda81ca6beb55ef551" revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
version = "v0.3.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -379,7 +386,7 @@
"imports", "imports",
"internal/fastwalk" "internal/fastwalk"
] ]
revision = "8cc4e8a6f4841aa92a8683fca47bc5d64b58875b" revision = "677d2ff680c188ddb7dcd2bfa6bc7d3f2f2f75b2"
[[projects]] [[projects]]
name = "gopkg.in/DataDog/dd-trace-go.v1" name = "gopkg.in/DataDog/dd-trace-go.v1"
@ -389,14 +396,15 @@
"ddtrace/internal", "ddtrace/internal",
"ddtrace/tracer" "ddtrace/tracer"
] ]
revision = "8efc9a798f2db99a9e00c7e57f45fc13611214e0" revision = "bcd20367df871708a36549e7fe36183ee5b4fc55"
version = "v1.2.3" version = "v1.3.0"
[[projects]] [[projects]]
name = "gopkg.in/fatih/set.v0" name = "gopkg.in/fsnotify.v1"
packages = ["."] packages = ["."]
revision = "57907de300222151a123d29255ed17f5ed43fad3" revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9"
version = "v0.1.0" source = "gopkg.in/fsnotify/fsnotify.v1"
version = "v1.4.7"
[[projects]] [[projects]]
branch = "v2" branch = "v2"
@ -411,14 +419,20 @@
revision = "c1b8fa8bdccecb0b8db834ee0b92fdbcfa606dd6" revision = "c1b8fa8bdccecb0b8db834ee0b92fdbcfa606dd6"
[[projects]] [[projects]]
branch = "v2" branch = "v1"
name = "gopkg.in/tomb.v1"
packages = ["."]
revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8"
[[projects]]
name = "gopkg.in/yaml.v2" name = "gopkg.in/yaml.v2"
packages = ["."] packages = ["."]
revision = "287cf08546ab5e7e37d55a84f7ed3fd1db036de5" revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183"
version = "v2.2.1"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "7a913c984013e026536456baa75bd95e261bbb0d294b7de77785819ac182b465" inputs-digest = "d1f001d06a295f55fcb3dd1f38744ca0d088b9c4bf88f251ef3e043c8852fe76"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -21,6 +21,10 @@
# version = "2.4.0" # version = "2.4.0"
[[override]]
name = "gopkg.in/fsnotify.v1"
source = "gopkg.in/fsnotify/fsnotify.v1"
[[constraint]] [[constraint]]
name = "github.com/onsi/ginkgo" name = "github.com/onsi/ginkgo"
version = "1.4.0" version = "1.4.0"
@ -39,4 +43,4 @@
[[constraint]] [[constraint]]
name = "github.com/ethereum/go-ethereum" name = "github.com/ethereum/go-ethereum"
version = "1.8" version = "1.8.15"

View File

@ -36,5 +36,5 @@ func (l LevelDatabase) GetBlockReceipts(blockHash []byte, blockNumber int64) typ
func (l LevelDatabase) GetHeadBlockNumber() int64 { func (l LevelDatabase) GetHeadBlockNumber() int64 {
h := l.reader.GetHeadBlockHash() h := l.reader.GetHeadBlockHash()
n := l.reader.GetBlockNumber(h) n := l.reader.GetBlockNumber(h)
return int64(n) return int64(*n)
} }

View File

@ -2,42 +2,42 @@ package level
import ( import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
) )
type Reader interface { type Reader interface {
GetBlock(hash common.Hash, number uint64) *types.Block GetBlock(hash common.Hash, number uint64) *types.Block
GetBlockNumber(hash common.Hash) uint64 GetBlockNumber(hash common.Hash) *uint64
GetBlockReceipts(hash common.Hash, number uint64) types.Receipts GetBlockReceipts(hash common.Hash, number uint64) types.Receipts
GetCanonicalHash(number uint64) common.Hash GetCanonicalHash(number uint64) common.Hash
GetHeadBlockHash() common.Hash GetHeadBlockHash() common.Hash
} }
type LevelDatabaseReader struct { type LevelDatabaseReader struct {
reader core.DatabaseReader reader rawdb.DatabaseReader
} }
func NewLevelDatabaseReader(reader core.DatabaseReader) *LevelDatabaseReader { func NewLevelDatabaseReader(reader rawdb.DatabaseReader) *LevelDatabaseReader {
return &LevelDatabaseReader{reader: reader} return &LevelDatabaseReader{reader: reader}
} }
func (ldbr *LevelDatabaseReader) GetBlock(hash common.Hash, number uint64) *types.Block { func (ldbr *LevelDatabaseReader) GetBlock(hash common.Hash, number uint64) *types.Block {
return core.GetBlock(ldbr.reader, hash, number) return rawdb.ReadBlock(ldbr.reader, hash, number)
} }
func (ldbr *LevelDatabaseReader) GetBlockNumber(hash common.Hash) uint64 { func (ldbr *LevelDatabaseReader) GetBlockNumber(hash common.Hash) *uint64 {
return core.GetBlockNumber(ldbr.reader, hash) return rawdb.ReadHeaderNumber(ldbr.reader, hash)
} }
func (ldbr *LevelDatabaseReader) GetBlockReceipts(hash common.Hash, number uint64) types.Receipts { func (ldbr *LevelDatabaseReader) GetBlockReceipts(hash common.Hash, number uint64) types.Receipts {
return core.GetBlockReceipts(ldbr.reader, hash, number) return rawdb.ReadReceipts(ldbr.reader, hash, number)
} }
func (ldbr *LevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash { func (ldbr *LevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash {
return core.GetCanonicalHash(ldbr.reader, number) return rawdb.ReadCanonicalHash(ldbr.reader, number)
} }
func (ldbr *LevelDatabaseReader) GetHeadBlockHash() common.Hash { func (ldbr *LevelDatabaseReader) GetHeadBlockHash() common.Hash {
return core.GetHeadBlockHash(ldbr.reader) return rawdb.ReadHeadBlockHash(ldbr.reader)
} }

View File

@ -82,10 +82,10 @@ func (mldr *MockLevelDatabaseReader) GetBlockReceipts(hash common.Hash, number u
return mldr.returnReceipts return mldr.returnReceipts
} }
func (mldr *MockLevelDatabaseReader) GetBlockNumber(hash common.Hash) uint64 { func (mldr *MockLevelDatabaseReader) GetBlockNumber(hash common.Hash) *uint64 {
mldr.getBlockNumberCalled = true mldr.getBlockNumberCalled = true
mldr.getBlockNumberPassedHash = hash mldr.getBlockNumberPassedHash = hash
return mldr.returnBlockNumber return &mldr.returnBlockNumber
} }
func (mldr *MockLevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash { func (mldr *MockLevelDatabaseReader) GetCanonicalHash(number uint64) common.Hash {

View File

@ -218,7 +218,7 @@ var _ = Describe("Conversion of GethBlock to core.Block", func() {
CumulativeGasUsed: uint64(7996119), CumulativeGasUsed: uint64(7996119),
GasUsed: uint64(21000), GasUsed: uint64(21000),
Logs: []*types.Log{}, Logs: []*types.Log{},
Status: uint(1), Status: uint64(1),
TxHash: gethTransaction.Hash(), TxHash: gethTransaction.Hash(),
} }

View File

@ -54,7 +54,7 @@ var _ = Describe("Conversion of GethReceipt to core.Receipt", func() {
CumulativeGasUsed: uint64(7996119), CumulativeGasUsed: uint64(7996119),
GasUsed: uint64(21000), GasUsed: uint64(21000),
Logs: []*types.Log{}, Logs: []*types.Log{},
Status: uint(1), Status: uint64(1),
TxHash: common.HexToHash("0xe340558980f89d5f86045ac11e5cc34e4bcec20f9f1e2a427aa39d87114e8223"), TxHash: common.HexToHash("0xe340558980f89d5f86045ac11e5cc34e4bcec20f9f1e2a427aa39d87114e8223"),
} }

View File

@ -37,7 +37,7 @@ var (
) )
var DentLog = types.Log{ var DentLog = types.Log{
Address: common.StringToAddress(shared.FlipperContractAddress), Address: common.HexToAddress(shared.FlipperContractAddress),
Topics: []common.Hash{ Topics: []common.Hash{
common.HexToHash("0x5ff3a38200000000000000000000000000000000000000000000000000000000"), common.HexToHash("0x5ff3a38200000000000000000000000000000000000000000000000000000000"),
common.HexToHash("0x00000000000000000000000064d922894153be9eef7b7218dc565d1d0ce2a092"), common.HexToHash("0x00000000000000000000000064d922894153be9eef7b7218dc565d1d0ce2a092"),

View File

@ -38,7 +38,7 @@ var (
) )
var TendLogNote = types.Log{ var TendLogNote = types.Log{
Address: common.StringToAddress(shared.FlipperContractAddress), Address: common.HexToAddress(shared.FlipperContractAddress),
Topics: []common.Hash{ Topics: []common.Hash{
common.HexToHash("0x4b43ed1200000000000000000000000000000000000000000000000000000000"), //abbreviated tend function signature common.HexToHash("0x4b43ed1200000000000000000000000000000000000000000000000000000000"), //abbreviated tend function signature
common.HexToHash("0x0000000000000000000000007d7bee5fcfd8028cf7b00876c5b1421c800561a6"), //msg caller address common.HexToHash("0x0000000000000000000000007d7bee5fcfd8028cf7b00876c5b1421c800561a6"), //msg caller address

View File

@ -1,7 +1,8 @@
language: go language: go
go: go:
- 1.9 - 1.10.x
- tip - 1.x
- master
before_install: before_install:
- go get -v github.com/golang/lint/golint - go get -v github.com/golang/lint/golint
- go get -v -t -d ./... - go get -v -t -d ./...

View File

@ -3,7 +3,7 @@
# that can be found in the COPYING file. # that can be found in the COPYING file.
# TODO: move this to cmd/ockafka (https://github.com/docker/hub-feedback/issues/292) # TODO: move this to cmd/ockafka (https://github.com/docker/hub-feedback/issues/292)
FROM golang:1.7.3 FROM golang:1.10.3
RUN mkdir -p /go/src/github.com/aristanetworks/goarista/cmd RUN mkdir -p /go/src/github.com/aristanetworks/goarista/cmd
WORKDIR /go/src/github.com/aristanetworks/goarista WORKDIR /go/src/github.com/aristanetworks/goarista

View File

@ -25,19 +25,20 @@ class of service flags to use for incoming connections. Requires `go1.9`.
## key ## key
Provides a common type used across various Arista projects, named `key.Key`, Provides common types used across various Arista projects. The type `key.Key`
which is used to work around the fact that Go can't let one is used to work around the fact that Go can't let one use a non-hashable type
use a non-hashable type as a key to a `map`, and we sometimes need to use as a key to a `map`, and we sometimes need to use a `map[string]interface{}`
a `map[string]interface{}` (or something containing one) as a key to maps. (or something containing one) as a key to maps. As a result, we frequently use
As a result, we frequently use `map[key.Key]interface{}` instead of just `map[key.Key]interface{}` instead of just `map[interface{}]interface{}` when we
`map[interface{}]interface{}` when we need a generic key-value collection. need a generic key-value collection. The type `key.Path` is the representation
of a path broken down into individual elements, where each element is a `key.Key`.
The type `key.Pointer` represents a pointer to a `key.Path`.
## path ## path
Provides a common type used across various Arista projects, named `path.Path`, Provides functions that can be used to manipulate `key.Path` objects. The type
which is the representation of a path broken down into individual elements. `path.Map` may be used for mapping paths to values. It allows for some fuzzy
Each element is a `key.Key`. The type `path.Map` may be used for mapping paths matching for paths containing `path.Wildcard` keys.
to values. It allows for some fuzzy matching.
## lanz ## lanz
A client for [LANZ](https://eos.arista.com/latency-analyzer-lanz-architectures-and-configuration/) A client for [LANZ](https://eos.arista.com/latency-analyzer-lanz-architectures-and-configuration/)
@ -55,9 +56,10 @@ A library to help expose monitoring metrics on top of the
`netns.Do(namespace, cb)` provides a handy mechanism to execute the given `netns.Do(namespace, cb)` provides a handy mechanism to execute the given
callback `cb` in the given [network namespace](https://lwn.net/Articles/580893/). callback `cb` in the given [network namespace](https://lwn.net/Articles/580893/).
## pathmap ## influxlib
DEPRECATED; use`path.Map` instead. This is a influxdb library that provides easy methods of connecting to, writing to,
and reading from the service.
## test ## test

View File

@ -16,10 +16,20 @@ under [GOPATH](https://golang.org/doc/code.html#GOPATH).
# Usage # Usage
```
$ gnmi [OPTIONS] [OPERATION]
```
When running on the switch in a non-default VRF:
```
$ ip netns exec ns-<VRF> gnmi [OPTIONS] [OPERATION]
```
## Options ## Options
* `-addr ADDR:PORT` * `-addr [<VRF-NAME>/]ADDR:PORT`
Address of the gNMI endpoint (REQUIRED) Address of the gNMI endpoint (REQUIRED) with VRF name (OPTIONAL)
* `-username USERNAME` * `-username USERNAME`
Username to authenticate with Username to authenticate with
* `-password PASSWORD` * `-password PASSWORD`
@ -92,7 +102,7 @@ $ gnmi [OPTIONS] delete '/network-instances/network-instance[name=default]/proto
``` ```
`update` and `replace` both take a path and a value in JSON `update` and `replace` both take a path and a value in JSON
format. See format. The JSON data may be provided in a file. See
[here](https://github.com/openconfig/reference/blob/master/rpc/gnmi/gnmi-specification.md#344-modes-of-update-replace-versus-update) [here](https://github.com/openconfig/reference/blob/master/rpc/gnmi/gnmi-specification.md#344-modes-of-update-replace-versus-update)
for documentation on the differences between `update` and `replace`. for documentation on the differences between `update` and `replace`.
@ -108,15 +118,46 @@ Replace the BGP global configuration:
gnmi [OPTIONS] replace '/network-instances/network-instance[name=default]/protocols/protocol[name=BGP][identifier=BGP]/bgp/global' '{"config":{"as": 1234, "router-id": "1.2.3.4"}}' gnmi [OPTIONS] replace '/network-instances/network-instance[name=default]/protocols/protocol[name=BGP][identifier=BGP]/bgp/global' '{"config":{"as": 1234, "router-id": "1.2.3.4"}}'
``` ```
Note: String values must be quoted. For example, setting the hostname to `"tor13"`: Note: String values need to be quoted if they look like JSON. For example, setting the login banner to `tor[13]`:
``` ```
gnmi [OPTIONS] update '/system/config/hostname' '"tor13"' gnmi [OPTIONS] update '/system/config/login-banner '"tor[13]"'
```
#### JSON in a file
The value argument to `update` and `replace` may be a file. The
content of the file is used to make the request.
Example:
File `path/to/subintf100.json` contains the following:
```
{
"subinterface": [
{
"config": {
"enabled": true,
"index": 100
},
"index": 100
}
]
}
```
Add subinterface 100 to interfaces Ethernet4/1/1 and Ethernet4/2/1 in
one transaction:
```
gnmi [OPTIONS] update '/interfaces/interface[name=Ethernet4/1/1]/subinterfaces' path/to/subintf100.json \
update '/interfaces/interface[name=Ethernet4/2/1]/subinterfaces' path/to/subintf100.json
``` ```
### CLI requests ### CLI requests
`gnmi` offers the ability to send CLI text inside an `update` or `gnmi` offers the ability to send CLI text inside an `update` or
`replace` operation. This is achieved by doing an `update` or `replace` operation. This is achieved by doing an `update` or
`replace` and using `"cli"` as the path and a set of configure-mode `replace` and specifying `"origin=cli"` along with an empty path and a set of configure-mode
CLI commands separated by `\n`. CLI commands separated by `\n`.
Example: Example:
@ -127,6 +168,18 @@ gnmi [OPTIONS] update 'cli' 'management ssh
idle-timeout 300' idle-timeout 300'
``` ```
### P4 Config
`gnmi` offers the ability to send p4 config files inside a `replace` operation.
This is achieved by doing a `replace` and specifying `"origin=p4_config"`
along with the path of the p4 config file to send.
Example:
Send the config.p4 file
```
gnmi [OPTIONS] replace 'origin=p4_config' 'config.p4'
```
## Paths ## Paths
Paths in `gnmi` use a simplified xpath style. Path elements are Paths in `gnmi` use a simplified xpath style. Path elements are

View File

@ -9,6 +9,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"strings"
"time"
"github.com/aristanetworks/goarista/gnmi" "github.com/aristanetworks/goarista/gnmi"
@ -18,22 +20,24 @@ import (
// TODO: Make this more clear // TODO: Make this more clear
var help = `Usage of gnmi: var help = `Usage of gnmi:
gnmi -addr ADDRESS:PORT [options...] gnmi -addr [<VRF-NAME>/]ADDRESS:PORT [options...]
capabilities capabilities
get PATH+ get PATH+
subscribe PATH+ subscribe PATH+
((update|replace PATH JSON)|(delete PATH))+ ((update|replace (origin=ORIGIN) PATH JSON|FILE)|(delete (origin=ORIGIN) PATH))+
` `
func exitWithError(s string) { func usageAndExit(s string) {
flag.Usage() flag.Usage()
if s != "" {
fmt.Fprintln(os.Stderr, s) fmt.Fprintln(os.Stderr, s)
}
os.Exit(1) os.Exit(1)
} }
func main() { func main() {
cfg := &gnmi.Config{} cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "", "Address of gNMI gRPC server") flag.StringVar(&cfg.Addr, "addr", "", "Address of gNMI gRPC server with optional VRF name")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file") flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file") flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file") flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
@ -41,26 +45,54 @@ func main() {
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with") flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS") flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
subscribeOptions := &gnmi.SubscribeOptions{}
flag.StringVar(&subscribeOptions.Prefix, "prefix", "", "Subscribe prefix path")
flag.BoolVar(&subscribeOptions.UpdatesOnly, "updates_only", false,
"Subscribe to updates only (false | true)")
flag.StringVar(&subscribeOptions.Mode, "mode", "stream",
"Subscribe mode (stream | once | poll)")
flag.StringVar(&subscribeOptions.StreamMode, "stream_mode", "target_defined",
"Subscribe stream mode, only applies for stream subscriptions "+
"(target_defined | on_change | sample)")
sampleIntervalStr := flag.String("sample_interval", "0", "Subscribe sample interval, "+
"only applies for sample subscriptions (400ms, 2.5s, 1m, etc.)")
heartbeatIntervalStr := flag.String("heartbeat_interval", "0", "Subscribe heartbeat "+
"interval, only applies for on-change subscriptions (400ms, 2.5s, 1m, etc.)")
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintln(os.Stderr, help) fmt.Fprintln(os.Stderr, help)
flag.PrintDefaults() flag.PrintDefaults()
} }
flag.Parse() flag.Parse()
if cfg.Addr == "" { if cfg.Addr == "" {
exitWithError("error: address not specified") usageAndExit("error: address not specified")
} }
var sampleInterval, heartbeatInterval time.Duration
var err error
if sampleInterval, err = time.ParseDuration(*sampleIntervalStr); err != nil {
usageAndExit(fmt.Sprintf("error: sample interval (%s) invalid", *sampleIntervalStr))
}
subscribeOptions.SampleInterval = uint64(sampleInterval)
if heartbeatInterval, err = time.ParseDuration(*heartbeatIntervalStr); err != nil {
usageAndExit(fmt.Sprintf("error: heartbeat interval (%s) invalid", *heartbeatIntervalStr))
}
subscribeOptions.HeartbeatInterval = uint64(heartbeatInterval)
args := flag.Args() args := flag.Args()
ctx := gnmi.NewContext(context.Background(), cfg) ctx := gnmi.NewContext(context.Background(), cfg)
client := gnmi.Dial(cfg) client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
var setOps []*gnmi.Operation var setOps []*gnmi.Operation
for i := 0; i < len(args); i++ { for i := 0; i < len(args); i++ {
switch args[i] { switch args[i] {
case "capabilities": case "capabilities":
if len(setOps) != 0 { if len(setOps) != 0 {
exitWithError("error: 'capabilities' not allowed after 'merge|replace|delete'") usageAndExit("error: 'capabilities' not allowed after 'merge|replace|delete'")
} }
err := gnmi.Capabilities(ctx, client) err := gnmi.Capabilities(ctx, client)
if err != nil { if err != nil {
@ -69,7 +101,7 @@ func main() {
return return
case "get": case "get":
if len(setOps) != 0 { if len(setOps) != 0 {
exitWithError("error: 'get' not allowed after 'merge|replace|delete'") usageAndExit("error: 'get' not allowed after 'merge|replace|delete'")
} }
err := gnmi.Get(ctx, client, gnmi.SplitPaths(args[i+1:])) err := gnmi.Get(ctx, client, gnmi.SplitPaths(args[i+1:]))
if err != nil { if err != nil {
@ -78,49 +110,55 @@ func main() {
return return
case "subscribe": case "subscribe":
if len(setOps) != 0 { if len(setOps) != 0 {
exitWithError("error: 'subscribe' not allowed after 'merge|replace|delete'") usageAndExit("error: 'subscribe' not allowed after 'merge|replace|delete'")
} }
respChan := make(chan *pb.SubscribeResponse) respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error) errChan := make(chan error)
defer close(respChan)
defer close(errChan) defer close(errChan)
go gnmi.Subscribe(ctx, client, gnmi.SplitPaths(args[i+1:]), respChan, errChan) subscribeOptions.Paths = gnmi.SplitPaths(args[i+1:])
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for { for {
select { select {
case resp := <-respChan: case resp, open := <-respChan:
if !open {
return
}
if err := gnmi.LogSubscribeResponse(resp); err != nil { if err := gnmi.LogSubscribeResponse(resp); err != nil {
exitWithError(err.Error()) glog.Fatal(err)
} }
case err := <-errChan: case err := <-errChan:
exitWithError(err.Error()) glog.Fatal(err)
} }
} }
case "update", "replace", "delete": case "update", "replace", "delete":
if len(args) == i+1 { if len(args) == i+1 {
exitWithError("error: missing path") usageAndExit("error: missing path")
} }
op := &gnmi.Operation{ op := &gnmi.Operation{
Type: args[i], Type: args[i],
} }
i++ i++
if strings.HasPrefix(args[i], "origin=") {
op.Origin = strings.TrimPrefix(args[i], "origin=")
i++
}
op.Path = gnmi.SplitPath(args[i]) op.Path = gnmi.SplitPath(args[i])
if op.Type != "delete" { if op.Type != "delete" {
if len(args) == i+1 { if len(args) == i+1 {
exitWithError("error: missing JSON") usageAndExit("error: missing JSON or FILEPATH to data")
} }
i++ i++
op.Val = args[i] op.Val = args[i]
} }
setOps = append(setOps, op) setOps = append(setOps, op)
default: default:
exitWithError(fmt.Sprintf("error: unknown operation %q", args[i])) usageAndExit(fmt.Sprintf("error: unknown operation %q", args[i]))
} }
} }
if len(setOps) == 0 { if len(setOps) == 0 {
flag.Usage() usageAndExit("")
os.Exit(1)
} }
err := gnmi.Set(ctx, client, setOps) err = gnmi.Set(ctx, client, setOps)
if err != nil { if err != nil {
glog.Fatal(err) glog.Fatal(err)
} }

View File

@ -0,0 +1,120 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// json2test reformats 'go test -json' output as text as if the -json
// flag were not passed to go test. It is useful if you want to
// analyze go test -json output, but still want a human readable test
// log.
//
// Usage:
//
// go test -json > out.txt; <analysis program> out.txt; cat out.txt | json2test
//
package main
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"time"
)
var errTestFailure = errors.New("testfailure")
func main() {
err := writeTestOutput(os.Stdin, os.Stdout)
if err == errTestFailure {
os.Exit(1)
} else if err != nil {
log.Fatal(err)
}
}
type testEvent struct {
Time time.Time // encodes as an RFC3339-format string
Action string
Package string
Test string
Elapsed float64 // seconds
Output string
}
type test struct {
pkg string
test string
}
type outputBuffer struct {
output []string
}
func (o *outputBuffer) push(s string) {
o.output = append(o.output, s)
}
type testFailure struct {
t test
o outputBuffer
}
func writeTestOutput(in io.Reader, out io.Writer) error {
testOutputBuffer := map[test]*outputBuffer{}
var failures []testFailure
d := json.NewDecoder(in)
buf := bufio.NewWriter(out)
defer buf.Flush()
for {
var e testEvent
if err := d.Decode(&e); err != nil {
break
}
switch e.Action {
default:
continue
case "run":
testOutputBuffer[test{pkg: e.Package, test: e.Test}] = new(outputBuffer)
case "pass":
// Don't hold onto text for passing
delete(testOutputBuffer, test{pkg: e.Package, test: e.Test})
case "fail":
// fail may be for a package, which won't have an entry in
// testOutputBuffer because packages don't have a "run"
// action.
t := test{pkg: e.Package, test: e.Test}
if o, ok := testOutputBuffer[t]; ok {
f := testFailure{t: t, o: *o}
delete(testOutputBuffer, t)
failures = append(failures, f)
}
case "output":
buf.WriteString(e.Output)
// output may be for a package, which won't have an entry
// in testOutputBuffer because packages don't have a "run"
// action.
if o, ok := testOutputBuffer[test{pkg: e.Package, test: e.Test}]; ok {
o.push(e.Output)
}
}
}
if len(failures) == 0 {
return nil
}
buf.WriteString("\nTest failures:\n")
for i, f := range failures {
fmt.Fprintf(buf, "[%d] %s.%s\n", i+1, f.t.pkg, f.t.test)
for _, s := range f.o.output {
buf.WriteString(s)
}
if i < len(failures)-1 {
buf.WriteByte('\n')
}
}
return errTestFailure
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"testing"
)
func TestWriteTestOutput(t *testing.T) {
input, err := os.Open("testdata/input.txt")
if err != nil {
t.Fatal(err)
}
var out bytes.Buffer
if err := writeTestOutput(input, &out); err != errTestFailure {
t.Error("expected test failure")
}
gold, err := os.Open("testdata/gold.txt")
if err != nil {
t.Fatal(err)
}
expected, err := ioutil.ReadAll(gold)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(out.Bytes(), expected) {
t.Error("output does not match gold.txt")
fmt.Println("Expected:")
fmt.Println(string(expected))
fmt.Println("Got:")
fmt.Println(out.String())
}
}

View File

@ -0,0 +1,16 @@
? pkg/skipped [no test files]
=== RUN TestPass
--- PASS: TestPass (0.00s)
PASS
ok pkg/passed 0.013s
panic
FAIL pkg/panic 600.029s
--- FAIL: TestFail (0.18s)
Test failures:
[1] pkg/panic.TestPanic
panic
FAIL pkg/panic 600.029s
[2] pkg/failed.TestFail
--- FAIL: TestFail (0.18s)

View File

@ -0,0 +1,17 @@
{"Time":"2018-03-08T10:33:12.002692769-08:00","Action":"output","Package":"pkg/skipped","Output":"? \tpkg/skipped\t[no test files]\n"}
{"Time":"2018-03-08T10:33:12.003199228-08:00","Action":"skip","Package":"pkg/skipped","Elapsed":0.001}
{"Time":"2018-03-08T10:33:12.343866281-08:00","Action":"run","Package":"pkg/passed","Test":"TestPass"}
{"Time":"2018-03-08T10:33:12.34406622-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"=== RUN TestPass\n"}
{"Time":"2018-03-08T10:33:12.344139342-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"--- PASS: TestPass (0.00s)\n"}
{"Time":"2018-03-08T10:33:12.344165231-08:00","Action":"pass","Package":"pkg/passed","Test":"TestPass","Elapsed":0}
{"Time":"2018-03-08T10:33:12.344297059-08:00","Action":"output","Package":"pkg/passed","Output":"PASS\n"}
{"Time":"2018-03-08T10:33:12.345217622-08:00","Action":"output","Package":"pkg/passed","Output":"ok \tpkg/passed\t0.013s\n"}
{"Time":"2018-03-08T10:33:12.34533033-08:00","Action":"pass","Package":"pkg/passed","Elapsed":0.013}
{"Time":"2018-03-08T10:33:20.243866281-08:00","Action":"run","Package":"pkg/panic","Test":"TestPanic"}
{"Time":"2018-03-08T10:33:20.27231537-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"panic\n"}
{"Time":"2018-03-08T10:33:20.272414481-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"FAIL\tpkg/panic\t600.029s\n"}
{"Time":"2018-03-08T10:33:20.272440286-08:00","Action":"fail","Package":"pkg/panic","Test":"TestPanic","Elapsed":600.029}
{"Time":"2018-03-08T10:33:26.143866281-08:00","Action":"run","Package":"pkg/failed","Test":"TestFail"}
{"Time":"2018-03-08T10:33:27.158776469-08:00","Action":"output","Package":"pkg/failed","Test":"TestFail","Output":"--- FAIL: TestFail (0.18s)\n"}
{"Time":"2018-03-08T10:33:27.158860934-08:00","Action":"fail","Package":"pkg/failed","Test":"TestFail","Elapsed":0.18}
{"Time":"2018-03-08T10:33:27.161302093-08:00","Action":"fail","Package":"pkg/failed","Elapsed":0.204}

View File

@ -26,13 +26,12 @@ var keysFlag = flag.String("kafkakeys", "",
"Keys for kafka messages (comma-separated, default: the value of -addrs") "Keys for kafka messages (comma-separated, default: the value of -addrs")
func newProducer(addresses []string, topic, key, dataset string) (producer.Producer, error) { func newProducer(addresses []string, topic, key, dataset string) (producer.Producer, error) {
glog.Infof("Connected to Kafka brokers at %s", addresses)
encodedKey := sarama.StringEncoder(key) encodedKey := sarama.StringEncoder(key)
p, err := producer.New(openconfig.NewEncoder(topic, encodedKey, dataset), addresses, nil) p, err := producer.New(openconfig.NewEncoder(topic, encodedKey, dataset), addresses, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to create Kafka producer: %s", err) return nil, fmt.Errorf("Failed to create Kafka brokers: %s", err)
} }
glog.Infof("Connected to Kafka brokers at %s", addresses)
return p, nil return p, nil
} }

View File

@ -33,5 +33,5 @@ Prometheus 2.0 will probably support timestamps.
See the `-help` output, but here's an example to push all the metrics defined See the `-help` output, but here's an example to push all the metrics defined
in the sample config file: in the sample config file:
``` ```
ocprometheus -addrs <switch-hostname>:6042 -config sampleconfig.json ocprometheus -addr <switch-hostname>:6042 -config sampleconfig.json
``` ```

View File

@ -11,7 +11,7 @@ import (
"github.com/aristanetworks/glog" "github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/openconfig/reference/rpc/openconfig" pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -26,6 +26,8 @@ type source struct {
type labelledMetric struct { type labelledMetric struct {
metric prometheus.Metric metric prometheus.Metric
labels []string labels []string
defaultValue float64
stringMetric bool
} }
type collector struct { type collector struct {
@ -43,13 +45,14 @@ func newCollector(config *Config) *collector {
} }
} }
// Process a notfication and update or create the corresponding metrics. // Process a notification and update or create the corresponding metrics.
func (c *collector) update(addr string, message proto.Message) { func (c *collector) update(addr string, message proto.Message) {
resp, ok := message.(*openconfig.SubscribeResponse) resp, ok := message.(*pb.SubscribeResponse)
if !ok { if !ok {
glog.Errorf("Unexpected type of message: %T", message) glog.Errorf("Unexpected type of message: %T", message)
return return
} }
notif := resp.GetUpdate() notif := resp.GetUpdate()
if notif == nil { if notif == nil {
return return
@ -57,7 +60,6 @@ func (c *collector) update(addr string, message proto.Message) {
device := strings.Split(addr, ":")[0] device := strings.Split(addr, ":")[0]
prefix := "/" + strings.Join(notif.Prefix.Element, "/") prefix := "/" + strings.Join(notif.Prefix.Element, "/")
// Process deletes first // Process deletes first
for _, del := range notif.Delete { for _, del := range notif.Delete {
path := prefix + "/" + strings.Join(del.Element, "/") path := prefix + "/" + strings.Join(del.Element, "/")
@ -70,7 +72,7 @@ func (c *collector) update(addr string, message proto.Message) {
// Process updates next // Process updates next
for _, update := range notif.Update { for _, update := range notif.Update {
// We only use JSON encoded values // We only use JSON encoded values
if update.Value == nil || update.Value.Type != openconfig.Type_JSON { if update.Value == nil || update.Value.Type != pb.Encoding_JSON {
glog.V(9).Infof("Ignoring incompatible update value in %s", update) glog.V(9).Infof("Ignoring incompatible update value in %s", update)
continue continue
} }
@ -80,40 +82,81 @@ func (c *collector) update(addr string, message proto.Message) {
if !ok { if !ok {
continue continue
} }
var strUpdate bool
var floatVal float64
var strVal string
switch v := value.(type) {
case float64:
strUpdate = false
floatVal = v
case string:
strUpdate = true
strVal = v
}
if suffix != "" { if suffix != "" {
path += "/" + suffix path += "/" + suffix
} }
src := source{addr: device, path: path} src := source{addr: device, path: path}
c.m.Lock() c.m.Lock()
// Use the cached labels and descriptor if available // Use the cached labels and descriptor if available
if m, ok := c.metrics[src]; ok { if m, ok := c.metrics[src]; ok {
m.metric = prometheus.MustNewConstMetric(m.metric.Desc(), prometheus.GaugeValue, value, if strUpdate {
m.labels...) // Skip string updates for non string metrics
if !m.stringMetric {
c.m.Unlock() c.m.Unlock()
continue continue
} }
c.m.Unlock() // Display a default value and replace the value label with the string value
floatVal = m.defaultValue
m.labels[len(m.labels)-1] = strVal
}
m.metric = prometheus.MustNewConstMetric(m.metric.Desc(), prometheus.GaugeValue,
floatVal, m.labels...)
c.m.Unlock()
continue
}
c.m.Unlock()
// Get the descriptor and labels for this source // Get the descriptor and labels for this source
desc, labelValues := c.config.getDescAndLabels(src) metric := c.config.getMetricValues(src)
if desc == nil { if metric == nil || metric.desc == nil {
glog.V(8).Infof("Ignoring unmatched update at %s:%s: %+v", device, path, update.Value) glog.V(8).Infof("Ignoring unmatched update at %s:%s: %+v", device, path, update.Value)
continue continue
} }
c.m.Lock() if strUpdate {
if !metric.stringMetric {
// Skip string updates for non string metrics
continue
}
// Display a default value and replace the value label with the string value
floatVal = metric.defaultValue
metric.labels[len(metric.labels)-1] = strVal
}
// Save the metric and labels in the cache // Save the metric and labels in the cache
metric := prometheus.MustNewConstMetric(desc, prometheus.GaugeValue, value, labelValues...) c.m.Lock()
lm := prometheus.MustNewConstMetric(metric.desc, prometheus.GaugeValue,
floatVal, metric.labels...)
c.metrics[src] = &labelledMetric{ c.metrics[src] = &labelledMetric{
metric: metric, metric: lm,
labels: labelValues, labels: metric.labels,
defaultValue: metric.defaultValue,
stringMetric: metric.stringMetric,
} }
c.m.Unlock() c.m.Unlock()
} }
} }
func parseValue(update *openconfig.Update) (float64, string, bool) { // ParseValue takes in an update and parses a value and suffix
// All metrics in Prometheus are floats, so only try to unmarshal as float64. // Returns an interface that contains either a string or a float64 as well as a suffix
// Unparseable updates return (0, empty string, false)
func parseValue(update *pb.Update) (interface{}, string, bool) {
var intf interface{} var intf interface{}
if err := json.Unmarshal(update.Value.Value, &intf); err != nil { if err := json.Unmarshal(update.Value.Value, &intf); err != nil {
glog.Errorf("Can't parse value in update %v: %v", update, err) glog.Errorf("Can't parse value in update %v: %v", update, err)
@ -129,13 +172,16 @@ func parseValue(update *openconfig.Update) (float64, string, bool) {
return val, "value", true return val, "value", true
} }
} }
// float64 or string expected as the return value
case bool: case bool:
if value { if value {
return 1, "", true return float64(1), "", true
} }
return 0, "", true return float64(0), "", true
case string:
return value, "", true
default: default:
glog.V(9).Infof("Ignorig non-numeric update: %v", update) glog.V(9).Infof("Ignoring update with unexpected type: %T", value)
} }
return 0, "", false return 0, "", false

View File

@ -5,33 +5,80 @@
package main package main
import ( import (
"fmt"
"strings"
"testing" "testing"
"github.com/aristanetworks/goarista/test" "github.com/aristanetworks/goarista/test"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/openconfig/reference/rpc/openconfig"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
func makeMetrics(cfg *Config, expValues map[source]float64) map[source]*labelledMetric { func makeMetrics(cfg *Config, expValues map[source]float64, notification *pb.Notification,
prevMetrics map[source]*labelledMetric) map[source]*labelledMetric {
expMetrics := map[source]*labelledMetric{} expMetrics := map[source]*labelledMetric{}
for k, v := range expValues { if prevMetrics != nil {
desc, labels := cfg.getDescAndLabels(k) expMetrics = prevMetrics
if desc == nil || labels == nil {
panic("cfg.getDescAndLabels returned nil")
} }
expMetrics[k] = &labelledMetric{ for src, v := range expValues {
metric: prometheus.MustNewConstMetric(desc, prometheus.GaugeValue, v, labels...), metric := cfg.getMetricValues(src)
labels: labels, if metric == nil || metric.desc == nil || metric.labels == nil {
panic("cfg.getMetricValues returned nil")
} }
// Preserve current value of labels
labels := metric.labels
if _, ok := expMetrics[src]; ok && expMetrics[src] != nil {
labels = expMetrics[src].labels
} }
// Handle string updates
if notification.Update != nil {
if update, err := findUpdate(notification, src.path); err == nil {
val, _, ok := parseValue(update)
if !ok {
continue
}
if strVal, ok := val.(string); ok {
if !metric.stringMetric {
continue
}
v = metric.defaultValue
labels[len(labels)-1] = strVal
}
}
}
expMetrics[src] = &labelledMetric{
metric: prometheus.MustNewConstMetric(metric.desc, prometheus.GaugeValue, v,
labels...),
labels: labels,
defaultValue: metric.defaultValue,
stringMetric: metric.stringMetric,
}
}
// Handle deletion
for key := range expMetrics {
if _, ok := expValues[key]; !ok {
delete(expMetrics, key)
}
}
return expMetrics return expMetrics
} }
func makeResponse(notif *openconfig.Notification) *openconfig.SubscribeResponse { func findUpdate(notif *pb.Notification, path string) (*pb.Update, error) {
return &openconfig.SubscribeResponse{ prefix := notif.Prefix.Element
Response: &openconfig.SubscribeResponse_Update{Update: notif}, for _, v := range notif.Update {
fullPath := "/" + strings.Join(append(prefix, v.Path.Element...), "/")
if strings.Contains(path, fullPath) || path == fullPath {
return v, nil
}
}
return nil, fmt.Errorf("Failed to find matching update for path %v", path)
}
func makeResponse(notif *pb.Notification) *pb.SubscribeResponse {
return &pb.SubscribeResponse{
Response: &pb.SubscribeResponse_Update{Update: notif},
} }
} }
@ -49,6 +96,11 @@ subscriptions:
- /Sysdb/environment/power/status - /Sysdb/environment/power/status
- /Sysdb/bridging/igmpsnooping/forwarding/forwarding/status - /Sysdb/bridging/igmpsnooping/forwarding/forwarding/status
metrics: metrics:
- name: fanName
path: /Sysdb/environment/cooling/status/fan/name
help: Fan Name
valuelabel: name
defaultvalue: 2.5
- name: intfCounter - name: intfCounter
path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter
help: Per-Interface Bytes/Errors/Discards Counters help: Per-Interface Bytes/Errors/Discards Counters
@ -64,79 +116,101 @@ metrics:
} }
coll := newCollector(cfg) coll := newCollector(cfg)
notif := &openconfig.Notification{ notif := &pb.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}}, Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{ Update: []*pb.Update{
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"}, Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("42"), Value: []byte("42"),
}, },
}, },
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "speed"}, Element: []string{"environment", "cooling", "status", "fan", "speed"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("{\"value\": 45}"), Value: []byte("{\"value\": 45}"),
}, },
}, },
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"igmpsnooping", "vlanStatus", "2050", "ethGroup", Element: []string{"igmpsnooping", "vlanStatus", "2050", "ethGroup",
"01:00:5e:01:01:01", "intf", "Cpu"}, "01:00:5e:01:01:01", "intf", "Cpu"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("true"), Value: []byte("true"),
}, },
}, },
{
Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "name"},
},
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("\"Fan1.1\""),
},
},
}, },
} }
expValues := map[source]float64{ expValues := map[source]float64{
source{ {
addr: "10.1.1.1", addr: "10.1.1.1",
path: "/Sysdb/lag/intfCounterDir/Ethernet1/intfCounter", path: "/Sysdb/lag/intfCounterDir/Ethernet1/intfCounter",
}: 42, }: 42,
source{ {
addr: "10.1.1.1", addr: "10.1.1.1",
path: "/Sysdb/environment/cooling/status/fan/speed/value", path: "/Sysdb/environment/cooling/status/fan/speed/value",
}: 45, }: 45,
source{ {
addr: "10.1.1.1", addr: "10.1.1.1",
path: "/Sysdb/igmpsnooping/vlanStatus/2050/ethGroup/01:00:5e:01:01:01/intf/Cpu", path: "/Sysdb/igmpsnooping/vlanStatus/2050/ethGroup/01:00:5e:01:01:01/intf/Cpu",
}: 1, }: 1,
{
addr: "10.1.1.1",
path: "/Sysdb/environment/cooling/status/fan/name",
}: 2.5,
} }
coll.update("10.1.1.1:6042", makeResponse(notif)) coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics := makeMetrics(cfg, expValues) expMetrics := makeMetrics(cfg, expValues, notif, nil)
if !test.DeepEqual(expMetrics, coll.metrics) { if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics)) t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
} }
// Update one value, and one path which is not a metric // Update two values, and one path which is not a metric
notif = &openconfig.Notification{ notif = &pb.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}}, Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{ Update: []*pb.Update{
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"}, Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("52"), Value: []byte("52"),
}, },
}, },
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"environment", "cooling", "status", "fan", "name"},
},
Value: &pb.Value{
Type: pb.Encoding_JSON,
Value: []byte("\"Fan2.1\""),
},
},
{
Path: &pb.Path{
Element: []string{"environment", "doesntexist", "status"}, Element: []string{"environment", "doesntexist", "status"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("{\"value\": 45}"), Value: []byte("{\"value\": 45}"),
}, },
}, },
@ -149,21 +223,21 @@ metrics:
expValues[src] = 52 expValues[src] = 52
coll.update("10.1.1.1:6042", makeResponse(notif)) coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues) expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) { if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics)) t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
} }
// Same path, different device // Same path, different device
notif = &openconfig.Notification{ notif = &pb.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}}, Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{ Update: []*pb.Update{
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"}, Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("42"), Value: []byte("42"),
}, },
}, },
@ -173,15 +247,15 @@ metrics:
expValues[src] = 42 expValues[src] = 42
coll.update("10.1.1.2:6042", makeResponse(notif)) coll.update("10.1.1.2:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues) expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) { if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics)) t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
} }
// Delete a path // Delete a path
notif = &openconfig.Notification{ notif = &pb.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}}, Prefix: &pb.Path{Element: []string{"Sysdb"}},
Delete: []*openconfig.Path{ Delete: []*pb.Path{
{ {
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"}, Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
}, },
@ -191,21 +265,21 @@ metrics:
delete(expValues, src) delete(expValues, src)
coll.update("10.1.1.1:6042", makeResponse(notif)) coll.update("10.1.1.1:6042", makeResponse(notif))
expMetrics = makeMetrics(cfg, expValues) expMetrics = makeMetrics(cfg, expValues, notif, expMetrics)
if !test.DeepEqual(expMetrics, coll.metrics) { if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics)) t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
} }
// Non-numeric update // Non-numeric update to path without value label
notif = &openconfig.Notification{ notif = &pb.Notification{
Prefix: &openconfig.Path{Element: []string{"Sysdb"}}, Prefix: &pb.Path{Element: []string{"Sysdb"}},
Update: []*openconfig.Update{ Update: []*pb.Update{
{ {
Path: &openconfig.Path{ Path: &pb.Path{
Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"}, Element: []string{"lag", "intfCounterDir", "Ethernet1", "intfCounter"},
}, },
Value: &openconfig.Value{ Value: &pb.Value{
Type: openconfig.Type_JSON, Type: pb.Encoding_JSON,
Value: []byte("\"test\""), Value: []byte("\"test\""),
}, },
}, },
@ -213,6 +287,7 @@ metrics:
} }
coll.update("10.1.1.1:6042", makeResponse(notif)) coll.update("10.1.1.1:6042", makeResponse(notif))
// Don't make new metrics as it should have no effect
if !test.DeepEqual(expMetrics, coll.metrics) { if !test.DeepEqual(expMetrics, coll.metrics) {
t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics)) t.Errorf("Mismatched metrics: %v", test.Diff(expMetrics, coll.metrics))
} }

View File

@ -33,7 +33,7 @@ type MetricDef struct {
Path string Path string
// Path compiled as a regexp. // Path compiled as a regexp.
re *regexp.Regexp re *regexp.Regexp `deepequal:"ignore"`
// Metric name. // Metric name.
Name string Name string
@ -41,6 +41,15 @@ type MetricDef struct {
// Metric help string. // Metric help string.
Help string Help string
// Label to store string values
ValueLabel string
// Default value to display for string values
DefaultValue float64
// Does the metric store a string value
stringMetric bool
// This map contains the metric descriptors for this metric for each device. // This map contains the metric descriptors for this metric for each device.
devDesc map[string]*prometheus.Desc devDesc map[string]*prometheus.Desc
@ -48,6 +57,14 @@ type MetricDef struct {
desc *prometheus.Desc desc *prometheus.Desc
} }
// metricValues contains the values used in updating a metric
type metricValues struct {
desc *prometheus.Desc
labels []string
defaultValue float64
stringMetric bool
}
// Parses the config and creates the descriptors for each path and device. // Parses the config and creates the descriptors for each path and device.
func parseConfig(cfg []byte) (*Config, error) { func parseConfig(cfg []byte) (*Config, error) {
config := &Config{ config := &Config{
@ -56,10 +73,8 @@ func parseConfig(cfg []byte) (*Config, error) {
if err := yaml.Unmarshal(cfg, config); err != nil { if err := yaml.Unmarshal(cfg, config); err != nil {
return nil, fmt.Errorf("Failed to parse config: %v", err) return nil, fmt.Errorf("Failed to parse config: %v", err)
} }
for _, def := range config.Metrics { for _, def := range config.Metrics {
def.re = regexp.MustCompile(def.Path) def.re = regexp.MustCompile(def.Path)
// Extract label names // Extract label names
reNames := def.re.SubexpNames()[1:] reNames := def.re.SubexpNames()[1:]
labelNames := make([]string, len(reNames)) labelNames := make([]string, len(reNames))
@ -69,7 +84,10 @@ func parseConfig(cfg []byte) (*Config, error) {
labelNames[i] = "unnamedLabel" + strconv.Itoa(i+1) labelNames[i] = "unnamedLabel" + strconv.Itoa(i+1)
} }
} }
if def.ValueLabel != "" {
labelNames = append(labelNames, def.ValueLabel)
def.stringMetric = true
}
// Create a default descriptor only if there aren't any per-device labels, // Create a default descriptor only if there aren't any per-device labels,
// or if it's explicitly declared // or if it's explicitly declared
if len(config.DeviceLabels) == 0 || len(config.DeviceLabels["*"]) > 0 { if len(config.DeviceLabels) == 0 || len(config.DeviceLabels["*"]) > 0 {
@ -88,20 +106,25 @@ func parseConfig(cfg []byte) (*Config, error) {
return config, nil return config, nil
} }
// Returns the descriptor corresponding to the device and path, and labels extracted from the path. // Returns a struct containing the descriptor corresponding to the device and path, labels
// extracted from the path, the default value for the metric and if it accepts string values.
// If the device and path doesn't match any metrics, returns nil. // If the device and path doesn't match any metrics, returns nil.
func (c *Config) getDescAndLabels(s source) (*prometheus.Desc, []string) { func (c *Config) getMetricValues(s source) *metricValues {
for _, def := range c.Metrics { for _, def := range c.Metrics {
if groups := def.re.FindStringSubmatch(s.path); groups != nil { if groups := def.re.FindStringSubmatch(s.path); groups != nil {
if desc, ok := def.devDesc[s.addr]; ok { if def.ValueLabel != "" {
return desc, groups[1:] groups = append(groups, def.ValueLabel)
} }
desc, ok := def.devDesc[s.addr]
return def.desc, groups[1:] if !ok {
desc = def.desc
}
return &metricValues{desc: desc, labels: groups[1:], defaultValue: def.DefaultValue,
stringMetric: def.stringMetric}
} }
} }
return nil, nil return nil
} }
// Sends all the descriptors to the channel. // Sends all the descriptors to the channel.

View File

@ -31,6 +31,11 @@ subscriptions:
- /Sysdb/environment/cooling/status - /Sysdb/environment/cooling/status
- /Sysdb/environment/power/status - /Sysdb/environment/power/status
metrics: metrics:
- name: fanName
path: /Sysdb/environment/cooling/status/fan/name
help: Fan Name
valuelabel: name
defaultvalue: 25
- name: intfCounter - name: intfCounter
path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter path: /Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter
help: Per-Interface Bytes/Errors/Discards Counters help: Per-Interface Bytes/Errors/Discards Counters
@ -53,6 +58,26 @@ metrics:
"/Sysdb/environment/power/status", "/Sysdb/environment/power/status",
}, },
Metrics: []*MetricDef{ Metrics: []*MetricDef{
{
Path: "/Sysdb/environment/cooling/status/fan/name",
re: regexp.MustCompile(
"/Sysdb/environment/cooling/status/fan/name"),
Name: "fanName",
Help: "Fan Name",
ValueLabel: "name",
DefaultValue: 25,
stringMetric: true,
devDesc: map[string]*prometheus.Desc{
"10.1.1.1": prometheus.NewDesc("fanName",
"Fan Name",
[]string{"name"},
prometheus.Labels{"lab1": "val1", "lab2": "val2"}),
},
desc: prometheus.NewDesc("fanName",
"Fan Name",
[]string{"name"},
prometheus.Labels{"lab1": "val3", "lab2": "val4"}),
},
{ {
Path: "/Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter", Path: "/Sysdb/(lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter",
re: regexp.MustCompile( re: regexp.MustCompile(
@ -127,7 +152,8 @@ metrics:
}, },
{ {
Path: "/Sysdb/environment/cooling/fan/speed/value", Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"), re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed", Name: "fanSpeed",
Help: "Fan Speed", Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{}, devDesc: map[string]*prometheus.Desc{},
@ -180,7 +206,8 @@ metrics:
}, },
{ {
Path: "/Sysdb/environment/cooling/fan/speed/value", Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"), re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed", Name: "fanSpeed",
Help: "Fan Speed", Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{ devDesc: map[string]*prometheus.Desc{
@ -223,7 +250,8 @@ metrics:
}, },
{ {
Path: "/Sysdb/environment/cooling/fan/speed/value", Path: "/Sysdb/environment/cooling/fan/speed/value",
re: regexp.MustCompile("/Sysdb/environment/cooling/fan/speed/value"), re: regexp.MustCompile(
"/Sysdb/environment/cooling/fan/speed/value"),
Name: "fanSpeed", Name: "fanSpeed",
Help: "Fan Speed", Help: "Fan Speed",
devDesc: map[string]*prometheus.Desc{}, devDesc: map[string]*prometheus.Desc{},
@ -247,7 +275,7 @@ metrics:
} }
} }
func TestGetDescAndLabels(t *testing.T) { func TestGetMetricValues(t *testing.T) {
config := []byte(` config := []byte(`
devicelabels: devicelabels:
10.1.1.1: 10.1.1.1:
@ -317,12 +345,16 @@ metrics:
} }
for i, c := range tCases { for i, c := range tCases {
desc, labels := cfg.getDescAndLabels(c.src) metric := cfg.getMetricValues(c.src)
if !test.DeepEqual(desc, c.desc) { if metric == nil {
t.Errorf("Test case %d: desc mismatch %v", i+1, test.Diff(desc, c.desc)) // Avoids error from trying to access metric.desc when metric is nil
metric = &metricValues{}
} }
if !test.DeepEqual(labels, c.labels) { if !test.DeepEqual(metric.desc, c.desc) {
t.Errorf("Test case %d: labels mismatch %v", i+1, test.Diff(labels, c.labels)) t.Errorf("Test case %d: desc mismatch %v", i+1, test.Diff(metric.desc, c.desc))
}
if !test.DeepEqual(metric.labels, c.labels) {
t.Errorf("Test case %d: labels mismatch %v", i+1, test.Diff(metric.labels, c.labels))
} }
} }
} }

View File

@ -6,25 +6,39 @@
package main package main
import ( import (
"context"
"flag" "flag"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"sync" "strings"
"github.com/aristanetworks/goarista/openconfig/client"
"github.com/aristanetworks/glog" "github.com/aristanetworks/glog"
"github.com/aristanetworks/goarista/gnmi"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
) )
func main() { func main() {
// gNMI options
gNMIcfg := &gnmi.Config{}
flag.StringVar(&gNMIcfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&gNMIcfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&gNMIcfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&gNMIcfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&gNMIcfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&gNMIcfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&gNMIcfg.TLS, "tls", false, "Enable TLS")
subscribePaths := flag.String("subscribe", "/", "Comma-separated list of paths to subscribe to")
// program options
listenaddr := flag.String("listenaddr", ":8080", "Address on which to expose the metrics") listenaddr := flag.String("listenaddr", ":8080", "Address on which to expose the metrics")
url := flag.String("url", "/metrics", "URL where to expose the metrics") url := flag.String("url", "/metrics", "URL where to expose the metrics")
configFlag := flag.String("config", "", configFlag := flag.String("config", "",
"Config to turn OpenConfig telemetry into Prometheus metrics") "Config to turn OpenConfig telemetry into Prometheus metrics")
username, password, subscriptions, addrs, opts := client.ParseFlags()
flag.Parse()
subscriptions := strings.Split(*subscribePaths, ",")
if *configFlag == "" { if *configFlag == "" {
glog.Fatal("You need specify a config file using -config flag") glog.Fatal("You need specify a config file using -config flag")
} }
@ -39,7 +53,7 @@ func main() {
// Ignore the default "subscribe-to-everything" subscription of the // Ignore the default "subscribe-to-everything" subscription of the
// -subscribe flag. // -subscribe flag.
if subscriptions[0] == "" { if subscriptions[0] == "/" {
subscriptions = subscriptions[1:] subscriptions = subscriptions[1:]
} }
// Add the subscriptions from the config file. // Add the subscriptions from the config file.
@ -47,14 +61,33 @@ func main() {
coll := newCollector(config) coll := newCollector(config)
prometheus.MustRegister(coll) prometheus.MustRegister(coll)
ctx := gnmi.NewContext(context.Background(), gNMIcfg)
wg := new(sync.WaitGroup) client, err := gnmi.Dial(gNMIcfg)
for _, addr := range addrs { if err != nil {
wg.Add(1) glog.Fatal(err)
c := client.New(username, password, addr, opts)
go c.Subscribe(wg, subscriptions, coll.update)
} }
respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
go handleSubscription(respChan, errChan, coll, gNMIcfg.Addr)
http.Handle(*url, promhttp.Handler()) http.Handle(*url, promhttp.Handler())
glog.Fatal(http.ListenAndServe(*listenaddr, nil)) glog.Fatal(http.ListenAndServe(*listenaddr, nil))
} }
func handleSubscription(respChan chan *pb.SubscribeResponse,
errChan chan error, coll *collector, addr string) {
for {
select {
case resp := <-respChan:
coll.update(addr, resp)
case err := <-errChan:
glog.Fatal(err)
}
}
}

View File

@ -0,0 +1,80 @@
# Per-device labels. Optional
# Exactly the same set of labels must be specified for each device.
# If device address is *, the labels apply to all devices not listed explicitly.
# If any explicit device if listed below, then you need to specify all devices you're subscribed to,
# or have a wildcard entry. Otherwise, updates from non-listed devices will be ignored.
#deviceLabels:
# 10.1.1.1:
# lab1: val1
# lab2: val2
# '*':
# lab1: val3
# lab2: val4
# Subscriptions to OpenConfig paths.
subscriptions:
- /Smash/counters/ethIntf
- /Smash/interface/counter/lag/current/counter
- /Sysdb/environment/archer/cooling/status
- /Sysdb/environment/archer/power/status
- /Sysdb/environment/archer/temperature/status
- /Sysdb/hardware/archer/xcvr/status
- /Sysdb/interface/config/eth
# Prometheus metrics configuration.
# If you use named capture groups in the path, they will be extracted into labels with the same name.
# All fields are mandatory.
metrics:
- name: interfaceDescription
path: /Sysdb/interface/config/eth/phy/slice/1/intfConfig/(?P<interface>Ethernet.)/description
help: Description
valuelabel: description
defaultvalue: 15
- name: intfCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)
help: Per-Interface Bytes/Errors/Discards Counters
- name: intfLagCounter
path: /Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)
help: Per-Lag Bytes/Errors/Discards Counters
- name: intfPktCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)
help: Per-Interface Unicast/Multicast/Broadcast Packer Counters
- name: intfLagPktCounter
path: /Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)
help: Per-Lag Unicast/Multicast/Broadcast Packer Counters
- name: intfPfcClassCounter
path: /Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/ethStatistics/(?P<direction>(?:in|out))(PfcClassFrames)
help: Per-Interface Input/Output PFC Frames Counters
- name: tempSensor
path: /Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/((?:maxT|t)emperature)
help: Temperature and Maximum Temperature
- name: tempSensorAlert
path: /Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/(alertRaisedCount)
help: Temperature Alerts Counter
- name: currentSensor
path: /Sysdb/(environment)/archer/power/status/currentSensor/(?P<sensor>.+)/(current)
help: Current Levels
- name: powerSensor
path: /Sysdb/(environment)/archer/(power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power
help: Input/Output Power Levels
- name: voltageSensor
path: /Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(voltage)
help: Voltage Levels
- name: railCurrentSensor
path: /Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(current)
help: Rail Current Levels
- name: fanSpeed
path: /Sysdb/(environment)/archer/(cooling)/status/(?P<fan>.+)/speed
help: Fan Speed
- name: qsfpModularRxPower
path: /Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)
help: qsfpModularRxPower
- name: qsfpFixedRxPower
path: /Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)
help: qsfpFixedRxPower
- name: sfpModularTemperature
path: /Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/lastDomUpdateTime/(temperature)
help: sfpModularTemperature
- name: sfpFixedTemperature
path: /Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/lastDomUpdateTime/(temperature)
help: sfpFixedTemperature

View File

@ -18,11 +18,18 @@ subscriptions:
- /Sysdb/environment/temperature/status - /Sysdb/environment/temperature/status
- /Sysdb/interface/counter/eth/lag - /Sysdb/interface/counter/eth/lag
- /Sysdb/interface/counter/eth/slice/phy - /Sysdb/interface/counter/eth/slice/phy
- /Sysdb/interface/config
- /Sysdb/interface/config/eth/phy/slice/1/intfConfig
# Prometheus metrics configuration. # Prometheus metrics configuration.
# If you use named capture groups in the path, they will be extracted into labels with the same name. # If you use named capture groups in the path, they will be extracted into labels with the same name.
# All fields are mandatory. # All fields are mandatory.
metrics: metrics:
- name: interfaceDescription
path: Sysdb/interface/config/eth/phy/slice/1/intfConfig/(?P<interface>Ethernet.)/description
help: Description
valuelabel: description
defaultvalue: 15
- name: intfCounter - name: intfCounter
path: /Sysdb/interface/counter/eth/(?:lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter/current/statistics/(?P<direction>(?:in|out))(?P<type>(Octets|Errors|Discards)) path: /Sysdb/interface/counter/eth/(?:lag|slice/phy/.+)/intfCounterDir/(?P<intf>.+)/intfCounter/current/statistics/(?P<direction>(?:in|out))(?P<type>(Octets|Errors|Discards))
help: Per-Interface Bytes/Errors/Discards Counters help: Per-Interface Bytes/Errors/Discards Counters

View File

@ -17,5 +17,5 @@ See the `-help` output, but here's an example to push all the temperature
sensors into Redis. You can also not pass any `-subscribe` flag to push sensors into Redis. You can also not pass any `-subscribe` flag to push
_everything_ into Redis. _everything_ into Redis.
``` ```
ocredis -subscribe /Sysdb/environment/temperature -addrs <switch-hostname>:6042 -redis <redis-hostname>:6379 ocredis -subscribe /Sysdb/environment/temperature -addr <switch-hostname>:6042 -redis <redis-hostname>:6379
``` ```

View File

@ -8,16 +8,15 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"flag" "flag"
"strings" "strings"
"sync"
occlient "github.com/aristanetworks/goarista/openconfig/client" "github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/glog" "github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto" pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/openconfig/reference/rpc/openconfig"
redis "gopkg.in/redis.v4" redis "gopkg.in/redis.v4"
) )
@ -42,11 +41,23 @@ type baseClient interface {
var client baseClient var client baseClient
func main() { func main() {
username, password, subscriptions, hostAddrs, opts := occlient.ParseFlags()
// gNMI options
cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&cfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
subscribePaths := flag.String("subscribe", "/", "Comma-separated list of paths to subscribe to")
flag.Parse()
if *redisFlag == "" { if *redisFlag == "" {
glog.Fatal("Specify the address of the Redis server to write to with -redis") glog.Fatal("Specify the address of the Redis server to write to with -redis")
} }
subscriptions := strings.Split(*subscribePaths, ",")
redisAddrs := strings.Split(*redisFlag, ",") redisAddrs := strings.Split(*redisFlag, ",")
if !*clusterMode && len(redisAddrs) > 1 { if !*clusterMode && len(redisAddrs) > 1 {
glog.Fatal("Please pass only 1 redis address in noncluster mode or enable cluster mode") glog.Fatal("Please pass only 1 redis address in noncluster mode or enable cluster mode")
@ -72,25 +83,27 @@ func main() {
if err != nil { if err != nil {
glog.Fatal("Failed to connect to client: ", err) glog.Fatal("Failed to connect to client: ", err)
} }
ctx := gnmi.NewContext(context.Background(), cfg)
ocPublish := func(addr string, message proto.Message) { client, err := gnmi.Dial(cfg)
resp, ok := message.(*openconfig.SubscribeResponse) if err != nil {
if !ok { glog.Fatal(err)
glog.Errorf("Unexpected type of message: %T", message)
return
} }
if notif := resp.GetUpdate(); notif != nil { respChan := make(chan *pb.SubscribeResponse)
bufferToRedis(addr, notif) errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for {
select {
case resp := <-respChan:
bufferToRedis(cfg.Addr, resp.GetUpdate())
case err := <-errChan:
glog.Fatal(err)
} }
} }
wg := new(sync.WaitGroup)
for _, hostAddr := range hostAddrs {
wg.Add(1)
c := occlient.New(username, password, hostAddr, opts)
go c.Subscribe(wg, subscriptions, ocPublish)
}
wg.Wait()
} }
type redisData struct { type redisData struct {
@ -100,7 +113,12 @@ type redisData struct {
pub map[string]interface{} pub map[string]interface{}
} }
func bufferToRedis(addr string, notif *openconfig.Notification) { func bufferToRedis(addr string, notif *pb.Notification) {
if notif == nil {
// possible that this should be ignored silently
glog.Error("Nil notification ignored")
return
}
path := addr + "/" + joinPath(notif.Prefix) path := addr + "/" + joinPath(notif.Prefix)
data := &redisData{key: path} data := &redisData{key: path}
@ -167,20 +185,21 @@ func redisPublish(path, kind string, payload interface{}) {
} }
} }
func joinPath(path *openconfig.Path) string { func joinPath(path *pb.Path) string {
// path.Elem is empty for some reason so using path.Element instead
return strings.Join(path.Element, "/") return strings.Join(path.Element, "/")
} }
func convertUpdate(update *openconfig.Update) interface{} { func convertUpdate(update *pb.Update) interface{} {
switch update.Value.Type { switch update.Value.Type {
case openconfig.Type_JSON: case pb.Encoding_JSON:
var value interface{} var value interface{}
err := json.Unmarshal(update.Value.Value, &value) err := json.Unmarshal(update.Value.Value, &value)
if err != nil { if err != nil {
glog.Fatalf("Malformed JSON update %q in %s", update.Value.Value, update) glog.Fatalf("Malformed JSON update %q in %s", update.Value.Value, update)
} }
return value return value
case openconfig.Type_BYTES: case pb.Encoding_BYTES:
return update.Value.Value return update.Value.Value
default: default:
glog.Fatalf("Unhandled type of value %v in %s", update.Value.Type, update) glog.Fatalf("Unhandled type of value %v in %s", update.Value.Type, update)

View File

@ -14,9 +14,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/aristanetworks/glog"
"github.com/aristanetworks/goarista/gnmi" "github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/splunk-hec-go"
"github.com/fuyufjh/splunk-hec-go"
pb "github.com/openconfig/gnmi/proto/gnmi" pb "github.com/openconfig/gnmi/proto/gnmi"
) )
@ -49,7 +50,10 @@ func main() {
ctx := gnmi.NewContext(context.Background(), cfg) ctx := gnmi.NewContext(context.Background(), cfg)
// Store the address without the port so it can be used as the host in the Splunk event. // Store the address without the port so it can be used as the host in the Splunk event.
addr := cfg.Addr addr := cfg.Addr
client := gnmi.Dial(cfg) client, err := gnmi.Dial(cfg)
if err != nil {
glog.Fatal(err)
}
// Splunk connection // Splunk connection
urls := strings.Split(*splunkURLs, ",") urls := strings.Split(*splunkURLs, ",")
@ -67,10 +71,14 @@ func main() {
// gNMI subscription // gNMI subscription
respChan := make(chan *pb.SubscribeResponse) respChan := make(chan *pb.SubscribeResponse)
errChan := make(chan error) errChan := make(chan error)
defer close(respChan)
defer close(errChan) defer close(errChan)
paths := strings.Split(*subscribePaths, ",") paths := strings.Split(*subscribePaths, ",")
go gnmi.Subscribe(ctx, client, gnmi.SplitPaths(paths), respChan, errChan) subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(paths),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
// Forward subscribe responses to Splunk // Forward subscribe responses to Splunk
for { for {

View File

@ -6,28 +6,51 @@ dropped.
This tool requires a config file to specify how to map the path of the This tool requires a config file to specify how to map the path of the
notificatons coming out of the OpenConfig gRPC interface onto OpenTSDB notificatons coming out of the OpenConfig gRPC interface onto OpenTSDB
metric names, and how to extract tags from the path. For example, the metric names, and how to extract tags from the path.
following rule, excerpt from `sampleconfig.json`:
## Getting Started
To begin, a list of subscriptions is required (excerpt from `sampleconfig.json`):
```json
"subscriptions": [
"/Sysdb/interface/counter/eth/lag",
"/Sysdb/interface/counter/eth/slice/phy",
"/Sysdb/environment/temperature/status",
"/Sysdb/environment/cooling/status",
"/Sysdb/environment/power/status",
"/Sysdb/hardware/xcvr/status/all/xcvrStatus"
],
...
```
Note that subscriptions should not end with a trailing `/` as that will cause
the subscription to fail.
Afterwards, the metrics are defined (excerpt from `sampleconfig.json`):
```json ```json
"metrics": { "metrics": {
"tempSensor": { "tempSensor": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)/value" "path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)"
}, },
... ...
}
``` ```
Applied to an update for the path In the metrics path, unnamed matched groups are used to make up the metric name, and named matched groups
`/Sysdb/environment/temperature/status/tempSensor/TempSensor1/temperature/value` are used to extract optional tags. Note that unnamed groups are required, otherwise the metric
will lead to the metric name `environment.temperature` and tags `sensor=TempSensor1`. name will be empty and the update will be silently dropped.
Basically, un-named groups are used to make up the metric name, and named For example, using the above metrics path applied to an update for the path
groups are used to extract (optional) tags. `/Sysdb/environment/temperature/status/tempSensor/TempSensor1/temperature`
will lead to the metric name `environment.temperature` and tags `sensor=TempSensor1`.
## Usage ## Usage
See the `-help` output, but here's an example to push all the metrics defined See the `-help` output, but here's an example to push all the metrics defined
in the sample config file: in the sample config file:
``` ```
octsdb -addrs <switch-hostname>:6042 -config sampleconfig.json -text | nc <tsd-hostname> 4242 octsdb -addr <switch-hostname>:6042 -config sampleconfig.json -text | nc <tsd-hostname> 4242
``` ```

View File

@ -16,6 +16,9 @@ func TestConfig(t *testing.T) {
t.Fatal("Managed to load a nonexistent config!") t.Fatal("Managed to load a nonexistent config!")
} }
cfg, err = loadConfig("sampleconfig.json") cfg, err = loadConfig("sampleconfig.json")
if err != nil {
t.Fatal("Failed to load config:", err)
}
testcases := []struct { testcases := []struct {
path string path string

View File

@ -7,22 +7,35 @@ package main
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"flag" "flag"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/aristanetworks/goarista/openconfig/client" "github.com/aristanetworks/goarista/gnmi"
"github.com/aristanetworks/glog" "github.com/aristanetworks/glog"
"github.com/golang/protobuf/proto" pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/openconfig/reference/rpc/openconfig"
) )
func main() { func main() {
// gNMI options
cfg := &gnmi.Config{}
flag.StringVar(&cfg.Addr, "addr", "localhost", "gNMI gRPC server `address`")
flag.StringVar(&cfg.CAFile, "cafile", "", "Path to server TLS certificate file")
flag.StringVar(&cfg.CertFile, "certfile", "", "Path to client TLS certificate file")
flag.StringVar(&cfg.KeyFile, "keyfile", "", "Path to client TLS private key file")
flag.StringVar(&cfg.Username, "username", "", "Username to authenticate with")
flag.StringVar(&cfg.Password, "password", "", "Password to authenticate with")
flag.BoolVar(&cfg.TLS, "tls", false, "Enable TLS")
// Program options
subscribePaths := flag.String("paths", "/", "Comma-separated list of paths to subscribe to")
tsdbFlag := flag.String("tsdb", "", tsdbFlag := flag.String("tsdb", "",
"Address of the OpenTSDB server where to push telemetry to") "Address of the OpenTSDB server where to push telemetry to")
textFlag := flag.Bool("text", false, textFlag := flag.Bool("text", false,
@ -38,8 +51,8 @@ func main() {
" Clients and servers should have the same number.") " Clients and servers should have the same number.")
udpTimeoutFlag := flag.Duration("udptimeout", 2*time.Second, udpTimeoutFlag := flag.Duration("udptimeout", 2*time.Second,
"Timeout for each") "Timeout for each")
username, password, subscriptions, addrs, opts := client.ParseFlags()
flag.Parse()
if !(*tsdbFlag != "" || *textFlag || *udpAddrFlag != "") { if !(*tsdbFlag != "" || *textFlag || *udpAddrFlag != "") {
glog.Fatal("Specify the address of the OpenTSDB server to write to with -tsdb") glog.Fatal("Specify the address of the OpenTSDB server to write to with -tsdb")
} else if *configFlag == "" { } else if *configFlag == "" {
@ -52,6 +65,7 @@ func main() {
} }
// Ignore the default "subscribe-to-everything" subscription of the // Ignore the default "subscribe-to-everything" subscription of the
// -subscribe flag. // -subscribe flag.
subscriptions := strings.Split(*subscribePaths, ",")
if subscriptions[0] == "" { if subscriptions[0] == "" {
subscriptions = subscriptions[1:] subscriptions = subscriptions[1:]
} }
@ -79,33 +93,37 @@ func main() {
// TODO: support HTTP(S). // TODO: support HTTP(S).
c = newTelnetClient(*tsdbFlag) c = newTelnetClient(*tsdbFlag)
} }
ctx := gnmi.NewContext(context.Background(), cfg)
wg := new(sync.WaitGroup) client, err := gnmi.Dial(cfg)
for _, addr := range addrs { if err != nil {
wg.Add(1) glog.Fatal(err)
publish := func(addr string, message proto.Message) {
resp, ok := message.(*openconfig.SubscribeResponse)
if !ok {
glog.Errorf("Unexpected type of message: %T", message)
return
} }
if notif := resp.GetUpdate(); notif != nil { respChan := make(chan *pb.SubscribeResponse)
pushToOpenTSDB(addr, c, config, notif) errChan := make(chan error)
subscribeOptions := &gnmi.SubscribeOptions{
Mode: "stream",
StreamMode: "target_defined",
Paths: gnmi.SplitPaths(subscriptions),
}
go gnmi.Subscribe(ctx, client, subscribeOptions, respChan, errChan)
for {
select {
case resp := <-respChan:
pushToOpenTSDB(cfg.Addr, c, config, resp.GetUpdate())
case err := <-errChan:
glog.Fatal(err)
} }
} }
c := client.New(username, password, addr, opts)
go c.Subscribe(wg, subscriptions, publish)
}
wg.Wait()
} }
func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config, func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config, notif *pb.Notification) {
notif *openconfig.Notification) { if notif == nil {
glog.Error("Nil notification ignored")
return
}
if notif.Timestamp <= 0 { if notif.Timestamp <= 0 {
glog.Fatalf("Invalid timestamp %d in %s", notif.Timestamp, notif) glog.Fatalf("Invalid timestamp %d in %s", notif.Timestamp, notif)
} }
host := addr[:strings.IndexRune(addr, ':')] host := addr[:strings.IndexRune(addr, ':')]
if host == "localhost" { if host == "localhost" {
// TODO: On Linux this reads /proc/sys/kernel/hostname each time, // TODO: On Linux this reads /proc/sys/kernel/hostname each time,
@ -118,18 +136,15 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
} }
} }
prefix := "/" + strings.Join(notif.Prefix.Element, "/") prefix := "/" + strings.Join(notif.Prefix.Element, "/")
for _, update := range notif.Update { for _, update := range notif.Update {
if update.Value == nil || update.Value.Type != openconfig.Type_JSON { if update.Value == nil || update.Value.Type != pb.Encoding_JSON {
glog.V(9).Infof("Ignoring incompatible update value in %s", update) glog.V(9).Infof("Ignoring incompatible update value in %s", update)
continue continue
} }
value := parseValue(update) value := parseValue(update)
if value == nil { if value == nil {
continue continue
} }
path := prefix + "/" + strings.Join(update.Path.Element, "/") path := prefix + "/" + strings.Join(update.Path.Element, "/")
metricName, tags := config.Match(path) metricName, tags := config.Match(path)
if metricName == "" { if metricName == "" {
@ -137,7 +152,6 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
continue continue
} }
tags["host"] = host tags["host"] = host
for i, v := range value { for i, v := range value {
if len(value) > 1 { if len(value) > 1 {
tags["index"] = strconv.Itoa(i) tags["index"] = strconv.Itoa(i)
@ -158,7 +172,7 @@ func pushToOpenTSDB(addr string, conn OpenTSDBConn, config *Config,
// parseValue returns either an integer/floating point value of the given update, or if // parseValue returns either an integer/floating point value of the given update, or if
// the value is a slice of integers/floating point values. If the value is neither of these // the value is a slice of integers/floating point values. If the value is neither of these
// or if any element in the slice is non numerical, parseValue returns nil. // or if any element in the slice is non numerical, parseValue returns nil.
func parseValue(update *openconfig.Update) []interface{} { func parseValue(update *pb.Update) []interface{} {
var value interface{} var value interface{}
decoder := json.NewDecoder(bytes.NewReader(update.Value.Value)) decoder := json.NewDecoder(bytes.NewReader(update.Value.Value))
@ -196,7 +210,7 @@ func parseValue(update *openconfig.Update) []interface{} {
} }
// Convert our json.Number to either an int64, uint64, or float64. // Convert our json.Number to either an int64, uint64, or float64.
func parseNumber(num json.Number, update *openconfig.Update) interface{} { func parseNumber(num json.Number, update *pb.Update) interface{} {
var value interface{} var value interface{}
var err error var err error
if value, err = num.Int64(); err != nil { if value, err = num.Int64(); err != nil {

View File

@ -9,8 +9,7 @@ import (
"testing" "testing"
"github.com/aristanetworks/goarista/test" "github.com/aristanetworks/goarista/test"
pb "github.com/openconfig/gnmi/proto/gnmi"
"github.com/openconfig/reference/rpc/openconfig"
) )
func TestParseValue(t *testing.T) { // Because parsing JSON sucks. func TestParseValue(t *testing.T) { // Because parsing JSON sucks.
@ -35,8 +34,8 @@ func TestParseValue(t *testing.T) { // Because parsing JSON sucks.
}, },
} }
for i, tcase := range testcases { for i, tcase := range testcases {
actual := parseValue(&openconfig.Update{ actual := parseValue(&pb.Update{
Value: &openconfig.Value{ Value: &pb.Value{
Value: []byte(tcase.input), Value: []byte(tcase.input),
}, },
}) })

View File

@ -1,11 +1,14 @@
{ {
"comment": "This is a sample configuration", "comment": "This is a sample configuration for EOS versions below 4.20",
"subscriptions": [ "subscriptions": [
"/Sysdb/interface/counter/eth/lag",
"/Sysdb/interface/counter/eth/slice/phy",
"/Sysdb/environment/temperature/status",
"/Sysdb/environment/cooling/status", "/Sysdb/environment/cooling/status",
"/Sysdb/environment/power/status", "/Sysdb/environment/power/status",
"/Sysdb/environment/temperature/status",
"/Sysdb/interface/counter/eth/lag", "/Sysdb/hardware/xcvr/status/all/xcvrStatus"
"/Sysdb/interface/counter/eth/slice/phy"
], ],
"metricPrefix": "eos", "metricPrefix": "eos",
"metrics": { "metrics": {
@ -20,25 +23,32 @@
}, },
"tempSensor": { "tempSensor": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)/value" "path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/((?:maxT|t)emperature)"
}, },
"tempSensorAlert": { "tempSensorAlert": {
"path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/(alertRaisedCount)" "path": "/Sysdb/(environment)/temperature/status/tempSensor/(?P<sensor>.+)/(alertRaisedCount)"
}, },
"currentSensor": { "currentSensor": {
"path": "/Sysdb/(environment)/power/status/currentSensor/(?P<sensor>.+)/(current)/value" "path": "/Sysdb/(environment)/power/status/currentSensor/(?P<sensor>.+)/(current)"
}, },
"powerSensor": { "powerSensor": {
"path": "/Sysdb/(environment/power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power/value" "path": "/Sysdb/(environment/power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power"
}, },
"voltageSensor": { "voltageSensor": {
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(voltage)/value" "path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(voltage)"
}, },
"railCurrentSensor": { "railCurrentSensor": {
"path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(current)/value" "path": "/Sysdb/(environment)/power/status/voltageSensor/(?P<sensor>.+)/(current)"
}, },
"fanSpeed": { "fanSpeed": {
"path": "/Sysdb/(environment)/cooling/status/(fan)/(?P<fan>.+)/(speed)/value" "path": "/Sysdb/(environment)/cooling/status/(fan)/(?P<fan>.+)/(speed)"
},
"qsfpRxPower": {
"path": "/Sysdb/hardware/(xcvr)/status/all/xcvrStatus/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"sfpTemperature": {
"path": "/Sysdb/hardware/(xcvr)/status/all/xcvrStatus/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
} }
} }
} }

View File

@ -0,0 +1,66 @@
{
"comment": "This is a sample configuration for EOS versions above 4.20",
"subscriptions": [
"/Smash/counters/ethIntf",
"/Smash/interface/counter/lag/current/counter",
"/Sysdb/environment/archer/cooling/status",
"/Sysdb/environment/archer/power/status",
"/Sysdb/environment/archer/temperature/status",
"/Sysdb/hardware/archer/xcvr/status"
],
"metricPrefix": "eos",
"metrics": {
"intfCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)"
},
"intfLagCounter": {
"path": "/Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(Octets|Errors|Discards)"
},
"intfPktCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)"
},
"intfLagPktCounter": {
"path": "/Smash/interface/counter/lag/current/(counter)/(?P<intf>.+)/statistics/(?P<direction>(?:in|out))(?P<type>(?:Ucast|Multicast|Broadcast))(Pkt)"
},
"intfPfcClassCounter": {
"path": "/Smash/counters/ethIntf/FocalPointV2/current/(counter)/(?P<intf>.+)/ethStatistics/(?P<direction>(?:in|out))(PfcClassFrames)"
},
"tempSensor": {
"path": "/Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/((?:maxT|t)emperature)"
},
"tempSensorAlert": {
"path": "/Sysdb/(environment)/archer/temperature/status/(?P<sensor>.+)/(alertRaisedCount)"
},
"currentSensor": {
"path": "/Sysdb/(environment)/archer/power/status/currentSensor/(?P<sensor>.+)/(current)"
},
"powerSensor": {
"path": "/Sysdb/(environment)/archer/(power)/status/powerSupply/(?P<sensor>.+)/(input|output)Power"
},
"voltageSensor": {
"path": "/Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(voltage)"
},
"railCurrentSensor": {
"path": "/Sysdb/(environment)/archer/power/status/voltageSensor/(?:cell/.+|system)/(?P<sensor>.+)/(current)"
},
"fanSpeed": {
"path": "/Sysdb/(environment)/archer/(cooling)/status/(?P<fan>.+)/speed"
},
"qsfpModularRxPower": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"qsfpFixedRxPower": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/domRegisterData/lane(?P<lane>\\d)(OpticalRxPower)"
},
"sfpModularTemperature": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/slice/(?P<linecard>.+)/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
},
"sfpFixedTemperature": {
"path": "/Sysdb/hardware/archer/(xcvr)/status/all/(?P<intf>.+)/lastDomUpdateTime/(temperature)"
}
}
}

View File

@ -0,0 +1,286 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// test2influxdb writes results from 'go test -json' to an influxdb
// database.
//
// Example usage:
//
// go test -json | test2influxdb [options...]
//
// Points are written to influxdb with tags:
//
// package
// type "package" for a package result; "test" for a test result
// Additional tags set by -tags flag
//
// And fields:
//
// test string // "NONE" for whole package results
// elapsed float64 // in seconds
// pass float64 // 1 for PASS, 0 for FAIL
// Additional fields set by -fields flag
//
// "test" is a field instead of a tag to reduce cardinality of data.
//
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"github.com/aristanetworks/glog"
client "github.com/influxdata/influxdb/client/v2"
)
type tag struct {
key string
value string
}
type tags []tag
func (ts *tags) String() string {
s := make([]string, len(*ts))
for i, t := range *ts {
s[i] = t.key + "=" + t.value
}
return strings.Join(s, ",")
}
func (ts *tags) Set(s string) error {
for _, fieldString := range strings.Split(s, ",") {
kv := strings.Split(fieldString, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid tag, expecting one '=': %q", fieldString)
}
key := strings.TrimSpace(kv[0])
if key == "" {
return fmt.Errorf("invalid tag key %q in %q", key, fieldString)
}
val := strings.TrimSpace(kv[1])
if val == "" {
return fmt.Errorf("invalid tag value %q in %q", val, fieldString)
}
*ts = append(*ts, tag{key: key, value: val})
}
return nil
}
type field struct {
key string
value interface{}
}
type fields []field
func (fs *fields) String() string {
s := make([]string, len(*fs))
for i, f := range *fs {
var valString string
switch v := f.value.(type) {
case bool:
valString = strconv.FormatBool(v)
case float64:
valString = strconv.FormatFloat(v, 'f', -1, 64)
case int64:
valString = strconv.FormatInt(v, 10) + "i"
case string:
valString = v
}
s[i] = f.key + "=" + valString
}
return strings.Join(s, ",")
}
func (fs *fields) Set(s string) error {
for _, fieldString := range strings.Split(s, ",") {
kv := strings.Split(fieldString, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid field, expecting one '=': %q", fieldString)
}
key := strings.TrimSpace(kv[0])
if key == "" {
return fmt.Errorf("invalid field key %q in %q", key, fieldString)
}
val := strings.TrimSpace(kv[1])
if val == "" {
return fmt.Errorf("invalid field value %q in %q", val, fieldString)
}
var value interface{}
var err error
if value, err = strconv.ParseBool(val); err == nil {
// It's a bool
} else if value, err = strconv.ParseFloat(val, 64); err == nil {
// It's a float64
} else if value, err = strconv.ParseInt(val[:len(val)-1], 0, 64); err == nil &&
val[len(val)-1] == 'i' {
// ints are suffixed with an "i"
} else {
value = val
}
*fs = append(*fs, field{key: key, value: value})
}
return nil
}
var (
flagAddr = flag.String("addr", "http://localhost:8086", "adddress of influxdb database")
flagDB = flag.String("db", "gotest", "use `database` in influxdb")
flagMeasurement = flag.String("m", "result", "`measurement` used in influxdb database")
flagTags tags
flagFields fields
)
func init() {
flag.Var(&flagTags, "tags", "set additional `tags`. Ex: name=alice,food=pasta")
flag.Var(&flagFields, "fields", "set additional `fields`. Ex: id=1234i,long=34.123,lat=72.234")
}
func main() {
flag.Parse()
c, err := client.NewHTTPClient(client.HTTPConfig{
Addr: *flagAddr,
})
if err != nil {
glog.Fatal(err)
}
batch, err := client.NewBatchPoints(client.BatchPointsConfig{Database: *flagDB})
if err != nil {
glog.Fatal(err)
}
if err := parseTestOutput(os.Stdin, batch); err != nil {
glog.Fatal(err)
}
if err := c.Write(batch); err != nil {
glog.Fatal(err)
}
}
// See https://golang.org/cmd/test2json/ for a description of 'go test
// -json' output
type testEvent struct {
Time time.Time // encodes as an RFC3339-format string
Action string
Package string
Test string
Elapsed float64 // seconds
Output string
}
func createTags(e *testEvent) map[string]string {
tags := make(map[string]string, len(flagTags)+2)
for _, t := range flagTags {
tags[t.key] = t.value
}
resultType := "test"
if e.Test == "" {
resultType = "package"
}
tags["package"] = e.Package
tags["type"] = resultType
return tags
}
func createFields(e *testEvent) map[string]interface{} {
fields := make(map[string]interface{}, len(flagFields)+3)
for _, f := range flagFields {
fields[f.key] = f.value
}
// Use a float64 instead of a bool to be able to SUM test
// successes in influxdb.
var pass float64
if e.Action == "pass" {
pass = 1
}
fields["pass"] = pass
fields["elapsed"] = e.Elapsed
if e.Test != "" {
fields["test"] = e.Test
}
return fields
}
func parseTestOutput(r io.Reader, batch client.BatchPoints) error {
// pkgs holds packages seen in r. Unfortunately, if a test panics,
// then there is no "fail" result from a package. To detect these
// kind of failures, keep track of all the packages that never had
// a "pass" or "fail".
//
// The last seen timestamp is stored with the package, so that
// package result measurement written to influxdb can be later
// than any test result for that package.
pkgs := make(map[string]time.Time)
d := json.NewDecoder(r)
for {
e := &testEvent{}
if err := d.Decode(e); err != nil {
if err != io.EOF {
return err
}
break
}
switch e.Action {
case "pass", "fail":
default:
continue
}
if e.Test == "" {
// A package has completed.
delete(pkgs, e.Package)
} else {
pkgs[e.Package] = e.Time
}
point, err := client.NewPoint(
*flagMeasurement,
createTags(e),
createFields(e),
e.Time,
)
if err != nil {
return err
}
batch.AddPoint(point)
}
for pkg, t := range pkgs {
pkgFail := &testEvent{
Action: "fail",
Package: pkg,
}
point, err := client.NewPoint(
*flagMeasurement,
createTags(pkgFail),
createFields(pkgFail),
// Fake a timestamp that is later than anything that
// occurred for this package
t.Add(time.Millisecond),
)
if err != nil {
return err
}
batch.AddPoint(point)
}
return nil
}

View File

@ -0,0 +1,151 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package main
import (
"os"
"testing"
"time"
"github.com/aristanetworks/goarista/test"
"github.com/influxdata/influxdb/client/v2"
)
func newPoint(t *testing.T, measurement string, tags map[string]string,
fields map[string]interface{}, timeString string) *client.Point {
t.Helper()
timestamp, err := time.Parse(time.RFC3339Nano, timeString)
if err != nil {
t.Fatal(err)
}
p, err := client.NewPoint(measurement, tags, fields, timestamp)
if err != nil {
t.Fatal(err)
}
return p
}
func TestParseTestOutput(t *testing.T) {
// Verify tags and fields set by flags are set in records
flagTags.Set("tag=foo")
flagFields.Set("field=true")
defer func() {
flagTags = nil
flagFields = nil
}()
f, err := os.Open("testdata/output.txt")
if err != nil {
t.Fatal(err)
}
makeTags := func(pkg, resultType string) map[string]string {
return map[string]string{"package": pkg, "type": resultType, "tag": "foo"}
}
makeFields := func(pass, elapsed float64, test string) map[string]interface{} {
m := map[string]interface{}{"pass": pass, "elapsed": elapsed, "field": true}
if test != "" {
m["test"] = test
}
return m
}
expected := []*client.Point{
newPoint(t,
"result",
makeTags("pkg/passed", "test"),
makeFields(1, 0, "TestPass"),
"2018-03-08T10:33:12.344165231-08:00",
),
newPoint(t,
"result",
makeTags("pkg/passed", "package"),
makeFields(1, 0.013, ""),
"2018-03-08T10:33:12.34533033-08:00",
),
newPoint(t,
"result",
makeTags("pkg/panic", "test"),
makeFields(0, 600.029, "TestPanic"),
"2018-03-08T10:33:20.272440286-08:00",
),
newPoint(t,
"result",
makeTags("pkg/failed", "test"),
makeFields(0, 0.18, "TestFail"),
"2018-03-08T10:33:27.158860934-08:00",
),
newPoint(t,
"result",
makeTags("pkg/failed", "package"),
makeFields(0, 0.204, ""),
"2018-03-08T10:33:27.161302093-08:00",
),
newPoint(t,
"result",
makeTags("pkg/panic", "package"),
makeFields(0, 0, ""),
"2018-03-08T10:33:20.273440286-08:00",
),
}
batch, err := client.NewBatchPoints(client.BatchPointsConfig{})
if err != nil {
t.Fatal(err)
}
if err := parseTestOutput(f, batch); err != nil {
t.Fatal(err)
}
if diff := test.Diff(expected, batch.Points()); diff != "" {
t.Errorf("unexpected diff: %s", diff)
}
}
func TestTagsFlag(t *testing.T) {
for tc, expected := range map[string]tags{
"abc=def": tags{tag{key: "abc", value: "def"}},
"abc=def,ghi=klm": tags{tag{key: "abc", value: "def"}, tag{key: "ghi", value: "klm"}},
} {
t.Run(tc, func(t *testing.T) {
var ts tags
ts.Set(tc)
if diff := test.Diff(expected, ts); diff != "" {
t.Errorf("unexpected diff from Set: %s", diff)
}
if s := ts.String(); s != tc {
t.Errorf("unexpected diff from String: %q vs. %q", tc, s)
}
})
}
}
func TestFieldsFlag(t *testing.T) {
for tc, expected := range map[string]fields{
"str=abc": fields{field{key: "str", value: "abc"}},
"bool=true": fields{field{key: "bool", value: true}},
"bool=false": fields{field{key: "bool", value: false}},
"float64=42": fields{field{key: "float64", value: float64(42)}},
"float64=42.123": fields{field{key: "float64", value: float64(42.123)}},
"int64=42i": fields{field{key: "int64", value: int64(42)}},
"str=abc,bool=true,float64=42,int64=42i": fields{field{key: "str", value: "abc"},
field{key: "bool", value: true},
field{key: "float64", value: float64(42)},
field{key: "int64", value: int64(42)}},
} {
t.Run(tc, func(t *testing.T) {
var fs fields
fs.Set(tc)
if diff := test.Diff(expected, fs); diff != "" {
t.Errorf("unexpected diff from Set: %s", diff)
}
if s := fs.String(); s != tc {
t.Errorf("unexpected diff from String: %q vs. %q", tc, s)
}
})
}
}

View File

@ -0,0 +1,15 @@
{"Time":"2018-03-08T10:33:12.002692769-08:00","Action":"output","Package":"pkg/skipped","Output":"? \tpkg/skipped\t[no test files]\n"}
{"Time":"2018-03-08T10:33:12.003199228-08:00","Action":"skip","Package":"pkg/skipped","Elapsed":0.001}
{"Time":"2018-03-08T10:33:12.343866281-08:00","Action":"run","Package":"pkg/passed","Test":"TestPass"}
{"Time":"2018-03-08T10:33:12.34406622-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"=== RUN TestPass\n"}
{"Time":"2018-03-08T10:33:12.344139342-08:00","Action":"output","Package":"pkg/passed","Test":"TestPass","Output":"--- PASS: TestPass (0.00s)\n"}
{"Time":"2018-03-08T10:33:12.344165231-08:00","Action":"pass","Package":"pkg/passed","Test":"TestPass","Elapsed":0}
{"Time":"2018-03-08T10:33:12.344297059-08:00","Action":"output","Package":"pkg/passed","Output":"PASS\n"}
{"Time":"2018-03-08T10:33:12.345217622-08:00","Action":"output","Package":"pkg/passed","Output":"ok \tpkg/passed\t0.013s\n"}
{"Time":"2018-03-08T10:33:12.34533033-08:00","Action":"pass","Package":"pkg/passed","Elapsed":0.013}
{"Time":"2018-03-08T10:33:20.27231537-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"panic\n"}
{"Time":"2018-03-08T10:33:20.272414481-08:00","Action":"output","Package":"pkg/panic","Test":"TestPanic","Output":"FAIL\tpkg/panic\t600.029s\n"}
{"Time":"2018-03-08T10:33:20.272440286-08:00","Action":"fail","Package":"pkg/panic","Test":"TestPanic","Elapsed":600.029}
{"Time":"2018-03-08T10:33:27.158776469-08:00","Action":"output","Package":"pkg/failed","Test":"TestFail","Output":"--- FAIL: TestFail (0.18s)\n"}
{"Time":"2018-03-08T10:33:27.158860934-08:00","Action":"fail","Package":"pkg/failed","Test":"TestFail","Elapsed":0.18}
{"Time":"2018-03-08T10:33:27.161302093-08:00","Action":"fail","Package":"pkg/failed","Elapsed":0.204}

View File

@ -8,7 +8,6 @@ package dscp
import ( import (
"fmt" "fmt"
"net" "net"
"reflect"
"time" "time"
) )
@ -20,8 +19,7 @@ func DialTCPWithTOS(laddr, raddr *net.TCPAddr, tos byte) (*net.TCPConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
value := reflect.ValueOf(conn) if err = setTOS(raddr.IP, conn, tos); err != nil {
if err = setTOS(raddr.IP, value, tos); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }
@ -54,7 +52,24 @@ func DialTimeoutWithTOS(network, address string, timeout time.Duration, tos byte
conn.Close() conn.Close()
return nil, fmt.Errorf("DialTimeoutWithTOS: cannot set TOS on a %s socket", network) return nil, fmt.Errorf("DialTimeoutWithTOS: cannot set TOS on a %s socket", network)
} }
if err = setTOS(ip, reflect.ValueOf(conn), tos); err != nil { if err = setTOS(ip, conn, tos); err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// DialTCPTimeoutWithTOS is same as DialTimeoutWithTOS except for enforcing "tcp" and
// providing an option to specify local address (source)
func DialTCPTimeoutWithTOS(laddr, raddr *net.TCPAddr, tos byte, timeout time.Duration) (net.Conn,
error) {
d := net.Dialer{Timeout: timeout, LocalAddr: laddr}
conn, err := d.Dial("tcp", raddr.String())
if err != nil {
return nil, err
}
if err = setTOS(raddr.IP, conn, tos); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }

View File

@ -5,8 +5,11 @@
package dscp_test package dscp_test
import ( import (
"fmt"
"net" "net"
"strings"
"testing" "testing"
"time"
"github.com/aristanetworks/goarista/dscp" "github.com/aristanetworks/goarista/dscp"
) )
@ -51,3 +54,50 @@ func TestDialTCPWithTOS(t *testing.T) {
conn.Write(buf) conn.Write(buf)
<-done <-done
} }
func TestDialTCPTimeoutWithTOS(t *testing.T) {
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
for name, td := range map[string]*net.TCPAddr{
"ipNoPort": &net.TCPAddr{
IP: net.ParseIP("127.0.0.42"), Port: 0,
},
"ipWithPort": &net.TCPAddr{
IP: net.ParseIP("127.0.0.42"), Port: 10001,
},
} {
t.Run(name, func(t *testing.T) {
l, err := net.ListenTCP("tcp", raddr)
if err != nil {
t.Fatal(err)
}
defer l.Close()
var srcAddr net.Addr
done := make(chan struct{})
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
srcAddr = conn.RemoteAddr()
close(done)
}()
conn, err := dscp.DialTCPTimeoutWithTOS(td, l.Addr().(*net.TCPAddr), 40, 5*time.Second)
if err != nil {
t.Fatal("Connection failed:", err)
}
defer conn.Close()
pfx := td.IP.String() + ":"
if td.Port > 0 {
pfx = fmt.Sprintf("%s%d", pfx, td.Port)
}
<-done
if !strings.HasPrefix(srcAddr.String(), pfx) {
t.Fatalf("DialTCPTimeoutWithTOS wrong address: %q instead of %q", srcAddr, pfx)
}
})
}
}

View File

@ -7,7 +7,6 @@ package dscp
import ( import (
"net" "net"
"reflect"
) )
// ListenTCPWithTOS is similar to net.ListenTCP but with the socket configured // ListenTCPWithTOS is similar to net.ListenTCP but with the socket configured
@ -18,8 +17,7 @@ func ListenTCPWithTOS(address *net.TCPAddr, tos byte) (*net.TCPListener, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
value := reflect.ValueOf(lsnr) if err = setTOS(address.IP, lsnr, tos); err != nil {
if err = setTOS(address.IP, value, tos); err != nil {
lsnr.Close() lsnr.Close()
return nil, err return nil, err
} }

View File

@ -5,19 +5,18 @@
package dscp package dscp
import ( import (
"fmt"
"net" "net"
"os" "os"
"reflect" "reflect"
"syscall"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// This works for the UNIX implementation of netFD, i.e. not on Windows and Plan9. // This works for the UNIX implementation of netFD, i.e. not on Windows and Plan9.
// This kludge is needed until https://github.com/golang/go/issues/9661 is fixed. // conn must either implement syscall.Conn or be a TCPListener.
// value can be the reflection of a connection or a dialer. func setTOS(ip net.IP, conn interface{}, tos byte) error {
func setTOS(ip net.IP, value reflect.Value, tos byte) error {
netFD := value.Elem().FieldByName("fd").Elem()
fd := int(netFD.FieldByName("pfd").FieldByName("Sysfd").Int())
var proto, optname int var proto, optname int
if ip.To4() != nil { if ip.To4() != nil {
proto = unix.IPPROTO_IP proto = unix.IPPROTO_IP
@ -26,8 +25,42 @@ func setTOS(ip net.IP, value reflect.Value, tos byte) error {
proto = unix.IPPROTO_IPV6 proto = unix.IPPROTO_IPV6
optname = unix.IPV6_TCLASS optname = unix.IPV6_TCLASS
} }
switch c := conn.(type) {
case syscall.Conn:
return setTOSWithSyscallConn(proto, optname, c, tos)
case *net.TCPListener:
// This code is needed to support go1.9. In go1.10
// *net.TCPListener implements syscall.Conn.
return setTOSWithTCPListener(proto, optname, c, tos)
}
return fmt.Errorf("unsupported connection type: %T", conn)
}
func setTOSWithTCPListener(proto, optname int, conn *net.TCPListener, tos byte) error {
// A kludge for pre-go1.10 to get the fd of a net.TCPListener
value := reflect.ValueOf(conn)
netFD := value.Elem().FieldByName("fd").Elem()
fd := int(netFD.FieldByName("pfd").FieldByName("Sysfd").Int())
if err := unix.SetsockoptInt(fd, proto, optname, int(tos)); err != nil { if err := unix.SetsockoptInt(fd, proto, optname, int(tos)); err != nil {
return os.NewSyscallError("setsockopt", err) return os.NewSyscallError("setsockopt", err)
} }
return nil return nil
} }
func setTOSWithSyscallConn(proto, optname int, conn syscall.Conn, tos byte) error {
syscallConn, err := conn.SyscallConn()
if err != nil {
return err
}
var setsockoptErr error
err = syscallConn.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), proto, optname, int(tos)); err != nil {
setsockoptErr = os.NewSyscallError("setsockopt", err)
}
})
if setsockoptErr != nil {
return setsockoptErr
}
return err
}

View File

@ -8,10 +8,15 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt"
"math"
"net"
"time"
"io/ioutil" "io/ioutil"
"strings" "strings"
"github.com/aristanetworks/glog" "github.com/aristanetworks/goarista/netns"
pb "github.com/openconfig/gnmi/proto/gnmi" pb "github.com/openconfig/gnmi/proto/gnmi"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -19,7 +24,7 @@ import (
) )
const ( const (
defaultPort = "6042" defaultPort = "6030"
) )
// Config is the gnmi.Client config // Config is the gnmi.Client config
@ -33,19 +38,30 @@ type Config struct {
TLS bool TLS bool
} }
// SubscribeOptions is the gNMI subscription request options
type SubscribeOptions struct {
UpdatesOnly bool
Prefix string
Mode string
StreamMode string
SampleInterval uint64
HeartbeatInterval uint64
Paths [][]string
}
// Dial connects to a gnmi service and returns a client // Dial connects to a gnmi service and returns a client
func Dial(cfg *Config) pb.GNMIClient { func Dial(cfg *Config) (pb.GNMIClient, error) {
var opts []grpc.DialOption var opts []grpc.DialOption
if cfg.TLS || cfg.CAFile != "" || cfg.CertFile != "" { if cfg.TLS || cfg.CAFile != "" || cfg.CertFile != "" {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
if cfg.CAFile != "" { if cfg.CAFile != "" {
b, err := ioutil.ReadFile(cfg.CAFile) b, err := ioutil.ReadFile(cfg.CAFile)
if err != nil { if err != nil {
glog.Fatal(err) return nil, err
} }
cp := x509.NewCertPool() cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) { if !cp.AppendCertsFromPEM(b) {
glog.Fatalf("credentials: failed to append certificates") return nil, fmt.Errorf("credentials: failed to append certificates")
} }
tlsConfig.RootCAs = cp tlsConfig.RootCAs = cp
} else { } else {
@ -53,11 +69,11 @@ func Dial(cfg *Config) pb.GNMIClient {
} }
if cfg.CertFile != "" { if cfg.CertFile != "" {
if cfg.KeyFile == "" { if cfg.KeyFile == "" {
glog.Fatalf("Please provide both -certfile and -keyfile") return nil, fmt.Errorf("please provide both -certfile and -keyfile")
} }
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil { if err != nil {
glog.Fatal(err) return nil, err
} }
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.Certificates = []tls.Certificate{cert}
} }
@ -69,12 +85,33 @@ func Dial(cfg *Config) pb.GNMIClient {
if !strings.ContainsRune(cfg.Addr, ':') { if !strings.ContainsRune(cfg.Addr, ':') {
cfg.Addr += ":" + defaultPort cfg.Addr += ":" + defaultPort
} }
conn, err := grpc.Dial(cfg.Addr, opts...)
dial := func(addrIn string, time time.Duration) (net.Conn, error) {
var conn net.Conn
nsName, addr, err := netns.ParseAddress(addrIn)
if err != nil { if err != nil {
glog.Fatalf("Failed to dial: %s", err) return nil, err
} }
return pb.NewGNMIClient(conn) err = netns.Do(nsName, func() error {
var err error
conn, err = net.Dial("tcp", addr)
return err
})
return conn, err
}
opts = append(opts,
grpc.WithDialer(dial),
// Allows received protobuf messages to be larger than 4MB
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
)
grpcconn, err := grpc.Dial(cfg.Addr, opts...)
if err != nil {
return nil, fmt.Errorf("failed to dial: %s", err)
}
return pb.NewGNMIClient(grpcconn), nil
} }
// NewContext returns a new context with username and password // NewContext returns a new context with username and password
@ -104,17 +141,53 @@ func NewGetRequest(paths [][]string) (*pb.GetRequest, error) {
} }
// NewSubscribeRequest returns a SubscribeRequest for the given paths // NewSubscribeRequest returns a SubscribeRequest for the given paths
func NewSubscribeRequest(paths [][]string) (*pb.SubscribeRequest, error) { func NewSubscribeRequest(subscribeOptions *SubscribeOptions) (*pb.SubscribeRequest, error) {
subList := &pb.SubscriptionList{ var mode pb.SubscriptionList_Mode
Subscription: make([]*pb.Subscription, len(paths)), switch subscribeOptions.Mode {
case "once":
mode = pb.SubscriptionList_ONCE
case "poll":
mode = pb.SubscriptionList_POLL
case "stream":
mode = pb.SubscriptionList_STREAM
default:
return nil, fmt.Errorf("subscribe mode (%s) invalid", subscribeOptions.Mode)
} }
for i, p := range paths {
var streamMode pb.SubscriptionMode
switch subscribeOptions.StreamMode {
case "on_change":
streamMode = pb.SubscriptionMode_ON_CHANGE
case "sample":
streamMode = pb.SubscriptionMode_SAMPLE
case "target_defined":
streamMode = pb.SubscriptionMode_TARGET_DEFINED
default:
return nil, fmt.Errorf("subscribe stream mode (%s) invalid", subscribeOptions.StreamMode)
}
prefixPath, err := ParseGNMIElements(SplitPath(subscribeOptions.Prefix))
if err != nil {
return nil, err
}
subList := &pb.SubscriptionList{
Subscription: make([]*pb.Subscription, len(subscribeOptions.Paths)),
Mode: mode,
UpdatesOnly: subscribeOptions.UpdatesOnly,
Prefix: prefixPath,
}
for i, p := range subscribeOptions.Paths {
gnmiPath, err := ParseGNMIElements(p) gnmiPath, err := ParseGNMIElements(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
subList.Subscription[i] = &pb.Subscription{Path: gnmiPath} subList.Subscription[i] = &pb.Subscription{
Path: gnmiPath,
Mode: streamMode,
SampleInterval: subscribeOptions.SampleInterval,
HeartbeatInterval: subscribeOptions.HeartbeatInterval,
} }
return &pb.SubscribeRequest{ }
Request: &pb.SubscribeRequest_Subscribe{Subscribe: subList}}, nil return &pb.SubscribeRequest{Request: &pb.SubscribeRequest_Subscribe{
Subscribe: subList}}, nil
} }

View File

@ -17,7 +17,7 @@ func NotificationToMap(notif *gnmi.Notification) (map[string]interface{}, error)
updates := make(map[string]interface{}, len(notif.Update)) updates := make(map[string]interface{}, len(notif.Update))
var err error var err error
for _, update := range notif.Update { for _, update := range notif.Update {
updates[StrPath(update.Path)] = strUpdateVal(update) updates[StrPath(update.Path)] = StrUpdateVal(update)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -5,14 +5,22 @@
package gnmi package gnmi
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os"
"path" "path"
"strconv"
"strings"
"time"
"github.com/aristanetworks/glog"
pb "github.com/openconfig/gnmi/proto/gnmi" pb "github.com/openconfig/gnmi/proto/gnmi"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
) )
@ -28,9 +36,10 @@ func Get(ctx context.Context, client pb.GNMIClient, paths [][]string) error {
return err return err
} }
for _, notif := range resp.Notification { for _, notif := range resp.Notification {
prefix := StrPath(notif.Prefix)
for _, update := range notif.Update { for _, update := range notif.Update {
fmt.Printf("%s:\n", StrPath(update.Path)) fmt.Printf("%s:\n", path.Join(prefix, StrPath(update.Path)))
fmt.Println(strUpdateVal(update)) fmt.Println(StrUpdateVal(update))
} }
} }
return nil return nil
@ -55,49 +64,104 @@ func Capabilities(ctx context.Context, client pb.GNMIClient) error {
// val may be a path to a file or it may be json. First see if it is a // val may be a path to a file or it may be json. First see if it is a
// file, if so return its contents, otherwise return val // file, if so return its contents, otherwise return val
func extractJSON(val string) []byte { func extractJSON(val string) []byte {
jsonBytes, err := ioutil.ReadFile(val) if jsonBytes, err := ioutil.ReadFile(val); err == nil {
if err != nil {
jsonBytes = []byte(val)
}
return jsonBytes return jsonBytes
}
// strUpdateVal will return a string representing the value within the supplied update
func strUpdateVal(u *pb.Update) string {
if u.Value != nil {
return string(u.Value.Value) // Backwards compatibility with pre-v0.4 gnmi
} }
return strVal(u.Val) // Best effort check if the value might a string literal, in which
// case wrap it in quotes. This is to allow a user to do:
// gnmi update ../hostname host1234
// gnmi update ../description 'This is a description'
// instead of forcing them to quote the string:
// gnmi update ../hostname '"host1234"'
// gnmi update ../description '"This is a description"'
maybeUnquotedStringLiteral := func(s string) bool {
if s == "true" || s == "false" || s == "null" || // JSON reserved words
strings.ContainsAny(s, `"'{}[]`) { // Already quoted or is a JSON object or array
return false
} else if _, err := strconv.ParseInt(s, 0, 32); err == nil {
// Integer. Using byte size of 32 because larger integer
// types are supposed to be sent as strings in JSON.
return false
} else if _, err := strconv.ParseFloat(s, 64); err == nil {
// Float
return false
}
return true
}
if maybeUnquotedStringLiteral(val) {
out := make([]byte, len(val)+2)
out[0] = '"'
copy(out[1:], val)
out[len(out)-1] = '"'
return out
}
return []byte(val)
} }
// strVal will return a string representing the supplied value // StrUpdateVal will return a string representing the value within the supplied update
func strVal(val *pb.TypedValue) string { func StrUpdateVal(u *pb.Update) string {
if u.Value != nil {
// Backwards compatibility with pre-v0.4 gnmi
switch u.Value.Type {
case pb.Encoding_JSON, pb.Encoding_JSON_IETF:
return strJSON(u.Value.Value)
case pb.Encoding_BYTES, pb.Encoding_PROTO:
return base64.StdEncoding.EncodeToString(u.Value.Value)
case pb.Encoding_ASCII:
return string(u.Value.Value)
default:
return string(u.Value.Value)
}
}
return StrVal(u.Val)
}
// StrVal will return a string representing the supplied value
func StrVal(val *pb.TypedValue) string {
switch v := val.GetValue().(type) { switch v := val.GetValue().(type) {
case *pb.TypedValue_StringVal: case *pb.TypedValue_StringVal:
return v.StringVal return v.StringVal
case *pb.TypedValue_JsonIetfVal: case *pb.TypedValue_JsonIetfVal:
return string(v.JsonIetfVal) return strJSON(v.JsonIetfVal)
case *pb.TypedValue_JsonVal:
return strJSON(v.JsonVal)
case *pb.TypedValue_IntVal: case *pb.TypedValue_IntVal:
return fmt.Sprintf("%v", v.IntVal) return strconv.FormatInt(v.IntVal, 10)
case *pb.TypedValue_UintVal: case *pb.TypedValue_UintVal:
return fmt.Sprintf("%v", v.UintVal) return strconv.FormatUint(v.UintVal, 10)
case *pb.TypedValue_BoolVal: case *pb.TypedValue_BoolVal:
return fmt.Sprintf("%v", v.BoolVal) return strconv.FormatBool(v.BoolVal)
case *pb.TypedValue_BytesVal: case *pb.TypedValue_BytesVal:
return string(v.BytesVal) return base64.StdEncoding.EncodeToString(v.BytesVal)
case *pb.TypedValue_DecimalVal: case *pb.TypedValue_DecimalVal:
return strDecimal64(v.DecimalVal) return strDecimal64(v.DecimalVal)
case *pb.TypedValue_FloatVal:
return strconv.FormatFloat(float64(v.FloatVal), 'g', -1, 32)
case *pb.TypedValue_LeaflistVal: case *pb.TypedValue_LeaflistVal:
return strLeaflist(v.LeaflistVal) return strLeaflist(v.LeaflistVal)
case *pb.TypedValue_AsciiVal:
return v.AsciiVal
case *pb.TypedValue_AnyVal:
return v.AnyVal.String()
default: default:
panic(v) panic(v)
} }
} }
func strJSON(inJSON []byte) string {
var out bytes.Buffer
err := json.Indent(&out, inJSON, "", " ")
if err != nil {
return fmt.Sprintf("(error unmarshalling json: %s)\n", err) + string(inJSON)
}
return out.String()
}
func strDecimal64(d *pb.Decimal64) string { func strDecimal64(d *pb.Decimal64) string {
var i, frac uint64 var i, frac int64
if d.Precision > 0 { if d.Precision > 0 {
div := uint64(10) div := int64(10)
it := d.Precision - 1 it := d.Precision - 1
for it > 0 { for it > 0 {
div *= 10 div *= 10
@ -108,32 +172,25 @@ func strDecimal64(d *pb.Decimal64) string {
} else { } else {
i = d.Digits i = d.Digits
} }
if frac < 0 {
frac = -frac
}
return fmt.Sprintf("%d.%d", i, frac) return fmt.Sprintf("%d.%d", i, frac)
} }
// strLeafList builds a human-readable form of a leaf-list. e.g. [1,2,3] or [a,b,c] // strLeafList builds a human-readable form of a leaf-list. e.g. [1, 2, 3] or [a, b, c]
func strLeaflist(v *pb.ScalarArray) string { func strLeaflist(v *pb.ScalarArray) string {
s := make([]string, 0, len(v.Element)) var buf bytes.Buffer
sz := 2 // [] buf.WriteByte('[')
// convert arbitrary TypedValues to string form for i, elm := range v.Element {
for _, elm := range v.Element { buf.WriteString(StrVal(elm))
str := strVal(elm)
s = append(s, str)
sz += len(str) + 1 // %v + ,
}
b := make([]byte, sz)
buf := bytes.NewBuffer(b)
buf.WriteRune('[')
for i := range v.Element {
buf.WriteString(s[i])
if i < len(v.Element)-1 { if i < len(v.Element)-1 {
buf.WriteRune(',') buf.WriteString(", ")
} }
} }
buf.WriteRune(']')
buf.WriteByte(']')
return buf.String() return buf.String()
} }
@ -143,9 +200,16 @@ func update(p *pb.Path, val string) *pb.Update {
case "": case "":
v = &pb.TypedValue{ v = &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: extractJSON(val)}} Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: extractJSON(val)}}
case "cli": case "cli", "test-regen-cli":
v = &pb.TypedValue{ v = &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: val}} Value: &pb.TypedValue_AsciiVal{AsciiVal: val}}
case "p4_config":
b, err := ioutil.ReadFile(val)
if err != nil {
glog.Fatalf("Cannot read p4 file: %s", err)
}
v = &pb.TypedValue{
Value: &pb.TypedValue_ProtoBytes{ProtoBytes: b}}
default: default:
panic(fmt.Errorf("unexpected origin: %q", p.Origin)) panic(fmt.Errorf("unexpected origin: %q", p.Origin))
} }
@ -156,6 +220,7 @@ func update(p *pb.Path, val string) *pb.Update {
// Operation describes an gNMI operation. // Operation describes an gNMI operation.
type Operation struct { type Operation struct {
Type string Type string
Origin string
Path []string Path []string
Val string Val string
} }
@ -167,6 +232,7 @@ func newSetRequest(setOps []*Operation) (*pb.SetRequest, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.Origin = op.Origin
switch op.Type { switch op.Type {
case "delete": case "delete":
@ -199,16 +265,18 @@ func Set(ctx context.Context, client pb.GNMIClient, setOps []*Operation) error {
} }
// Subscribe sends a SubscribeRequest to the given client. // Subscribe sends a SubscribeRequest to the given client.
func Subscribe(ctx context.Context, client pb.GNMIClient, paths [][]string, func Subscribe(ctx context.Context, client pb.GNMIClient, subscribeOptions *SubscribeOptions,
respChan chan<- *pb.SubscribeResponse, errChan chan<- error) { respChan chan<- *pb.SubscribeResponse, errChan chan<- error) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
defer close(respChan)
stream, err := client.Subscribe(ctx) stream, err := client.Subscribe(ctx)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
req, err := NewSubscribeRequest(paths) req, err := NewSubscribeRequest(subscribeOptions)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
@ -228,6 +296,26 @@ func Subscribe(ctx context.Context, client pb.GNMIClient, paths [][]string,
return return
} }
respChan <- resp respChan <- resp
// For POLL subscriptions, initiate a poll request by pressing ENTER
if subscribeOptions.Mode == "poll" {
switch resp.Response.(type) {
case *pb.SubscribeResponse_SyncResponse:
fmt.Print("Press ENTER to send a poll request: ")
reader := bufio.NewReader(os.Stdin)
reader.ReadString('\n')
pollReq := &pb.SubscribeRequest{
Request: &pb.SubscribeRequest_Poll{
Poll: &pb.Poll{},
},
}
if err := stream.Send(pollReq); err != nil {
errChan <- err
return
}
}
}
} }
} }
@ -241,10 +329,16 @@ func LogSubscribeResponse(response *pb.SubscribeResponse) error {
return errors.New("initial sync failed") return errors.New("initial sync failed")
} }
case *pb.SubscribeResponse_Update: case *pb.SubscribeResponse_Update:
t := time.Unix(0, resp.Update.Timestamp).UTC()
prefix := StrPath(resp.Update.Prefix) prefix := StrPath(resp.Update.Prefix)
for _, update := range resp.Update.Update { for _, update := range resp.Update.Update {
fmt.Printf("%s = %s\n", path.Join(prefix, StrPath(update.Path)), fmt.Printf("[%s] %s = %s\n", t.Format(time.RFC3339Nano),
strUpdateVal(update)) path.Join(prefix, StrPath(update.Path)),
StrUpdateVal(update))
}
for _, del := range resp.Update.Delete {
fmt.Printf("[%s] Deleted %s\n", t.Format(time.RFC3339Nano),
path.Join(prefix, StrPath(del)))
} }
} }
return nil return nil

View File

@ -5,9 +5,14 @@
package gnmi package gnmi
import ( import (
"bytes"
"io/ioutil"
"os"
"testing" "testing"
"github.com/aristanetworks/goarista/test" "github.com/aristanetworks/goarista/test"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/any"
pb "github.com/openconfig/gnmi/proto/gnmi" pb "github.com/openconfig/gnmi/proto/gnmi"
) )
@ -15,24 +20,41 @@ import (
func TestNewSetRequest(t *testing.T) { func TestNewSetRequest(t *testing.T) {
pathFoo := &pb.Path{ pathFoo := &pb.Path{
Element: []string{"foo"}, Element: []string{"foo"},
Elem: []*pb.PathElem{&pb.PathElem{Name: "foo"}}, Elem: []*pb.PathElem{{Name: "foo"}},
} }
pathCli := &pb.Path{ pathCli := &pb.Path{
Origin: "cli", Origin: "cli",
} }
pathP4 := &pb.Path{
Origin: "p4_config",
}
p4FileContent := "p4_config test"
p4TestFile, err := ioutil.TempFile("", "p4TestFile")
if err != nil {
t.Errorf("cannot create test file for p4_config")
}
p4Filename := p4TestFile.Name()
defer os.Remove(p4Filename)
if _, err := p4TestFile.WriteString(p4FileContent); err != nil {
t.Errorf("cannot write test file for p4_config")
}
p4TestFile.Close()
testCases := map[string]struct { testCases := map[string]struct {
setOps []*Operation setOps []*Operation
exp pb.SetRequest exp pb.SetRequest
}{ }{
"delete": { "delete": {
setOps: []*Operation{&Operation{Type: "delete", Path: []string{"foo"}}}, setOps: []*Operation{{Type: "delete", Path: []string{"foo"}}},
exp: pb.SetRequest{Delete: []*pb.Path{pathFoo}}, exp: pb.SetRequest{Delete: []*pb.Path{pathFoo}},
}, },
"update": { "update": {
setOps: []*Operation{&Operation{Type: "update", Path: []string{"foo"}, Val: "true"}}, setOps: []*Operation{{Type: "update", Path: []string{"foo"}, Val: "true"}},
exp: pb.SetRequest{ exp: pb.SetRequest{
Update: []*pb.Update{&pb.Update{ Update: []*pb.Update{{
Path: pathFoo, Path: pathFoo,
Val: &pb.TypedValue{ Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}}, Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}},
@ -40,9 +62,9 @@ func TestNewSetRequest(t *testing.T) {
}, },
}, },
"replace": { "replace": {
setOps: []*Operation{&Operation{Type: "replace", Path: []string{"foo"}, Val: "true"}}, setOps: []*Operation{{Type: "replace", Path: []string{"foo"}, Val: "true"}},
exp: pb.SetRequest{ exp: pb.SetRequest{
Replace: []*pb.Update{&pb.Update{ Replace: []*pb.Update{{
Path: pathFoo, Path: pathFoo,
Val: &pb.TypedValue{ Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}}, Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte("true")}},
@ -50,16 +72,27 @@ func TestNewSetRequest(t *testing.T) {
}, },
}, },
"cli-replace": { "cli-replace": {
setOps: []*Operation{&Operation{Type: "replace", Path: []string{"cli"}, setOps: []*Operation{{Type: "replace", Origin: "cli",
Val: "hostname foo\nip routing"}}, Val: "hostname foo\nip routing"}},
exp: pb.SetRequest{ exp: pb.SetRequest{
Replace: []*pb.Update{&pb.Update{ Replace: []*pb.Update{{
Path: pathCli, Path: pathCli,
Val: &pb.TypedValue{ Val: &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: "hostname foo\nip routing"}}, Value: &pb.TypedValue_AsciiVal{AsciiVal: "hostname foo\nip routing"}},
}}, }},
}, },
}, },
"p4_config": {
setOps: []*Operation{{Type: "replace", Origin: "p4_config",
Val: p4Filename}},
exp: pb.SetRequest{
Replace: []*pb.Update{{
Path: pathP4,
Val: &pb.TypedValue{
Value: &pb.TypedValue_ProtoBytes{ProtoBytes: []byte(p4FileContent)}},
}},
},
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -74,3 +107,183 @@ func TestNewSetRequest(t *testing.T) {
}) })
} }
} }
func TestStrUpdateVal(t *testing.T) {
anyBytes, err := proto.Marshal(&pb.ModelData{Name: "foobar"})
if err != nil {
t.Fatal(err)
}
anyMessage := &any.Any{TypeUrl: "gnmi/ModelData", Value: anyBytes}
anyString := proto.CompactTextString(anyMessage)
for name, tc := range map[string]struct {
update *pb.Update
exp string
}{
"JSON Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte(`{"foo":"bar"}`),
Type: pb.Encoding_JSON}},
exp: `{
"foo": "bar"
}`,
},
"JSON_IETF Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte(`{"foo":"bar"}`),
Type: pb.Encoding_JSON_IETF}},
exp: `{
"foo": "bar"
}`,
},
"BYTES Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte{0xde, 0xad},
Type: pb.Encoding_BYTES}},
exp: "3q0=",
},
"PROTO Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte{0xde, 0xad},
Type: pb.Encoding_PROTO}},
exp: "3q0=",
},
"ASCII Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte("foobar"),
Type: pb.Encoding_ASCII}},
exp: "foobar",
},
"INVALID Value": {
update: &pb.Update{
Value: &pb.Value{
Value: []byte("foobar"),
Type: pb.Encoding(42)}},
exp: "foobar",
},
"StringVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_StringVal{StringVal: "foobar"}}},
exp: "foobar",
},
"IntVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_IntVal{IntVal: -42}}},
exp: "-42",
},
"UintVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_UintVal{UintVal: 42}}},
exp: "42",
},
"BoolVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_BoolVal{BoolVal: true}}},
exp: "true",
},
"BytesVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_BytesVal{BytesVal: []byte{0xde, 0xad}}}},
exp: "3q0=",
},
"FloatVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_FloatVal{FloatVal: 3.14}}},
exp: "3.14",
},
"DecimalVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_DecimalVal{
DecimalVal: &pb.Decimal64{Digits: 314, Precision: 2},
}}},
exp: "3.14",
},
"LeafListVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_LeaflistVal{
LeaflistVal: &pb.ScalarArray{Element: []*pb.TypedValue{
&pb.TypedValue{Value: &pb.TypedValue_BoolVal{BoolVal: true}},
&pb.TypedValue{Value: &pb.TypedValue_AsciiVal{AsciiVal: "foobar"}},
}},
}}},
exp: "[true, foobar]",
},
"AnyVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_AnyVal{AnyVal: anyMessage}}},
exp: anyString,
},
"JsonVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonVal{JsonVal: []byte(`{"foo":"bar"}`)}}},
exp: `{
"foo": "bar"
}`,
},
"JsonIetfVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_JsonIetfVal{JsonIetfVal: []byte(`{"foo":"bar"}`)}}},
exp: `{
"foo": "bar"
}`,
},
"AsciiVal": {
update: &pb.Update{Val: &pb.TypedValue{
Value: &pb.TypedValue_AsciiVal{AsciiVal: "foobar"}}},
exp: "foobar",
},
} {
t.Run(name, func(t *testing.T) {
got := StrUpdateVal(tc.update)
if got != tc.exp {
t.Errorf("Expected: %q Got: %q", tc.exp, got)
}
})
}
}
func TestExtractJSON(t *testing.T) {
jsonFile, err := ioutil.TempFile("", "extractJSON")
if err != nil {
t.Fatal(err)
}
defer os.Remove(jsonFile.Name())
if _, err := jsonFile.Write([]byte(`"jsonFile"`)); err != nil {
jsonFile.Close()
t.Fatal(err)
}
if err := jsonFile.Close(); err != nil {
t.Fatal(err)
}
for val, exp := range map[string][]byte{
jsonFile.Name(): []byte(`"jsonFile"`),
"foobar": []byte(`"foobar"`),
`"foobar"`: []byte(`"foobar"`),
"Val: true": []byte(`"Val: true"`),
"host42": []byte(`"host42"`),
"42": []byte("42"),
"-123.43": []byte("-123.43"),
"0xFFFF": []byte("0xFFFF"),
// Int larger than can fit in 32 bits should be quoted
"0x8000000000": []byte(`"0x8000000000"`),
"-0x8000000000": []byte(`"-0x8000000000"`),
"true": []byte("true"),
"false": []byte("false"),
"null": []byte("null"),
"{true: 42}": []byte("{true: 42}"),
"[]": []byte("[]"),
} {
t.Run(val, func(t *testing.T) {
got := extractJSON(val)
if !bytes.Equal(exp, got) {
t.Errorf("Unexpected diff. Expected: %q Got: %q", exp, got)
}
})
}
}

View File

@ -126,11 +126,6 @@ func writeSafeString(buf *bytes.Buffer, s string, esc rune) {
// ParseGNMIElements builds up a gnmi path, from user-supplied text // ParseGNMIElements builds up a gnmi path, from user-supplied text
func ParseGNMIElements(elms []string) (*pb.Path, error) { func ParseGNMIElements(elms []string) (*pb.Path, error) {
if len(elms) == 1 && elms[0] == "cli" {
return &pb.Path{
Origin: "cli",
}, nil
}
var parsed []*pb.PathElem var parsed []*pb.PathElem
for _, e := range elms { for _, e := range elms {
n, keys, err := parseElement(e) n, keys, err := parseElement(e)

View File

@ -85,19 +85,6 @@ func TestStrPath(t *testing.T) {
} }
} }
func TestOriginCLIPath(t *testing.T) {
path := "cli"
sElms := SplitPath(path)
pbPath, err := ParseGNMIElements(sElms)
if err != nil {
t.Fatal(err)
}
expected := pb.Path{Origin: "cli"}
if !test.DeepEqual(expected, *pbPath) {
t.Errorf("want %v, got %v", expected, *pbPath)
}
}
func TestStrPathBackwardsCompat(t *testing.T) { func TestStrPathBackwardsCompat(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
path *pb.Path path *pb.Path

View File

@ -0,0 +1,20 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
//Connection type.
const (
HTTP = "HTTP"
UDP = "UDP"
)
//InfluxConfig is a configuration struct for influxlib.
type InfluxConfig struct {
Hostname string
Port uint16
Protocol string
Database string
RetentionPolicy string
}

View File

@ -0,0 +1,37 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
/*
Package: influxlib
Title: Influx DB Library
Authors: ssdaily, manojm321, senkrish, kthommandra
Email: influxdb-dev@arista.com
Description: The main purpose of influxlib is to provide users with a simple
and easy interface through which to connect to influxdb. It removed a lot of
the need to run the same setup and tear down code to connect the the service.
Example Code:
connection, err := influxlib.Connect(&influxlib.InfluxConfig {
Hostname: conf.Host,
Port: conf.Port,
Protocol: influxlib.UDP,
Database, conf.AlertDB,
})
tags := map[string]string {
"tag1": someStruct.Tag["host"],
"tag2": someStruct.Tag["tag2"],
}
fields := map[string]interface{} {
"field1": someStruct.Somefield,
"field2": someStruct.Somefield2,
}
connection.WritePoint("measurement", tags, fields)
*/
package influxlib

View File

@ -0,0 +1,173 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"errors"
"fmt"
"time"
influxdb "github.com/influxdata/influxdb/client/v2"
)
// Row is defined as a map where the key (string) is the name of the
// column (field name) and the value is left as an interface to
// accept any value.
type Row map[string]interface{}
// InfluxDBConnection is an object that the wrapper uses.
// Holds a client of the type v2.Client and the configuration
type InfluxDBConnection struct {
Client influxdb.Client
Config *InfluxConfig
}
// Point represents a datapoint to be written.
// Measurement:
// The measurement to write to
// Tags:
// A dictionary of tags in the form string=string
// Fields:
// A dictionary of fields(keys) with their associated values
type Point struct {
Measurement string
Tags map[string]string
Fields map[string]interface{}
Timestamp time.Time
}
// Connect takes an InfluxConfig and establishes a connection
// to InfluxDB. It returns an InfluxDBConnection structure.
// InfluxConfig may be nil for a default connection.
func Connect(config *InfluxConfig) (*InfluxDBConnection, error) {
var con influxdb.Client
var err error
switch config.Protocol {
case HTTP:
addr := fmt.Sprintf("http://%s:%v", config.Hostname, config.Port)
con, err = influxdb.NewHTTPClient(influxdb.HTTPConfig{
Addr: addr,
Timeout: 1 * time.Second,
})
case UDP:
addr := fmt.Sprintf("%s:%v", config.Hostname, config.Port)
con, err = influxdb.NewUDPClient(influxdb.UDPConfig{
Addr: addr,
})
default:
return nil, errors.New("Invalid Protocol")
}
if err != nil {
return nil, err
}
return &InfluxDBConnection{Client: con, Config: config}, nil
}
// RecordBatchPoints takes in a slice of influxlib.Point and writes them to the
// database.
func (conn *InfluxDBConnection) RecordBatchPoints(points []Point) error {
var err error
bp, err := influxdb.NewBatchPoints(influxdb.BatchPointsConfig{
Database: conn.Config.Database,
Precision: "ns",
RetentionPolicy: conn.Config.RetentionPolicy,
})
if err != nil {
return err
}
var influxPoints []*influxdb.Point
for _, p := range points {
if p.Timestamp.IsZero() {
p.Timestamp = time.Now()
}
point, err := influxdb.NewPoint(p.Measurement, p.Tags, p.Fields,
p.Timestamp)
if err != nil {
return err
}
influxPoints = append(influxPoints, point)
}
bp.AddPoints(influxPoints)
if err = conn.Client.Write(bp); err != nil {
return err
}
return nil
}
// WritePoint stores a datapoint to the database.
// Measurement:
// The measurement to write to
// Tags:
// A dictionary of tags in the form string=string
// Fields:
// A dictionary of fields(keys) with their associated values
func (conn *InfluxDBConnection) WritePoint(measurement string,
tags map[string]string, fields map[string]interface{}) error {
return conn.RecordPoint(Point{
Measurement: measurement,
Tags: tags,
Fields: fields,
Timestamp: time.Now(),
})
}
// RecordPoint implements the same as WritePoint but used a point struct
// as the argument instead.
func (conn *InfluxDBConnection) RecordPoint(p Point) error {
return conn.RecordBatchPoints([]Point{p})
}
// Query sends a query to the influxCli and returns a slice of
// rows. Rows are of type map[string]interface{}
func (conn *InfluxDBConnection) Query(query string) ([]Row, error) {
q := influxdb.NewQuery(query, conn.Config.Database, "ns")
var rows []Row
var index = 0
response, err := conn.Client.Query(q)
if err != nil {
return nil, err
}
if response.Error() != nil {
return nil, response.Error()
}
// The intent here is to combine the separate client v2
// series into a single array. As a result queries that
// utilize "group by" will be combined into a single
// array. And the tag value will be added to the query.
// Similar to what you would expect from a SQL query
for _, result := range response.Results {
for _, series := range result.Series {
columnNames := series.Columns
for _, row := range series.Values {
rows = append(rows, make(Row))
for columnIdx, value := range row {
rows[index][columnNames[columnIdx]] = value
}
for tagKey, tagValue := range series.Tags {
rows[index][tagKey] = tagValue
}
index++
}
}
}
return rows, nil
}
// Close closes the connection opened by Connect()
func (conn *InfluxDBConnection) Close() {
conn.Client.Close()
}

View File

@ -0,0 +1,157 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func testFields(line string, fields map[string]interface{},
t *testing.T) {
for k, v := range fields {
formatString := "%s=%v"
if _, ok := v.(string); ok {
formatString = "%s=%q"
}
assert.Contains(t, line, fmt.Sprintf(formatString, k, v),
fmt.Sprintf(formatString+" expected in %s", k, v, line))
}
}
func testTags(line string, tags map[string]string,
t *testing.T) {
for k, v := range tags {
assert.Contains(t, line, fmt.Sprintf("%s=%s", k, v),
fmt.Sprintf("%s=%s expected in %s", k, v, line))
}
}
func TestBasicWrite(t *testing.T) {
testConn, _ := NewMockConnection()
measurement := "TestData"
tags := map[string]string{
"tag1": "Happy",
"tag2": "Valentines",
"tag3": "Day",
}
fields := map[string]interface{}{
"Data1": 1234,
"Data2": "apples",
"Data3": 5.34,
}
err := testConn.WritePoint(measurement, tags, fields)
assert.NoError(t, err)
line, err := GetTestBuffer(testConn)
assert.NoError(t, err)
assert.Contains(t, line, measurement,
fmt.Sprintf("%s does not appear in %s", measurement, line))
testTags(line, tags, t)
testFields(line, fields, t)
}
func TestConnectionToHostFailure(t *testing.T) {
assert := assert.New(t)
var err error
config := &InfluxConfig{
Port: 8086,
Protocol: HTTP,
Database: "test",
}
config.Hostname = "this is fake.com"
_, err = Connect(config)
assert.Error(err)
config.Hostname = "\\-Fake.Url.Com"
_, err = Connect(config)
assert.Error(err)
}
func TestWriteFailure(t *testing.T) {
con, _ := NewMockConnection()
measurement := "TestData"
tags := map[string]string{
"tag1": "hi",
}
data := map[string]interface{}{
"Data1": "cats",
}
err := con.WritePoint(measurement, tags, data)
assert.NoError(t, err)
fc, _ := con.Client.(*fakeClient)
fc.failAll = true
err = con.WritePoint(measurement, tags, data)
assert.Error(t, err)
}
func TestQuery(t *testing.T) {
query := "SELECT * FROM 'system' LIMIT 50;"
con, _ := NewMockConnection()
_, err := con.Query(query)
assert.NoError(t, err)
}
func TestAddAndWriteBatchPoints(t *testing.T) {
testConn, _ := NewMockConnection()
measurement := "TestData"
points := []Point{
Point{
Measurement: measurement,
Tags: map[string]string{
"tag1": "Happy",
"tag2": "Valentines",
"tag3": "Day",
},
Fields: map[string]interface{}{
"Data1": 1234,
"Data2": "apples",
"Data3": 5.34,
},
Timestamp: time.Now(),
},
Point{
Measurement: measurement,
Tags: map[string]string{
"tag1": "Happy",
"tag2": "New",
"tag3": "Year",
},
Fields: map[string]interface{}{
"Data1": 5678,
"Data2": "bananas",
"Data3": 3.14,
},
Timestamp: time.Now(),
},
}
err := testConn.RecordBatchPoints(points)
assert.NoError(t, err)
line, err := GetTestBuffer(testConn)
assert.NoError(t, err)
assert.Contains(t, line, measurement,
fmt.Sprintf("%s does not appear in %s", measurement, line))
for _, p := range points {
testTags(line, p.Tags, t)
testFields(line, p.Fields, t)
}
}

View File

@ -0,0 +1,79 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package influxlib
import (
"bytes"
"errors"
"fmt"
"time"
influx "github.com/influxdata/influxdb/client/v2"
)
// This will serve as a fake client object to test off of.
// The idea is to provide a way to test the Influx Wrapper
// without having it connected to the database.
type fakeClient struct {
writer bytes.Buffer
failAll bool
}
func (w *fakeClient) Ping(timeout time.Duration) (time.Duration,
string, error) {
return 0, "", nil
}
func (w *fakeClient) Query(q influx.Query) (*influx.Response, error) {
if w.failAll {
return nil, errors.New("quering points failed")
}
return &influx.Response{Results: nil, Err: ""}, nil
}
func (w *fakeClient) Close() error {
return nil
}
func (w *fakeClient) Write(bp influx.BatchPoints) error {
if w.failAll {
return errors.New("writing point failed")
}
w.writer.Reset()
for _, p := range bp.Points() {
fmt.Fprintf(&w.writer, p.String()+"\n")
}
return nil
}
func (w *fakeClient) qString() string {
return w.writer.String()
}
/***************************************************/
// NewMockConnection returns an influxDBConnection with
// a "fake" client for offline testing.
func NewMockConnection() (*InfluxDBConnection, error) {
client := new(fakeClient)
config := &InfluxConfig{
Hostname: "localhost",
Port: 8086,
Protocol: HTTP,
Database: "Test",
}
return &InfluxDBConnection{client, config}, nil
}
// GetTestBuffer returns the string that would normally
// be written to influx DB
func GetTestBuffer(con *InfluxDBConnection) (string, error) {
fc, ok := con.Client.(*fakeClient)
if !ok {
return "", errors.New("Expected a fake client but recieved a real one")
}
return fc.qString(), nil
}

View File

@ -1,26 +0,0 @@
#!/bin/sh
# Copyright (c) 2016 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the COPYING file.
DEFAULT_PORT=6042
set -e
if [ "$#" -lt 1 ]
then
echo "usage: $0 <host> [<gNMI port>]"
exit 1
fi
echo "WARNING: if you're not using EOS-INT, EOS-REV-0-1 or EOS 4.18 or earlier please use -allowed_ips on the server instead."
host=$1
port=$DEFAULT_PORT
if [ "$#" -gt 1 ]
then
port=$2
fi
iptables="bash sudo iptables -A INPUT -p tcp --dport $port -j ACCEPT"
ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no $host "$iptables"
echo "opened TCP port $port on $host"

View File

@ -79,7 +79,7 @@ func (e *elasticsearchMessageEncoder) Encode(message proto.Message) ([]*sarama.P
} }
glog.V(9).Infof("kafka: %s", updateJSON) glog.V(9).Infof("kafka: %s", updateJSON)
return []*sarama.ProducerMessage{ return []*sarama.ProducerMessage{
&sarama.ProducerMessage{ {
Topic: e.topic, Topic: e.topic,
Key: e.key, Key: e.key,
Value: sarama.ByteEncoder(updateJSON), Value: sarama.ByteEncoder(updateJSON),

View File

@ -116,7 +116,7 @@ func (p *producer) produceNotifications(protoMessage proto.Message) error {
case <-p.done: case <-p.done:
return nil return nil
case p.kafkaProducer.Input() <- m: case p.kafkaProducer.Input() <- m:
glog.V(9).Infof("Message produced to Kafka: %s", m) glog.V(9).Infof("Message produced to Kafka: %v", m)
} }
} }
return nil return nil

View File

@ -5,59 +5,30 @@
package key package key
import ( import (
"encoding/json"
"fmt"
"reflect" "reflect"
"unsafe" "unsafe"
"github.com/aristanetworks/goarista/areflect" "github.com/aristanetworks/goarista/areflect"
) )
// composite allows storing a map[string]interface{} as a key in a Go map.
// This is useful when the key isn't a fixed data structure known at compile
// time but rather something generic, like a bag of key-value pairs.
// Go does not allow storing a map inside the key of a map, because maps are
// not comparable or hashable, and keys in maps must be both. This file is
// a hack specific to the 'gc' implementation of Go (which is the one most
// people use when they use Go), to bypass this check, by abusing reflection
// to override how Go compares composite for equality or how it's hashed.
// The values allowed in this map are only the types whitelisted in New() as
// well as map[Key]interface{}.
//
// See also https://github.com/golang/go/issues/283
type composite struct {
// This value must always be set to the sentinel constant above.
sentinel uintptr
m map[string]interface{}
}
func (k composite) Key() interface{} {
return k.m
}
func (k composite) String() string {
return stringify(k.Key())
}
func (k composite) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key())
}
func (k composite) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key())
}
func (k composite) Equal(other interface{}) bool {
o, ok := other.(composite)
return ok && mapStringEqual(k.m, o.m)
}
func hashInterface(v interface{}) uintptr { func hashInterface(v interface{}) uintptr {
switch v := v.(type) { switch v := v.(type) {
case map[string]interface{}: case map[string]interface{}:
return hashMapString(v) return hashMapString(v)
case map[Key]interface{}: case map[Key]interface{}:
return hashMapKey(v) return hashMapKey(v)
case []interface{}:
return hashSlice(v)
case Pointer:
// This case applies to pointers used
// as values in maps or slices (i.e.
// not wrapped in a key).
return hashSlice(pointerToSlice(v))
case Path:
// This case applies to paths used
// as values in maps or slices (i.e
// not wrapped in a kay).
return hashSlice(pathToSlice(v))
default: default:
return _nilinterhash(v) return _nilinterhash(v)
} }
@ -78,9 +49,9 @@ func hashMapKey(m map[Key]interface{}) uintptr {
for k, v := range m { for k, v := range m {
// Use addition so that the order of iteration doesn't matter. // Use addition so that the order of iteration doesn't matter.
switch k := k.(type) { switch k := k.(type) {
case keyImpl: case interfaceKey:
h += _nilinterhash(k.key) h += _nilinterhash(k.key)
case composite: case compositeKey:
h += hashMapString(k.m) h += hashMapString(k.m)
} }
h += hashInterface(v) h += hashInterface(v)
@ -88,28 +59,42 @@ func hashMapKey(m map[Key]interface{}) uintptr {
return h return h
} }
func hashSlice(s []interface{}) uintptr {
h := uintptr(31 * (len(s) + 1))
for _, v := range s {
h += hashInterface(v)
}
return h
}
func hash(p unsafe.Pointer, seed uintptr) uintptr { func hash(p unsafe.Pointer, seed uintptr) uintptr {
ck := *(*composite)(p) ck := *(*compositeKey)(p)
if ck.sentinel != sentinel { if ck.sentinel != sentinel {
panic("use of unhashable type in a map") panic("use of unhashable type in a map")
} }
if ck.m != nil {
return seed ^ hashMapString(ck.m) return seed ^ hashMapString(ck.m)
}
return seed ^ hashSlice(ck.s)
} }
func equal(a unsafe.Pointer, b unsafe.Pointer) bool { func equal(a unsafe.Pointer, b unsafe.Pointer) bool {
ca := (*composite)(a) ca := (*compositeKey)(a)
cb := (*composite)(b) cb := (*compositeKey)(b)
if ca.sentinel != sentinel { if ca.sentinel != sentinel {
panic("use of uncomparable type on the lhs of ==") panic("use of uncomparable type on the lhs of ==")
} }
if cb.sentinel != sentinel { if cb.sentinel != sentinel {
panic("use of uncomparable type on the rhs of ==") panic("use of uncomparable type on the rhs of ==")
} }
if ca.m != nil {
return mapStringEqual(ca.m, cb.m) return mapStringEqual(ca.m, cb.m)
}
return sliceEqual(ca.s, cb.s)
} }
func init() { func init() {
typ := reflect.TypeOf(composite{}) typ := reflect.TypeOf(compositeKey{})
alg := reflect.ValueOf(typ).Elem().FieldByName("alg").Elem() alg := reflect.ValueOf(typ).Elem().FieldByName("alg").Elem()
// Pretty certain that doing this voids your warranty. // Pretty certain that doing this voids your warranty.
// This overwrites the typeAlg of either alg_NOEQ64 (on 32-bit platforms) // This overwrites the typeAlg of either alg_NOEQ64 (on 32-bit platforms)

View File

@ -19,7 +19,7 @@ type unhashable struct {
func TestBadComposite(t *testing.T) { func TestBadComposite(t *testing.T) {
test.ShouldPanicWith(t, "use of unhashable type in a map", func() { test.ShouldPanicWith(t, "use of unhashable type in a map", func() {
m := map[interface{}]struct{}{ m := map[interface{}]struct{}{
unhashable{func() {}, 0x42}: struct{}{}, unhashable{func() {}, 0x42}: {},
} }
// Use Key here to make sure init() is called. // Use Key here to make sure init() is called.
if _, ok := m[New("foo")]; ok { if _, ok := m[New("foo")]; ok {

View File

@ -16,14 +16,35 @@ import (
// Key represents the Key in the updates and deletes of the Notification // Key represents the Key in the updates and deletes of the Notification
// objects. The only reason this exists is that Go won't let us define // objects. The only reason this exists is that Go won't let us define
// our own hash function for non-hashable types, and unfortunately we // our own hash function for non-hashable types, and unfortunately we
// need to be able to index maps by map[string]interface{} objects. // need to be able to index maps by map[string]interface{} objects
// and slices by []interface{} objects.
type Key interface { type Key interface {
Key() interface{} Key() interface{}
String() string String() string
Equal(other interface{}) bool Equal(other interface{}) bool
} }
type keyImpl struct { // compositeKey allows storing a map[string]interface{} or []interface{} as a key
// in a Go map. This is useful when the key isn't a fixed data structure known
// at compile time but rather something generic, like a bag of key-value pairs
// or a list of elements. Go does not allow storing a map or slice inside the
// key of a map, because maps and slices are not comparable or hashable, and
// keys in maps and slice elements must be both. This file is a hack specific
// to the 'gc' implementation of Go (which is the one most people use when they
// use Go), to bypass this check, by abusing reflection to override how Go
// compares compositeKey for equality or how it's hashed. The values allowed in
// this map are only the types whitelisted in New() as well as map[Key]interface{}
// and []interface{}.
//
// See also https://github.com/golang/go/issues/283
type compositeKey struct {
// This value must always be set to the sentinel constant above.
sentinel uintptr
m map[string]interface{}
s []interface{}
}
type interfaceKey struct {
key interface{} key interface{}
} }
@ -44,13 +65,43 @@ type float64Key float64
type boolKey bool type boolKey bool
type pointerKey compositeKey
type pathKey compositeKey
func pathToSlice(path Path) []interface{} {
s := make([]interface{}, len(path))
for i, element := range path {
s[i] = element.Key()
}
return s
}
func sliceToPath(s []interface{}) Path {
path := make(Path, len(s))
for i, intf := range s {
path[i] = New(intf)
}
return path
}
func pointerToSlice(ptr Pointer) []interface{} {
return pathToSlice(ptr.Pointer())
}
func sliceToPointer(s []interface{}) pointer {
return pointer(sliceToPath(s))
}
// New wraps the given value in a Key. // New wraps the given value in a Key.
// This function panics if the value passed in isn't allowed in a Key or // This function panics if the value passed in isn't allowed in a Key or
// doesn't implement value.Value. // doesn't implement value.Value.
func New(intf interface{}) Key { func New(intf interface{}) Key {
switch t := intf.(type) { switch t := intf.(type) {
case map[string]interface{}: case map[string]interface{}:
return composite{sentinel, t} return compositeKey{sentinel: sentinel, m: t}
case []interface{}:
return compositeKey{sentinel: sentinel, s: t}
case string: case string:
return strKey(t) return strKey(t)
case int8: case int8:
@ -76,31 +127,35 @@ func New(intf interface{}) Key {
case bool: case bool:
return boolKey(t) return boolKey(t)
case value.Value: case value.Value:
return keyImpl{key: intf} return interfaceKey{key: intf}
case Pointer:
return pointerKey{sentinel: sentinel, s: pointerToSlice(t)}
case Path:
return pathKey{sentinel: sentinel, s: pathToSlice(t)}
default: default:
panic(fmt.Sprintf("Invalid type for key: %T", intf)) panic(fmt.Sprintf("Invalid type for key: %T", intf))
} }
} }
func (k keyImpl) Key() interface{} { func (k interfaceKey) Key() interface{} {
return k.key return k.key
} }
func (k keyImpl) String() string { func (k interfaceKey) String() string {
return stringify(k.key) return stringify(k.key)
} }
func (k keyImpl) GoString() string { func (k interfaceKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key()) return fmt.Sprintf("key.New(%#v)", k.Key())
} }
func (k keyImpl) MarshalJSON() ([]byte, error) { func (k interfaceKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key()) return json.Marshal(k.Key())
} }
func (k keyImpl) Equal(other interface{}) bool { func (k interfaceKey) Equal(other interface{}) bool {
o, ok := other.(keyImpl) o, ok := other.(Key)
return ok && keyEqual(k.key, o.key) return ok && keyEqual(k.key, o.Key())
} }
// Comparable types have an equality-testing method. // Comparable types have an equality-testing method.
@ -121,6 +176,18 @@ func mapStringEqual(a, b map[string]interface{}) bool {
return true return true
} }
func sliceEqual(a, b []interface{}) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !keyEqual(v, b[i]) {
return false
}
}
return true
}
func keyEqual(a, b interface{}) bool { func keyEqual(a, b interface{}) bool {
switch a := a.(type) { switch a := a.(type) {
case map[string]interface{}: case map[string]interface{}:
@ -137,19 +204,56 @@ func keyEqual(a, b interface{}) bool {
} }
} }
return true return true
case []interface{}:
b, ok := b.([]interface{})
return ok && sliceEqual(a, b)
case Comparable: case Comparable:
return a.Equal(b) return a.Equal(b)
case Pointer:
b, ok := b.(Pointer)
return ok && pointerEqual(a, b)
case Path:
b, ok := b.(Path)
return ok && pathEqual(a, b)
} }
return a == b return a == b
} }
// Key interface implementation for map[string]interface{} and []interface{}
func (k compositeKey) Key() interface{} {
if k.m != nil {
return k.m
}
return k.s
}
func (k compositeKey) String() string {
return stringify(k.Key())
}
func (k compositeKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.Key())
}
func (k compositeKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.Key())
}
func (k compositeKey) Equal(other interface{}) bool {
o, ok := other.(compositeKey)
if k.m != nil {
return ok && mapStringEqual(k.m, o.m)
}
return ok && sliceEqual(k.s, o.s)
}
func (k strKey) Key() interface{} { func (k strKey) Key() interface{} {
return string(k) return string(k)
} }
func (k strKey) String() string { func (k strKey) String() string {
return escape(string(k)) return string(k)
} }
func (k strKey) GoString() string { func (k strKey) GoString() string {
@ -157,7 +261,7 @@ func (k strKey) GoString() string {
} }
func (k strKey) MarshalJSON() ([]byte, error) { func (k strKey) MarshalJSON() ([]byte, error) {
return json.Marshal(string(k)) return json.Marshal(escape(string(k)))
} }
func (k strKey) Equal(other interface{}) bool { func (k strKey) Equal(other interface{}) bool {
@ -175,7 +279,7 @@ func (k int8Key) String() string {
} }
func (k int8Key) GoString() string { func (k int8Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int8(k)) return fmt.Sprintf("key.New(int8(%d))", int8(k))
} }
func (k int8Key) MarshalJSON() ([]byte, error) { func (k int8Key) MarshalJSON() ([]byte, error) {
@ -197,7 +301,7 @@ func (k int16Key) String() string {
} }
func (k int16Key) GoString() string { func (k int16Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int16(k)) return fmt.Sprintf("key.New(int16(%d))", int16(k))
} }
func (k int16Key) MarshalJSON() ([]byte, error) { func (k int16Key) MarshalJSON() ([]byte, error) {
@ -219,7 +323,7 @@ func (k int32Key) String() string {
} }
func (k int32Key) GoString() string { func (k int32Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int32(k)) return fmt.Sprintf("key.New(int32(%d))", int32(k))
} }
func (k int32Key) MarshalJSON() ([]byte, error) { func (k int32Key) MarshalJSON() ([]byte, error) {
@ -241,7 +345,7 @@ func (k int64Key) String() string {
} }
func (k int64Key) GoString() string { func (k int64Key) GoString() string {
return fmt.Sprintf("key.New(%d)", int64(k)) return fmt.Sprintf("key.New(int64(%d))", int64(k))
} }
func (k int64Key) MarshalJSON() ([]byte, error) { func (k int64Key) MarshalJSON() ([]byte, error) {
@ -263,7 +367,7 @@ func (k uint8Key) String() string {
} }
func (k uint8Key) GoString() string { func (k uint8Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint8(k)) return fmt.Sprintf("key.New(uint8(%d))", uint8(k))
} }
func (k uint8Key) MarshalJSON() ([]byte, error) { func (k uint8Key) MarshalJSON() ([]byte, error) {
@ -285,7 +389,7 @@ func (k uint16Key) String() string {
} }
func (k uint16Key) GoString() string { func (k uint16Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint16(k)) return fmt.Sprintf("key.New(uint16(%d))", uint16(k))
} }
func (k uint16Key) MarshalJSON() ([]byte, error) { func (k uint16Key) MarshalJSON() ([]byte, error) {
@ -307,7 +411,7 @@ func (k uint32Key) String() string {
} }
func (k uint32Key) GoString() string { func (k uint32Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint32(k)) return fmt.Sprintf("key.New(uint32(%d))", uint32(k))
} }
func (k uint32Key) MarshalJSON() ([]byte, error) { func (k uint32Key) MarshalJSON() ([]byte, error) {
@ -329,7 +433,7 @@ func (k uint64Key) String() string {
} }
func (k uint64Key) GoString() string { func (k uint64Key) GoString() string {
return fmt.Sprintf("key.New(%d)", uint64(k)) return fmt.Sprintf("key.New(uint64(%d))", uint64(k))
} }
func (k uint64Key) MarshalJSON() ([]byte, error) { func (k uint64Key) MarshalJSON() ([]byte, error) {
@ -351,7 +455,7 @@ func (k float32Key) String() string {
} }
func (k float32Key) GoString() string { func (k float32Key) GoString() string {
return fmt.Sprintf("key.New(%v)", float32(k)) return fmt.Sprintf("key.New(float32(%v))", float32(k))
} }
func (k float32Key) MarshalJSON() ([]byte, error) { func (k float32Key) MarshalJSON() ([]byte, error) {
@ -373,7 +477,7 @@ func (k float64Key) String() string {
} }
func (k float64Key) GoString() string { func (k float64Key) GoString() string {
return fmt.Sprintf("key.New(%v)", float64(k)) return fmt.Sprintf("key.New(float64(%v))", float64(k))
} }
func (k float64Key) MarshalJSON() ([]byte, error) { func (k float64Key) MarshalJSON() ([]byte, error) {
@ -406,3 +510,59 @@ func (k boolKey) Equal(other interface{}) bool {
o, ok := other.(boolKey) o, ok := other.(boolKey)
return ok && k == o return ok && k == o
} }
// Key interface implementation for Pointer
func (k pointerKey) Key() interface{} {
return sliceToPointer(k.s)
}
func (k pointerKey) String() string {
return sliceToPointer(k.s).String()
}
func (k pointerKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.s)
}
func (k pointerKey) MarshalJSON() ([]byte, error) {
return sliceToPointer(k.s).MarshalJSON()
}
func (k pointerKey) Equal(other interface{}) bool {
if o, ok := other.(pointerKey); ok {
return sliceEqual(k.s, o.s)
}
key, ok := other.(Key)
if !ok {
return false
}
return ok && sliceToPointer(k.s).Equal(key.Key())
}
// Key interface implementation for Path
func (k pathKey) Key() interface{} {
return sliceToPath(k.s)
}
func (k pathKey) String() string {
return sliceToPath(k.s).String()
}
func (k pathKey) GoString() string {
return fmt.Sprintf("key.New(%#v)", k.s)
}
func (k pathKey) MarshalJSON() ([]byte, error) {
return sliceToPath(k.s).MarshalJSON()
}
func (k pathKey) Equal(other interface{}) bool {
if o, ok := other.(pathKey); ok {
return sliceEqual(k.s, o.s)
}
key, ok := other.(Key)
if !ok {
return false
}
return ok && sliceToPath(k.s).Equal(key.Key())
}

View File

@ -55,6 +55,34 @@ func TestKeyEqual(t *testing.T) {
a: New("foo"), a: New("foo"),
b: New("bar"), b: New("bar"),
result: false, result: false,
}, {
a: New([]interface{}{}),
b: New("bar"),
result: false,
}, {
a: New([]interface{}{}),
b: New([]interface{}{}),
result: true,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"a"}),
result: false,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"b", "a"}),
result: false,
}, {
a: New([]interface{}{"a", "b"}),
b: New([]interface{}{"a", "b"}),
result: true,
}, {
a: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
b: New([]interface{}{"a", map[string]interface{}{"c": "b"}}),
result: false,
}, {
a: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
b: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
result: true,
}, { }, {
a: New(map[string]interface{}{}), a: New(map[string]interface{}{}),
b: New("bar"), b: New("bar"),
@ -151,6 +179,32 @@ func TestGetFromMap(t *testing.T) {
k: New(uint32(37)), k: New(uint32(37)),
m: map[Key]interface{}{}, m: map[Key]interface{}{},
found: false, found: false,
}, {
k: New([]interface{}{"a", "b"}),
m: map[Key]interface{}{
New([]interface{}{"a", "b"}): "foo",
},
v: "foo",
found: true,
}, {
k: New([]interface{}{"a", "b"}),
m: map[Key]interface{}{
New([]interface{}{"a", "b", "c"}): "foo",
},
found: false,
}, {
k: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
m: map[Key]interface{}{
New([]interface{}{"a", map[string]interface{}{"b": "c"}}): "foo",
},
v: "foo",
found: true,
}, {
k: New([]interface{}{"a", map[string]interface{}{"b": "c"}}),
m: map[Key]interface{}{
New([]interface{}{"a", map[string]interface{}{"c": "b"}}): "foo",
},
found: false,
}, { }, {
k: New(map[string]interface{}{"a": "b", "c": uint64(4)}), k: New(map[string]interface{}{"a": "b", "c": uint64(4)}),
m: map[Key]interface{}{ m: map[Key]interface{}{
@ -392,6 +446,53 @@ func TestSetToMap(t *testing.T) {
} }
} }
func TestGoString(t *testing.T) {
tcases := []struct {
in Key
out string
}{{
in: New(uint8(1)),
out: "key.New(uint8(1))",
}, {
in: New(uint16(1)),
out: "key.New(uint16(1))",
}, {
in: New(uint32(1)),
out: "key.New(uint32(1))",
}, {
in: New(uint64(1)),
out: "key.New(uint64(1))",
}, {
in: New(int8(1)),
out: "key.New(int8(1))",
}, {
in: New(int16(1)),
out: "key.New(int16(1))",
}, {
in: New(int32(1)),
out: "key.New(int32(1))",
}, {
in: New(int64(1)),
out: "key.New(int64(1))",
}, {
in: New(float32(1)),
out: "key.New(float32(1))",
}, {
in: New(float64(1)),
out: "key.New(float64(1))",
}, {
in: New(map[string]interface{}{"foo": true}),
out: `key.New(map[string]interface {}{"foo":true})`,
}}
for i, tcase := range tcases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
if out := fmt.Sprintf("%#v", tcase.in); out != tcase.out {
t.Errorf("Wanted Go representation %q but got %q", tcase.out, out)
}
})
}
}
func TestMisc(t *testing.T) { func TestMisc(t *testing.T) {
k := New(map[string]interface{}{"foo": true}) k := New(map[string]interface{}{"foo": true})
js, err := json.Marshal(k) js, err := json.Marshal(k)
@ -400,11 +501,6 @@ func TestMisc(t *testing.T) {
} else if expected := `{"foo":true}`; string(js) != expected { } else if expected := `{"foo":true}`; string(js) != expected {
t.Errorf("Wanted JSON %q but got %q", expected, js) t.Errorf("Wanted JSON %q but got %q", expected, js)
} }
expected := `key.New(map[string]interface {}{"foo":true})`
gostr := fmt.Sprintf("%#v", k)
if expected != gostr {
t.Errorf("Wanted Go representation %q but got %q", expected, gostr)
}
test.ShouldPanic(t, func() { New(42) }) test.ShouldPanic(t, func() { New(42) })

51
vendor/github.com/aristanetworks/goarista/key/path.go generated vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key
import (
"bytes"
"fmt"
)
// Path represents a path decomposed into elements where each
// element is a Key. A Path can be interpreted as either
// absolute or relative depending on how it is used.
type Path []Key
// String returns the Path as an absolute path string.
func (p Path) String() string {
if len(p) == 0 {
return "/"
}
var buf bytes.Buffer
for _, element := range p {
buf.WriteByte('/')
buf.WriteString(element.String())
}
return buf.String()
}
// MarshalJSON marshals a Path to JSON.
func (p Path) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"_path":%q}`, p)), nil
}
// Equal returns whether a Path is equal to @other.
func (p Path) Equal(other interface{}) bool {
o, ok := other.(Path)
return ok && pathEqual(p, o)
}
func pathEqual(a, b Path) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}

View File

@ -0,0 +1,141 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/path"
)
func TestPath(t *testing.T) {
expected := path.New("leet")
k := key.New(expected)
keyPath, ok := k.Key().(key.Path)
if !ok {
t.Fatalf("key.Key() did not return a key.Path")
}
if !path.Equal(keyPath, expected) {
t.Errorf("expected path from key is %#v, got %#v", expected, keyPath)
}
if expected, actual := "/leet", fmt.Sprint(k); actual != expected {
t.Errorf("expected string from path key %q, got %q", expected, actual)
}
if js, err := json.Marshal(k); err != nil {
t.Errorf("JSON marshalling error: %v", err)
} else if expected, actual := `{"_path":"/leet"}`, string(js); actual != expected {
t.Errorf("expected json output %q, got %q", expected, actual)
}
}
func newPathKey(e ...interface{}) key.Key {
return key.New(path.New(e...))
}
type customPath string
func (p customPath) Key() interface{} {
return path.FromString(string(p))
}
func (p customPath) Equal(other interface{}) bool {
o, ok := other.(key.Key)
return ok && o.Equal(p)
}
func (p customPath) String() string { return string(p) }
func (p customPath) ToBuiltin() interface{} { panic("not impl") }
func (p customPath) MarshalJSON() ([]byte, error) { panic("not impl") }
func TestPathEqual(t *testing.T) {
tests := []struct {
a key.Key
b key.Key
result bool
}{{
a: newPathKey(),
b: nil,
result: false,
}, {
a: newPathKey(),
b: newPathKey(),
result: true,
}, {
a: newPathKey(),
b: newPathKey("foo"),
result: false,
}, {
a: newPathKey("foo"),
b: newPathKey(),
result: false,
}, {
a: newPathKey(int16(1337)),
b: newPathKey(int64(1337)),
result: false,
}, {
a: newPathKey(path.Wildcard, "bar"),
b: newPathKey("foo", path.Wildcard),
result: false,
}, {
a: newPathKey(map[string]interface{}{"a": "x", "b": "y"}),
b: newPathKey(map[string]interface{}{"b": "y", "a": "x"}),
result: true,
}, {
a: newPathKey(map[string]interface{}{"a": "x", "b": "y"}),
b: newPathKey(map[string]interface{}{"x": "x", "y": "y"}),
result: false,
}, {
a: newPathKey("foo", "bar"),
b: customPath("/foo/bar"),
result: true,
}, {
a: customPath("/foo/bar"),
b: newPathKey("foo", "bar"),
result: true,
}, {
a: newPathKey("foo"),
b: key.New(customPath("/bar")),
result: false,
}, {
a: key.New(customPath("/foo")),
b: newPathKey("foo", "bar"),
result: false,
}}
for i, tc := range tests {
if a, b := tc.a, tc.b; a.Equal(b) != tc.result {
t.Errorf("result not as expected for test case %d", i)
}
}
}
func TestPathAsKey(t *testing.T) {
a := newPathKey("foo", path.Wildcard, map[string]interface{}{
"bar": map[key.Key]interface{}{
// Should be able to embed a path key and value
newPathKey("path", "to", "something"): path.New("else"),
},
})
m := map[key.Key]string{
a: "thats a complex key!",
}
if s, ok := m[a]; !ok {
t.Error("complex key not found in map")
} else if s != "thats a complex key!" {
t.Errorf("incorrect value in map: %s", s)
}
// preserve custom path implementations
b := key.New(customPath("/foo/bar"))
if _, ok := b.Key().(customPath); !ok {
t.Errorf("customPath implementation not preserved: %T", b.Key())
}
}

View File

@ -0,0 +1,46 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key
import (
"fmt"
)
// Pointer is a pointer to a path.
type Pointer interface {
Pointer() Path
}
// NewPointer creates a new pointer to a path.
func NewPointer(path Path) Pointer {
return pointer(path)
}
// This is the type returned by pointerKey.Key. Returning this is a
// lot faster than having pointerKey implement Pointer, since it is
// a compositeKey and thus would require reconstructing a Path from
// []interface{} any time the Pointer method is called.
type pointer Path
func (ptr pointer) Pointer() Path {
return Path(ptr)
}
func (ptr pointer) String() string {
return "{" + ptr.Pointer().String() + "}"
}
func (ptr pointer) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"_ptr":%q}`, ptr.Pointer().String())), nil
}
func (ptr pointer) Equal(other interface{}) bool {
o, ok := other.(Pointer)
return ok && pointerEqual(ptr, o)
}
func pointerEqual(a, b Pointer) bool {
return pathEqual(a.Pointer(), b.Pointer())
}

View File

@ -0,0 +1,189 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package key_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/path"
)
func TestPointer(t *testing.T) {
p := key.NewPointer(path.New("foo"))
if expected, actual := path.New("foo"), p.Pointer(); !path.Equal(expected, actual) {
t.Errorf("Expected %#v but got %#v", expected, actual)
}
if expected, actual := "{/foo}", fmt.Sprintf("%s", p); actual != expected {
t.Errorf("Expected %q but got %q", expected, actual)
}
if js, err := json.Marshal(p); err != nil {
t.Errorf("JSON marshaling failed: %s", err)
} else if expected, actual := `{"_ptr":"/foo"}`, string(js); actual != expected {
t.Errorf("Expected %q but got %q", expected, actual)
}
}
type pointer string
func (ptr pointer) Pointer() key.Path {
return path.FromString(string(ptr))
}
func (ptr pointer) ToBuiltin() interface{} {
panic("NOT IMPLEMENTED")
}
func (ptr pointer) String() string {
panic("NOT IMPLEMENTED")
}
func (ptr pointer) MarshalJSON() ([]byte, error) {
panic("NOT IMPLEMENTED")
}
func TestPointerEqual(t *testing.T) {
tests := []struct {
a key.Pointer
b key.Pointer
result bool
}{{
a: key.NewPointer(nil),
b: key.NewPointer(path.New()),
result: true,
}, {
a: key.NewPointer(path.New()),
b: key.NewPointer(nil),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: key.NewPointer(path.New("foo")),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: key.NewPointer(path.New("bar")),
result: false,
}, {
a: key.NewPointer(path.New(true)),
b: key.NewPointer(path.New("true")),
result: false,
}, {
a: key.NewPointer(path.New(int8(0))),
b: key.NewPointer(path.New(int16(0))),
result: false,
}, {
a: key.NewPointer(path.New(path.Wildcard, "bar")),
b: key.NewPointer(path.New("foo", path.Wildcard)),
result: false,
}, {
a: key.NewPointer(path.New(map[string]interface{}{"a": "x", "b": "y"})),
b: key.NewPointer(path.New(map[string]interface{}{"b": "y", "a": "x"})),
result: true,
}, {
a: key.NewPointer(path.New(map[string]interface{}{"a": "x", "b": "y"})),
b: key.NewPointer(path.New(map[string]interface{}{"x": "x", "y": "y"})),
result: false,
}, {
a: key.NewPointer(path.New("foo")),
b: pointer("/foo"),
result: true,
}, {
a: pointer("/foo"),
b: key.NewPointer(path.New("foo")),
result: true,
}, {
a: key.NewPointer(path.New("foo")),
b: pointer("/foo/bar"),
result: false,
}, {
a: pointer("/foo/bar"),
b: key.NewPointer(path.New("foo")),
result: false,
}}
for i, tcase := range tests {
if key.New(tcase.a).Equal(key.New(tcase.b)) != tcase.result {
t.Errorf("Error in pointer comparison for test %d", i)
}
}
}
func TestPointerAsKey(t *testing.T) {
a := key.NewPointer(path.New("foo", path.Wildcard, map[string]interface{}{
"bar": map[key.Key]interface{}{
// Should be able to embed pointer key.
key.New(key.NewPointer(path.New("baz"))):
// Should be able to embed pointer value.
key.NewPointer(path.New("baz")),
},
}))
m := map[key.Key]string{
key.New(a): "a",
}
if s, ok := m[key.New(a)]; !ok {
t.Error("pointer to path not keyed in map")
} else if s != "a" {
t.Errorf("pointer to path not mapped to correct value in map: %s", s)
}
// Ensure that we preserve custom pointer implementations.
b := key.New(pointer("/foo/bar"))
if _, ok := b.Key().(pointer); !ok {
t.Errorf("pointer implementation not preserved: %T", b.Key())
}
}
func BenchmarkPointer(b *testing.B) {
benchmarks := []key.Path{
path.New(),
path.New("foo"),
path.New("foo", "bar"),
path.New("foo", "bar", "baz"),
path.New("foo", "bar", "baz", "qux"),
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.NewPointer(benchmark)
}
})
}
}
func BenchmarkPointerAsKey(b *testing.B) {
benchmarks := []key.Pointer{
key.NewPointer(path.New()),
key.NewPointer(path.New("foo")),
key.NewPointer(path.New("foo", "bar")),
key.NewPointer(path.New("foo", "bar", "baz")),
key.NewPointer(path.New("foo", "bar", "baz", "qux")),
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.New(benchmark)
}
})
}
}
func BenchmarkEmbeddedPointerAsKey(b *testing.B) {
benchmarks := [][]interface{}{
[]interface{}{key.NewPointer(path.New())},
[]interface{}{key.NewPointer(path.New("foo"))},
[]interface{}{key.NewPointer(path.New("foo", "bar"))},
[]interface{}{key.NewPointer(path.New("foo", "bar", "baz"))},
[]interface{}{key.NewPointer(path.New("foo", "bar", "baz", "qux"))},
}
for i, benchmark := range benchmarks {
b.Run(fmt.Sprintf("%d", i), func(b *testing.B) {
for i := 0; i < b.N; i++ {
key.New(benchmark)
}
})
}
}

View File

@ -5,11 +5,13 @@
package key package key
import ( import (
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
"strings" "strings"
"unicode/utf8"
"github.com/aristanetworks/goarista/value" "github.com/aristanetworks/goarista/value"
) )
@ -67,7 +69,16 @@ func StringifyInterface(key interface{}) (string, error) {
keys[i] = stringify(k) + "=" + stringify(m[k]) keys[i] = stringify(k) + "=" + stringify(m[k])
} }
str = strings.Join(keys, "_") str = strings.Join(keys, "_")
case []interface{}:
elements := make([]string, len(key))
for i, element := range key {
elements[i] = stringify(element)
}
str = strings.Join(elements, ",")
case Pointer:
return "{" + key.Pointer().String() + "}", nil
case Path:
return "[" + key.String() + "]", nil
case value.Value: case value.Value:
return key.String(), nil return key.String(), nil
@ -78,15 +89,14 @@ func StringifyInterface(key interface{}) (string, error) {
return str, nil return str, nil
} }
// escape checks if the string is a valid utf-8 string.
// If it is, it will return the string as is.
// If it is not, it will return the base64 representation of the byte array string
func escape(str string) string { func escape(str string) string {
for i := 0; i < len(str); i++ { if utf8.ValidString(str) {
if chr := str[i]; chr < 0x20 || chr > 0x7E {
str = strconv.QuoteToASCII(str)
str = str[1 : len(str)-1] // Drop the leading and trailing quotes.
break
}
}
return str return str
}
return base64.StdEncoding.EncodeToString([]byte(str))
} }
func stringify(key interface{}) string { func stringify(key interface{}) string {

View File

@ -27,9 +27,21 @@ func TestStringify(t *testing.T) {
input: "foobar", input: "foobar",
output: "foobar", output: "foobar",
}, { }, {
name: "non-ASCII string", name: "valid non-ASCII UTF-8 string",
input: "日本語", input: "日本語",
output: `\u65e5\u672c\u8a9e`, output: "日本語",
}, {
name: "invalid UTF-8 string 1",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe}),
output: "77++vr6+vg==",
}, {
name: "invalid UTF-8 string 2",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe, 0x23}),
output: "77++vr6+viM=",
}, {
name: "invalid UTF-8 string 3",
input: string([]byte{0xef, 0xbf, 0xbe, 0xbe, 0xbe, 0xbe, 0xbe, 0x23, 0x24}),
output: "77++vr6+viMk",
}, { }, {
name: "uint8", name: "uint8",
input: uint8(43), input: uint8(43),
@ -107,6 +119,22 @@ func TestStringify(t *testing.T) {
"n": nil, "n": nil,
}, },
output: "Unable to stringify nil", output: "Unable to stringify nil",
}, {
name: "[]interface{}",
input: []interface{}{
uint32(42),
true,
"foo",
map[Key]interface{}{
New("a"): "b",
New("b"): "c",
},
},
output: "42,true,foo,a=b_b=c",
}, {
name: "pointer",
input: NewPointer(Path{New("foo"), New("bar")}),
output: "{/foo/bar}",
}} }}
for _, tcase := range testcases { for _, tcase := range testcases {

View File

@ -8,8 +8,6 @@ package netns
import ( import (
"fmt" "fmt"
"os"
"runtime"
) )
const ( const (
@ -45,54 +43,3 @@ func setNsByName(nsName string) error {
} }
return nil return nil
} }
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, and you should not kick off any
// complex logic from the callback or call any complicated code or create any
// new goroutine from the callback. The callback should not panic or use defer.
// The behavior of this function is undefined if the callback doesn't conform
// these demands.
//go:nosplit
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
runtime.LockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
runtime.UnlockOSThread()
currNsFd.close()
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
cbErr = fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
runtime.UnlockOSThread()
currNsFd.close()
return cbErr
}

View File

@ -0,0 +1,60 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// +build go1.10
package netns
import (
"fmt"
"os"
"runtime"
)
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, as it's not desirable to start
// complex logic in a goroutine that is pinned to the current OS thread.
// Also any goroutine started from the callback function may or may not
// execute in the desired namespace.
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
defer currNsFd.close()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
return fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
return cbErr
}

View File

@ -0,0 +1,64 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
// +build !go1.10
package netns
import (
"fmt"
"os"
"runtime"
)
// Do takes a function which it will call in the network namespace specified by nsName.
// The goroutine that calls this will lock itself to its current OS thread, hop
// namespaces, call the given function, hop back to its original namespace, and then
// unlock itself from its current OS thread.
// Do returns an error if an error occurs at any point besides in the invocation of
// the given function, or if the given function itself returns an error.
//
// The callback function is expected to do something simple such as just
// creating a socket / opening a connection, and you should not kick off any
// complex logic from the callback or call any complicated code or create any
// new goroutine from the callback. The callback should not panic or use defer.
// The behavior of this function is undefined if the callback doesn't conform
// these demands.
//go:nosplit
func Do(nsName string, cb Callback) error {
// If destNS is empty, the function is called in the caller's namespace
if nsName == "" {
return cb()
}
// Get the file descriptor to the current namespace
currNsFd, err := getNs(selfNsFile)
if os.IsNotExist(err) {
return fmt.Errorf("File descriptor to current namespace does not exist: %s", err)
} else if err != nil {
return fmt.Errorf("Failed to open %s: %s", selfNsFile, err)
}
runtime.LockOSThread()
// Jump to the new network namespace
if err := setNsByName(nsName); err != nil {
runtime.UnlockOSThread()
currNsFd.close()
return fmt.Errorf("Failed to set the namespace to %s: %s", nsName, err)
}
// Call the given function
cbErr := cb()
// Come back to the original namespace
if err = setNs(currNsFd); err != nil {
cbErr = fmt.Errorf("Failed to return to the original namespace: %s (callback returned %v)",
err, cbErr)
}
runtime.UnlockOSThread()
currNsFd.close()
return cbErr
}

View File

@ -18,7 +18,7 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
const defaultPort = "6042" const defaultPort = "6030"
// PublishFunc is the method to publish responses // PublishFunc is the method to publish responses
type PublishFunc func(addr string, message proto.Message) type PublishFunc func(addr string, message proto.Message)
@ -36,6 +36,9 @@ func New(username, password, addr string, opts []grpc.DialOption) *Client {
if !strings.ContainsRune(addr, ':') { if !strings.ContainsRune(addr, ':') {
addr += ":" + defaultPort addr += ":" + defaultPort
} }
// Make sure we don't move past the grpc.Dial() call until we actually
// established an HTTP/2 connection successfully.
opts = append(opts, grpc.WithBlock(), grpc.WithWaitForHandshake())
conn, err := grpc.Dial(addr, opts...) conn, err := grpc.Dial(addr, opts...)
if err != nil { if err != nil {
glog.Fatalf("Failed to dial: %s", err) glog.Fatalf("Failed to dial: %s", err)
@ -93,7 +96,7 @@ func (c *Client) Subscribe(wg *sync.WaitGroup, subscriptions []string,
Request: &openconfig.SubscribeRequest_Subscribe{ Request: &openconfig.SubscribeRequest_Subscribe{
Subscribe: &openconfig.SubscriptionList{ Subscribe: &openconfig.SubscriptionList{
Subscription: []*openconfig.Subscription{ Subscription: []*openconfig.Subscription{
&openconfig.Subscription{ {
Path: &openconfig.Path{Element: strings.Split(path, "/")}, Path: &openconfig.Path{Element: strings.Split(path, "/")},
}, },
}, },

View File

@ -23,7 +23,7 @@ func ParseFlags() (username string, password string, subscriptions, addrs []stri
opts []grpc.DialOption) { opts []grpc.DialOption) {
var ( var (
addrsFlag = flag.String("addrs", "localhost:6042", addrsFlag = flag.String("addrs", "localhost:6030",
"Comma-separated list of addresses of OpenConfig gRPC servers") "Comma-separated list of addresses of OpenConfig gRPC servers")
caFileFlag = flag.String("cafile", "", caFileFlag = flag.String("cafile", "",

View File

@ -80,11 +80,11 @@ func TestNotificationToMap(t *testing.T) {
}, },
}, },
Delete: []*openconfig.Path{ Delete: []*openconfig.Path{
&openconfig.Path{ {
Element: []string{ Element: []string{
"route", "237.255.255.250_0.0.0.0", "route", "237.255.255.250_0.0.0.0",
}}, }},
&openconfig.Path{ {
Element: []string{ Element: []string{
"route", "238.255.255.250_0.0.0.0", "route", "238.255.255.250_0.0.0.0",
}, },

View File

@ -10,221 +10,288 @@ import (
"sort" "sort"
"github.com/aristanetworks/goarista/key" "github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/pathmap"
) )
// Map associates Paths to values. It allows wildcards. The // Map associates paths to values. It allows wildcards. A Map
// primary use of Map is to be able to register handlers to paths // is primarily used to register handlers with paths that can
// that can be efficiently looked up every time a path is updated. // be easily looked up each time a path is updated.
// type Map struct {
// For example:
//
// m.Set({key.New("interfaces"), key.New("*"), key.New("adminStatus")}, AdminStatusHandler)
// m.Set({key.New("interface"), key.New("Management1"), key.New("adminStatus")},
// Management1AdminStatusHandler)
//
// m.Visit(Path{key.New("interfaces"), key.New("Ethernet3/32/1"), key.New("adminStatus")},
// HandlerExecutor)
// >> AdminStatusHandler gets passed to HandlerExecutor
// m.Visit(Path{key.New("interfaces"), key.New("Management1"), key.New("adminStatus")},
// HandlerExecutor)
// >> AdminStatusHandler and Management1AdminStatusHandler gets passed to HandlerExecutor
//
// Note, Visit performance is typically linearly with the length of
// the path. But, it can be as bad as O(2^len(Path)) when TreeMap
// nodes have children and a wildcard associated with it. For example,
// if these paths were registered:
//
// m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1)
// m.Set(Path{key.New("*"), key.New("bar"), key.New("baz")}, 2)
// m.Set(Path{key.New("*"), key.New("*"), key.New("baz")}, 3)
// m.Set(Path{key.New("*"), key.New("*"), key.New("*")}, 4)
// m.Set(Path{key.New("foo"), key.New("*"), key.New("*")}, 5)
// m.Set(Path{key.New("foo"), key.New("bar"), key.New("*")}, 6)
// m.Set(Path{key.New("foo"), key.New("*"), key.New("baz")}, 7)
// m.Set(Path{key.New("*"), key.New("bar"), key.New("*")}, 8)
//
// m.Visit(Path{key.New("foo"),key.New("bar"),key.New("baz")}, Foo) // 2^3 nodes traversed
//
// This shouldn't be a concern with our paths because it is likely
// that a TreeMap node will either have a wildcard or children, not
// both. A TreeMap node that corresponds to a collection will often be a
// wildcard, otherwise it will have specific children.
type Map interface {
// Visit calls f for every registration in the Map that
// matches path. For example,
//
// m.Set(Path{key.New("foo"), key.New("bar")}, 1)
// m.Set(Path{key.New("*"), key.New("bar")}, 2)
//
// m.Visit(Path{key.New("foo"), key.New("bar")}, Printer)
// >> Calls Printer(1) and Printer(2)
Visit(p Path, f pathmap.VisitorFunc) error
// VisitPrefix calls f for every registration in the Map that
// is a prefix of path. For example,
//
// m.Set(Path{}, 0)
// m.Set(Path{key.New("foo")}, 1)
// m.Set(Path{key.New("foo"), key.New("bar")}, 2)
// m.Set(Path{key.New("foo"), key.New("quux")}, 3)
// m.Set(Path{key.New("*"), key.New("bar")}, 4)
//
// m.VisitPrefix(Path{key.New("foo"), key.New("bar"), key.New("baz")}, Printer)
// >> Calls Printer on values 0, 1, 2, and 4
VisitPrefix(p Path, f pathmap.VisitorFunc) error
// Get returns the mapping for path. This returns the exact
// mapping for path. For example, if you register two paths
//
// m.Set(Path{key.New("foo"), key.New("bar")}, 1)
// m.Set(Path{key.New("*"), key.New("bar")}, 2)
//
// m.Get(Path{key.New("foo"), key.New("bar")}) => 1
// m.Get(Path{key.New("*"), key.New("bar")}) => 2
Get(p Path) interface{}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
Set(p Path, v interface{})
// Delete removes the mapping for path
Delete(p Path) bool
}
// Wildcard is a special key representing any possible path
var Wildcard key.Key = key.New("*")
type node struct {
val interface{} val interface{}
wildcard *node ok bool
children map[key.Key]*node wildcard *Map
children map[key.Key]*Map
} }
// NewMap creates a new Map // VisitorFunc is a function that handles the value associated
func NewMap() Map { // with a path in a Map. Note that only the value is passed in
return &node{} // as an argument since the path can be stored inside the value
} // if needed.
type VisitorFunc func(v interface{}) error
// Visit calls f for every matching registration in the Map // Visit calls a function fn for every value in the Map
func (n *node) Visit(p Path, f pathmap.VisitorFunc) error { // that is registered with a match of a path p. In the
// general case, time complexity is linear with respect
// to the length of p but it can be as bad as O(2^len(p))
// if there are a lot of paths with wildcards registered.
//
// Example:
//
// a := path.New("foo", "bar", "baz")
// b := path.New("foo", path.Wildcard, "baz")
// c := path.New(path.Wildcard, "bar", "baz")
// d := path.New("foo", "bar", path.Wildcard)
// e := path.New(path.Wildcard, path.Wildcard, "baz")
// f := path.New(path.Wildcard, "bar", path.Wildcard)
// g := path.New("foo", path.Wildcard, path.Wildcard)
// h := path.New(path.Wildcard, path.Wildcard, path.Wildcard)
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
// m.Set(e, 5)
// m.Set(f, 6)
// m.Set(g, 7)
// m.Set(h, 8)
//
// p := path.New("foo", "bar", "baz")
//
// m.Visit(p, fn)
//
// Result: fn(1), fn(2), fn(3), fn(4), fn(5), fn(6), fn(7) and fn(8)
func (m *Map) Visit(p key.Path, fn VisitorFunc) error {
for i, element := range p { for i, element := range p {
if n.wildcard != nil { if m.wildcard != nil {
if err := n.wildcard.Visit(p[i+1:], f); err != nil { if err := m.wildcard.Visit(p[i+1:], fn); err != nil {
return err return err
} }
} }
next, ok := n.children[element] next, ok := m.children[element]
if !ok { if !ok {
return nil return nil
} }
n = next m = next
} }
if n.val == nil { if !m.ok {
return nil return nil
} }
return f(n.val) return fn(m.val)
} }
// VisitPrefix calls f for every registered path that is a prefix of // VisitPrefixes calls a function fn for every value in the
// the path // Map that is registered with a prefix of a path p.
func (n *node) VisitPrefix(p Path, f pathmap.VisitorFunc) error { //
// Example:
//
// a := path.New()
// b := path.New("foo")
// c := path.New("foo", "bar")
// d := path.New("foo", "baz")
// e := path.New(path.Wildcard, "bar")
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
// m.Set(e, 5)
//
// p := path.New("foo", "bar", "baz")
//
// m.VisitPrefixes(p, fn)
//
// Result: fn(1), fn(2), fn(3), fn(5)
func (m *Map) VisitPrefixes(p key.Path, fn VisitorFunc) error {
for i, element := range p { for i, element := range p {
// Call f on each node we visit if m.ok {
if n.val != nil { if err := fn(m.val); err != nil {
if err := f(n.val); err != nil {
return err return err
} }
} }
if n.wildcard != nil { if m.wildcard != nil {
if err := n.wildcard.VisitPrefix(p[i+1:], f); err != nil { if err := m.wildcard.VisitPrefixes(p[i+1:], fn); err != nil {
return err return err
} }
} }
next, ok := n.children[element] next, ok := m.children[element]
if !ok { if !ok {
return nil return nil
} }
n = next m = next
} }
if n.val == nil { if !m.ok {
return nil return nil
} }
// Call f on the final node return fn(m.val)
return f(n.val)
} }
// Get returns the mapping for path // VisitPrefixed calls fn for every value in the map that is
func (n *node) Get(p Path) interface{} { // registerd with a path that is prefixed by p. This method
// can be used to visit every registered path if p is the
// empty path (or root path) which prefixes all paths.
//
// Example:
//
// a := path.New("foo")
// b := path.New("foo", "bar")
// c := path.New("foo", "bar", "baz")
// d := path.New("foo", path.Wildcard)
//
// m.Set(a, 1)
// m.Set(b, 2)
// m.Set(c, 3)
// m.Set(d, 4)
//
// p := path.New("foo", "bar")
//
// m.VisitPrefixed(p, fn)
//
// Result: fn(2), fn(3), fn(4)
func (m *Map) VisitPrefixed(p key.Path, fn VisitorFunc) error {
for i, element := range p {
if m.wildcard != nil {
if err := m.wildcard.VisitPrefixed(p[i+1:], fn); err != nil {
return err
}
}
next, ok := m.children[element]
if !ok {
return nil
}
m = next
}
return m.visitSubtree(fn)
}
func (m *Map) visitSubtree(fn VisitorFunc) error {
if m.ok {
if err := fn(m.val); err != nil {
return err
}
}
if m.wildcard != nil {
if err := m.wildcard.visitSubtree(fn); err != nil {
return err
}
}
for _, next := range m.children {
if err := next.visitSubtree(fn); err != nil {
return err
}
}
return nil
}
// Get returns the value registered with an exact match of a
// path p. If there is no exact match for p, Get returns nil
// and false. If p has an exact match and it is set to true,
// Get returns nil and true.
//
// Example:
//
// m.Set(path.New("foo", "bar"), 1)
// m.Set(path.New("baz", "qux"), nil)
//
// a := m.Get(path.New("foo", "bar"))
// b := m.Get(path.New("foo", path.Wildcard))
// c, ok := m.Get(path.New("baz", "qux"))
//
// Result: a == 1, b == nil, c == nil and ok == true
func (m *Map) Get(p key.Path) (interface{}, bool) {
for _, element := range p { for _, element := range p {
if element.Equal(Wildcard) { if element.Equal(Wildcard) {
if n.wildcard == nil { if m.wildcard == nil {
return nil return nil, false
} }
n = n.wildcard m = m.wildcard
continue continue
} }
next, ok := n.children[element] next, ok := m.children[element]
if !ok { if !ok {
return nil return nil, false
} }
n = next m = next
} }
return n.val return m.val, m.ok
} }
// Set a mapping of path to value. Path may contain wildcards. Set // Set registers a path p with a value. If the path was already
// replaces what was there before. // registered with a value it returns true and false otherwise.
func (n *node) Set(p Path, v interface{}) { //
// Example:
//
// p := path.New("foo", "bar")
//
// a := m.Set(p, 0)
// b := m.Set(p, 1)
//
// v := m.Get(p)
//
// Result: a == false, b == true and v == 1
func (m *Map) Set(p key.Path, v interface{}) bool {
for _, element := range p { for _, element := range p {
if element.Equal(Wildcard) { if element.Equal(Wildcard) {
if n.wildcard == nil { if m.wildcard == nil {
n.wildcard = &node{} m.wildcard = &Map{}
} }
n = n.wildcard m = m.wildcard
continue continue
} }
if n.children == nil { if m.children == nil {
n.children = map[key.Key]*node{} m.children = map[key.Key]*Map{}
} }
next, ok := n.children[element] next, ok := m.children[element]
if !ok { if !ok {
next = &node{} next = &Map{}
n.children[element] = next m.children[element] = next
} }
n = next m = next
} }
n.val = v set := !m.ok
m.val, m.ok = v, true
return set
} }
// Delete removes the mapping for path // Delete unregisters the value registered with a path. It
func (n *node) Delete(p Path) bool { // returns true if a value was deleted and false otherwise.
nodes := make([]*node, len(p)+1) //
// Example:
//
// p := path.New("foo", "bar")
//
// m.Set(p, 0)
//
// a := m.Delete(p)
// b := m.Delete(p)
//
// Result: a == true and b == false
func (m *Map) Delete(p key.Path) bool {
maps := make([]*Map, len(p)+1)
for i, element := range p { for i, element := range p {
nodes[i] = n maps[i] = m
if element.Equal(Wildcard) { if element.Equal(Wildcard) {
if n.wildcard == nil { if m.wildcard == nil {
return false return false
} }
n = n.wildcard m = m.wildcard
continue continue
} }
next, ok := n.children[element] next, ok := m.children[element]
if !ok { if !ok {
return false return false
} }
n = next m = next
} }
n.val = nil deleted := m.ok
nodes[len(p)] = n m.val, m.ok = nil, false
maps[len(p)] = m
// See if we can delete any node objects // Remove any empty maps.
for i := len(p); i > 0; i-- { for i := len(p); i > 0; i-- {
n = nodes[i] m = maps[i]
if n.val != nil || n.wildcard != nil || len(n.children) > 0 { if m.ok || m.wildcard != nil || len(m.children) > 0 {
break break
} }
parent := nodes[i-1] parent := maps[i-1]
element := p[i-1] element := p[i-1]
if element.Equal(Wildcard) { if element.Equal(Wildcard) {
parent.wildcard = nil parent.wildcard = nil
@ -232,28 +299,28 @@ func (n *node) Delete(p Path) bool {
delete(parent.children, element) delete(parent.children, element)
} }
} }
return true return deleted
} }
func (n *node) String() string { func (m *Map) String() string {
var b bytes.Buffer var b bytes.Buffer
n.write(&b, "") m.write(&b, "")
return b.String() return b.String()
} }
func (n *node) write(b *bytes.Buffer, indent string) { func (m *Map) write(b *bytes.Buffer, indent string) {
if n.val != nil { if m.ok {
b.WriteString(indent) b.WriteString(indent)
fmt.Fprintf(b, "Val: %v", n.val) fmt.Fprintf(b, "Val: %v", m.val)
b.WriteString("\n") b.WriteString("\n")
} }
if n.wildcard != nil { if m.wildcard != nil {
b.WriteString(indent) b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", Wildcard) fmt.Fprintf(b, "Child %q:\n", Wildcard)
n.wildcard.write(b, indent+" ") m.wildcard.write(b, indent+" ")
} }
children := make([]key.Key, 0, len(n.children)) children := make([]key.Key, 0, len(m.children))
for key := range n.children { for key := range m.children {
children = append(children, key) children = append(children, key)
} }
sort.Slice(children, func(i, j int) bool { sort.Slice(children, func(i, j int) bool {
@ -261,7 +328,7 @@ func (n *node) write(b *bytes.Buffer, indent string) {
}) })
for _, key := range children { for _, key := range children {
child := n.children[key] child := m.children[key]
b.WriteString(indent) b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", key.String()) fmt.Fprintf(b, "Child %q:\n", key.String())
child.write(b, indent+" ") child.write(b, indent+" ")

View File

@ -10,68 +10,76 @@ import (
"testing" "testing"
"github.com/aristanetworks/goarista/key" "github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/pathmap"
"github.com/aristanetworks/goarista/test" "github.com/aristanetworks/goarista/test"
) )
func accumulator(counter map[int]int) pathmap.VisitorFunc { func accumulator(counter map[int]int) VisitorFunc {
return func(val interface{}) error { return func(val interface{}) error {
counter[val.(int)]++ counter[val.(int)]++
return nil return nil
} }
} }
func TestVisit(t *testing.T) { func TestMapSet(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1) a := m.Set(key.Path{key.New("foo")}, 0)
m.Set(Path{key.New("*"), key.New("bar"), key.New("baz")}, 2) b := m.Set(key.Path{key.New("foo")}, 1)
m.Set(Path{key.New("*"), key.New("*"), key.New("baz")}, 3) if !a || b {
m.Set(Path{key.New("*"), key.New("*"), key.New("*")}, 4) t.Fatal("Map.Set not working properly")
m.Set(Path{key.New("foo"), key.New("*"), key.New("*")}, 5) }
m.Set(Path{key.New("foo"), key.New("bar"), key.New("*")}, 6) }
m.Set(Path{key.New("foo"), key.New("*"), key.New("baz")}, 7)
m.Set(Path{key.New("*"), key.New("bar"), key.New("*")}, 8)
m.Set(Path{}, 10) func TestMapVisit(t *testing.T) {
m := Map{}
m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz")}, 1)
m.Set(key.Path{Wildcard, key.New("bar"), key.New("baz")}, 2)
m.Set(key.Path{Wildcard, Wildcard, key.New("baz")}, 3)
m.Set(key.Path{Wildcard, Wildcard, Wildcard}, 4)
m.Set(key.Path{key.New("foo"), Wildcard, Wildcard}, 5)
m.Set(key.Path{key.New("foo"), key.New("bar"), Wildcard}, 6)
m.Set(key.Path{key.New("foo"), Wildcard, key.New("baz")}, 7)
m.Set(key.Path{Wildcard, key.New("bar"), Wildcard}, 8)
m.Set(Path{key.New("*")}, 20) m.Set(key.Path{}, 10)
m.Set(Path{key.New("foo")}, 21)
m.Set(Path{key.New("zap"), key.New("zip")}, 30) m.Set(key.Path{Wildcard}, 20)
m.Set(Path{key.New("zap"), key.New("zip")}, 31) m.Set(key.Path{key.New("foo")}, 21)
m.Set(Path{key.New("zip"), key.New("*")}, 40) m.Set(key.Path{key.New("zap"), key.New("zip")}, 30)
m.Set(Path{key.New("zip"), key.New("*")}, 41) m.Set(key.Path{key.New("zap"), key.New("zip")}, 31)
m.Set(key.Path{key.New("zip"), Wildcard}, 40)
m.Set(key.Path{key.New("zip"), Wildcard}, 41)
testCases := []struct { testCases := []struct {
path Path path key.Path
expected map[int]int expected map[int]int
}{{ }{{
path: Path{key.New("foo"), key.New("bar"), key.New("baz")}, path: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: map[int]int{1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1}, expected: map[int]int{1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1},
}, { }, {
path: Path{key.New("qux"), key.New("bar"), key.New("baz")}, path: key.Path{key.New("qux"), key.New("bar"), key.New("baz")},
expected: map[int]int{2: 1, 3: 1, 4: 1, 8: 1}, expected: map[int]int{2: 1, 3: 1, 4: 1, 8: 1},
}, { }, {
path: Path{key.New("foo"), key.New("qux"), key.New("baz")}, path: key.Path{key.New("foo"), key.New("qux"), key.New("baz")},
expected: map[int]int{3: 1, 4: 1, 5: 1, 7: 1}, expected: map[int]int{3: 1, 4: 1, 5: 1, 7: 1},
}, { }, {
path: Path{key.New("foo"), key.New("bar"), key.New("qux")}, path: key.Path{key.New("foo"), key.New("bar"), key.New("qux")},
expected: map[int]int{4: 1, 5: 1, 6: 1, 8: 1}, expected: map[int]int{4: 1, 5: 1, 6: 1, 8: 1},
}, { }, {
path: Path{}, path: key.Path{},
expected: map[int]int{10: 1}, expected: map[int]int{10: 1},
}, { }, {
path: Path{key.New("foo")}, path: key.Path{key.New("foo")},
expected: map[int]int{20: 1, 21: 1}, expected: map[int]int{20: 1, 21: 1},
}, { }, {
path: Path{key.New("foo"), key.New("bar")}, path: key.Path{key.New("foo"), key.New("bar")},
expected: map[int]int{}, expected: map[int]int{},
}, { }, {
path: Path{key.New("zap"), key.New("zip")}, path: key.Path{key.New("zap"), key.New("zip")},
expected: map[int]int{31: 1}, expected: map[int]int{31: 1},
}, { }, {
path: Path{key.New("zip"), key.New("zap")}, path: key.Path{key.New("zip"), key.New("zap")},
expected: map[int]int{41: 1}, expected: map[int]int{41: 1},
}} }}
@ -84,135 +92,160 @@ func TestVisit(t *testing.T) {
} }
} }
func TestVisitError(t *testing.T) { func TestMapVisitError(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{key.New("foo"), key.New("bar")}, 1) m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(Path{key.New("*"), key.New("bar")}, 2) m.Set(key.Path{Wildcard, key.New("bar")}, 2)
errTest := errors.New("Test") errTest := errors.New("Test")
err := m.Visit(Path{key.New("foo"), key.New("bar")}, err := m.Visit(key.Path{key.New("foo"), key.New("bar")},
func(v interface{}) error { return errTest }) func(v interface{}) error { return errTest })
if err != errTest { if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err) t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
} }
err = m.VisitPrefix(Path{key.New("foo"), key.New("bar"), key.New("baz")}, err = m.VisitPrefixes(key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
func(v interface{}) error { return errTest }) func(v interface{}) error { return errTest })
if err != errTest { if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err) t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
} }
} }
func TestGet(t *testing.T) { func TestMapGet(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{}, 0) m.Set(key.Path{}, 0)
m.Set(Path{key.New("foo"), key.New("bar")}, 1) m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(Path{key.New("foo"), key.New("*")}, 2) m.Set(key.Path{key.New("foo"), Wildcard}, 2)
m.Set(Path{key.New("*"), key.New("bar")}, 3) m.Set(key.Path{Wildcard, key.New("bar")}, 3)
m.Set(Path{key.New("zap"), key.New("zip")}, 4) m.Set(key.Path{key.New("zap"), key.New("zip")}, 4)
m.Set(key.Path{key.New("baz"), key.New("qux")}, nil)
testCases := []struct { testCases := []struct {
path Path path key.Path
expected interface{} v interface{}
ok bool
}{{ }{{
path: Path{}, path: key.Path{},
expected: 0, v: 0,
ok: true,
}, { }, {
path: Path{key.New("foo"), key.New("bar")}, path: key.Path{key.New("foo"), key.New("bar")},
expected: 1, v: 1,
ok: true,
}, { }, {
path: Path{key.New("foo"), key.New("*")}, path: key.Path{key.New("foo"), Wildcard},
expected: 2, v: 2,
ok: true,
}, { }, {
path: Path{key.New("*"), key.New("bar")}, path: key.Path{Wildcard, key.New("bar")},
expected: 3, v: 3,
ok: true,
}, { }, {
path: Path{key.New("bar"), key.New("foo")}, path: key.Path{key.New("baz"), key.New("qux")},
expected: nil, v: nil,
ok: true,
}, { }, {
path: Path{key.New("zap"), key.New("*")}, path: key.Path{key.New("bar"), key.New("foo")},
expected: nil, v: nil,
}, {
path: key.Path{key.New("zap"), Wildcard},
v: nil,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
got := m.Get(tc.path) v, ok := m.Get(tc.path)
if got != tc.expected { if v != tc.v || ok != tc.ok {
t.Errorf("Test case %v: Expected %v, Got %v", t.Errorf("Test case %v: Expected (v: %v, ok: %t), Got (v: %v, ok: %t)",
tc.path, tc.expected, got) tc.path, tc.v, tc.ok, v, ok)
} }
} }
} }
func countNodes(n *node) int { func countNodes(m *Map) int {
if n == nil { if m == nil {
return 0 return 0
} }
count := 1 count := 1
count += countNodes(n.wildcard) count += countNodes(m.wildcard)
for _, child := range n.children { for _, child := range m.children {
count += countNodes(child) count += countNodes(child)
} }
return count return count
} }
func TestDelete(t *testing.T) { func TestMapDelete(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{}, 0) m.Set(key.Path{}, 0)
m.Set(Path{key.New("*")}, 1) m.Set(key.Path{Wildcard}, 1)
m.Set(Path{key.New("foo"), key.New("bar")}, 2) m.Set(key.Path{key.New("foo"), key.New("bar")}, 2)
m.Set(Path{key.New("foo"), key.New("*")}, 3) m.Set(key.Path{key.New("foo"), Wildcard}, 3)
m.Set(key.Path{key.New("foo")}, 4)
n := countNodes(m.(*node)) n := countNodes(&m)
if n != 5 { if n != 5 {
t.Errorf("Initial count wrong. Expected: 5, Got: %d", n) t.Errorf("Initial count wrong. Expected: 5, Got: %d", n)
} }
testCases := []struct { testCases := []struct {
del Path // Path to delete del key.Path // key.Path to delete
expected bool // expected return value of Delete expected bool // expected return value of Delete
visit Path // Path to Visit visit key.Path // key.Path to Visit
before map[int]int // Expected to find items before deletion before map[int]int // Expected to find items before deletion
after map[int]int // Expected to find items after deletion after map[int]int // Expected to find items after deletion
count int // Count of nodes count int // Count of nodes
}{{ }{{
del: Path{key.New("zap")}, // A no-op Delete del: key.Path{key.New("zap")}, // A no-op Delete
expected: false, expected: false,
visit: Path{key.New("foo"), key.New("bar")}, visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{2: 1, 3: 1}, before: map[int]int{2: 1, 3: 1},
after: map[int]int{2: 1, 3: 1}, after: map[int]int{2: 1, 3: 1},
count: 5, count: 5,
}, { }, {
del: Path{key.New("foo"), key.New("bar")}, del: key.Path{key.New("foo"), key.New("bar")},
expected: true, expected: true,
visit: Path{key.New("foo"), key.New("bar")}, visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{2: 1, 3: 1}, before: map[int]int{2: 1, 3: 1},
after: map[int]int{3: 1}, after: map[int]int{3: 1},
count: 4, count: 4,
}, { }, {
del: Path{key.New("*")}, del: key.Path{key.New("foo")},
expected: true, expected: true,
visit: Path{key.New("foo")}, visit: key.Path{key.New("foo")},
before: map[int]int{1: 1, 4: 1},
after: map[int]int{1: 1},
count: 4,
}, {
del: key.Path{key.New("foo")},
expected: false,
visit: key.Path{key.New("foo")},
before: map[int]int{1: 1},
after: map[int]int{1: 1},
count: 4,
}, {
del: key.Path{Wildcard},
expected: true,
visit: key.Path{key.New("foo")},
before: map[int]int{1: 1}, before: map[int]int{1: 1},
after: map[int]int{}, after: map[int]int{},
count: 3, count: 3,
}, { }, {
del: Path{key.New("*")}, del: key.Path{Wildcard},
expected: false, expected: false,
visit: Path{key.New("foo")}, visit: key.Path{key.New("foo")},
before: map[int]int{}, before: map[int]int{},
after: map[int]int{}, after: map[int]int{},
count: 3, count: 3,
}, { }, {
del: Path{key.New("foo"), key.New("*")}, del: key.Path{key.New("foo"), Wildcard},
expected: true, expected: true,
visit: Path{key.New("foo"), key.New("bar")}, visit: key.Path{key.New("foo"), key.New("bar")},
before: map[int]int{3: 1}, before: map[int]int{3: 1},
after: map[int]int{}, after: map[int]int{},
count: 1, // Should have deleted "foo" and "bar" nodes count: 1, // Should have deleted "foo" and "bar" nodes
}, { }, {
del: Path{}, del: key.Path{},
expected: true, expected: true,
visit: Path{}, visit: key.Path{},
before: map[int]int{0: 1}, before: map[int]int{0: 1},
after: map[int]int{}, after: map[int]int{},
count: 1, // Root node can't be deleted count: 1, // Root node can't be deleted
@ -238,53 +271,102 @@ func TestDelete(t *testing.T) {
} }
} }
func TestVisitPrefix(t *testing.T) { func TestMapVisitPrefixes(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{}, 0) m.Set(key.Path{}, 0)
m.Set(Path{key.New("foo")}, 1) m.Set(key.Path{key.New("foo")}, 1)
m.Set(Path{key.New("foo"), key.New("bar")}, 2) m.Set(key.Path{key.New("foo"), key.New("bar")}, 2)
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz")}, 3) m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz")}, 3)
m.Set(Path{key.New("foo"), key.New("bar"), key.New("baz"), key.New("quux")}, 4) m.Set(key.Path{key.New("foo"), key.New("bar"), key.New("baz"), key.New("quux")}, 4)
m.Set(Path{key.New("quux"), key.New("bar")}, 5) m.Set(key.Path{key.New("quux"), key.New("bar")}, 5)
m.Set(Path{key.New("foo"), key.New("quux")}, 6) m.Set(key.Path{key.New("foo"), key.New("quux")}, 6)
m.Set(Path{key.New("*")}, 7) m.Set(key.Path{Wildcard}, 7)
m.Set(Path{key.New("foo"), key.New("*")}, 8) m.Set(key.Path{key.New("foo"), Wildcard}, 8)
m.Set(Path{key.New("*"), key.New("bar")}, 9) m.Set(key.Path{Wildcard, key.New("bar")}, 9)
m.Set(Path{key.New("*"), key.New("quux")}, 10) m.Set(key.Path{Wildcard, key.New("quux")}, 10)
m.Set(Path{key.New("quux"), key.New("quux"), key.New("quux"), key.New("quux")}, 11) m.Set(key.Path{key.New("quux"), key.New("quux"), key.New("quux"), key.New("quux")}, 11)
testCases := []struct { testCases := []struct {
path Path path key.Path
expected map[int]int expected map[int]int
}{{ }{{
path: Path{key.New("foo"), key.New("bar"), key.New("baz")}, path: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
expected: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 7: 1, 8: 1, 9: 1}, expected: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 7: 1, 8: 1, 9: 1},
}, { }, {
path: Path{key.New("zip"), key.New("zap")}, path: key.Path{key.New("zip"), key.New("zap")},
expected: map[int]int{0: 1, 7: 1}, expected: map[int]int{0: 1, 7: 1},
}, { }, {
path: Path{key.New("foo"), key.New("zap")}, path: key.Path{key.New("foo"), key.New("zap")},
expected: map[int]int{0: 1, 1: 1, 8: 1, 7: 1}, expected: map[int]int{0: 1, 1: 1, 8: 1, 7: 1},
}, { }, {
path: Path{key.New("quux"), key.New("quux"), key.New("quux")}, path: key.Path{key.New("quux"), key.New("quux"), key.New("quux")},
expected: map[int]int{0: 1, 7: 1, 10: 1}, expected: map[int]int{0: 1, 7: 1, 10: 1},
}} }}
for _, tc := range testCases { for _, tc := range testCases {
result := make(map[int]int, len(tc.expected)) result := make(map[int]int, len(tc.expected))
m.VisitPrefix(tc.path, accumulator(result)) m.VisitPrefixes(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" { if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff) t.Errorf("Test case %v: %s", tc.path, diff)
} }
} }
} }
func TestString(t *testing.T) { func TestMapVisitPrefixed(t *testing.T) {
m := NewMap() m := Map{}
m.Set(Path{}, 0) m.Set(key.Path{}, 0)
m.Set(Path{key.New("foo"), key.New("bar")}, 1) m.Set(key.Path{key.New("qux")}, 1)
m.Set(Path{key.New("foo"), key.New("quux")}, 2) m.Set(key.Path{key.New("foo")}, 2)
m.Set(Path{key.New("foo"), key.New("*")}, 3) m.Set(key.Path{key.New("foo"), key.New("qux")}, 3)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 4)
m.Set(key.Path{Wildcard, key.New("bar")}, 5)
m.Set(key.Path{key.New("foo"), Wildcard}, 6)
m.Set(key.Path{key.New("qux"), key.New("foo"), key.New("bar")}, 7)
testCases := []struct {
in key.Path
out map[int]int
}{{
in: key.Path{},
out: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1},
}, {
in: key.Path{key.New("qux")},
out: map[int]int{1: 1, 5: 1, 7: 1},
}, {
in: key.Path{key.New("foo")},
out: map[int]int{2: 1, 3: 1, 4: 1, 5: 1, 6: 1},
}, {
in: key.Path{key.New("foo"), key.New("qux")},
out: map[int]int{3: 1, 6: 1},
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: map[int]int{4: 1, 5: 1, 6: 1},
}, {
in: key.Path{key.New(int64(0))},
out: map[int]int{5: 1},
}, {
in: key.Path{Wildcard},
out: map[int]int{5: 1},
}, {
in: key.Path{Wildcard, Wildcard},
out: map[int]int{},
}}
for _, tc := range testCases {
out := make(map[int]int, len(tc.out))
m.VisitPrefixed(tc.in, accumulator(out))
if diff := test.Diff(tc.out, out); diff != "" {
t.Errorf("Test case %v: %s", tc.out, diff)
}
}
}
func TestMapString(t *testing.T) {
m := Map{}
m.Set(key.Path{}, 0)
m.Set(key.Path{key.New("foo"), key.New("bar")}, 1)
m.Set(key.Path{key.New("foo"), key.New("quux")}, 2)
m.Set(key.Path{key.New("foo"), Wildcard}, 3)
expected := `Val: 0 expected := `Val: 0
Child "foo": Child "foo":
@ -295,19 +377,19 @@ Child "foo":
Child "quux": Child "quux":
Val: 2 Val: 2
` `
got := fmt.Sprint(m) got := fmt.Sprint(&m)
if expected != got { if expected != got {
t.Errorf("Unexpected string. Expected:\n\n%s\n\nGot:\n\n%s", expected, got) t.Errorf("Unexpected string. Expected:\n\n%s\n\nGot:\n\n%s", expected, got)
} }
} }
func genWords(count, wordLength int) Path { func genWords(count, wordLength int) key.Path {
chars := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") chars := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
if count+wordLength > len(chars) { if count+wordLength > len(chars) {
panic("need more chars") panic("need more chars")
} }
result := make(Path, count) result := make(key.Path, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
result[i] = key.New(string(chars[i : i+wordLength])) result[i] = key.New(string(chars[i : i+wordLength]))
} }
@ -315,18 +397,16 @@ func genWords(count, wordLength int) Path {
} }
func benchmarkPathMap(pathLength, pathDepth int, b *testing.B) { func benchmarkPathMap(pathLength, pathDepth int, b *testing.B) {
m := NewMap()
// Push pathDepth paths, each of length pathLength // Push pathDepth paths, each of length pathLength
path := genWords(pathLength, 10) path := genWords(pathLength, 10)
words := genWords(pathDepth, 10) words := genWords(pathDepth, 10)
n := m.(*node) m := &Map{}
for _, element := range path { for _, element := range path {
n.children = map[key.Key]*node{} m.children = map[key.Key]*Map{}
for _, word := range words { for _, word := range words {
n.children[word] = &node{} m.children[word] = &Map{}
} }
n = n.children[element] m = m.children[element]
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@ -2,109 +2,167 @@
// Use of this source code is governed by the Apache License 2.0 // Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file. // that can be found in the COPYING file.
// Package path provides functionality for dealing with absolute paths elementally. // Package path contains methods for dealing with key.Paths.
package path package path
import ( import (
"bytes"
"fmt"
"strings" "strings"
"github.com/aristanetworks/goarista/key" "github.com/aristanetworks/goarista/key"
) )
// Path is an absolute path broken down into elements where each element is a key.Key. // New constructs a path from a variable number of elements.
type Path []key.Key // Each element may either be a key.Key or a value that can
// be wrapped by a key.Key.
func copyElements(path Path, elements ...interface{}) { func New(elements ...interface{}) key.Path {
for i, element := range elements { result := make(key.Path, len(elements))
switch val := element.(type) { copyElements(result, elements...)
case string: return result
path[i] = key.New(val)
case key.Key:
path[i] = val
default:
panic(fmt.Errorf("unsupported type: %T", element))
}
}
} }
// New constructs a Path from a variable number of elements. // Append appends a variable number of elements to a path.
// Each element may either be a string or a key.Key. // Each element may either be a key.Key or a value that can
func New(elements ...interface{}) Path { // be wrapped by a key.Key. Note that calling Append on a
path := make(Path, len(elements)) // single path returns that same path, whereas in all other
copyElements(path, elements...) // cases a new path is returned.
return path func Append(path key.Path, elements ...interface{}) key.Path {
}
// FromString constructs a Path from the elements resulting
// from a split of the input string by "/". The string MUST
// begin with a '/' character unless it is the empty string
// in which case an empty Path is returned.
func FromString(str string) Path {
if str == "" {
return Path{}
} else if str[0] != '/' {
panic(fmt.Errorf("not an absolute path: %q", str))
}
elements := strings.Split(str, "/")[1:]
path := make(Path, len(elements))
for i, element := range elements {
path[i] = key.New(element)
}
return path
}
// Append appends a variable number of elements to a Path.
// Each element may either be a string or a key.Key.
func Append(path Path, elements ...interface{}) Path {
if len(elements) == 0 { if len(elements) == 0 {
return path return path
} }
n := len(path) n := len(path)
p := make(Path, n+len(elements)) result := make(key.Path, n+len(elements))
copy(p, path) copy(result, path)
copyElements(p[n:], elements...) copyElements(result[n:], elements...)
return p return result
} }
// String returns the Path as a string. // Join joins a variable number of paths together. Each path
func (p Path) String() string { // in the joining is treated as a subpath of its predecessor.
if len(p) == 0 { // Calling Join with no or only empty paths returns nil.
return "" func Join(paths ...key.Path) key.Path {
n := 0
for _, path := range paths {
n += len(path)
} }
var buf bytes.Buffer if n == 0 {
for _, element := range p { return nil
buf.WriteByte('/')
buf.WriteString(element.String())
} }
return buf.String() result, i := make(key.Path, n), 0
for _, path := range paths {
i += copy(result[i:], path)
}
return result
} }
// Equal returns whether the Path contains the same elements as the other Path. // Parent returns all but the last element of the path. If
// This method implements key.Comparable. // the path is empty, Parent returns nil.
func (p Path) Equal(other interface{}) bool { func Parent(path key.Path) key.Path {
o, ok := other.(Path) if len(path) > 0 {
if !ok { return path[:len(path)-1]
}
return nil
}
// Base returns the last element of the path. If the path is
// empty, Base returns nil.
func Base(path key.Path) key.Key {
if len(path) > 0 {
return path[len(path)-1]
}
return nil
}
// Clone returns a new path with the same elements as in the
// provided path.
func Clone(path key.Path) key.Path {
result := make(key.Path, len(path))
copy(result, path)
return result
}
// Equal returns whether path a and path b are the same
// length and whether each element in b corresponds to the
// same element in a.
func Equal(a, b key.Path) bool {
return len(a) == len(b) && hasPrefix(a, b)
}
// HasElement returns whether element b exists in path a.
func HasElement(a key.Path, b key.Key) bool {
for _, element := range a {
if element.Equal(b) {
return true
}
}
return false return false
}
if len(o) != len(p) {
return false
}
return o.hasPrefix(p)
} }
// HasPrefix returns whether the Path is prefixed by the other Path. // HasPrefix returns whether path b is a prefix of path a.
func (p Path) HasPrefix(prefix Path) bool { // It checks that b is at most the length of path a and
if len(prefix) > len(p) { // whether each element in b corresponds to the same element
return false // in a from the first element.
} func HasPrefix(a, b key.Path) bool {
return p.hasPrefix(prefix) return len(a) >= len(b) && hasPrefix(a, b)
} }
func (p Path) hasPrefix(prefix Path) bool { // Match returns whether path a and path b are the same
for i := range prefix { // length and whether each element in b corresponds to the
if !prefix[i].Equal(p[i]) { // same element or a wildcard in a.
func Match(a, b key.Path) bool {
return len(a) == len(b) && matchPrefix(a, b)
}
// MatchPrefix returns whether path b is a prefix of path a
// where path a may contain wildcards.
// It checks that b is at most the length of path a and
// whether each element in b corresponds to the same element
// or a wildcard in a from the first element.
func MatchPrefix(a, b key.Path) bool {
return len(a) >= len(b) && matchPrefix(a, b)
}
// FromString constructs a path from the elements resulting
// from a split of the input string by "/". Strings that do
// not lead with a '/' are accepted but not reconstructable
// with key.Path.String. Both "" and "/" are treated as a
// key.Path{}.
func FromString(str string) key.Path {
if str == "" || str == "/" {
return key.Path{}
} else if str[0] == '/' {
str = str[1:]
}
elements := strings.Split(str, "/")
result := make(key.Path, len(elements))
for i, element := range elements {
result[i] = key.New(element)
}
return result
}
func copyElements(dest key.Path, elements ...interface{}) {
for i, element := range elements {
switch val := element.(type) {
case key.Key:
dest[i] = val
default:
dest[i] = key.New(val)
}
}
}
func hasPrefix(a, b key.Path) bool {
for i := range b {
if !b[i].Equal(a[i]) {
return false
}
}
return true
}
func matchPrefix(a, b key.Path) bool {
for i := range b {
if !a[i].Equal(Wildcard) && !b[i].Equal(a[i]) {
return false return false
} }
} }

View File

@ -12,79 +12,179 @@ import (
"github.com/aristanetworks/goarista/value" "github.com/aristanetworks/goarista/value"
) )
func TestNewPath(t *testing.T) { func TestNew(t *testing.T) {
tcases := []struct { tcases := []struct {
in []interface{} in []interface{}
out Path out key.Path
}{ }{
{ {
in: nil, in: nil,
out: nil, out: key.Path{},
}, { }, {
in: []interface{}{}, in: []interface{}{},
out: Path{}, out: key.Path{},
}, { }, {
in: []interface{}{""}, in: []interface{}{"foo", key.New("bar"), true},
out: Path{key.New("")}, out: key.Path{key.New("foo"), key.New("bar"), key.New(true)},
}, { }, {
in: []interface{}{key.New("")}, in: []interface{}{int8(5), int16(5), int32(5), int64(5)},
out: Path{key.New("")}, out: key.Path{key.New(int8(5)), key.New(int16(5)), key.New(int32(5)),
key.New(int64(5))},
}, { }, {
in: []interface{}{"foo"}, in: []interface{}{uint8(5), uint16(5), uint32(5), uint64(5)},
out: Path{key.New("foo")}, out: key.Path{key.New(uint8(5)), key.New(uint16(5)), key.New(uint32(5)),
key.New(uint64(5))},
}, { }, {
in: []interface{}{key.New("foo")}, in: []interface{}{float32(5), float64(5)},
out: Path{key.New("foo")}, out: key.Path{key.New(float32(5)), key.New(float64(5))},
}, { }, {
in: []interface{}{"foo", key.New("bar")}, in: []interface{}{customKey{i: &a}, map[string]interface{}{}},
out: Path{key.New("foo"), key.New("bar")}, out: key.Path{key.New(customKey{i: &a}), key.New(map[string]interface{}{})},
}, {
in: []interface{}{key.New("foo"), "bar", key.New("baz")},
out: Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, },
} }
for i, tcase := range tcases { for i, tcase := range tcases {
if p := New(tcase.in...); !p.Equal(tcase.out) { if p := New(tcase.in...); !Equal(p, tcase.out) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out) t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out)
} }
} }
} }
func TestAppendPath(t *testing.T) { func TestClone(t *testing.T) {
if !Equal(Clone(key.Path{}), key.Path{}) {
t.Error("Clone(key.Path{}) != key.Path{}")
}
a := key.Path{key.New("foo"), key.New("bar")}
b, c := Clone(a), Clone(a)
b[1] = key.New("baz")
if Equal(a, b) || !Equal(a, c) {
t.Error("Clone is not making a copied path")
}
}
func TestAppend(t *testing.T) {
tcases := []struct { tcases := []struct {
base Path a key.Path
elements []interface{} b []interface{}
expected Path result key.Path
}{ }{
{ {
base: Path{}, a: key.Path{},
elements: []interface{}{}, b: []interface{}{},
expected: Path{}, result: key.Path{},
}, { }, {
base: Path{}, a: key.Path{key.New("foo")},
elements: []interface{}{""}, b: []interface{}{},
expected: Path{key.New("")}, result: key.Path{key.New("foo")},
}, { }, {
base: Path{}, a: key.Path{},
elements: []interface{}{key.New("")}, b: []interface{}{"foo", key.New("bar")},
expected: Path{key.New("")}, result: key.Path{key.New("foo"), key.New("bar")},
}, { }, {
base: Path{}, a: key.Path{key.New("foo")},
elements: []interface{}{"foo", key.New("bar")}, b: []interface{}{int64(0), key.New("bar")},
expected: Path{key.New("foo"), key.New("bar")}, result: key.Path{key.New("foo"), key.New(int64(0)), key.New("bar")},
}, {
base: Path{key.New("foo")},
elements: []interface{}{key.New("bar"), "baz"},
expected: Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, {
base: Path{key.New("foo"), key.New("bar")},
elements: []interface{}{key.New("baz")},
expected: Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, },
} }
for i, tcase := range tcases { for i, tcase := range tcases {
if p := Append(tcase.base, tcase.elements...); !p.Equal(tcase.expected) { if p := Append(tcase.a, tcase.b...); !Equal(p, tcase.result) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.expected) t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.result)
}
}
}
func TestJoin(t *testing.T) {
tcases := []struct {
paths []key.Path
result key.Path
}{
{
paths: nil,
result: nil,
}, {
paths: []key.Path{},
result: nil,
}, {
paths: []key.Path{key.Path{}},
result: nil,
}, {
paths: []key.Path{key.Path{key.New(true)}, key.Path{}},
result: key.Path{key.New(true)},
}, {
paths: []key.Path{key.Path{}, key.Path{key.New(true)}},
result: key.Path{key.New(true)},
}, {
paths: []key.Path{key.Path{key.New("foo")}, key.Path{key.New("bar")}},
result: key.Path{key.New("foo"), key.New("bar")},
}, {
paths: []key.Path{key.Path{key.New("bar")}, key.Path{key.New("foo")}},
result: key.Path{key.New("bar"), key.New("foo")},
}, {
paths: []key.Path{
key.Path{key.New(uint32(0)), key.New(uint64(0))},
key.Path{key.New(int8(0))},
key.Path{key.New(int16(0)), key.New(int32(0))},
key.Path{key.New(int64(0)), key.New(uint8(0)), key.New(uint16(0))},
},
result: key.Path{
key.New(uint32(0)), key.New(uint64(0)),
key.New(int8(0)), key.New(int16(0)),
key.New(int32(0)), key.New(int64(0)),
key.New(uint8(0)), key.New(uint16(0)),
},
},
}
for i, tcase := range tcases {
if p := Join(tcase.paths...); !Equal(p, tcase.result) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.result)
}
}
}
func TestParent(t *testing.T) {
if Parent(key.Path{}) != nil {
t.Fatal("Parent of empty key.Path should be nil")
}
tcases := []struct {
in key.Path
out key.Path
}{
{
in: key.Path{key.New("foo")},
out: key.Path{},
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: key.Path{key.New("foo")},
}, {
in: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
out: key.Path{key.New("foo"), key.New("bar")},
},
}
for _, tcase := range tcases {
if !Equal(Parent(tcase.in), tcase.out) {
t.Fatalf("Parent of %#v != %#v", tcase.in, tcase.out)
}
}
}
func TestBase(t *testing.T) {
if Base(key.Path{}) != nil {
t.Fatal("Base of empty key.Path should be nil")
}
tcases := []struct {
in key.Path
out key.Key
}{
{
in: key.Path{key.New("foo")},
out: key.New("foo"),
}, {
in: key.Path{key.New("foo"), key.New("bar")},
out: key.New("bar"),
},
}
for _, tcase := range tcases {
if !Base(tcase.in).Equal(tcase.out) {
t.Fatalf("Base of %#v != %#v", tcase.in, tcase.out)
} }
} }
} }
@ -117,176 +217,407 @@ var (
b = 1 b = 1
) )
func TestPathEquality(t *testing.T) { func TestEqual(t *testing.T) {
tcases := []struct { tcases := []struct {
base Path a key.Path
other Path b key.Path
expected bool result bool
}{ }{
{ {
base: Path{}, a: nil,
other: Path{}, b: nil,
expected: true, result: true,
}, { }, {
base: Path{}, a: nil,
other: Path{key.New("")}, b: key.Path{},
expected: false, result: true,
}, { }, {
base: Path{key.New("foo")}, a: key.Path{},
other: Path{key.New("foo")}, b: nil,
expected: true, result: true,
}, { }, {
base: Path{key.New("foo")}, a: key.Path{},
other: Path{key.New("bar")}, b: key.Path{},
expected: false, result: true,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{},
other: Path{key.New("foo")}, b: key.Path{key.New("")},
expected: false, result: false,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{Wildcard},
other: Path{key.New("bar"), key.New("foo")}, b: key.Path{key.New("foo")},
expected: false, result: false,
}, { }, {
base: Path{key.New("foo"), key.New("bar"), key.New("baz")}, a: key.Path{Wildcard},
other: Path{key.New("foo"), key.New("bar"), key.New("baz")}, b: key.Path{Wildcard},
expected: true, result: true,
}, {
a: key.Path{key.New("foo")},
b: key.Path{key.New("foo")},
result: true,
}, {
a: key.Path{key.New(true)},
b: key.Path{key.New(false)},
result: false,
}, {
a: key.Path{key.New(int32(5))},
b: key.Path{key.New(int64(5))},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New(uint8(0)), key.New(int8(0))},
b: key.Path{key.New(int8(0)), key.New(uint8(0))},
result: false,
}, },
// Ensure that we check deep equality. // Ensure that we check deep equality.
{ {
base: Path{key.New(map[string]interface{}{})}, a: key.Path{key.New(map[string]interface{}{})},
other: Path{key.New(map[string]interface{}{})}, b: key.Path{key.New(map[string]interface{}{})},
expected: true, result: true,
}, { }, {
base: Path{key.New(customKey{i: &a})}, a: key.Path{key.New(customKey{i: &a})},
other: Path{key.New(customKey{i: &b})}, b: key.Path{key.New(customKey{i: &b})},
expected: true, result: true,
}, },
} }
for i, tcase := range tcases { for i, tcase := range tcases {
if result := tcase.base.Equal(tcase.other); result != tcase.expected { if result := Equal(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: base: %#v; other: %#v, expected: %t", t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.base, tcase.other, tcase.expected) i, tcase.a, tcase.b, tcase.result)
} }
} }
} }
func TestPathHasPrefix(t *testing.T) { func TestMatch(t *testing.T) {
tcases := []struct { tcases := []struct {
base Path a key.Path
prefix Path b key.Path
expected bool result bool
}{ }{
{ {
base: Path{}, a: nil,
prefix: Path{}, b: nil,
expected: true, result: true,
}, { }, {
base: Path{key.New("foo")}, a: nil,
prefix: Path{}, b: key.Path{},
expected: true, result: true,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{},
prefix: Path{key.New("foo")}, b: nil,
expected: true, result: true,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{},
prefix: Path{key.New("bar")}, b: key.Path{},
expected: false, result: true,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{},
prefix: Path{key.New("bar"), key.New("foo")}, b: key.Path{key.New("foo")},
expected: false, result: false,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{Wildcard},
prefix: Path{key.New("foo"), key.New("bar")}, b: key.Path{key.New("foo")},
expected: true, result: true,
}, { }, {
base: Path{key.New("foo"), key.New("bar")}, a: key.Path{key.New("foo")},
prefix: Path{key.New("foo"), key.New("bar"), key.New("baz")}, b: key.Path{Wildcard},
expected: false, result: false,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{Wildcard, Wildcard},
b: key.Path{key.New(int64(0))},
result: false,
}, {
a: key.Path{Wildcard, Wildcard},
b: key.Path{key.New(int64(0)), key.New(int32(0))},
result: true,
}, {
a: key.Path{Wildcard, key.New(false)},
b: key.Path{key.New(true), Wildcard},
result: false,
}, },
} }
for i, tcase := range tcases { for i, tcase := range tcases {
if result := tcase.base.HasPrefix(tcase.prefix); result != tcase.expected { if result := Match(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: base: %#v; prefix: %#v, expected: %t", t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.base, tcase.prefix, tcase.expected) i, tcase.a, tcase.b, tcase.result)
} }
} }
} }
func TestPathFromString(t *testing.T) { func TestHasElement(t *testing.T) {
tcases := []struct {
a key.Path
b key.Key
result bool
}{
{
a: nil,
b: nil,
result: false,
}, {
a: nil,
b: key.New("foo"),
result: false,
}, {
a: key.Path{},
b: nil,
result: false,
}, {
a: key.Path{key.New("foo")},
b: nil,
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.New("foo"),
result: true,
}, {
a: key.Path{key.New(true)},
b: key.New("true"),
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.New("bar"),
result: true,
}, {
a: key.Path{key.New(map[string]interface{}{})},
b: key.New(map[string]interface{}{}),
result: true,
}, {
a: key.Path{key.New(map[string]interface{}{"foo": "a"})},
b: key.New(map[string]interface{}{"bar": "a"}),
result: false,
},
}
for i, tcase := range tcases {
if result := HasElement(tcase.a, tcase.b); result != tcase.result {
t.Errorf("Test %d failed: a: %#v; b: %#v, result: %t, expected: %t",
i, tcase.a, tcase.b, result, tcase.result)
}
}
}
func TestHasPrefix(t *testing.T) {
tcases := []struct {
a key.Path
b key.Path
result bool
}{
{
a: nil,
b: nil,
result: true,
}, {
a: nil,
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: nil,
result: true,
}, {
a: key.Path{},
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{},
result: true,
}, {
a: key.Path{key.New(true)},
b: key.Path{key.New(false)},
result: false,
}, {
a: key.Path{key.New("foo"), key.New("bar")},
b: key.Path{key.New("bar"), key.New("foo")},
result: false,
}, {
a: key.Path{key.New(int8(0)), key.New(uint8(0))},
b: key.Path{key.New(uint8(0)), key.New(uint8(0))},
result: false,
}, {
a: key.Path{key.New(true), key.New(true)},
b: key.Path{key.New(true), key.New(true), key.New(true)},
result: false,
}, {
a: key.Path{key.New(true), key.New(true), key.New(true)},
b: key.Path{key.New(true), key.New(true)},
result: true,
}, {
a: key.Path{Wildcard, key.New(int32(0)), Wildcard},
b: key.Path{key.New(int64(0)), Wildcard},
result: false,
},
}
for i, tcase := range tcases {
if result := HasPrefix(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestMatchPrefix(t *testing.T) {
tcases := []struct {
a key.Path
b key.Path
result bool
}{
{
a: nil,
b: nil,
result: true,
}, {
a: nil,
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: nil,
result: true,
}, {
a: key.Path{},
b: key.Path{},
result: true,
}, {
a: key.Path{},
b: key.Path{key.New("foo")},
result: false,
}, {
a: key.Path{key.New("foo")},
b: key.Path{},
result: true,
}, {
a: key.Path{key.New("foo")},
b: key.Path{Wildcard},
result: false,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo")},
result: true,
}, {
a: key.Path{Wildcard},
b: key.Path{key.New("foo"), key.New("bar")},
result: false,
}, {
a: key.Path{Wildcard, key.New(true)},
b: key.Path{key.New(false), Wildcard},
result: false,
}, {
a: key.Path{Wildcard, key.New(int32(0)), key.New(int16(0))},
b: key.Path{key.New(int64(0)), key.New(int32(0))},
result: true,
},
}
for i, tcase := range tcases {
if result := MatchPrefix(tcase.a, tcase.b); result != tcase.result {
t.Fatalf("Test %d failed: a: %#v; b: %#v, result: %t",
i, tcase.a, tcase.b, tcase.result)
}
}
}
func TestFromString(t *testing.T) {
tcases := []struct { tcases := []struct {
in string in string
out Path out key.Path
}{ }{
{ {
in: "", in: "",
out: Path{}, out: key.Path{},
}, { }, {
in: "/", in: "/",
out: Path{key.New("")}, out: key.Path{},
}, { }, {
in: "//", in: "//",
out: Path{key.New(""), key.New("")}, out: key.Path{key.New(""), key.New("")},
}, {
in: "foo",
out: key.Path{key.New("foo")},
}, { }, {
in: "/foo", in: "/foo",
out: Path{key.New("foo")}, out: key.Path{key.New("foo")},
}, {
in: "foo/bar",
out: key.Path{key.New("foo"), key.New("bar")},
}, { }, {
in: "/foo/bar", in: "/foo/bar",
out: Path{key.New("foo"), key.New("bar")}, out: key.Path{key.New("foo"), key.New("bar")},
}, {
in: "foo/bar/baz",
out: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, { }, {
in: "/foo/bar/baz", in: "/foo/bar/baz",
out: Path{key.New("foo"), key.New("bar"), key.New("baz")}, out: key.Path{key.New("foo"), key.New("bar"), key.New("baz")},
}, {
in: "0/123/456/789",
out: key.Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")},
}, { }, {
in: "/0/123/456/789", in: "/0/123/456/789",
out: Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")}, out: key.Path{key.New("0"), key.New("123"), key.New("456"), key.New("789")},
}, {
in: "`~!@#$%^&*()_+{}\\/|[];':\"<>?,./",
out: key.Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")},
}, { }, {
in: "/`~!@#$%^&*()_+{}\\/|[];':\"<>?,./", in: "/`~!@#$%^&*()_+{}\\/|[];':\"<>?,./",
out: Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")}, out: key.Path{key.New("`~!@#$%^&*()_+{}\\"), key.New("|[];':\"<>?,."), key.New("")},
}, },
} }
for i, tcase := range tcases { for i, tcase := range tcases {
if p := FromString(tcase.in); !p.Equal(tcase.out) { if p := FromString(tcase.in); !Equal(p, tcase.out) {
t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out) t.Fatalf("Test %d failed: %#v != %#v", i, p, tcase.out)
} }
} }
} }
func TestPathToString(t *testing.T) { func TestString(t *testing.T) {
tcases := []struct { tcases := []struct {
in Path in key.Path
out string out string
}{ }{
{ {
in: Path{}, in: key.Path{},
out: "",
}, {
in: Path{key.New("")},
out: "/", out: "/",
}, { }, {
in: Path{key.New("foo")}, in: key.Path{key.New("")},
out: "/",
}, {
in: key.Path{key.New("foo")},
out: "/foo", out: "/foo",
}, { }, {
in: Path{key.New("foo"), key.New("bar")}, in: key.Path{key.New("foo"), key.New("bar")},
out: "/foo/bar", out: "/foo/bar",
}, { }, {
in: Path{key.New("/foo"), key.New("bar")}, in: key.Path{key.New("/foo"), key.New("bar")},
out: "//foo/bar", out: "//foo/bar",
}, { }, {
in: Path{key.New("foo"), key.New("bar/")}, in: key.Path{key.New("foo"), key.New("bar/")},
out: "/foo/bar/", out: "/foo/bar/",
}, { }, {
in: Path{key.New(""), key.New("foo"), key.New("bar")}, in: key.Path{key.New(""), key.New("foo"), key.New("bar")},
out: "//foo/bar", out: "//foo/bar",
}, { }, {
in: Path{key.New("foo"), key.New("bar"), key.New("")}, in: key.Path{key.New("foo"), key.New("bar"), key.New("")},
out: "/foo/bar/", out: "/foo/bar/",
}, { }, {
in: Path{key.New("/"), key.New("foo"), key.New("bar")}, in: key.Path{key.New("/"), key.New("foo"), key.New("bar")},
out: "///foo/bar", out: "///foo/bar",
}, { }, {
in: Path{key.New("foo"), key.New("bar"), key.New("/")}, in: key.Path{key.New("foo"), key.New("bar"), key.New("/")},
out: "/foo/bar//", out: "/foo/bar//",
}, },
} }
@ -296,3 +627,59 @@ func TestPathToString(t *testing.T) {
} }
} }
} }
func BenchmarkJoin(b *testing.B) {
generate := func(n int) []key.Path {
paths := make([]key.Path, 0, n)
for i := 0; i < n; i++ {
paths = append(paths, key.Path{key.New("foo")})
}
return paths
}
benchmarks := map[string][]key.Path{
"10 key.Paths": generate(10),
"100 key.Paths": generate(100),
"1000 key.Paths": generate(1000),
"10000 key.Paths": generate(10000),
}
for name, benchmark := range benchmarks {
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Join(benchmark...)
}
})
}
}
func BenchmarkHasElement(b *testing.B) {
element := key.New("waldo")
generate := func(n, loc int) key.Path {
path := make(key.Path, n)
for i := 0; i < n; i++ {
if i == loc {
path[i] = element
} else {
path[i] = key.New(int8(0))
}
}
return path
}
benchmarks := map[string]key.Path{
"10 Elements Index 0": generate(10, 0),
"10 Elements Index 4": generate(10, 4),
"10 Elements Index 9": generate(10, 9),
"100 Elements Index 0": generate(100, 0),
"100 Elements Index 49": generate(100, 49),
"100 Elements Index 99": generate(100, 99),
"1000 Elements Index 0": generate(1000, 0),
"1000 Elements Index 499": generate(1000, 499),
"1000 Elements Index 999": generate(1000, 999),
}
for name, benchmark := range benchmarks {
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
HasElement(benchmark, element)
}
})
}
}

View File

@ -0,0 +1,36 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package path
import "github.com/aristanetworks/goarista/key"
// Wildcard is a special element in a path that is used by Map
// and the Match* functions to match any other element.
var Wildcard = key.New(WildcardType{})
// WildcardType is the type used to construct a Wildcard. It
// implements the value.Value interface so it can be used as
// a key.Key.
type WildcardType struct{}
func (w WildcardType) String() string {
return "*"
}
// Equal implements the key.Comparable interface.
func (w WildcardType) Equal(other interface{}) bool {
_, ok := other.(WildcardType)
return ok
}
// ToBuiltin implements the value.Value interface.
func (w WildcardType) ToBuiltin() interface{} {
return WildcardType{}
}
// MarshalJSON implements the value.Value interface.
func (w WildcardType) MarshalJSON() ([]byte, error) {
return []byte(`{"_wildcard":{}}`), nil
}

View File

@ -0,0 +1,79 @@
// Copyright (c) 2018 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package path
import (
"encoding/json"
"testing"
"github.com/aristanetworks/goarista/key"
"github.com/aristanetworks/goarista/value"
)
type pseudoWildcard struct{}
func (w pseudoWildcard) Key() interface{} {
return struct{}{}
}
func (w pseudoWildcard) String() string {
return "*"
}
func (w pseudoWildcard) Equal(other interface{}) bool {
o, ok := other.(pseudoWildcard)
return ok && w == o
}
func TestWildcardUniqueness(t *testing.T) {
if Wildcard.Equal(pseudoWildcard{}) {
t.Fatal("Wildcard is not unique")
}
if Wildcard.Equal(struct{}{}) {
t.Fatal("Wildcard is not unique")
}
if Wildcard.Equal(key.New("*")) {
t.Fatal("Wildcard is not unique")
}
}
func TestWildcardTypeIsNotAKey(t *testing.T) {
var intf interface{} = WildcardType{}
_, ok := intf.(key.Key)
if ok {
t.Error("WildcardType should not implement key.Key")
}
}
func TestWildcardTypeEqual(t *testing.T) {
k1 := key.New(WildcardType{})
k2 := key.New(WildcardType{})
if !k1.Equal(k2) {
t.Error("They should be equal")
}
if !Wildcard.Equal(k1) {
t.Error("They should be equal")
}
}
func TestWildcardTypeAsValue(t *testing.T) {
var k value.Value = WildcardType{}
w := WildcardType{}
if k.ToBuiltin() != w {
t.Error("WildcardType.ToBuiltin is not correct")
}
}
func TestWildcardMarshalJSON(t *testing.T) {
b, err := json.Marshal(Wildcard)
if err != nil {
t.Fatal(err)
}
expected := `{"_wildcard":{}}`
if string(b) != expected {
t.Errorf("Invalid Wildcard json representation.\nExpected: %s\nReceived: %s",
expected, string(b))
}
}

View File

@ -1,265 +0,0 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package pathmap
import (
"bytes"
"fmt"
"sort"
)
// PathMap associates Paths to a values. It allows wildcards. The
// primary use of PathMap is to be able to register handlers to paths
// that can be efficiently looked up every time a path is updated.
//
// For example:
//
// m.Set({"interfaces", "*", "adminStatus"}, AdminStatusHandler)
// m.Set({"interface", "Management1", "adminStatus"}, Management1AdminStatusHandler)
//
// m.Visit({"interfaces", "Ethernet3/32/1", "adminStatus"}, HandlerExecutor)
// >> AdminStatusHandler gets passed to HandlerExecutor
// m.Visit({"interfaces", "Management1", "adminStatus"}, HandlerExecutor)
// >> AdminStatusHandler and Management1AdminStatusHandler gets passed to HandlerExecutor
//
// Note, Visit performance is typically linearly with the length of
// the path. But, it can be as bad as O(2^len(Path)) when TreeMap
// nodes have children and a wildcard associated with it. For example,
// if these paths were registered:
//
// m.Set([]string{"foo", "bar", "baz"}, 1)
// m.Set([]string{"*", "bar", "baz"}, 2)
// m.Set([]string{"*", "*", "baz"}, 3)
// m.Set([]string{"*", "*", "*"}, 4)
// m.Set([]string{"foo", "*", "*"}, 5)
// m.Set([]string{"foo", "bar", "*"}, 6)
// m.Set([]string{"foo", "*", "baz"}, 7)
// m.Set([]string{"*", "bar", "*"}, 8)
//
// m.Visit([]{"foo","bar","baz"}, Foo) // 2^3 nodes traversed
//
// This shouldn't be a concern with our paths because it is likely
// that a TreeMap node will either have a wildcard or children, not
// both. A TreeMap node that corresponds to a collection will often be a
// wildcard, otherwise it will have specific children.
type PathMap interface {
// Visit calls f for every registration in the PathMap that
// matches path. For example,
//
// m.Set({"foo", "bar"}, 1)
// m.Set({"*", "bar"}, 2)
//
// m.Visit({"foo", "bar"}, Printer)
// >> Calls Printer(1) and Printer(2)
Visit(path []string, f VisitorFunc) error
// VisitPrefix calls f for every registration in the PathMap that
// is a prefix of path. For example,
//
// m.Set({}, 0)
// m.Set({"foo"}, 1)
// m.Set({"foo", "bar"}, 2)
// m.Set({"foo", "quux"}, 3)
// m.Set({"*", "bar"}, 4)
//
// m.VisitPrefix({"foo", "bar", "baz"}, Printer)
// >> Calls Printer on values 0, 1, 2, and 4
VisitPrefix(path []string, f VisitorFunc) error
// Get returns the mapping for path. This returns the exact
// mapping for path. For example, if you register two paths
//
// m.Set({"foo", "bar"}, 1)
// m.Set({"*", "bar"}, 2)
//
// m.Get({"foo", "bar"}) => 1
// m.Get({"*", "bar"}) => 2
Get(path []string) interface{}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
Set(path []string, v interface{})
// Delete removes the mapping for path
Delete(path []string) bool
}
// Wildcard is a special string representing any possible path
const Wildcard string = "*"
type node struct {
val interface{}
wildcard *node
children map[string]*node
}
// New creates a new PathMap
func New() PathMap {
return &node{}
}
// VisitorFunc is the func type passed to Visit
type VisitorFunc func(v interface{}) error
// Visit calls f for every matching registration in the PathMap
func (n *node) Visit(path []string, f VisitorFunc) error {
for i, element := range path {
if n.wildcard != nil {
if err := n.wildcard.Visit(path[i+1:], f); err != nil {
return err
}
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
if n.val == nil {
return nil
}
return f(n.val)
}
// VisitPrefix calls f for every registered path that is a prefix of
// the path
func (n *node) VisitPrefix(path []string, f VisitorFunc) error {
for i, element := range path {
// Call f on each node we visit
if n.val != nil {
if err := f(n.val); err != nil {
return err
}
}
if n.wildcard != nil {
if err := n.wildcard.VisitPrefix(path[i+1:], f); err != nil {
return err
}
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
if n.val == nil {
return nil
}
// Call f on the final node
return f(n.val)
}
// Get returns the mapping for path
func (n *node) Get(path []string) interface{} {
for _, element := range path {
if element == Wildcard {
if n.wildcard == nil {
return nil
}
n = n.wildcard
continue
}
next, ok := n.children[element]
if !ok {
return nil
}
n = next
}
return n.val
}
// Set a mapping of path to value. Path may contain wildcards. Set
// replaces what was there before.
func (n *node) Set(path []string, v interface{}) {
for _, element := range path {
if element == Wildcard {
if n.wildcard == nil {
n.wildcard = &node{}
}
n = n.wildcard
continue
}
if n.children == nil {
n.children = map[string]*node{}
}
next, ok := n.children[element]
if !ok {
next = &node{}
n.children[element] = next
}
n = next
}
n.val = v
}
// Delete removes the mapping for path
func (n *node) Delete(path []string) bool {
nodes := make([]*node, len(path)+1)
for i, element := range path {
nodes[i] = n
if element == Wildcard {
if n.wildcard == nil {
return false
}
n = n.wildcard
continue
}
next, ok := n.children[element]
if !ok {
return false
}
n = next
}
n.val = nil
nodes[len(path)] = n
// See if we can delete any node objects
for i := len(path); i > 0; i-- {
n = nodes[i]
if n.val != nil || n.wildcard != nil || len(n.children) > 0 {
break
}
parent := nodes[i-1]
element := path[i-1]
if element == Wildcard {
parent.wildcard = nil
} else {
delete(parent.children, element)
}
}
return true
}
func (n *node) String() string {
var b bytes.Buffer
n.write(&b, "")
return b.String()
}
func (n *node) write(b *bytes.Buffer, indent string) {
if n.val != nil {
b.WriteString(indent)
fmt.Fprintf(b, "Val: %v", n.val)
b.WriteString("\n")
}
if n.wildcard != nil {
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", Wildcard)
n.wildcard.write(b, indent+" ")
}
children := make([]string, 0, len(n.children))
for name := range n.children {
children = append(children, name)
}
sort.Strings(children)
for _, name := range children {
child := n.children[name]
b.WriteString(indent)
fmt.Fprintf(b, "Child %q:\n", name)
child.write(b, indent+" ")
}
}

View File

@ -1,335 +0,0 @@
// Copyright (c) 2016 Arista Networks, Inc.
// Use of this source code is governed by the Apache License 2.0
// that can be found in the COPYING file.
package pathmap
import (
"errors"
"fmt"
"testing"
"github.com/aristanetworks/goarista/test"
)
func accumulator(counter map[int]int) VisitorFunc {
return func(val interface{}) error {
counter[val.(int)]++
return nil
}
}
func TestVisit(t *testing.T) {
m := New()
m.Set([]string{"foo", "bar", "baz"}, 1)
m.Set([]string{"*", "bar", "baz"}, 2)
m.Set([]string{"*", "*", "baz"}, 3)
m.Set([]string{"*", "*", "*"}, 4)
m.Set([]string{"foo", "*", "*"}, 5)
m.Set([]string{"foo", "bar", "*"}, 6)
m.Set([]string{"foo", "*", "baz"}, 7)
m.Set([]string{"*", "bar", "*"}, 8)
m.Set([]string{}, 10)
m.Set([]string{"*"}, 20)
m.Set([]string{"foo"}, 21)
m.Set([]string{"zap", "zip"}, 30)
m.Set([]string{"zap", "zip"}, 31)
m.Set([]string{"zip", "*"}, 40)
m.Set([]string{"zip", "*"}, 41)
testCases := []struct {
path []string
expected map[int]int
}{{
path: []string{"foo", "bar", "baz"},
expected: map[int]int{1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1},
}, {
path: []string{"qux", "bar", "baz"},
expected: map[int]int{2: 1, 3: 1, 4: 1, 8: 1},
}, {
path: []string{"foo", "qux", "baz"},
expected: map[int]int{3: 1, 4: 1, 5: 1, 7: 1},
}, {
path: []string{"foo", "bar", "qux"},
expected: map[int]int{4: 1, 5: 1, 6: 1, 8: 1},
}, {
path: []string{},
expected: map[int]int{10: 1},
}, {
path: []string{"foo"},
expected: map[int]int{20: 1, 21: 1},
}, {
path: []string{"foo", "bar"},
expected: map[int]int{},
}, {
path: []string{"zap", "zip"},
expected: map[int]int{31: 1},
}, {
path: []string{"zip", "zap"},
expected: map[int]int{41: 1},
}}
for _, tc := range testCases {
result := make(map[int]int, len(tc.expected))
m.Visit(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff)
}
}
}
func TestVisitError(t *testing.T) {
m := New()
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"*", "bar"}, 2)
errTest := errors.New("Test")
err := m.Visit([]string{"foo", "bar"}, func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
err = m.VisitPrefix([]string{"foo", "bar", "baz"}, func(v interface{}) error { return errTest })
if err != errTest {
t.Errorf("Unexpected error. Expected: %v, Got: %v", errTest, err)
}
}
func TestGet(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"foo", "*"}, 2)
m.Set([]string{"*", "bar"}, 3)
m.Set([]string{"zap", "zip"}, 4)
testCases := []struct {
path []string
expected interface{}
}{{
path: []string{},
expected: 0,
}, {
path: []string{"foo", "bar"},
expected: 1,
}, {
path: []string{"foo", "*"},
expected: 2,
}, {
path: []string{"*", "bar"},
expected: 3,
}, {
path: []string{"bar", "foo"},
expected: nil,
}, {
path: []string{"zap", "*"},
expected: nil,
}}
for _, tc := range testCases {
got := m.Get(tc.path)
if got != tc.expected {
t.Errorf("Test case %v: Expected %v, Got %v",
tc.path, tc.expected, got)
}
}
}
func countNodes(n *node) int {
if n == nil {
return 0
}
count := 1
count += countNodes(n.wildcard)
for _, child := range n.children {
count += countNodes(child)
}
return count
}
func TestDelete(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"*"}, 1)
m.Set([]string{"foo", "bar"}, 2)
m.Set([]string{"foo", "*"}, 3)
n := countNodes(m.(*node))
if n != 5 {
t.Errorf("Initial count wrong. Expected: 5, Got: %d", n)
}
testCases := []struct {
del []string // Path to delete
expected bool // expected return value of Delete
visit []string // Path to Visit
before map[int]int // Expected to find items before deletion
after map[int]int // Expected to find items after deletion
count int // Count of nodes
}{{
del: []string{"zap"}, // A no-op Delete
expected: false,
visit: []string{"foo", "bar"},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{2: 1, 3: 1},
count: 5,
}, {
del: []string{"foo", "bar"},
expected: true,
visit: []string{"foo", "bar"},
before: map[int]int{2: 1, 3: 1},
after: map[int]int{3: 1},
count: 4,
}, {
del: []string{"*"},
expected: true,
visit: []string{"foo"},
before: map[int]int{1: 1},
after: map[int]int{},
count: 3,
}, {
del: []string{"*"},
expected: false,
visit: []string{"foo"},
before: map[int]int{},
after: map[int]int{},
count: 3,
}, {
del: []string{"foo", "*"},
expected: true,
visit: []string{"foo", "bar"},
before: map[int]int{3: 1},
after: map[int]int{},
count: 1, // Should have deleted "foo" and "bar" nodes
}, {
del: []string{},
expected: true,
visit: []string{},
before: map[int]int{0: 1},
after: map[int]int{},
count: 1, // Root node can't be deleted
}}
for i, tc := range testCases {
beforeResult := make(map[int]int, len(tc.before))
m.Visit(tc.visit, accumulator(beforeResult))
if diff := test.Diff(tc.before, beforeResult); diff != "" {
t.Errorf("Test case %d (%v): %s", i, tc.del, diff)
}
if got := m.Delete(tc.del); got != tc.expected {
t.Errorf("Test case %d (%v): Unexpected return. Expected %t, Got: %t",
i, tc.del, tc.expected, got)
}
afterResult := make(map[int]int, len(tc.after))
m.Visit(tc.visit, accumulator(afterResult))
if diff := test.Diff(tc.after, afterResult); diff != "" {
t.Errorf("Test case %d (%v): %s", i, tc.del, diff)
}
}
}
func TestVisitPrefix(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo"}, 1)
m.Set([]string{"foo", "bar"}, 2)
m.Set([]string{"foo", "bar", "baz"}, 3)
m.Set([]string{"foo", "bar", "baz", "quux"}, 4)
m.Set([]string{"quux", "bar"}, 5)
m.Set([]string{"foo", "quux"}, 6)
m.Set([]string{"*"}, 7)
m.Set([]string{"foo", "*"}, 8)
m.Set([]string{"*", "bar"}, 9)
m.Set([]string{"*", "quux"}, 10)
m.Set([]string{"quux", "quux", "quux", "quux"}, 11)
testCases := []struct {
path []string
expected map[int]int
}{{
path: []string{"foo", "bar", "baz"},
expected: map[int]int{0: 1, 1: 1, 2: 1, 3: 1, 7: 1, 8: 1, 9: 1},
}, {
path: []string{"zip", "zap"},
expected: map[int]int{0: 1, 7: 1},
}, {
path: []string{"foo", "zap"},
expected: map[int]int{0: 1, 1: 1, 8: 1, 7: 1},
}, {
path: []string{"quux", "quux", "quux"},
expected: map[int]int{0: 1, 7: 1, 10: 1},
}}
for _, tc := range testCases {
result := make(map[int]int, len(tc.expected))
m.VisitPrefix(tc.path, accumulator(result))
if diff := test.Diff(tc.expected, result); diff != "" {
t.Errorf("Test case %v: %s", tc.path, diff)
}
}
}
func TestString(t *testing.T) {
m := New()
m.Set([]string{}, 0)
m.Set([]string{"foo", "bar"}, 1)
m.Set([]string{"foo", "quux"}, 2)
m.Set([]string{"foo", "*"}, 3)
expected := `Val: 0
Child "foo":
Child "*":
Val: 3
Child "bar":
Val: 1
Child "quux":
Val: 2
`
got := fmt.Sprint(m)
if expected != got {
t.Errorf("Unexpected string. Expected:\n\n%s\n\nGot:\n\n%s", expected, got)
}
}
func genWords(count, wordLength int) []string {
chars := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
if count+wordLength > len(chars) {
panic("need more chars")
}
result := make([]string, count)
for i := 0; i < count; i++ {
result[i] = string(chars[i : i+wordLength])
}
return result
}
func benchmarkPathMap(pathLength, pathDepth int, b *testing.B) {
m := New()
// Push pathDepth paths, each of length pathLength
path := genWords(pathLength, 10)
words := genWords(pathDepth, 10)
n := m.(*node)
for _, element := range path {
n.children = map[string]*node{}
for _, word := range words {
n.children[word] = &node{}
}
n = n.children[element]
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Visit(path, func(v interface{}) error { return nil })
}
}
func BenchmarkPathMap1x25(b *testing.B) { benchmarkPathMap(1, 25, b) }
func BenchmarkPathMap10x50(b *testing.B) { benchmarkPathMap(10, 25, b) }
func BenchmarkPathMap20x50(b *testing.B) { benchmarkPathMap(20, 25, b) }

View File

@ -0,0 +1,387 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
package sizeof
import (
"errors"
"reflect"
"unsafe"
"github.com/aristanetworks/goarista/areflect"
)
// blocks are used to keep track of which memory areas were already
// been visited.
type block struct {
start uintptr
end uintptr
}
func (b block) size() uintptr {
return b.end - b.start
}
// DeepSizeof returns total memory occupied by each type for val.
// The value passed in argument must be a pointer.
func DeepSizeof(val interface{}) (map[string]uintptr, error) {
value := reflect.ValueOf(val)
// We want to force val to be a pointer to the original value, because if we get a copy, we
// can get some pointers that will point back to our original value.
if value.Kind() != reflect.Ptr {
return nil, errors.New("cannot get the deep size of a non-pointer value")
}
m := make(map[string]uintptr)
ptrsTypes := make(map[uintptr]map[string]struct{})
sizeof(value.Elem(), m, ptrsTypes, false, block{start: uintptr(value.Pointer())}, nil)
return m, nil
}
// Check if curBlock overlap tmpBlock
func isOverlapping(curBlock, tmpBlock block) bool {
return curBlock.start <= tmpBlock.end && tmpBlock.start <= curBlock.end
}
func getOverlappingBlocks(curBlock block, seen []block) ([]block, int) {
var tmp []block
for idx, a := range seen {
if a.start > curBlock.end {
return tmp, idx
}
if isOverlapping(curBlock, a) {
tmp = append(tmp, a)
}
}
return tmp, len(seen)
}
func insertBlock(curBlock block, idxToInsert int, seen []block) []block {
seen = append(seen, block{})
copy(seen[idxToInsert+1:], seen[idxToInsert:])
seen[idxToInsert] = curBlock
return seen
}
// get the size of our current block that is not overlapping other blocks.
func getUnseenSizeOfCurrentBlock(curBlock block, overlappingBlocks []block) uintptr {
var size uintptr
for idx, a := range overlappingBlocks {
if idx == 0 && curBlock.start < a.start {
size += a.start - curBlock.start
}
if idx == len(overlappingBlocks)-1 {
if curBlock.end > a.end {
size += curBlock.end - a.end
}
} else {
size += overlappingBlocks[idx+1].start - a.end
}
}
return size
}
func updateSeenBlocks(curBlock block, seen []block) ([]block, uintptr) {
if len(seen) == 0 {
return []block{curBlock}, curBlock.size()
}
overlappingBlocks, idx := getOverlappingBlocks(curBlock, seen)
if len(overlappingBlocks) == 0 {
// No overlap, so we will insert our new block in our list.
return insertBlock(curBlock, idx, seen), curBlock.size()
}
unseenSize := getUnseenSizeOfCurrentBlock(curBlock, overlappingBlocks)
idxFirstOverlappingBlock := idx - len(overlappingBlocks)
firstOverlappingBlock := &seen[idxFirstOverlappingBlock]
lastOverlappingBlock := seen[idx-1]
if firstOverlappingBlock.start > curBlock.start {
firstOverlappingBlock.start = curBlock.start
}
if lastOverlappingBlock.end < curBlock.end {
firstOverlappingBlock.end = curBlock.end
} else {
firstOverlappingBlock.end = lastOverlappingBlock.end
}
tailLen := len(seen[idx:])
copy(seen[idxFirstOverlappingBlock+1:], seen[idx:])
return seen[:idxFirstOverlappingBlock+1+tailLen], unseenSize
}
// Check if this current block is already fully contained in our list of seen blocks
func isKnownBlock(curBlock block, seen []block) bool {
for _, a := range seen {
if a.start <= curBlock.start &&
a.end >= curBlock.end {
// curBlock is fully contained in an other block
// that we already know
return true
}
if a.start > curBlock.start {
// Our slice of seens block is order by pointer address.
// That means, if curBlock was not contained in a previous known
// block, there is no need to continue.
return false
}
}
return false
}
func sizeof(v reflect.Value, m map[string]uintptr, ptrsTypes map[uintptr]map[string]struct{},
counted bool, curBlock block, seen []block) []block {
if !v.IsValid() {
return seen
}
vn := v.Type().String()
vs := v.Type().Size()
curBlock.end = vs + curBlock.start
if counted {
// already accounted for the size (field in struct, in array, etc)
vs = 0
}
if curBlock.start != 0 {
// A pointer can point to the same memory area than a previous pointer,
// but its type should be different (see tests struct_5 and struct_6).
if types, ok := ptrsTypes[curBlock.start]; ok {
if _, ok := types[vn]; ok {
return seen
}
types[vn] = struct{}{}
} else {
ptrsTypes[curBlock.start] = make(map[string]struct{})
}
if isKnownBlock(curBlock, seen) {
// we don't want to count this size if we have a known block
vs = 0
} else {
var tmpVs uintptr
seen, tmpVs = updateSeenBlocks(curBlock, seen)
if !counted {
vs = tmpVs
}
}
}
switch v.Kind() {
case reflect.Interface:
seen = sizeof(v.Elem(), m, ptrsTypes, false, block{}, seen)
case reflect.Ptr:
if v.IsNil() {
break
}
seen = sizeof(v.Elem(), m, ptrsTypes, false, block{start: uintptr(v.Pointer())}, seen)
case reflect.Array:
// get size of all elements in the array in case there are pointers
l := v.Len()
for i := 0; i < l; i++ {
seen = sizeof(v.Index(i), m, ptrsTypes, true, block{}, seen)
}
case reflect.Slice:
// get size of all elements in the slice in case there are pointers
// TODO: count elements that are not accessible after reslicing
l := v.Len()
vLen := v.Type().Elem().Size()
for i := 0; i < l; i++ {
e := v.Index(i)
eStart := uintptr(e.UnsafeAddr())
eBlock := block{
start: eStart,
end: eStart + vLen,
}
if !isKnownBlock(eBlock, seen) {
vs += vLen
seen = sizeof(e, m, ptrsTypes, true, eBlock, seen)
}
}
capStart := uintptr(v.Pointer()) + (v.Type().Elem().Size() * uintptr(v.Len()))
capEnd := uintptr(v.Pointer()) + (v.Type().Elem().Size() * uintptr(v.Cap()))
capBlock := block{start: capStart, end: capEnd}
if isKnownBlock(capBlock, seen) {
break
}
var tmpSize uintptr
seen, tmpSize = updateSeenBlocks(capBlock, seen)
vs += tmpSize
case reflect.Map:
if v.IsNil() {
break
}
var tmpSize uintptr
if tmpSize, seen = sizeofmap(v, seen); tmpSize == 0 {
// we saw this map
break
}
vs += tmpSize
for _, k := range v.MapKeys() {
kv := v.MapIndex(k)
seen = sizeof(k, m, ptrsTypes, true, block{}, seen)
seen = sizeof(kv, m, ptrsTypes, true, block{}, seen)
}
case reflect.Struct:
for i, n := 0, v.NumField(); i < n; i++ {
vf := areflect.ForceExport(v.Field(i))
seen = sizeof(vf, m, ptrsTypes, true, block{}, seen)
}
case reflect.String:
str := v.String()
strHdr := (*reflect.StringHeader)(unsafe.Pointer(&str))
tmpSize := uintptr(strHdr.Len)
strBlock := block{start: strHdr.Data, end: strHdr.Data + tmpSize}
if isKnownBlock(strBlock, seen) {
break
}
seen, tmpSize = updateSeenBlocks(strBlock, seen)
vs += tmpSize
case reflect.Chan:
var tmpSize uintptr
tmpSize, seen = sizeofChan(v, m, ptrsTypes, seen)
vs += tmpSize
}
if vs != 0 {
m[vn] += vs
}
return seen
}
//go:linkname typesByString reflect.typesByString
func typesByString(s string) []unsafe.Pointer
func sizeofmap(v reflect.Value, seen []block) (uintptr, []block) {
// get field typ *rtype of our Value v and store it in an interface
var ti interface{} = v.Type()
tp := (*unsafe.Pointer)(unsafe.Pointer(&ti))
// we know that this pointer rtype point at the begining of struct
// mapType defined in /go/src/reflect/type.go, so we can change the underlying
// type of the interface to be a pointer to runtime.maptype because it as the
// exact same definition as reflect.mapType.
*tp = typesByString("*runtime.maptype")[0]
maptypev := reflect.ValueOf(ti)
maptypev = reflect.Indirect(maptypev)
// now we can access field bucketsize in struct maptype
bucketsize := maptypev.FieldByName("bucketsize").Uint()
// get hmap
var m interface{} = v.Interface()
hmap := (*unsafe.Pointer)(unsafe.Pointer(&m))
*hmap = typesByString("*runtime.hmap")[0]
hmapv := reflect.ValueOf(m)
// account for the size of the hmap, buckets and oldbuckets
hmapv = reflect.Indirect(hmapv)
mapBlock := block{
start: hmapv.UnsafeAddr(),
end: hmapv.UnsafeAddr() + hmapv.Type().Size(),
}
// is it a map we already saw?
if isKnownBlock(mapBlock, seen) {
return 0, seen
}
seen, _ = updateSeenBlocks(mapBlock, seen)
B := hmapv.FieldByName("B").Uint()
oldbuckets := hmapv.FieldByName("oldbuckets").Pointer()
buckets := hmapv.FieldByName("buckets").Pointer()
noverflow := int16(hmapv.FieldByName("noverflow").Uint())
nb := 2
if B == 0 {
nb = 1
}
size := uint64((nb << B)) * bucketsize
if noverflow != 0 {
size += uint64(noverflow) * bucketsize
}
seen, _ = updateSeenBlocks(block{start: buckets, end: buckets + uintptr(size)},
seen)
// As defined in /go/src/runtime/hashmap.go in struct hmap, oldbuckets is the
// previous bucket array that is half the size of the current one. We need to
// also take that in consideration since there is still a pointer to this previous bucket.
if oldbuckets != 0 {
tmp := (2 << (B - 1)) * bucketsize
size += tmp
seen, _ = updateSeenBlocks(block{
start: oldbuckets,
end: oldbuckets + uintptr(tmp),
}, seen)
}
return hmapv.Type().Size() + uintptr(size), seen
}
func getSliceToChanBuffer(buff unsafe.Pointer, buffLen uint, dataSize uint) []byte {
var slice []byte
sliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
sliceHdr.Len = int(buffLen * dataSize)
sliceHdr.Cap = sliceHdr.Len
sliceHdr.Data = uintptr(buff)
return slice
}
func sizeofChan(v reflect.Value, m map[string]uintptr, ptrsTypes map[uintptr]map[string]struct{},
seen []block) (uintptr, []block) {
var c interface{} = v.Interface()
hchan := (*unsafe.Pointer)(unsafe.Pointer(&c))
*hchan = typesByString("*runtime.hchan")[0]
hchanv := reflect.ValueOf(c)
hchanv = reflect.Indirect(hchanv)
chanBlock := block{
start: hchanv.UnsafeAddr(),
end: hchanv.UnsafeAddr() + hchanv.Type().Size(),
}
// is it a chan we already saw?
if isKnownBlock(chanBlock, seen) {
return 0, seen
}
seen, _ = updateSeenBlocks(chanBlock, seen)
elemType := unsafe.Pointer(hchanv.FieldByName("elemtype").Pointer())
buff := unsafe.Pointer(hchanv.FieldByName("buf").Pointer())
buffLen := hchanv.FieldByName("dataqsiz").Uint()
elemSize := uint16(hchanv.FieldByName("elemsize").Uint())
seen, _ = updateSeenBlocks(block{
start: uintptr(buff),
end: uintptr(buff) + uintptr(buffLen*uint64(elemSize)),
}, seen)
buffSlice := getSliceToChanBuffer(buff, uint(buffLen), uint(elemSize))
recvx := hchanv.FieldByName("recvx").Uint()
qcount := hchanv.FieldByName("qcount").Uint()
var tmp interface{}
eface := (*struct {
typ unsafe.Pointer
ptr unsafe.Pointer
})(unsafe.Pointer(&tmp))
eface.typ = elemType
for i := uint64(0); buffLen > 0 && i < qcount; i++ {
idx := (recvx + i) % buffLen
// get the pointer to the data inside the chan buffer.
elem := unsafe.Pointer(&buffSlice[uint64(elemSize)*idx])
eface.ptr = elem
ev := reflect.ValueOf(tmp)
var blk block
k := ev.Kind()
if k == reflect.Ptr || k == reflect.Chan || k == reflect.Map || k == reflect.Func {
// let's say our chan is a chan *whatEver, or chan chan whatEver or
// chan map[whatEver]whatEver. In this case elemType will
// be either of type *whatEver, chan whatEver or map[whatEver]whatEver
// but what we set eface.ptr = elem above, we make it point to a pointer
// to where the data is sotred in the buffer of our channel.
// So the interface tmp would look like:
// chan *whatEver -> (type=*whatEver, ptr=**whatEver)
// chan chan whatEver -> (type= chan whatEver, ptr=*chan whatEver)
// chan map[whatEver]whatEver -> (type= map[whatEver]whatEver{},
// ptr=*map[whatEver]whatEver)
// So we need to take the ptr which is stored into the buffer and replace
// the ptr to the data of our interface tmp.
ptr := (*unsafe.Pointer)(elem)
eface.ptr = *ptr
ev = reflect.ValueOf(tmp)
ev = reflect.Indirect(ev)
blk.start = uintptr(*ptr)
}
// It seems that when the chan is of type chan *whatEver, the type in eface
// will be whatEver and not *whatEver, but ev.Kind() will be a reflect.ptr.
// So if k is a reflect.Ptr (i.e. a pointer) to a struct, then we want to take
// the size of the struct into account because
// vs := v.Type().Size() will return us the size of the struct and not the size
// of the pointer that is in the channel's buffer.
seen = sizeof(ev, m, ptrsTypes, true && k != reflect.Ptr, blk, seen)
}
return hchanv.Type().Size() + uintptr(uint64(elemSize)*buffLen), seen
}

View File

@ -0,0 +1,8 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
// This file is intentionally empty.
// It's a workaround for https://github.com/golang/go/issues/15006

View File

@ -0,0 +1,573 @@
// Copyright (c) 2017 Arista Networks, Inc. All rights reserved.
// Arista Networks, Inc. Confidential and Proprietary.
// Subject to Arista Networks, Inc.'s EULA.
// FOR INTERNAL USE ONLY. NOT FOR DISTRIBUTION.
package sizeof
import (
"fmt"
"strconv"
"testing"
"unsafe"
"github.com/aristanetworks/goarista/test"
)
type yolo struct {
i int32
a [3]int8
p unsafe.Pointer
}
func (y yolo) String() string {
return "Yolo"
}
func TestDeepSizeof(t *testing.T) {
ptrSize := uintptr(unsafe.Sizeof(unsafe.Pointer(t)))
// hmapStructSize represent the size of struct hmap defined in
// file /go/src/runtime/hashmap.go
hmapStructSize := uintptr(unsafe.Sizeof(int(0)) + 2*1 + 2 + 4 +
2*ptrSize + ptrSize + ptrSize)
var alignement uintptr = 4
if ptrSize == 4 {
alignement = 0
}
strHdrSize := unsafe.Sizeof("") // int + ptr to data
sliceHdrSize := 3 * ptrSize // ptr to data + 2 * int
// struct hchan is defined in /go/src/runtime/chan.go
chanHdrSize := 2*ptrSize + ptrSize + 2 + 2 /* padding */ + 4 + ptrSize + 2*ptrSize +
2*(2*ptrSize) + ptrSize
yoloSize := unsafe.Sizeof(yolo{})
interfaceSize := 2 * ptrSize
topHashSize := uintptr(8)
tests := map[string]struct {
getStruct func() interface{}
expectedSize uintptr
}{
"bool": {
getStruct: func() interface{} {
var test bool
return &test
},
expectedSize: 1,
},
"int8": {
getStruct: func() interface{} {
test := int8(4)
return &test
},
expectedSize: 1,
},
"int16": {
getStruct: func() interface{} {
test := int16(4)
return &test
},
expectedSize: 2,
},
"int32": {
getStruct: func() interface{} {
test := int32(4)
return &test
},
expectedSize: 4,
},
"int64": {
getStruct: func() interface{} {
test := int64(4)
return &test
},
expectedSize: 8,
},
"uint": {
getStruct: func() interface{} {
test := uint(4)
return &test
},
expectedSize: ptrSize,
},
"uint8": {
getStruct: func() interface{} {
test := uint8(4)
return &test
},
expectedSize: 1,
},
"uint16": {
getStruct: func() interface{} {
test := uint16(4)
return &test
},
expectedSize: 2,
},
"uint32": {
getStruct: func() interface{} {
test := uint32(4)
return &test
},
expectedSize: 4,
},
"uint64": {
getStruct: func() interface{} {
test := uint64(4)
return &test
},
expectedSize: 8,
},
"uintptr": {
getStruct: func() interface{} {
test := uintptr(4)
return &test
},
expectedSize: ptrSize,
},
"float32": {
getStruct: func() interface{} {
test := float32(4)
return &test
},
expectedSize: 4,
},
"float64": {
getStruct: func() interface{} {
test := float64(4)
return &test
},
expectedSize: 8,
},
"complex64": {
getStruct: func() interface{} {
test := complex64(4 + 1i)
return &test
},
expectedSize: 8,
},
"complex128": {
getStruct: func() interface{} {
test := complex128(4 + 1i)
return &test
},
expectedSize: 16,
},
"string": {
getStruct: func() interface{} {
test := "Hello Dolly!"
return &test
},
expectedSize: strHdrSize + 12,
},
"unsafe_Pointer": {
getStruct: func() interface{} {
tmp := uint64(54)
var test unsafe.Pointer
test = unsafe.Pointer(&tmp)
return &test
},
expectedSize: ptrSize,
}, "rune": {
getStruct: func() interface{} {
test := rune('A')
return &test
},
expectedSize: 4,
}, "intPtr": {
getStruct: func() interface{} {
test := int(4)
return &test
},
expectedSize: ptrSize,
}, "FuncPtr": {
getStruct: func() interface{} {
test := TestDeepSizeof
return &test
},
expectedSize: ptrSize,
}, "struct_1": {
getStruct: func() interface{} {
v := struct {
a uint32
b *uint32
c struct {
e [8]byte
d string
}
f string
}{
a: 10,
c: struct {
e [8]byte
d string
}{
e: [8]byte{0, 1, 2, 3, 4, 5, 6, 7},
d: "Hello Test!",
},
f: "Hello Test!",
}
a := uint32(47)
v.b = &a
return &v
},
expectedSize: 4 + alignement + ptrSize + 8 + strHdrSize*2 +
11 /* "Hello Test!" */ + 4, /* uint32(47) */
}, "struct_2": {
getStruct: func() interface{} {
v := struct {
a []byte
b []byte
c []byte
}{
c: make([]byte, 32, 64),
}
v.a = v.c[20:32]
v.b = v.c[10:20]
return &v
},
expectedSize: 3*sliceHdrSize + 64, /*slice capacity*/
}, "struct_3": {
getStruct: func() interface{} {
type test struct {
a *byte
c []byte
}
tmp := make([]byte, 64, 128)
v := (*test)(unsafe.Pointer(&tmp[16]))
v.c = tmp
v.a = (*byte)(unsafe.Pointer(&tmp[5]))
return v
},
// we expect to see 128 bytes as struct test is part of the bytes slice
// and field c point to it.
expectedSize: 128,
}, "map_string_interface": {
getStruct: func() interface{} {
return &map[string]interface{}{}
},
expectedSize: ptrSize + topHashSize + hmapStructSize +
(8*(strHdrSize+interfaceSize) + ptrSize),
}, "map_interface_interface": {
getStruct: func() interface{} {
// All the map will use only one bucket because there is less than 8
// entries in each map. Also for amd64 and i386 the bucket size is
// computed like in function bucketOf in /go/src/reflect/type.go:
return &map[interface{}]interface{}{
// 2 + (8 + 4) 386
// 2 + (16 + 4) amd64
uint16(123): "yolo",
// (4 + 8) + (4 + 28 + 140) + (4 for SWAG) 386
// (4 + 16) + (8 + 48 + 272) + (4 for SWAG) amd64
"meow": map[string]string{"SWAG": "yolo"},
// (12) + (4 + 12) 386
// (16) + (8 + 16) amd64
yolo{i: 523}: &yolo{i: 126},
// (12) + (12) 386
// (16) + (16) amd64
fmt.Stringer(yolo{i: 123}): yolo{i: 234},
}
},
// Total
// 386: (4 + 28 + 140) + 2 + (8 + 4) + (4 + 8) + (4 + 28 + 140) + 4 + (12) +
// (4 + 12) + 12 + 12
// amd64: (8 + 48 + 272) + 2 + (16 + 4) + (4 + 16) + (8 + 48 + 272) + (4) +
// (16) + (8 + 16) + (16) + (16)
expectedSize: (ptrSize + topHashSize + hmapStructSize +
(8*(2*interfaceSize) + ptrSize)) +
(unsafe.Sizeof(uint16(123)) + strHdrSize + 4 /* "yolo" */) +
(strHdrSize + 4 /* "meow" */ +
(ptrSize + hmapStructSize + topHashSize +
(8*(2*strHdrSize) + ptrSize)) /*map[string]string*/ +
4 /* "SWAG" */) +
(yoloSize /* obj: */ + (ptrSize + yoloSize) /* &obj */) +
yoloSize*2,
}, "struct_4": {
getStruct: func() interface{} {
return &struct {
a map[interface{}]interface{}
c string
d []string
}{
a: map[interface{}]interface{}{
uint16(123): "yolo",
"meow": map[string]string{"SWAG": "yolo"},
yolo{i: 127}: &yolo{i: 124},
fmt.Stringer(yolo{i: 123}): yolo{i: 234},
}, // 4 (386) or 8 (amd64)
c: "Hello", // 8 (386) or 16 (amd64)
d: []string{"Bonjour", "Hello", "Hola"}, // 12 (386) or 24 (amd64)
}
},
// Total
// 386: sizeof(tmp map) + 8(test.c) + 12(test.d) +
// 3 * 8 (strSlice) + 16(len("Bonjour") + len("Hello")...)
// amd64: sizeof(tmp map) + 8 (test.b) + 16(test.c) + 24(test.d) +
// 3 * 16 (strSlice) + 16(len("Bonjour") + len("Hello")...)
expectedSize: (ptrSize + strHdrSize + sliceHdrSize) + (hmapStructSize +
topHashSize + (8*(2*2*ptrSize /* interface size */) + ptrSize) +
unsafe.Sizeof(uint16(123)) +
strHdrSize + 4 /* "yolo" */ + strHdrSize + 4 /* "meow" */ +
+(ptrSize + hmapStructSize + topHashSize + (8*(2*strHdrSize) + ptrSize) +
4 /* "SWAG" */) + yoloSize /* obj */ + (ptrSize + yoloSize) /* &obj */ +
yoloSize*2) + 5 /* "Hello" */ +
3*strHdrSize /*strings in strSlice*/ + 11, /* "Bonjour" + "Hola" */
}, "chan_int": {
getStruct: func() interface{} {
test := make(chan int)
return &test
},
// The expected size should be equal to the size of the struct hchan
// defined in /go/src/runtime/chan.go
expectedSize: ptrSize + chanHdrSize,
}, "chan_int_16": {
getStruct: func() interface{} {
test := make(chan int, 16)
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
}, "chan_yoloPtr_16": {
getStruct: func() interface{} {
test := make(chan *yolo, 16)
for i := 0; i < 16; i++ {
tmp := &yolo{
i: int32(i),
}
tmp.p = unsafe.Pointer(&tmp.i)
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*(ptrSize+yoloSize),
}, "struct_5": {
getStruct: func() interface{} {
tmp := make([]byte, 32)
test := struct {
a []byte
b **uint32
}{
a: tmp,
}
bob := uint32(42)
ptrInt := (*uintptr)(unsafe.Pointer(&tmp[0]))
*ptrInt = uintptr(unsafe.Pointer(&bob))
test.b = (**uint32)(unsafe.Pointer(&tmp[0]))
return &test
},
expectedSize: sliceHdrSize + ptrSize + 32 + 4,
}, "struct_6": {
getStruct: func() interface{} {
type A struct {
a uintptr
b *yolo
}
type B struct {
a *A
b uintptr
}
tmp := make([]byte, 32)
test := struct {
a []byte
b *B
}{
a: tmp,
}
y := yolo{i: 42}
test.b = (*B)(unsafe.Pointer(&tmp[0]))
test.b.a = (*A)(unsafe.Pointer(&tmp[0]))
test.b.a.b = &y
return &test
},
expectedSize: sliceHdrSize + ptrSize + 32 + yoloSize,
}, "chan_chan_int_16": {
getStruct: func() interface{} {
test := make(chan chan int, 16)
for i := 0; i < 16; i++ {
tmp := make(chan int)
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize*17 + 16*ptrSize,
}, "chan_yolo_16": {
getStruct: func() interface{} {
test := make(chan yolo, 16)
for i := 0; i < 16; i++ {
tmp := yolo{
i: int32(i),
}
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*yoloSize,
}, "chan_map_string_interface_16)": {
getStruct: func() interface{} {
test := make(chan map[string]interface{}, 16)
for i := 0; i < 16; i++ {
tmp := make(map[string]interface{})
test <- tmp
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*(ptrSize+hmapStructSize+
(8*(1+strHdrSize+interfaceSize)+ptrSize)),
}, "chan_unsafe_Pointer_16": {
getStruct: func() interface{} {
test := make(chan unsafe.Pointer, 16)
for i := 0; i < 16; i++ {
var a int
ptrToA := (unsafe.Pointer)(unsafe.Pointer(&a))
test <- ptrToA
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
}, "chan_[]int_16": {
getStruct: func() interface{} {
test := make(chan []int, 16)
for i := 0; i < 8; i++ {
intSlice := make([]int, 16)
test <- intSlice
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*sliceHdrSize + 8*16*ptrSize,
}, "chan_func": {
getStruct: func() interface{} {
test := make(chan func(), 16)
f := func() {
fmt.Printf("Hello!")
}
for i := 0; i < 8; i++ {
test <- f
}
return &test
},
expectedSize: ptrSize + chanHdrSize + 16*ptrSize,
},
}
for key, tcase := range tests {
t.Run(key, func(t *testing.T) {
v := tcase.getStruct()
m, err := DeepSizeof(v)
if err != nil {
t.Fatal(err)
}
var totalSize uintptr
for _, size := range m {
totalSize += size
}
expectedSize := tcase.expectedSize
if totalSize != expectedSize {
t.Fatalf("Expected size: %v, but got %v", expectedSize, totalSize)
}
})
}
}
func TestUpdateSeenAreas(t *testing.T) {
tests := []struct {
seen []block
expectedSeen []block
expectedSize uintptr
update block
}{{
seen: []block{
{start: 0x100000, end: 0x100050},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100050},
{start: 0x100100, end: 0x100150},
},
expectedSize: 0x50,
update: block{start: 0x100100, end: 0x100150},
}, {
seen: []block{
{start: 0x100000, end: 0x100050},
},
expectedSeen: []block{
{start: 0x100, end: 0x150},
{start: 0x100000, end: 0x100050},
},
expectedSize: 0x50,
update: block{start: 0x100, end: 0x150},
}, {
seen: []block{
{start: 0x100000, end: 0x100500},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100750},
},
expectedSize: 0x250,
update: block{start: 0x100250, end: 0x100750},
}, {
seen: []block{
{start: 0x100250, end: 0x100750},
},
expectedSeen: []block{
{start: 0x100000, end: 0x100750},
},
expectedSize: 0x250,
update: block{start: 0x100000, end: 0x100500},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
},
expectedSeen: []block{
{start: 0x1000, end: 0x1750},
},
expectedSize: 0x2B0,
update: block{start: 0x1200, end: 0x1700},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F50, end: 0x21A0},
},
expectedSeen: []block{
{start: 0xF00, end: 0x1F00},
{start: 0x1F50, end: 0x21A0},
},
expectedSize: 0xB60,
update: block{start: 0xF00, end: 0x1F00},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSeen: []block{
{start: 0xF00, end: 0x2150},
},
expectedSize: 0xB60,
update: block{start: 0xF00, end: 0x1F00},
}, {
seen: []block{
{start: 0x1000, end: 0x1250},
{start: 0x1500, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSeen: []block{
{start: 0x1000, end: 0x1750},
{start: 0x1F00, end: 0x2150},
},
expectedSize: 0x2B0,
update: block{start: 0x1250, end: 0x1500},
}}
for i, tcase := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
seen, size := updateSeenBlocks(tcase.update, tcase.seen)
if !test.DeepEqual(seen, tcase.expectedSeen) {
t.Fatalf("seen blocks %x for iterration %v are different than the "+
"one expected:\n %x", seen, i, tcase.expectedSeen)
}
if size != tcase.expectedSize {
t.Fatalf("Size does not match, expected 0x%x got 0x%x",
tcase.expectedSize, size)
}
})
}
}

View File

@ -5,7 +5,6 @@
package test package test
import ( import (
"regexp"
"testing" "testing"
"github.com/aristanetworks/goarista/key" "github.com/aristanetworks/goarista/key"
@ -357,21 +356,21 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
b: complexCompare{}, b: complexCompare{},
}, { }, {
a: complexCompare{ a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}}, m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{ b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}}, m: map[builtinCompare]int8{{1, "foo"}: 42}},
}, { }, {
a: complexCompare{ a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}}, m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{ b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 51}}, m: map[builtinCompare]int8{{1, "foo"}: 51}},
diff: `attributes "m" are different: for key test.builtinCompare{a:uint32(1),` + diff: `attributes "m" are different: for key test.builtinCompare{a:uint32(1),` +
` b:"foo"} in map, values are different: int8(42) != int8(51)`, ` b:"foo"} in map, values are different: int8(42) != int8(51)`,
}, { }, {
a: complexCompare{ a: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "foo"}: 42}}, m: map[builtinCompare]int8{{1, "foo"}: 42}},
b: complexCompare{ b: complexCompare{
m: map[builtinCompare]int8{builtinCompare{1, "bar"}: 42}}, m: map[builtinCompare]int8{{1, "bar"}: 42}},
diff: `attributes "m" are different: key test.builtinCompare{a:uint32(1),` + diff: `attributes "m" are different: key test.builtinCompare{a:uint32(1),` +
` b:"foo"} in map is missing in the actual map`, ` b:"foo"} in map is missing in the actual map`,
}, { }, {
@ -404,16 +403,16 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
a: partialCompare{a: 42, b: "foo"}, a: partialCompare{a: 42, b: "foo"},
b: partialCompare{a: 42, b: "bar"}, b: partialCompare{a: 42, b: "bar"},
}, { }, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42}, a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42}, b: map[*builtinCompare]uint32{{1, "foo"}: 42},
}, { }, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42}, a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{2, "foo"}: 42}, b: map[*builtinCompare]uint32{{2, "foo"}: 42},
diff: `complex key *test.builtinCompare{a:uint32(1), b:"foo"}` + diff: `complex key *test.builtinCompare{a:uint32(1), b:"foo"}` +
` in map is missing in the actual map`, ` in map is missing in the actual map`,
}, { }, {
a: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 42}, a: map[*builtinCompare]uint32{{1, "foo"}: 42},
b: map[*builtinCompare]uint32{&builtinCompare{1, "foo"}: 51}, b: map[*builtinCompare]uint32{{1, "foo"}: 51},
diff: `for complex key *test.builtinCompare{a:uint32(1), b:"foo"}` + diff: `for complex key *test.builtinCompare{a:uint32(1), b:"foo"}` +
` in map, values are different: uint32(42) != uint32(51)`, ` in map, values are different: uint32(42) != uint32(51)`,
}, { }, {
@ -436,10 +435,11 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
b: key.New(map[string]interface{}{ b: key.New(map[string]interface{}{
"a": map[key.Key]interface{}{key.New(map[string]interface{}{"k": 51}): true}}), "a": map[key.Key]interface{}{key.New(map[string]interface{}{"k": 51}): true}}),
diff: `Comparable types are different: ` + diff: `Comparable types are different: ` +
`key.composite{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` + `key.compositeKey{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` +
`{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}} vs` + `{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}, s:[]interface {}{}}` +
` key.composite{sentinel:uintptr(18379810577513696751), m:map[string]interface {}` + ` vs key.compositeKey{sentinel:uintptr(18379810577513696751), ` +
`{"a":map[key.Key]interface {}{<max_depth>:<max_depth>}}}`, `m:map[string]interface {}{"a":map[key.Key]interface {}` +
`{<max_depth>:<max_depth>}}, s:[]interface {}{}}`,
}, { }, {
a: code(42), a: code(42),
b: code(42), b: code(42),
@ -464,8 +464,5 @@ func getDeepEqualTests(t *testing.T) []deepEqualTestCase {
}, { }, {
a: embedder{builtinCompare: builtinCompare{}}, a: embedder{builtinCompare: builtinCompare{}},
b: embedder{builtinCompare: builtinCompare{}}, b: embedder{builtinCompare: builtinCompare{}},
}, {
a: regexp.MustCompile("foo.*bar"),
b: regexp.MustCompile("foo.*bar"),
}} }}
} }

View File

@ -6,12 +6,14 @@ package test
import ( import (
"io" "io"
"io/ioutil"
"os" "os"
"testing" "testing"
) )
// CopyFile copies a file // CopyFile copies a file
func CopyFile(t *testing.T, srcPath, dstPath string) { func CopyFile(t *testing.T, srcPath, dstPath string) {
t.Helper()
src, err := os.Open(srcPath) src, err := os.Open(srcPath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -27,3 +29,14 @@ func CopyFile(t *testing.T, srcPath, dstPath string) {
t.Fatal(err) t.Fatal(err)
} }
} }
// TempDir creates a temporary directory under the default directory for temporary files (see
// os.TempDir) and returns the path of the new directory or fails the test trying.
func TempDir(t *testing.T, dirName string) string {
t.Helper()
tempDir, err := ioutil.TempDir("", dirName)
if err != nil {
t.Fatal(err)
}
return tempDir
}

View File

@ -12,7 +12,9 @@ import (
// ShouldPanic will test is a function is panicking // ShouldPanic will test is a function is panicking
func ShouldPanic(t *testing.T, fn func()) { func ShouldPanic(t *testing.T, fn func()) {
t.Helper()
defer func() { defer func() {
t.Helper()
if r := recover(); r == nil { if r := recover(); r == nil {
t.Errorf("%sThe function %p should have panicked", t.Errorf("%sThe function %p should have panicked",
getCallerInfo(), fn) getCallerInfo(), fn)
@ -24,10 +26,12 @@ func ShouldPanic(t *testing.T, fn func()) {
// ShouldPanicWith will test is a function is panicking with a specific message // ShouldPanicWith will test is a function is panicking with a specific message
func ShouldPanicWith(t *testing.T, msg interface{}, fn func()) { func ShouldPanicWith(t *testing.T, msg interface{}, fn func()) {
t.Helper()
defer func() { defer func() {
t.Helper()
if r := recover(); r == nil { if r := recover(); r == nil {
t.Errorf("%sThe function %p should have panicked", t.Errorf("%sThe function %p should have panicked with %#v",
getCallerInfo(), fn) getCallerInfo(), fn, msg)
} else if d := Diff(msg, r); len(d) != 0 { } else if d := Diff(msg, r); len(d) != 0 {
t.Errorf("%sThe function %p panicked with the wrong message.\n"+ t.Errorf("%sThe function %p panicked with the wrong message.\n"+
"Expected: %#v\nReceived: %#v\nDiff:%s", "Expected: %#v\nReceived: %#v\nDiff:%s",

View File

@ -1,7 +1,7 @@
language: go language: go
go: go:
- 1.8.x - "1.9.5"
- 1.9.x - "1.10.1"
sudo: false sudo: false
install: install:
- GLIDE_TAG=v0.12.3 - GLIDE_TAG=v0.12.3
@ -10,8 +10,8 @@ install:
- export PATH=$PATH:$PWD/linux-amd64/ - export PATH=$PATH:$PWD/linux-amd64/
- glide install - glide install
- go install . ./cmd/... - go install . ./cmd/...
- go get -v github.com/alecthomas/gometalinter - go get -u gopkg.in/alecthomas/gometalinter.v2
- gometalinter --install - gometalinter.v2 --install
script: script:
- export PATH=$PATH:$HOME/gopath/bin - export PATH=$PATH:$HOME/gopath/bin
- ./goclean.sh - ./goclean.sh

View File

@ -570,8 +570,8 @@ Changes in 0.8.0-beta (Sun May 25 2014)
- btcctl utility changes: - btcctl utility changes:
- Add createencryptedwallet command - Add createencryptedwallet command
- Add getblockchaininfo command - Add getblockchaininfo command
- Add importwallet commmand - Add importwallet command
- Add addmultisigaddress commmand - Add addmultisigaddress command
- Add setgenerate command - Add setgenerate command
- Accept --testnet and --wallet flags which automatically select - Accept --testnet and --wallet flags which automatically select
the appropriate port and TLS certificates needed to communicate the appropriate port and TLS certificates needed to communicate

View File

@ -92,7 +92,7 @@ $ go install . ./cmd/...
## Getting Started ## Getting Started
btcd has several configuration options avilable to tweak how it runs, but all btcd has several configuration options available to tweak how it runs, but all
of the basic operations described in the intro section work with zero of the basic operations described in the intro section work with zero
configuration. configuration.

View File

@ -63,7 +63,7 @@ is by no means exhaustive:
* [ProcessBlock Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BlockChain-ProcessBlock) * [ProcessBlock Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BlockChain-ProcessBlock)
Demonstrates how to create a new chain instance and use ProcessBlock to Demonstrates how to create a new chain instance and use ProcessBlock to
attempt to attempt add a block to the chain. This example intentionally attempt to add a block to the chain. This example intentionally
attempts to insert a duplicate genesis block to illustrate how an invalid attempts to insert a duplicate genesis block to illustrate how an invalid
block is handled. block is handled.
@ -73,7 +73,7 @@ is by no means exhaustive:
typical hex notation. typical hex notation.
* [BigToCompact Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BigToCompact) * [BigToCompact Example](http://godoc.org/github.com/btcsuite/btcd/blockchain#example-BigToCompact)
Demonstrates how to convert how to convert a target difficulty into the Demonstrates how to convert a target difficulty into the
compact "bits" in a block header which represent that target difficulty. compact "bits" in a block header which represent that target difficulty.
## GPG Verification Key ## GPG Verification Key

View File

@ -54,23 +54,24 @@ func (b *BlockChain) maybeAcceptBlock(block *btcutil.Block, flags BehaviorFlags)
// such as making blocks that never become part of the main chain or // such as making blocks that never become part of the main chain or
// blocks that fail to connect available for further analysis. // blocks that fail to connect available for further analysis.
err = b.db.Update(func(dbTx database.Tx) error { err = b.db.Update(func(dbTx database.Tx) error {
return dbMaybeStoreBlock(dbTx, block) return dbStoreBlock(dbTx, block)
}) })
if err != nil { if err != nil {
return false, err return false, err
} }
// Create a new block node for the block and add it to the in-memory // Create a new block node for the block and add it to the node index. Even
// block chain (could be either a side chain or the main chain). // if the block ultimately gets connected to the main chain, it starts out
// on a side chain.
blockHeader := &block.MsgBlock().Header blockHeader := &block.MsgBlock().Header
newNode := newBlockNode(blockHeader, blockHeight) newNode := newBlockNode(blockHeader, prevNode)
newNode.status = statusDataStored newNode.status = statusDataStored
if prevNode != nil {
newNode.parent = prevNode
newNode.height = blockHeight
newNode.workSum.Add(prevNode.workSum, newNode.workSum)
}
b.index.AddNode(newNode) b.index.AddNode(newNode)
err = b.index.flushToDB()
if err != nil {
return false, err
}
// Connect the passed block to the chain while respecting proper chain // Connect the passed block to the chain while respecting proper chain
// selection according to the chain with the most proof of work. This // selection according to the chain with the most proof of work. This

View File

@ -101,33 +101,33 @@ type blockNode struct {
status blockStatus status blockStatus
} }
// initBlockNode initializes a block node from the given header and height. The // initBlockNode initializes a block node from the given header and parent node,
// node is completely disconnected from the chain and the workSum value is just // calculating the height and workSum from the respective fields on the parent.
// the work for the passed block. The work sum must be updated accordingly when
// the node is inserted into a chain.
//
// This function is NOT safe for concurrent access. It must only be called when // This function is NOT safe for concurrent access. It must only be called when
// initially creating a node. // initially creating a node.
func initBlockNode(node *blockNode, blockHeader *wire.BlockHeader, height int32) { func initBlockNode(node *blockNode, blockHeader *wire.BlockHeader, parent *blockNode) {
*node = blockNode{ *node = blockNode{
hash: blockHeader.BlockHash(), hash: blockHeader.BlockHash(),
workSum: CalcWork(blockHeader.Bits), workSum: CalcWork(blockHeader.Bits),
height: height,
version: blockHeader.Version, version: blockHeader.Version,
bits: blockHeader.Bits, bits: blockHeader.Bits,
nonce: blockHeader.Nonce, nonce: blockHeader.Nonce,
timestamp: blockHeader.Timestamp.Unix(), timestamp: blockHeader.Timestamp.Unix(),
merkleRoot: blockHeader.MerkleRoot, merkleRoot: blockHeader.MerkleRoot,
} }
if parent != nil {
node.parent = parent
node.height = parent.height + 1
node.workSum = node.workSum.Add(parent.workSum, node.workSum)
}
} }
// newBlockNode returns a new block node for the given block header. It is // newBlockNode returns a new block node for the given block header and parent
// completely disconnected from the chain and the workSum value is just the work // node, calculating the height and workSum from the respective fields on the
// for the passed block. The work sum must be updated accordingly when the node // parent. This function is NOT safe for concurrent access.
// is inserted into a chain. func newBlockNode(blockHeader *wire.BlockHeader, parent *blockNode) *blockNode {
func newBlockNode(blockHeader *wire.BlockHeader, height int32) *blockNode {
var node blockNode var node blockNode
initBlockNode(&node, blockHeader, height) initBlockNode(&node, blockHeader, parent)
return &node return &node
} }
@ -136,7 +136,7 @@ func newBlockNode(blockHeader *wire.BlockHeader, height int32) *blockNode {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (node *blockNode) Header() wire.BlockHeader { func (node *blockNode) Header() wire.BlockHeader {
// No lock is needed because all accessed fields are immutable. // No lock is needed because all accessed fields are immutable.
prevHash := zeroHash prevHash := &zeroHash
if node.parent != nil { if node.parent != nil {
prevHash = &node.parent.hash prevHash = &node.parent.hash
} }
@ -231,6 +231,7 @@ type blockIndex struct {
sync.RWMutex sync.RWMutex
index map[chainhash.Hash]*blockNode index map[chainhash.Hash]*blockNode
dirty map[*blockNode]struct{}
} }
// newBlockIndex returns a new empty instance of a block index. The index will // newBlockIndex returns a new empty instance of a block index. The index will
@ -241,6 +242,7 @@ func newBlockIndex(db database.DB, chainParams *chaincfg.Params) *blockIndex {
db: db, db: db,
chainParams: chainParams, chainParams: chainParams,
index: make(map[chainhash.Hash]*blockNode), index: make(map[chainhash.Hash]*blockNode),
dirty: make(map[*blockNode]struct{}),
} }
} }
@ -265,16 +267,25 @@ func (bi *blockIndex) LookupNode(hash *chainhash.Hash) *blockNode {
return node return node
} }
// AddNode adds the provided node to the block index. Duplicate entries are not // AddNode adds the provided node to the block index and marks it as dirty.
// checked so it is up to caller to avoid adding them. // Duplicate entries are not checked so it is up to caller to avoid adding them.
// //
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) AddNode(node *blockNode) { func (bi *blockIndex) AddNode(node *blockNode) {
bi.Lock() bi.Lock()
bi.index[node.hash] = node bi.addNode(node)
bi.dirty[node] = struct{}{}
bi.Unlock() bi.Unlock()
} }
// addNode adds the provided node to the block index, but does not mark it as
// dirty. This can be used while initializing the block index.
//
// This function is NOT safe for concurrent access.
func (bi *blockIndex) addNode(node *blockNode) {
bi.index[node.hash] = node
}
// NodeStatus provides concurrent-safe access to the status field of a node. // NodeStatus provides concurrent-safe access to the status field of a node.
// //
// This function is safe for concurrent access. // This function is safe for concurrent access.
@ -293,6 +304,7 @@ func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus {
func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) { func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock() bi.Lock()
node.status |= flags node.status |= flags
bi.dirty[node] = struct{}{}
bi.Unlock() bi.Unlock()
} }
@ -303,5 +315,34 @@ func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) { func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock() bi.Lock()
node.status &^= flags node.status &^= flags
bi.dirty[node] = struct{}{}
bi.Unlock() bi.Unlock()
} }
// flushToDB writes all dirty block nodes to the database. If all writes
// succeed, this clears the dirty set.
func (bi *blockIndex) flushToDB() error {
bi.Lock()
if len(bi.dirty) == 0 {
bi.Unlock()
return nil
}
err := bi.db.Update(func(dbTx database.Tx) error {
for node := range bi.dirty {
err := dbStoreBlockNode(dbTx, node)
if err != nil {
return err
}
}
return nil
})
// If write was successful, clear the dirty set.
if err == nil {
bi.dirty = make(map[*blockNode]struct{})
}
bi.Unlock()
return err
}

View File

@ -1,4 +1,5 @@
// Copyright (c) 2013-2017 The btcsuite developers // Copyright (c) 2013-2018 The btcsuite developers
// Copyright (c) 2015-2018 The Decred developers
// Use of this source code is governed by an ISC // Use of this source code is governed by an ISC
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -399,7 +400,7 @@ func (b *BlockChain) calcSequenceLock(node *blockNode, tx *btcutil.Tx, utxoView
nextHeight := node.height + 1 nextHeight := node.height + 1
for txInIndex, txIn := range mTx.TxIn { for txInIndex, txIn := range mTx.TxIn {
utxo := utxoView.LookupEntry(&txIn.PreviousOutPoint.Hash) utxo := utxoView.LookupEntry(txIn.PreviousOutPoint)
if utxo == nil { if utxo == nil {
str := fmt.Sprintf("output %v referenced from "+ str := fmt.Sprintf("output %v referenced from "+
"transaction %s:%d either does not exist or "+ "transaction %s:%d either does not exist or "+
@ -495,6 +496,8 @@ func LockTimeToSequence(isSeconds bool, locktime uint32) uint32 {
// passed node is the new end of the main chain. The lists will be empty if the // passed node is the new end of the main chain. The lists will be empty if the
// passed node is not on a side chain. // passed node is not on a side chain.
// //
// This function may modify node statuses in the block index without flushing.
//
// This function MUST be called with the chain state lock held (for reads). // This function MUST be called with the chain state lock held (for reads).
func (b *BlockChain) getReorganizeNodes(node *blockNode) (*list.List, *list.List) { func (b *BlockChain) getReorganizeNodes(node *blockNode) (*list.List, *list.List) {
attachNodes := list.New() attachNodes := list.New()
@ -544,20 +547,6 @@ func (b *BlockChain) getReorganizeNodes(node *blockNode) (*list.List, *list.List
return detachNodes, attachNodes return detachNodes, attachNodes
} }
// dbMaybeStoreBlock stores the provided block in the database if it's not
// already there.
func dbMaybeStoreBlock(dbTx database.Tx, block *btcutil.Block) error {
hasBlock, err := dbTx.HasBlock(block.Hash())
if err != nil {
return err
}
if hasBlock {
return nil
}
return dbTx.StoreBlock(block)
}
// connectBlock handles connecting the passed node/block to the end of the main // connectBlock handles connecting the passed node/block to the end of the main
// (best) chain. // (best) chain.
// //
@ -569,7 +558,9 @@ func dbMaybeStoreBlock(dbTx database.Tx, block *btcutil.Block) error {
// it would be inefficient to repeat it. // it would be inefficient to repeat it.
// //
// This function MUST be called with the chain state lock held (for writes). // This function MUST be called with the chain state lock held (for writes).
func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *UtxoViewpoint, stxos []spentTxOut) error { func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block,
view *UtxoViewpoint, stxos []SpentTxOut) error {
// Make sure it's extending the end of the best chain. // Make sure it's extending the end of the best chain.
prevHash := &block.MsgBlock().Header.PrevBlock prevHash := &block.MsgBlock().Header.PrevBlock
if !prevHash.IsEqual(&b.bestChain.Tip().hash) { if !prevHash.IsEqual(&b.bestChain.Tip().hash) {
@ -599,6 +590,12 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
} }
} }
// Write any block status changes to DB before updating best state.
err := b.index.flushToDB()
if err != nil {
return err
}
// Generate a new best state snapshot that will be used to update the // Generate a new best state snapshot that will be used to update the
// database and later memory if all database updates are successful. // database and later memory if all database updates are successful.
b.stateLock.RLock() b.stateLock.RLock()
@ -611,7 +608,7 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
curTotalTxns+numTxns, node.CalcPastMedianTime()) curTotalTxns+numTxns, node.CalcPastMedianTime())
// Atomically insert info into the database. // Atomically insert info into the database.
err := b.db.Update(func(dbTx database.Tx) error { err = b.db.Update(func(dbTx database.Tx) error {
// Update best block state. // Update best block state.
err := dbPutBestState(dbTx, state, node.workSum) err := dbPutBestState(dbTx, state, node.workSum)
if err != nil { if err != nil {
@ -644,7 +641,7 @@ func (b *BlockChain) connectBlock(node *blockNode, block *btcutil.Block, view *U
// optional indexes with the block being connected so they can // optional indexes with the block being connected so they can
// update themselves accordingly. // update themselves accordingly.
if b.indexManager != nil { if b.indexManager != nil {
err := b.indexManager.ConnectBlock(dbTx, block, view) err := b.indexManager.ConnectBlock(dbTx, block, stxos)
if err != nil { if err != nil {
return err return err
} }
@ -705,6 +702,12 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
return err return err
} }
// Write any block status changes to DB before updating best state.
err = b.index.flushToDB()
if err != nil {
return err
}
// Generate a new best state snapshot that will be used to update the // Generate a new best state snapshot that will be used to update the
// database and later memory if all database updates are successful. // database and later memory if all database updates are successful.
b.stateLock.RLock() b.stateLock.RLock()
@ -739,8 +742,15 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
return err return err
} }
// Before we delete the spend journal entry for this back,
// we'll fetch it as is so the indexers can utilize if needed.
stxos, err := dbFetchSpendJournalEntry(dbTx, block)
if err != nil {
return err
}
// Update the transaction spend journal by removing the record // Update the transaction spend journal by removing the record
// that contains all txos spent by the block . // that contains all txos spent by the block.
err = dbRemoveSpendJournalEntry(dbTx, block.Hash()) err = dbRemoveSpendJournalEntry(dbTx, block.Hash())
if err != nil { if err != nil {
return err return err
@ -750,7 +760,7 @@ func (b *BlockChain) disconnectBlock(node *blockNode, block *btcutil.Block, view
// optional indexes with the block being disconnected so they // optional indexes with the block being disconnected so they
// can update themselves accordingly. // can update themselves accordingly.
if b.indexManager != nil { if b.indexManager != nil {
err := b.indexManager.DisconnectBlock(dbTx, block, view) err := b.indexManager.DisconnectBlock(dbTx, block, stxos)
if err != nil { if err != nil {
return err return err
} }
@ -806,15 +816,49 @@ func countSpentOutputs(block *btcutil.Block) int {
// the chain) and nodes the are being attached must be in forwards order // the chain) and nodes the are being attached must be in forwards order
// (think pushing them onto the end of the chain). // (think pushing them onto the end of the chain).
// //
// This function may modify node statuses in the block index without flushing.
//
// This function MUST be called with the chain state lock held (for writes). // This function MUST be called with the chain state lock held (for writes).
func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error { func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error {
// Nothing to do if no reorganize nodes were provided.
if detachNodes.Len() == 0 && attachNodes.Len() == 0 {
return nil
}
// Ensure the provided nodes match the current best chain.
tip := b.bestChain.Tip()
if detachNodes.Len() != 0 {
firstDetachNode := detachNodes.Front().Value.(*blockNode)
if firstDetachNode.hash != tip.hash {
return AssertError(fmt.Sprintf("reorganize nodes to detach are "+
"not for the current best chain -- first detach node %v, "+
"current chain %v", &firstDetachNode.hash, &tip.hash))
}
}
// Ensure the provided nodes are for the same fork point.
if attachNodes.Len() != 0 && detachNodes.Len() != 0 {
firstAttachNode := attachNodes.Front().Value.(*blockNode)
lastDetachNode := detachNodes.Back().Value.(*blockNode)
if firstAttachNode.parent.hash != lastDetachNode.parent.hash {
return AssertError(fmt.Sprintf("reorganize nodes do not have the "+
"same fork point -- first attach parent %v, last detach "+
"parent %v", &firstAttachNode.parent.hash,
&lastDetachNode.parent.hash))
}
}
// Track the old and new best chains heads.
oldBest := tip
newBest := tip
// All of the blocks to detach and related spend journal entries needed // All of the blocks to detach and related spend journal entries needed
// to unspend transaction outputs in the blocks being disconnected must // to unspend transaction outputs in the blocks being disconnected must
// be loaded from the database during the reorg check phase below and // be loaded from the database during the reorg check phase below and
// then they are needed again when doing the actual database updates. // then they are needed again when doing the actual database updates.
// Rather than doing two loads, cache the loaded data into these slices. // Rather than doing two loads, cache the loaded data into these slices.
detachBlocks := make([]*btcutil.Block, 0, detachNodes.Len()) detachBlocks := make([]*btcutil.Block, 0, detachNodes.Len())
detachSpentTxOuts := make([][]spentTxOut, 0, detachNodes.Len()) detachSpentTxOuts := make([][]SpentTxOut, 0, detachNodes.Len())
attachBlocks := make([]*btcutil.Block, 0, attachNodes.Len()) attachBlocks := make([]*btcutil.Block, 0, attachNodes.Len())
// Disconnect all of the blocks back to the point of the fork. This // Disconnect all of the blocks back to the point of the fork. This
@ -822,7 +866,7 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// database and using that information to unspend all of the spent txos // database and using that information to unspend all of the spent txos
// and remove the utxos created by the blocks. // and remove the utxos created by the blocks.
view := NewUtxoViewpoint() view := NewUtxoViewpoint()
view.SetBestHash(&b.bestChain.Tip().hash) view.SetBestHash(&oldBest.hash)
for e := detachNodes.Front(); e != nil; e = e.Next() { for e := detachNodes.Front(); e != nil; e = e.Next() {
n := e.Value.(*blockNode) n := e.Value.(*blockNode)
var block *btcutil.Block var block *btcutil.Block
@ -834,6 +878,11 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
if err != nil { if err != nil {
return err return err
} }
if n.hash != *block.Hash() {
return AssertError(fmt.Sprintf("detach block node hash %v (height "+
"%v) does not match previous parent block hash %v", &n.hash,
n.height, block.Hash()))
}
// Load all of the utxos referenced by the block that aren't // Load all of the utxos referenced by the block that aren't
// already in the view. // already in the view.
@ -844,9 +893,9 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Load all of the spent txos for the block from the spend // Load all of the spent txos for the block from the spend
// journal. // journal.
var stxos []spentTxOut var stxos []SpentTxOut
err = b.db.View(func(dbTx database.Tx) error { err = b.db.View(func(dbTx database.Tx) error {
stxos, err = dbFetchSpendJournalEntry(dbTx, block, view) stxos, err = dbFetchSpendJournalEntry(dbTx, block)
return err return err
}) })
if err != nil { if err != nil {
@ -857,10 +906,19 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
detachBlocks = append(detachBlocks, block) detachBlocks = append(detachBlocks, block)
detachSpentTxOuts = append(detachSpentTxOuts, stxos) detachSpentTxOuts = append(detachSpentTxOuts, stxos)
err = view.disconnectTransactions(block, stxos) err = view.disconnectTransactions(b.db, block, stxos)
if err != nil { if err != nil {
return err return err
} }
newBest = n.parent
}
// Set the fork point only if there are nodes to attach since otherwise
// blocks are only being disconnected and thus there is no fork point.
var forkNode *blockNode
if attachNodes.Len() > 0 {
forkNode = newBest
} }
// Perform several checks to verify each block that needs to be attached // Perform several checks to verify each block that needs to be attached
@ -875,17 +933,9 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// at least a couple of ways accomplish that rollback, but both involve // at least a couple of ways accomplish that rollback, but both involve
// tweaking the chain and/or database. This approach catches these // tweaking the chain and/or database. This approach catches these
// issues before ever modifying the chain. // issues before ever modifying the chain.
var validationError error
for e := attachNodes.Front(); e != nil; e = e.Next() { for e := attachNodes.Front(); e != nil; e = e.Next() {
n := e.Value.(*blockNode) n := e.Value.(*blockNode)
// If any previous nodes in attachNodes failed validation,
// mark this one as having an invalid ancestor.
if validationError != nil {
b.index.SetStatusFlags(n, statusInvalidAncestor)
continue
}
var block *btcutil.Block var block *btcutil.Block
err := b.db.View(func(dbTx database.Tx) error { err := b.db.View(func(dbTx database.Tx) error {
var err error var err error
@ -911,6 +961,8 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
if err != nil { if err != nil {
return err return err
} }
newBest = n
continue continue
} }
@ -918,23 +970,24 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// thus will not be generated. This is done because the state // thus will not be generated. This is done because the state
// is not being immediately written to the database, so it is // is not being immediately written to the database, so it is
// not needed. // not needed.
//
// In the case the block is determined to be invalid due to a
// rule violation, mark it as invalid and mark all of its
// descendants as having an invalid ancestor.
err = b.checkConnectBlock(n, block, view, nil) err = b.checkConnectBlock(n, block, view, nil)
if err != nil { if err != nil {
// If the block failed validation mark it as invalid, then
// continue to loop through remaining nodes, marking them as
// having an invalid ancestor.
if _, ok := err.(RuleError); ok { if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(n, statusValidateFailed) b.index.SetStatusFlags(n, statusValidateFailed)
validationError = err for de := e.Next(); de != nil; de = de.Next() {
continue dn := de.Value.(*blockNode)
b.index.SetStatusFlags(dn, statusInvalidAncestor)
}
} }
return err return err
} }
b.index.SetStatusFlags(n, statusValid) b.index.SetStatusFlags(n, statusValid)
}
if validationError != nil { newBest = n
return validationError
} }
// Reset the view for the actual connection code below. This is // Reset the view for the actual connection code below. This is
@ -959,7 +1012,8 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Update the view to unspend all of the spent txos and remove // Update the view to unspend all of the spent txos and remove
// the utxos created by the block. // the utxos created by the block.
err = view.disconnectTransactions(block, detachSpentTxOuts[i]) err = view.disconnectTransactions(b.db, block,
detachSpentTxOuts[i])
if err != nil { if err != nil {
return err return err
} }
@ -987,7 +1041,7 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// as spent and add all transactions being created by this block // as spent and add all transactions being created by this block
// to it. Also, provide an stxo slice so the spent txout // to it. Also, provide an stxo slice so the spent txout
// details are generated. // details are generated.
stxos := make([]spentTxOut, 0, countSpentOutputs(block)) stxos := make([]SpentTxOut, 0, countSpentOutputs(block))
err = view.connectTransactions(block, &stxos) err = view.connectTransactions(block, &stxos)
if err != nil { if err != nil {
return err return err
@ -1002,12 +1056,14 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
// Log the point where the chain forked and old and new best chain // Log the point where the chain forked and old and new best chain
// heads. // heads.
firstAttachNode := attachNodes.Front().Value.(*blockNode) if forkNode != nil {
firstDetachNode := detachNodes.Front().Value.(*blockNode) log.Infof("REORGANIZE: Chain forks at %v (height %v)", forkNode.hash,
lastAttachNode := attachNodes.Back().Value.(*blockNode) forkNode.height)
log.Infof("REORGANIZE: Chain forks at %v", firstAttachNode.parent.hash) }
log.Infof("REORGANIZE: Old best chain head was %v", firstDetachNode.hash) log.Infof("REORGANIZE: Old best chain head was %v (height %v)",
log.Infof("REORGANIZE: New best chain head is %v", lastAttachNode.hash) &oldBest.hash, oldBest.height)
log.Infof("REORGANIZE: New best chain head is %v (height %v)",
newBest.hash, newBest.height)
return nil return nil
} }
@ -1029,6 +1085,17 @@ func (b *BlockChain) reorganizeChain(detachNodes, attachNodes *list.List) error
func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, flags BehaviorFlags) (bool, error) { func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, flags BehaviorFlags) (bool, error) {
fastAdd := flags&BFFastAdd == BFFastAdd fastAdd := flags&BFFastAdd == BFFastAdd
flushIndexState := func() {
// Intentionally ignore errors writing updated node status to DB. If
// it fails to write, it's not the end of the world. If the block is
// valid, we flush in connectBlock and if the block is invalid, the
// worst that can happen is we revalidate the block after a restart.
if writeErr := b.index.flushToDB(); writeErr != nil {
log.Warnf("Error flushing block index changes to disk: %v",
writeErr)
}
}
// We are extending the main (best) chain with a new block. This is the // We are extending the main (best) chain with a new block. This is the
// most common case. // most common case.
parentHash := &block.MsgBlock().Header.PrevBlock parentHash := &block.MsgBlock().Header.PrevBlock
@ -1041,16 +1108,22 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// actually connecting the block. // actually connecting the block.
view := NewUtxoViewpoint() view := NewUtxoViewpoint()
view.SetBestHash(parentHash) view.SetBestHash(parentHash)
stxos := make([]spentTxOut, 0, countSpentOutputs(block)) stxos := make([]SpentTxOut, 0, countSpentOutputs(block))
if !fastAdd { if !fastAdd {
err := b.checkConnectBlock(node, block, view, &stxos) err := b.checkConnectBlock(node, block, view, &stxos)
if err != nil { if err == nil {
if _, ok := err.(RuleError); ok { b.index.SetStatusFlags(node, statusValid)
} else if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(node, statusValidateFailed) b.index.SetStatusFlags(node, statusValidateFailed)
} } else {
return false, err
}
flushIndexState()
if err != nil {
return false, err return false, err
} }
b.index.SetStatusFlags(node, statusValid)
} }
// In the fast add case the code to check the block connection // In the fast add case the code to check the block connection
@ -1071,9 +1144,28 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// Connect the block to the main chain. // Connect the block to the main chain.
err := b.connectBlock(node, block, view, stxos) err := b.connectBlock(node, block, view, stxos)
if err != nil { if err != nil {
// If we got hit with a rule error, then we'll mark
// that status of the block as invalid and flush the
// index state to disk before returning with the error.
if _, ok := err.(RuleError); ok {
b.index.SetStatusFlags(
node, statusValidateFailed,
)
}
flushIndexState()
return false, err return false, err
} }
// If this is fast add, or this block node isn't yet marked as
// valid, then we'll update its status and flush the state to
// disk again.
if fastAdd || !b.index.NodeStatus(node).KnownValid() {
b.index.SetStatusFlags(node, statusValid)
flushIndexState()
}
return true, nil return true, nil
} }
if fastAdd { if fastAdd {
@ -1111,11 +1203,16 @@ func (b *BlockChain) connectBestChain(node *blockNode, block *btcutil.Block, fla
// Reorganize the chain. // Reorganize the chain.
log.Infof("REORGANIZE: Block %v is causing a reorganize.", node.hash) log.Infof("REORGANIZE: Block %v is causing a reorganize.", node.hash)
err := b.reorganizeChain(detachNodes, attachNodes) err := b.reorganizeChain(detachNodes, attachNodes)
if err != nil {
return false, err // Either getReorganizeNodes or reorganizeChain could have made unsaved
// changes to the block index, so flush regardless of whether there was an
// error. The index would only be dirty if the block failed to connect, so
// we can ignore any errors writing.
if writeErr := b.index.flushToDB(); writeErr != nil {
log.Warnf("Error flushing block index changes to disk: %v", writeErr)
} }
return true, nil return err == nil, err
} }
// isCurrent returns whether or not the chain believes it is current. Several // isCurrent returns whether or not the chain believes it is current. Several
@ -1168,25 +1265,17 @@ func (b *BlockChain) BestSnapshot() *BestState {
return snapshot return snapshot
} }
// FetchHeader returns the block header identified by the given hash or an error // HeaderByHash returns the block header identified by the given hash or an
// if it doesn't exist. // error if it doesn't exist. Note that this will return headers from both the
func (b *BlockChain) FetchHeader(hash *chainhash.Hash) (wire.BlockHeader, error) { // main and side chains.
// Reconstruct the header from the block index if possible. func (b *BlockChain) HeaderByHash(hash *chainhash.Hash) (wire.BlockHeader, error) {
if node := b.index.LookupNode(hash); node != nil { node := b.index.LookupNode(hash)
return node.Header(), nil if node == nil {
} err := fmt.Errorf("block %s is not known", hash)
// Fall back to loading it from the database.
var header *wire.BlockHeader
err := b.db.View(func(dbTx database.Tx) error {
var err error
header, err = dbFetchHeaderByHash(dbTx, hash)
return err
})
if err != nil {
return wire.BlockHeader{}, err return wire.BlockHeader{}, err
} }
return *header, nil
return node.Header(), nil
} }
// MainChainHasBlock returns whether or not the block with the given hash is in // MainChainHasBlock returns whether or not the block with the given hash is in
@ -1302,6 +1391,87 @@ func (b *BlockChain) HeightRange(startHeight, endHeight int32) ([]chainhash.Hash
return hashes, nil return hashes, nil
} }
// HeightToHashRange returns a range of block hashes for the given start height
// and end hash, inclusive on both ends. The hashes are for all blocks that are
// ancestors of endHash with height greater than or equal to startHeight. The
// end hash must belong to a block that is known to be valid.
//
// This function is safe for concurrent access.
func (b *BlockChain) HeightToHashRange(startHeight int32,
endHash *chainhash.Hash, maxResults int) ([]chainhash.Hash, error) {
endNode := b.index.LookupNode(endHash)
if endNode == nil {
return nil, fmt.Errorf("no known block header with hash %v", endHash)
}
if !b.index.NodeStatus(endNode).KnownValid() {
return nil, fmt.Errorf("block %v is not yet validated", endHash)
}
endHeight := endNode.height
if startHeight < 0 {
return nil, fmt.Errorf("start height (%d) is below 0", startHeight)
}
if startHeight > endHeight {
return nil, fmt.Errorf("start height (%d) is past end height (%d)",
startHeight, endHeight)
}
resultsLength := int(endHeight - startHeight + 1)
if resultsLength > maxResults {
return nil, fmt.Errorf("number of results (%d) would exceed max (%d)",
resultsLength, maxResults)
}
// Walk backwards from endHeight to startHeight, collecting block hashes.
node := endNode
hashes := make([]chainhash.Hash, resultsLength)
for i := resultsLength - 1; i >= 0; i-- {
hashes[i] = node.hash
node = node.parent
}
return hashes, nil
}
// IntervalBlockHashes returns hashes for all blocks that are ancestors of
// endHash where the block height is a positive multiple of interval.
//
// This function is safe for concurrent access.
func (b *BlockChain) IntervalBlockHashes(endHash *chainhash.Hash, interval int,
) ([]chainhash.Hash, error) {
endNode := b.index.LookupNode(endHash)
if endNode == nil {
return nil, fmt.Errorf("no known block header with hash %v", endHash)
}
if !b.index.NodeStatus(endNode).KnownValid() {
return nil, fmt.Errorf("block %v is not yet validated", endHash)
}
endHeight := endNode.height
resultsLength := int(endHeight) / interval
hashes := make([]chainhash.Hash, resultsLength)
b.bestChain.mtx.Lock()
defer b.bestChain.mtx.Unlock()
blockNode := endNode
for index := int(endHeight) / interval; index > 0; index-- {
// Use the bestChain chainView for faster lookups once lookup intersects
// the best chain.
blockHeight := int32(index * interval)
if b.bestChain.contains(blockNode) {
blockNode = b.bestChain.nodeByHeight(blockHeight)
} else {
blockNode = blockNode.Ancestor(blockHeight)
}
hashes[index-1] = blockNode.hash
}
return hashes, nil
}
// locateInventory returns the node of the block after the first known block in // locateInventory returns the node of the block after the first known block in
// the locator along with the number of subsequent nodes needed to either reach // the locator along with the number of subsequent nodes needed to either reach
// the provided stop hash or the provided max number of entries. // the provided stop hash or the provided max number of entries.
@ -1467,12 +1637,16 @@ type IndexManager interface {
Init(*BlockChain, <-chan struct{}) error Init(*BlockChain, <-chan struct{}) error
// ConnectBlock is invoked when a new block has been connected to the // ConnectBlock is invoked when a new block has been connected to the
// main chain. // main chain. The set of output spent within a block is also passed in
ConnectBlock(database.Tx, *btcutil.Block, *UtxoViewpoint) error // so indexers can access the previous output scripts input spent if
// required.
ConnectBlock(database.Tx, *btcutil.Block, []SpentTxOut) error
// DisconnectBlock is invoked when a block has been disconnected from // DisconnectBlock is invoked when a block has been disconnected from
// the main chain. // the main chain. The set of outputs scripts that were spent within
DisconnectBlock(database.Tx, *btcutil.Block, *UtxoViewpoint) error // this block is also returned so indexers can clean up the prior index
// state for this block.
DisconnectBlock(database.Tx, *btcutil.Block, []SpentTxOut) error
} }
// Config is a descriptor which specifies the blockchain instance configuration. // Config is a descriptor which specifies the blockchain instance configuration.
@ -1601,6 +1775,11 @@ func New(config *Config) (*BlockChain, error) {
return nil, err return nil, err
} }
// Perform any upgrades to the various chain-specific buckets as needed.
if err := b.maybeUpgradeDbBuckets(config.Interrupt); err != nil {
return nil, err
}
// Initialize and catch up all of the currently active optional indexes // Initialize and catch up all of the currently active optional indexes
// as needed. // as needed.
if config.IndexManager != nil { if config.IndexManager != nil {

View File

@ -800,3 +800,167 @@ func TestLocateInventory(t *testing.T) {
} }
} }
} }
// TestHeightToHashRange ensures that fetching a range of block hashes by start
// height and end hash works as expected.
func TestHeightToHashRange(t *testing.T) {
// Construct a synthetic block chain with a block index consisting of
// the following structure.
// genesis -> 1 -> 2 -> ... -> 15 -> 16 -> 17 -> 18
// \-> 16a -> 17a -> 18a (unvalidated)
tip := tstTip
chain := newFakeChain(&chaincfg.MainNetParams)
branch0Nodes := chainedNodes(chain.bestChain.Genesis(), 18)
branch1Nodes := chainedNodes(branch0Nodes[14], 3)
for _, node := range branch0Nodes {
chain.index.SetStatusFlags(node, statusValid)
chain.index.AddNode(node)
}
for _, node := range branch1Nodes {
if node.height < 18 {
chain.index.SetStatusFlags(node, statusValid)
}
chain.index.AddNode(node)
}
chain.bestChain.SetTip(tip(branch0Nodes))
tests := []struct {
name string
startHeight int32 // locator for requested inventory
endHash chainhash.Hash // stop hash for locator
maxResults int // max to locate, 0 = wire const
hashes []chainhash.Hash // expected located hashes
expectError bool
}{
{
name: "blocks below tip",
startHeight: 11,
endHash: branch0Nodes[14].hash,
maxResults: 10,
hashes: nodeHashes(branch0Nodes, 10, 11, 12, 13, 14),
},
{
name: "blocks on main chain",
startHeight: 15,
endHash: branch0Nodes[17].hash,
maxResults: 10,
hashes: nodeHashes(branch0Nodes, 14, 15, 16, 17),
},
{
name: "blocks on stale chain",
startHeight: 15,
endHash: branch1Nodes[1].hash,
maxResults: 10,
hashes: append(nodeHashes(branch0Nodes, 14),
nodeHashes(branch1Nodes, 0, 1)...),
},
{
name: "invalid start height",
startHeight: 19,
endHash: branch0Nodes[17].hash,
maxResults: 10,
expectError: true,
},
{
name: "too many results",
startHeight: 1,
endHash: branch0Nodes[17].hash,
maxResults: 10,
expectError: true,
},
{
name: "unvalidated block",
startHeight: 15,
endHash: branch1Nodes[2].hash,
maxResults: 10,
expectError: true,
},
}
for _, test := range tests {
hashes, err := chain.HeightToHashRange(test.startHeight, &test.endHash,
test.maxResults)
if err != nil {
if !test.expectError {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
continue
}
if !reflect.DeepEqual(hashes, test.hashes) {
t.Errorf("%s: unxpected hashes -- got %v, want %v",
test.name, hashes, test.hashes)
}
}
}
// TestIntervalBlockHashes ensures that fetching block hashes at specified
// intervals by end hash works as expected.
func TestIntervalBlockHashes(t *testing.T) {
// Construct a synthetic block chain with a block index consisting of
// the following structure.
// genesis -> 1 -> 2 -> ... -> 15 -> 16 -> 17 -> 18
// \-> 16a -> 17a -> 18a (unvalidated)
tip := tstTip
chain := newFakeChain(&chaincfg.MainNetParams)
branch0Nodes := chainedNodes(chain.bestChain.Genesis(), 18)
branch1Nodes := chainedNodes(branch0Nodes[14], 3)
for _, node := range branch0Nodes {
chain.index.SetStatusFlags(node, statusValid)
chain.index.AddNode(node)
}
for _, node := range branch1Nodes {
if node.height < 18 {
chain.index.SetStatusFlags(node, statusValid)
}
chain.index.AddNode(node)
}
chain.bestChain.SetTip(tip(branch0Nodes))
tests := []struct {
name string
endHash chainhash.Hash
interval int
hashes []chainhash.Hash
expectError bool
}{
{
name: "blocks on main chain",
endHash: branch0Nodes[17].hash,
interval: 8,
hashes: nodeHashes(branch0Nodes, 7, 15),
},
{
name: "blocks on stale chain",
endHash: branch1Nodes[1].hash,
interval: 8,
hashes: append(nodeHashes(branch0Nodes, 7),
nodeHashes(branch1Nodes, 0)...),
},
{
name: "no results",
endHash: branch0Nodes[17].hash,
interval: 20,
hashes: []chainhash.Hash{},
},
{
name: "unvalidated block",
endHash: branch1Nodes[2].hash,
interval: 8,
expectError: true,
},
}
for _, test := range tests {
hashes, err := chain.IntervalBlockHashes(&test.endHash, test.interval)
if err != nil {
if !test.expectError {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
continue
}
if !reflect.DeepEqual(hashes, test.hashes) {
t.Errorf("%s: unxpected hashes -- got %v, want %v",
test.name, hashes, test.hashes)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,6 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/database" "github.com/btcsuite/btcd/database"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
) )
@ -38,19 +37,6 @@ func TestErrNotInMainChain(t *testing.T) {
} }
} }
// maybeDecompress decompresses the amount and public key script fields of the
// stxo and marks it decompressed if needed.
func (o *spentTxOut) maybeDecompress(version int32) {
// Nothing to do if it's not compressed.
if !o.compressed {
return
}
o.amount = int64(decompressTxOutAmount(uint64(o.amount)))
o.pkScript = decompressScript(o.pkScript, version)
o.compressed = false
}
// TestStxoSerialization ensures serializing and deserializing spent transaction // TestStxoSerialization ensures serializing and deserializing spent transaction
// output entries works as expected. // output entries works as expected.
func TestStxoSerialization(t *testing.T) { func TestStxoSerialization(t *testing.T) {
@ -58,43 +44,38 @@ func TestStxoSerialization(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
stxo spentTxOut stxo SpentTxOut
txVersion int32 // When the txout is not fully spent.
serialized []byte serialized []byte
}{ }{
// From block 170 in main blockchain. // From block 170 in main blockchain.
{ {
name: "Spends last output of coinbase", name: "Spends last output of coinbase",
stxo: spentTxOut{ stxo: SpentTxOut{
amount: 5000000000, Amount: 5000000000,
pkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"), PkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
isCoinBase: true, IsCoinBase: true,
height: 9, Height: 9,
version: 1,
}, },
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"), serialized: hexToBytes("1300320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
}, },
// Adapted from block 100025 in main blockchain. // Adapted from block 100025 in main blockchain.
{ {
name: "Spends last output of non coinbase", name: "Spends last output of non coinbase",
stxo: spentTxOut{ stxo: SpentTxOut{
amount: 13761000000, Amount: 13761000000,
pkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"), PkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
isCoinBase: false, IsCoinBase: false,
height: 100024, Height: 100024,
version: 1,
}, },
serialized: hexToBytes("8b99700186c64700b2fb57eadf61e106a100a7445a8c3f67898841ec"), serialized: hexToBytes("8b99700086c64700b2fb57eadf61e106a100a7445a8c3f67898841ec"),
}, },
// Adapted from block 100025 in main blockchain. // Adapted from block 100025 in main blockchain.
{ {
name: "Does not spend last output", name: "Does not spend last output, legacy format",
stxo: spentTxOut{ stxo: SpentTxOut{
amount: 34405000000, Amount: 34405000000,
pkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"), PkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
version: 1,
}, },
txVersion: 1,
serialized: hexToBytes("0091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"), serialized: hexToBytes("0091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
}, },
} }
@ -104,7 +85,7 @@ func TestStxoSerialization(t *testing.T) {
// actually serializing it is calculated properly. // actually serializing it is calculated properly.
gotSize := spentTxOutSerializeSize(&test.stxo) gotSize := spentTxOutSerializeSize(&test.stxo)
if gotSize != len(test.serialized) { if gotSize != len(test.serialized) {
t.Errorf("spentTxOutSerializeSize (%s): did not get "+ t.Errorf("SpentTxOutSerializeSize (%s): did not get "+
"expected size - got %d, want %d", test.name, "expected size - got %d, want %d", test.name,
gotSize, len(test.serialized)) gotSize, len(test.serialized))
continue continue
@ -129,15 +110,13 @@ func TestStxoSerialization(t *testing.T) {
// Ensure the serialized bytes are decoded back to the expected // Ensure the serialized bytes are decoded back to the expected
// stxo. // stxo.
var gotStxo spentTxOut var gotStxo SpentTxOut
gotBytesRead, err := decodeSpentTxOut(test.serialized, &gotStxo, gotBytesRead, err := decodeSpentTxOut(test.serialized, &gotStxo)
test.txVersion)
if err != nil { if err != nil {
t.Errorf("decodeSpentTxOut (%s): unexpected error: %v", t.Errorf("decodeSpentTxOut (%s): unexpected error: %v",
test.name, err) test.name, err)
continue continue
} }
gotStxo.maybeDecompress(test.stxo.version)
if !reflect.DeepEqual(gotStxo, test.stxo) { if !reflect.DeepEqual(gotStxo, test.stxo) {
t.Errorf("decodeSpentTxOut (%s) mismatched entries - "+ t.Errorf("decodeSpentTxOut (%s) mismatched entries - "+
"got %v, want %v", test.name, gotStxo, test.stxo) "got %v, want %v", test.name, gotStxo, test.stxo)
@ -159,53 +138,43 @@ func TestStxoDecodeErrors(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
stxo spentTxOut stxo SpentTxOut
txVersion int32 // When the txout is not fully spent.
serialized []byte serialized []byte
bytesRead int // Expected number of bytes read. bytesRead int // Expected number of bytes read.
errType error errType error
}{ }{
{ {
name: "nothing serialized", name: "nothing serialized",
stxo: spentTxOut{}, stxo: SpentTxOut{},
serialized: hexToBytes(""), serialized: hexToBytes(""),
errType: errDeserialize(""), errType: errDeserialize(""),
bytesRead: 0, bytesRead: 0,
}, },
{ {
name: "no data after header code w/o version", name: "no data after header code w/o reserved",
stxo: spentTxOut{}, stxo: SpentTxOut{},
serialized: hexToBytes("00"), serialized: hexToBytes("00"),
errType: errDeserialize(""), errType: errDeserialize(""),
bytesRead: 1, bytesRead: 1,
}, },
{ {
name: "no data after header code with version", name: "no data after header code with reserved",
stxo: spentTxOut{}, stxo: SpentTxOut{},
serialized: hexToBytes("13"), serialized: hexToBytes("13"),
errType: errDeserialize(""), errType: errDeserialize(""),
bytesRead: 1, bytesRead: 1,
}, },
{ {
name: "no data after version", name: "no data after reserved",
stxo: spentTxOut{}, stxo: SpentTxOut{},
serialized: hexToBytes("1301"), serialized: hexToBytes("1300"),
errType: errDeserialize(""), errType: errDeserialize(""),
bytesRead: 2, bytesRead: 2,
}, },
{
name: "no serialized tx version and passed -1",
stxo: spentTxOut{},
txVersion: -1,
serialized: hexToBytes("003205"),
errType: AssertError(""),
bytesRead: 1,
},
{ {
name: "incomplete compressed txout", name: "incomplete compressed txout",
stxo: spentTxOut{}, stxo: SpentTxOut{},
txVersion: 1, serialized: hexToBytes("1332"),
serialized: hexToBytes("0032"),
errType: errDeserialize(""), errType: errDeserialize(""),
bytesRead: 2, bytesRead: 2,
}, },
@ -214,7 +183,7 @@ func TestStxoDecodeErrors(t *testing.T) {
for _, test := range tests { for _, test := range tests {
// Ensure the expected error type is returned. // Ensure the expected error type is returned.
gotBytesRead, err := decodeSpentTxOut(test.serialized, gotBytesRead, err := decodeSpentTxOut(test.serialized,
&test.stxo, test.txVersion) &test.stxo)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) { if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("decodeSpentTxOut (%s): expected error type "+ t.Errorf("decodeSpentTxOut (%s): expected error type "+
"does not match - got %T, want %T", test.name, "does not match - got %T, want %T", test.name,
@ -239,9 +208,8 @@ func TestSpendJournalSerialization(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
entry []spentTxOut entry []SpentTxOut
blockTxns []*wire.MsgTx blockTxns []*wire.MsgTx
utxoView *UtxoViewpoint
serialized []byte serialized []byte
}{ }{
// From block 2 in main blockchain. // From block 2 in main blockchain.
@ -249,18 +217,16 @@ func TestSpendJournalSerialization(t *testing.T) {
name: "No spends", name: "No spends",
entry: nil, entry: nil,
blockTxns: nil, blockTxns: nil,
utxoView: NewUtxoViewpoint(),
serialized: nil, serialized: nil,
}, },
// From block 170 in main blockchain. // From block 170 in main blockchain.
{ {
name: "One tx with one input spends last output of coinbase", name: "One tx with one input spends last output of coinbase",
entry: []spentTxOut{{ entry: []SpentTxOut{{
amount: 5000000000, Amount: 5000000000,
pkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"), PkScript: hexToBytes("410411db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5cb2e0eaddfb84ccf9744464f82e160bfa9b8b64f9d4c03f999b8643f656b412a3ac"),
isCoinBase: true, IsCoinBase: true,
height: 9, Height: 9,
version: 1,
}}, }},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted. blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1, Version: 1,
@ -281,22 +247,21 @@ func TestSpendJournalSerialization(t *testing.T) {
}}, }},
LockTime: 0, LockTime: 0,
}}, }},
utxoView: NewUtxoViewpoint(), serialized: hexToBytes("1300320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a5c"),
}, },
// Adapted from block 100025 in main blockchain. // Adapted from block 100025 in main blockchain.
{ {
name: "Two txns when one spends last output, one doesn't", name: "Two txns when one spends last output, one doesn't",
entry: []spentTxOut{{ entry: []SpentTxOut{{
amount: 34405000000, Amount: 34405000000,
pkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"), PkScript: hexToBytes("76a9146edbc6c4d31bae9f1ccc38538a114bf42de65e8688ac"),
version: 1, IsCoinBase: false,
Height: 100024,
}, { }, {
amount: 13761000000, Amount: 13761000000,
pkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"), PkScript: hexToBytes("76a914b2fb57eadf61e106a100a7445a8c3f67898841ec88ac"),
isCoinBase: false, IsCoinBase: false,
height: 100024, Height: 100024,
version: 1,
}}, }},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted. blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1, Version: 1,
@ -335,73 +300,7 @@ func TestSpendJournalSerialization(t *testing.T) {
}}, }},
LockTime: 0, LockTime: 0,
}}, }},
utxoView: &UtxoViewpoint{entries: map[chainhash.Hash]*UtxoEntry{ serialized: hexToBytes("8b99700086c64700b2fb57eadf61e106a100a7445a8c3f67898841ec8b99700091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
*newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"): {
version: 1,
isCoinBase: false,
blockHeight: 100024,
sparseOutputs: map[uint32]*utxoOutput{
1: {
amount: 34405000000,
pkScript: hexToBytes("76a9142084541c3931677527a7eafe56fd90207c344eb088ac"),
},
},
},
}},
serialized: hexToBytes("8b99700186c64700b2fb57eadf61e106a100a7445a8c3f67898841ec0091f20f006edbc6c4d31bae9f1ccc38538a114bf42de65e86"),
},
// Hand crafted.
{
name: "One tx, two inputs from same tx, neither spend last output",
entry: []spentTxOut{{
amount: 165125632,
pkScript: hexToBytes("51"),
version: 1,
}, {
amount: 154370000,
pkScript: hexToBytes("51"),
version: 1,
}},
blockTxns: []*wire.MsgTx{{ // Coinbase omitted.
Version: 1,
TxIn: []*wire.TxIn{{
PreviousOutPoint: wire.OutPoint{
Hash: *newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"),
Index: 1,
},
SignatureScript: hexToBytes(""),
Sequence: 0xffffffff,
}, {
PreviousOutPoint: wire.OutPoint{
Hash: *newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"),
Index: 2,
},
SignatureScript: hexToBytes(""),
Sequence: 0xffffffff,
}},
TxOut: []*wire.TxOut{{
Value: 165125632,
PkScript: hexToBytes("51"),
}, {
Value: 154370000,
PkScript: hexToBytes("51"),
}},
LockTime: 0,
}},
utxoView: &UtxoViewpoint{entries: map[chainhash.Hash]*UtxoEntry{
*newHashFromStr("c0ed017828e59ad5ed3cf70ee7c6fb0f426433047462477dc7a5d470f987a537"): {
version: 1,
isCoinBase: false,
blockHeight: 100000,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 165712179,
pkScript: hexToBytes("51"),
},
},
},
}},
serialized: hexToBytes("0087bc3707510084c3d19a790751"),
}, },
} }
@ -417,16 +316,12 @@ func TestSpendJournalSerialization(t *testing.T) {
// Deserialize to a spend journal entry. // Deserialize to a spend journal entry.
gotEntry, err := deserializeSpendJournalEntry(test.serialized, gotEntry, err := deserializeSpendJournalEntry(test.serialized,
test.blockTxns, test.utxoView) test.blockTxns)
if err != nil { if err != nil {
t.Errorf("deserializeSpendJournalEntry #%d (%s) "+ t.Errorf("deserializeSpendJournalEntry #%d (%s) "+
"unexpected error: %v", i, test.name, err) "unexpected error: %v", i, test.name, err)
continue continue
} }
for stxoIdx := range gotEntry {
stxo := &gotEntry[stxoIdx]
stxo.maybeDecompress(test.entry[stxoIdx].version)
}
// Ensure that the deserialized spend journal entry has the // Ensure that the deserialized spend journal entry has the
// correct properties. // correct properties.
@ -447,7 +342,6 @@ func TestSpendJournalErrors(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
blockTxns []*wire.MsgTx blockTxns []*wire.MsgTx
utxoView *UtxoViewpoint
serialized []byte serialized []byte
errType error errType error
}{ }{
@ -466,7 +360,6 @@ func TestSpendJournalErrors(t *testing.T) {
}}, }},
LockTime: 0, LockTime: 0,
}}, }},
utxoView: NewUtxoViewpoint(),
serialized: hexToBytes(""), serialized: hexToBytes(""),
errType: AssertError(""), errType: AssertError(""),
}, },
@ -484,7 +377,6 @@ func TestSpendJournalErrors(t *testing.T) {
}}, }},
LockTime: 0, LockTime: 0,
}}, }},
utxoView: NewUtxoViewpoint(),
serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a"), serialized: hexToBytes("1301320511db93e1dcdb8a016b49840f8c53bc1eb68a382e97b1482ecad7b148a6909a"),
errType: errDeserialize(""), errType: errDeserialize(""),
}, },
@ -494,7 +386,7 @@ func TestSpendJournalErrors(t *testing.T) {
// Ensure the expected error type is returned and the returned // Ensure the expected error type is returned and the returned
// slice is nil. // slice is nil.
stxos, err := deserializeSpendJournalEntry(test.serialized, stxos, err := deserializeSpendJournalEntry(test.serialized,
test.blockTxns, test.utxoView) test.blockTxns)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) { if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("deserializeSpendJournalEntry (%s): expected "+ t.Errorf("deserializeSpendJournalEntry (%s): expected "+
"error type does not match - got %T, want %T", "error type does not match - got %T, want %T",
@ -521,186 +413,52 @@ func TestUtxoSerialization(t *testing.T) {
serialized []byte serialized []byte
}{ }{
// From tx in main blockchain: // From tx in main blockchain:
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098 // 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098:0
{ {
name: "Only output 0, coinbase", name: "height 1, coinbase",
entry: &UtxoEntry{ entry: &UtxoEntry{
version: 1,
isCoinBase: true,
blockHeight: 1,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 5000000000, amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"), pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
},
},
},
serialized: hexToBytes("010103320496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52"),
},
// From tx in main blockchain:
// 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb
{
name: "Only output 1, not coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 100001,
sparseOutputs: map[uint32]*utxoOutput{
1: {
amount: 1000000,
pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
},
},
},
serialized: hexToBytes("01858c21040700ee8bd501094a7d5ca318da2506de35e1cb025ddc"),
},
// Adapted from tx in main blockchain:
// df3f3f442d9699857f7f49de4ff0b5d0f3448bec31cdc7b5bf6d25f2abd637d5
{
name: "Only output 2, coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: true,
blockHeight: 99004,
sparseOutputs: map[uint32]*utxoOutput{
2: {
amount: 100937281,
pkScript: hexToBytes("76a914da33f77cee27c2a975ed5124d7e4f7f97513510188ac"),
},
},
},
serialized: hexToBytes("0185843c010182b095bf4100da33f77cee27c2a975ed5124d7e4f7f975135101"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2 not coinbase",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
amount: 15000000,
pkScript: hexToBytes("76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, 1 marked spent",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
1: { // This won't be serialized.
spent: true,
amount: 1000000,
pkScript: hexToBytes("76a914e43031c3e46f20bf1ccee9553ce815de5a48467588ac"),
},
2: {
amount: 15000000,
pkScript: hexToBytes("76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, output 2 compressed",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
// Uncompressed Amount: 15000000
// Uncompressed PkScript: 76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac
compressed: true,
amount: 137,
pkScript: hexToBytes("00b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// Adapted from tx in main blockchain:
// 4a16969aa4764dd7507fc1de7f0baa4850a246de90c45e59a3207f9a26b5036f
{
name: "outputs 0 and 2, not coinbase, output 2 compressed, packed indexes reversed",
entry: &UtxoEntry{
version: 1,
isCoinBase: false,
blockHeight: 113931,
sparseOutputs: map[uint32]*utxoOutput{
0: {
amount: 20000000,
pkScript: hexToBytes("76a914e2ccd6ec7c6e2e581349c77e067385fa8236bf8a88ac"),
},
2: {
// Uncompressed Amount: 15000000
// Uncompressed PkScript: 76a914b8025be1b3efc63b0ad48e7f9f10e87544528d5888ac
compressed: true,
amount: 137,
pkScript: hexToBytes("00b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
},
},
serialized: hexToBytes("0185f90b0a011200e2ccd6ec7c6e2e581349c77e067385fa8236bf8a800900b8025be1b3efc63b0ad48e7f9f10e87544528d58"),
},
// From tx in main blockchain:
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098
{
name: "Only output 0, coinbase, fully spent",
entry: &UtxoEntry{
version: 1,
isCoinBase: true,
blockHeight: 1, blockHeight: 1,
sparseOutputs: map[uint32]*utxoOutput{ packedFlags: tfCoinBase,
0: { },
spent: true, serialized: hexToBytes("03320496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52"),
},
// From tx in main blockchain:
// 0e3e2357e806b6cdb1f70b54c3a3a17b6714ee1f0e68bebb44a74b1efd512098:0
{
name: "height 1, coinbase, spent",
entry: &UtxoEntry{
amount: 5000000000, amount: 5000000000,
pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"), pkScript: hexToBytes("410496b538e853519c726a2c91e61ec11600ae1390813a627c66fb8be7947be63c52da7589379515d4e0a604f8141781e62294721166bf621e73a82cbf2342c858eeac"),
}, blockHeight: 1,
}, packedFlags: tfCoinBase | tfSpent,
}, },
serialized: nil, serialized: nil,
}, },
// Adapted from tx in main blockchain: // From tx in main blockchain:
// 1b02d1c8cfef60a189017b9a420c682cf4a0028175f2f563209e4ff61c8c3620 // 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb:1
{ {
name: "Only output 22, not coinbase", name: "height 100001, not coinbase",
entry: &UtxoEntry{ entry: &UtxoEntry{
version: 1, amount: 1000000,
isCoinBase: false, pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
blockHeight: 338156, blockHeight: 100001,
sparseOutputs: map[uint32]*utxoOutput{ packedFlags: 0,
22: {
spent: false,
amount: 366875659,
pkScript: hexToBytes("a9141dd46a006572d820e448e12d2bbb38640bc718e687"),
}, },
serialized: hexToBytes("8b99420700ee8bd501094a7d5ca318da2506de35e1cb025ddc"),
}, },
// From tx in main blockchain:
// 8131ffb0a2c945ecaf9b9063e59558784f9c3a74741ce6ae2a18d0571dac15bb:1
{
name: "height 100001, not coinbase, spent",
entry: &UtxoEntry{
amount: 1000000,
pkScript: hexToBytes("76a914ee8bd501094a7d5ca318da2506de35e1cb025ddc88ac"),
blockHeight: 100001,
packedFlags: tfSpent,
}, },
serialized: hexToBytes("0193d06c100000108ba5b9e763011dd46a006572d820e448e12d2bbb38640bc718e6"), serialized: nil,
}, },
} }
@ -719,9 +477,9 @@ func TestUtxoSerialization(t *testing.T) {
continue continue
} }
// Don't try to deserialize if the test entry was fully spent // Don't try to deserialize if the test entry was spent since it
// since it will have a nil serialization. // will have a nil serialization.
if test.entry.IsFullySpent() { if test.entry.IsSpent() {
continue continue
} }
@ -733,12 +491,33 @@ func TestUtxoSerialization(t *testing.T) {
continue continue
} }
// Ensure that the deserialized utxo entry has the same // The deserialized entry must not be marked spent since unspent
// properties for the containing transaction and block height. // entries are not serialized.
if utxoEntry.Version() != test.entry.Version() { if utxoEntry.IsSpent() {
t.Errorf("deserializeUtxoEntry #%d (%s) output should "+
"not be marked spent", i, test.name)
continue
}
// Ensure the deserialized entry has the same properties as the
// ones in the test entry.
if utxoEntry.Amount() != test.entry.Amount() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+ t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"version: got %d, want %d", i, test.name, "amounts: got %d, want %d", i, test.name,
utxoEntry.Version(), test.entry.Version()) utxoEntry.Amount(), test.entry.Amount())
continue
}
if !bytes.Equal(utxoEntry.PkScript(), test.entry.PkScript()) {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"scripts: got %x, want %x", i, test.name,
utxoEntry.PkScript(), test.entry.PkScript())
continue
}
if utxoEntry.BlockHeight() != test.entry.BlockHeight() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"block height: got %d, want %d", i, test.name,
utxoEntry.BlockHeight(), test.entry.BlockHeight())
continue continue
} }
if utxoEntry.IsCoinBase() != test.entry.IsCoinBase() { if utxoEntry.IsCoinBase() != test.entry.IsCoinBase() {
@ -747,71 +526,6 @@ func TestUtxoSerialization(t *testing.T) {
utxoEntry.IsCoinBase(), test.entry.IsCoinBase()) utxoEntry.IsCoinBase(), test.entry.IsCoinBase())
continue continue
} }
if utxoEntry.BlockHeight() != test.entry.BlockHeight() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"block height: got %d, want %d", i, test.name,
utxoEntry.BlockHeight(),
test.entry.BlockHeight())
continue
}
if utxoEntry.IsFullySpent() != test.entry.IsFullySpent() {
t.Errorf("deserializeUtxoEntry #%d (%s) mismatched "+
"fully spent: got %v, want %v", i, test.name,
utxoEntry.IsFullySpent(),
test.entry.IsFullySpent())
continue
}
// Ensure all of the outputs in the test entry match the
// spentness of the output in the deserialized entry and the
// deserialized entry does not contain any additional utxos.
var numUnspent int
for outputIndex := range test.entry.sparseOutputs {
gotSpent := utxoEntry.IsOutputSpent(outputIndex)
wantSpent := test.entry.IsOutputSpent(outputIndex)
if !wantSpent {
numUnspent++
}
if gotSpent != wantSpent {
t.Errorf("deserializeUtxoEntry #%d (%s) output "+
"#%d: mismatched spent: got %v, want "+
"%v", i, test.name, outputIndex,
gotSpent, wantSpent)
continue
}
}
if len(utxoEntry.sparseOutputs) != numUnspent {
t.Errorf("deserializeUtxoEntry #%d (%s): mismatched "+
"number of unspent outputs: got %d, want %d", i,
test.name, len(utxoEntry.sparseOutputs),
numUnspent)
continue
}
// Ensure all of the amounts and scripts of the utxos in the
// deserialized entry match the ones in the test entry.
for outputIndex := range utxoEntry.sparseOutputs {
gotAmount := utxoEntry.AmountByIndex(outputIndex)
wantAmount := test.entry.AmountByIndex(outputIndex)
if gotAmount != wantAmount {
t.Errorf("deserializeUtxoEntry #%d (%s) "+
"output #%d: mismatched amounts: got "+
"%d, want %d", i, test.name,
outputIndex, gotAmount, wantAmount)
continue
}
gotPkScript := utxoEntry.PkScriptByIndex(outputIndex)
wantPkScript := test.entry.PkScriptByIndex(outputIndex)
if !bytes.Equal(gotPkScript, wantPkScript) {
t.Errorf("deserializeUtxoEntry #%d (%s) "+
"output #%d mismatched scripts: got "+
"%x, want %x", i, test.name,
outputIndex, gotPkScript, wantPkScript)
continue
}
}
} }
} }
@ -824,20 +538,18 @@ func TestUtxoEntryHeaderCodeErrors(t *testing.T) {
name string name string
entry *UtxoEntry entry *UtxoEntry
code uint64 code uint64
bytesRead int // Expected number of bytes read.
errType error errType error
}{ }{
{ {
name: "Force assertion due to fully spent tx", name: "Force assertion due to spent output",
entry: &UtxoEntry{}, entry: &UtxoEntry{packedFlags: tfSpent},
errType: AssertError(""), errType: AssertError(""),
bytesRead: 0,
}, },
} }
for _, test := range tests { for _, test := range tests {
// Ensure the expected error type is returned and the code is 0. // Ensure the expected error type is returned and the code is 0.
code, gotBytesRead, err := utxoEntryHeaderCode(test.entry, 0) code, err := utxoEntryHeaderCode(test.entry)
if reflect.TypeOf(err) != reflect.TypeOf(test.errType) { if reflect.TypeOf(err) != reflect.TypeOf(test.errType) {
t.Errorf("utxoEntryHeaderCode (%s): expected error "+ t.Errorf("utxoEntryHeaderCode (%s): expected error "+
"type does not match - got %T, want %T", "type does not match - got %T, want %T",
@ -849,14 +561,6 @@ func TestUtxoEntryHeaderCodeErrors(t *testing.T) {
"on error - got %d, want 0", test.name, code) "on error - got %d, want 0", test.name, code)
continue continue
} }
// Ensure the expected number of bytes read is returned.
if gotBytesRead != test.bytesRead {
t.Errorf("utxoEntryHeaderCode (%s): unexpected number "+
"of bytes read - got %d, want %d", test.name,
gotBytesRead, test.bytesRead)
continue
}
} }
} }
@ -870,29 +574,14 @@ func TestUtxoEntryDeserializeErrors(t *testing.T) {
serialized []byte serialized []byte
errType error errType error
}{ }{
{
name: "no data after version",
serialized: hexToBytes("01"),
errType: errDeserialize(""),
},
{
name: "no data after block height",
serialized: hexToBytes("0101"),
errType: errDeserialize(""),
},
{ {
name: "no data after header code", name: "no data after header code",
serialized: hexToBytes("010102"), serialized: hexToBytes("02"),
errType: errDeserialize(""),
},
{
name: "not enough bytes for unspentness bitmap",
serialized: hexToBytes("01017800"),
errType: errDeserialize(""), errType: errDeserialize(""),
}, },
{ {
name: "incomplete compressed txout", name: "incomplete compressed txout",
serialized: hexToBytes("01010232"), serialized: hexToBytes("0232"),
errType: errDeserialize(""), errType: errDeserialize(""),
}, },
} }

View File

@ -27,16 +27,11 @@ func chainedNodes(parent *blockNode, numNodes int) []*blockNode {
// This is invalid, but all that is needed is enough to get the // This is invalid, but all that is needed is enough to get the
// synthetic tests to work. // synthetic tests to work.
header := wire.BlockHeader{Nonce: testNoncePrng.Uint32()} header := wire.BlockHeader{Nonce: testNoncePrng.Uint32()}
height := int32(0)
if tip != nil { if tip != nil {
header.PrevBlock = tip.hash header.PrevBlock = tip.hash
height = tip.height + 1
} }
node := newBlockNode(&header, height) nodes[i] = newBlockNode(&header, tip)
node.parent = tip tip = nodes[i]
tip = node
nodes[i] = node
} }
return nodes return nodes
} }
@ -74,7 +69,7 @@ func zipLocators(locators ...BlockLocator) BlockLocator {
} }
// TestChainView ensures all of the exported functionality of chain views works // TestChainView ensures all of the exported functionality of chain views works
// as intended with the expection of some special cases which are handled in // as intended with the exception of some special cases which are handled in
// other tests. // other tests.
func TestChainView(t *testing.T) { func TestChainView(t *testing.T) {
// Construct a synthetic block index consisting of the following // Construct a synthetic block index consisting of the following

View File

@ -190,10 +190,10 @@ func chainSetup(dbName string, params *chaincfg.Params) (*BlockChain, func(), er
// loadUtxoView returns a utxo view loaded from a file. // loadUtxoView returns a utxo view loaded from a file.
func loadUtxoView(filename string) (*UtxoViewpoint, error) { func loadUtxoView(filename string) (*UtxoViewpoint, error) {
// The utxostore file format is: // The utxostore file format is:
// <tx hash><serialized utxo len><serialized utxo> // <tx hash><output index><serialized utxo len><serialized utxo>
// //
// The serialized utxo len is a little endian uint32 and the serialized // The output index and serialized utxo len are little endian uint32s
// utxo uses the format described in chainio.go. // and the serialized utxo uses the format described in chainio.go.
filename = filepath.Join("testdata", filename) filename = filepath.Join("testdata", filename)
fi, err := os.Open(filename) fi, err := os.Open(filename)
@ -223,7 +223,14 @@ func loadUtxoView(filename string) (*UtxoViewpoint, error) {
return nil, err return nil, err
} }
// Num of serialize utxo entry bytes. // Output index of the utxo entry.
var index uint32
err = binary.Read(r, binary.LittleEndian, &index)
if err != nil {
return nil, err
}
// Num of serialized utxo entry bytes.
var numBytes uint32 var numBytes uint32
err = binary.Read(r, binary.LittleEndian, &numBytes) err = binary.Read(r, binary.LittleEndian, &numBytes)
if err != nil { if err != nil {
@ -238,16 +245,98 @@ func loadUtxoView(filename string) (*UtxoViewpoint, error) {
} }
// Deserialize it and add it to the view. // Deserialize it and add it to the view.
utxoEntry, err := deserializeUtxoEntry(serialized) entry, err := deserializeUtxoEntry(serialized)
if err != nil { if err != nil {
return nil, err return nil, err
} }
view.Entries()[hash] = utxoEntry view.Entries()[wire.OutPoint{Hash: hash, Index: index}] = entry
} }
return view, nil return view, nil
} }
// convertUtxoStore reads a utxostore from the legacy format and writes it back
// out using the latest format. It is only useful for converting utxostore data
// used in the tests, which has already been done. However, the code is left
// available for future reference.
func convertUtxoStore(r io.Reader, w io.Writer) error {
// The old utxostore file format was:
// <tx hash><serialized utxo len><serialized utxo>
//
// The serialized utxo len was a little endian uint32 and the serialized
// utxo uses the format described in upgrade.go.
littleEndian := binary.LittleEndian
for {
// Hash of the utxo entry.
var hash chainhash.Hash
_, err := io.ReadAtLeast(r, hash[:], len(hash[:]))
if err != nil {
// Expected EOF at the right offset.
if err == io.EOF {
break
}
return err
}
// Num of serialized utxo entry bytes.
var numBytes uint32
err = binary.Read(r, littleEndian, &numBytes)
if err != nil {
return err
}
// Serialized utxo entry.
serialized := make([]byte, numBytes)
_, err = io.ReadAtLeast(r, serialized, int(numBytes))
if err != nil {
return err
}
// Deserialize the entry.
entries, err := deserializeUtxoEntryV0(serialized)
if err != nil {
return err
}
// Loop through all of the utxos and write them out in the new
// format.
for outputIdx, entry := range entries {
// Reserialize the entries using the new format.
serialized, err := serializeUtxoEntry(entry)
if err != nil {
return err
}
// Write the hash of the utxo entry.
_, err = w.Write(hash[:])
if err != nil {
return err
}
// Write the output index of the utxo entry.
err = binary.Write(w, littleEndian, outputIdx)
if err != nil {
return err
}
// Write num of serialized utxo entry bytes.
err = binary.Write(w, littleEndian, uint32(len(serialized)))
if err != nil {
return err
}
// Write the serialized utxo.
_, err = w.Write(serialized)
if err != nil {
return err
}
}
}
return nil
}
// TstSetCoinbaseMaturity makes the ability to set the coinbase maturity // TstSetCoinbaseMaturity makes the ability to set the coinbase maturity
// available when running tests. // available when running tests.
func (b *BlockChain) TstSetCoinbaseMaturity(maturity uint16) { func (b *BlockChain) TstSetCoinbaseMaturity(maturity uint16) {
@ -261,7 +350,7 @@ func (b *BlockChain) TstSetCoinbaseMaturity(maturity uint16) {
func newFakeChain(params *chaincfg.Params) *BlockChain { func newFakeChain(params *chaincfg.Params) *BlockChain {
// Create a genesis block node and block index index populated with it // Create a genesis block node and block index index populated with it
// for use when creating the fake chain below. // for use when creating the fake chain below.
node := newBlockNode(&params.GenesisBlock.Header, 0) node := newBlockNode(&params.GenesisBlock.Header, nil)
index := newBlockIndex(nil, params) index := newBlockIndex(nil, params)
index.AddNode(node) index.AddNode(node)
@ -291,8 +380,5 @@ func newFakeNode(parent *blockNode, blockVersion int32, bits uint32, timestamp t
Bits: bits, Bits: bits,
Timestamp: timestamp, Timestamp: timestamp,
} }
node := newBlockNode(header, parent.height+1) return newBlockNode(header, parent)
node.parent = parent
node.workSum.Add(parent.workSum, node.workSum)
return node
} }

View File

@ -241,7 +241,7 @@ func isPubKey(script []byte) (bool, []byte) {
// compressedScriptSize returns the number of bytes the passed script would take // compressedScriptSize returns the number of bytes the passed script would take
// when encoded with the domain specific compression algorithm described above. // when encoded with the domain specific compression algorithm described above.
func compressedScriptSize(pkScript []byte, version int32) int { func compressedScriptSize(pkScript []byte) int {
// Pay-to-pubkey-hash script. // Pay-to-pubkey-hash script.
if valid, _ := isPubKeyHash(pkScript); valid { if valid, _ := isPubKeyHash(pkScript); valid {
return 21 return 21
@ -268,7 +268,7 @@ func compressedScriptSize(pkScript []byte, version int32) int {
// script, possibly followed by other data, and returns the number of bytes it // script, possibly followed by other data, and returns the number of bytes it
// occupies taking into account the special encoding of the script size by the // occupies taking into account the special encoding of the script size by the
// domain specific compression algorithm described above. // domain specific compression algorithm described above.
func decodeCompressedScriptSize(serialized []byte, version int32) int { func decodeCompressedScriptSize(serialized []byte) int {
scriptSize, bytesRead := deserializeVLQ(serialized) scriptSize, bytesRead := deserializeVLQ(serialized)
if bytesRead == 0 { if bytesRead == 0 {
return 0 return 0
@ -296,7 +296,7 @@ func decodeCompressedScriptSize(serialized []byte, version int32) int {
// target byte slice. The target byte slice must be at least large enough to // target byte slice. The target byte slice must be at least large enough to
// handle the number of bytes returned by the compressedScriptSize function or // handle the number of bytes returned by the compressedScriptSize function or
// it will panic. // it will panic.
func putCompressedScript(target, pkScript []byte, version int32) int { func putCompressedScript(target, pkScript []byte) int {
// Pay-to-pubkey-hash script. // Pay-to-pubkey-hash script.
if valid, hash := isPubKeyHash(pkScript); valid { if valid, hash := isPubKeyHash(pkScript); valid {
target[0] = cstPayToPubKeyHash target[0] = cstPayToPubKeyHash
@ -344,7 +344,7 @@ func putCompressedScript(target, pkScript []byte, version int32) int {
// NOTE: The script parameter must already have been proven to be long enough // NOTE: The script parameter must already have been proven to be long enough
// to contain the number of bytes returned by decodeCompressedScriptSize or it // to contain the number of bytes returned by decodeCompressedScriptSize or it
// will panic. This is acceptable since it is only an internal function. // will panic. This is acceptable since it is only an internal function.
func decompressScript(compressedPkScript []byte, version int32) []byte { func decompressScript(compressedPkScript []byte) []byte {
// In practice this function will not be called with a zero-length or // In practice this function will not be called with a zero-length or
// nil script since the nil script encoding includes the length, however // nil script since the nil script encoding includes the length, however
// the code below assumes the length exists, so just return nil now if // the code below assumes the length exists, so just return nil now if
@ -542,43 +542,27 @@ func decompressTxOutAmount(amount uint64) uint64 {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// compressedTxOutSize returns the number of bytes the passed transaction output // compressedTxOutSize returns the number of bytes the passed transaction output
// fields would take when encoded with the format described above. The // fields would take when encoded with the format described above.
// preCompressed flag indicates the provided amount and script are already func compressedTxOutSize(amount uint64, pkScript []byte) int {
// compressed. This is useful since loaded utxo entries are not decompressed
// until the output is accessed.
func compressedTxOutSize(amount uint64, pkScript []byte, version int32, preCompressed bool) int {
if preCompressed {
return serializeSizeVLQ(amount) + len(pkScript)
}
return serializeSizeVLQ(compressTxOutAmount(amount)) + return serializeSizeVLQ(compressTxOutAmount(amount)) +
compressedScriptSize(pkScript, version) compressedScriptSize(pkScript)
} }
// putCompressedTxOut potentially compresses the passed amount and script // putCompressedTxOut compresses the passed amount and script according to their
// according to their domain specific compression algorithms and encodes them // domain specific compression algorithms and encodes them directly into the
// directly into the passed target byte slice with the format described above. // passed target byte slice with the format described above. The target byte
// The preCompressed flag indicates the provided amount and script are already // slice must be at least large enough to handle the number of bytes returned by
// compressed in which case the values are not modified. This is useful since // the compressedTxOutSize function or it will panic.
// loaded utxo entries are not decompressed until the output is accessed. The func putCompressedTxOut(target []byte, amount uint64, pkScript []byte) int {
// target byte slice must be at least large enough to handle the number of bytes
// returned by the compressedTxOutSize function or it will panic.
func putCompressedTxOut(target []byte, amount uint64, pkScript []byte, version int32, preCompressed bool) int {
if preCompressed {
offset := putVLQ(target, amount)
copy(target[offset:], pkScript)
return offset + len(pkScript)
}
offset := putVLQ(target, compressTxOutAmount(amount)) offset := putVLQ(target, compressTxOutAmount(amount))
offset += putCompressedScript(target[offset:], pkScript, version) offset += putCompressedScript(target[offset:], pkScript)
return offset return offset
} }
// decodeCompressedTxOut decodes the passed compressed txout, possibly followed // decodeCompressedTxOut decodes the passed compressed txout, possibly followed
// by other data, into its compressed amount and compressed script and returns // by other data, into its uncompressed amount and script and returns them along
// them along with the number of bytes they occupied. // with the number of bytes they occupied prior to decompression.
func decodeCompressedTxOut(serialized []byte, version int32) (uint64, []byte, int, error) { func decodeCompressedTxOut(serialized []byte) (uint64, []byte, int, error) {
// Deserialize the compressed amount and ensure there are bytes // Deserialize the compressed amount and ensure there are bytes
// remaining for the compressed script. // remaining for the compressed script.
compressedAmount, bytesRead := deserializeVLQ(serialized) compressedAmount, bytesRead := deserializeVLQ(serialized)
@ -589,15 +573,14 @@ func decodeCompressedTxOut(serialized []byte, version int32) (uint64, []byte, in
// Decode the compressed script size and ensure there are enough bytes // Decode the compressed script size and ensure there are enough bytes
// left in the slice for it. // left in the slice for it.
scriptSize := decodeCompressedScriptSize(serialized[bytesRead:], version) scriptSize := decodeCompressedScriptSize(serialized[bytesRead:])
if len(serialized[bytesRead:]) < scriptSize { if len(serialized[bytesRead:]) < scriptSize {
return 0, nil, bytesRead, errDeserialize("unexpected end of " + return 0, nil, bytesRead, errDeserialize("unexpected end of " +
"data after script size") "data after script size")
} }
// Make a copy of the compressed script so the original serialized data // Decompress and return the amount and script.
// can be released as soon as possible. amount := decompressTxOutAmount(compressedAmount)
compressedScript := make([]byte, scriptSize) script := decompressScript(serialized[bytesRead : bytesRead+scriptSize])
copy(compressedScript, serialized[bytesRead:bytesRead+scriptSize]) return amount, script, bytesRead + scriptSize, nil
return compressedAmount, compressedScript, bytesRead + scriptSize, nil
} }

Some files were not shown because too many files have changed in this diff Show More