Updated Vendoring

This commit is contained in:
Suraj Narwade 2017-11-22 17:44:47 +05:30
parent 5de4aa85f8
commit b3f2134cec
298 changed files with 27732 additions and 66948 deletions

48
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: 14203f2282ff9f112d87d12b1c3f855502d3d0e808520a887757ce8f273f1086 hash: 0b7bed52681c512a3c15884363e1fd04ebc39efc52368cfa34ec4ef46f4eedc7
updated: 2017-10-10T12:44:09.037210249-04:00 updated: 2017-10-25T14:45:17.433498815+05:30
imports: imports:
- name: cloud.google.com/go - name: cloud.google.com/go
version: 3b1ae45394a234c385be014e9a488f2bb6eef821 version: 3b1ae45394a234c385be014e9a488f2bb6eef821
@ -104,7 +104,7 @@ imports:
- name: github.com/dgrijalva/jwt-go - name: github.com/dgrijalva/jwt-go
version: 01aeca54ebda6e0fbfafd0a524d234159c05ec20 version: 01aeca54ebda6e0fbfafd0a524d234159c05ec20
- name: github.com/docker/cli - name: github.com/docker/cli
version: 9bdb0763b9e667dc01adf36ba98a2b7bd47bdc75 version: 9b7656cc05d2878c85dc1252f5813f0ad77d808f
subpackages: subpackages:
- cli/compose/interpolation - cli/compose/interpolation
- cli/compose/loader - cli/compose/loader
@ -170,7 +170,6 @@ imports:
- api/types/swarm - api/types/swarm
- api/types/versions - api/types/versions
- pkg/urlutil - pkg/urlutil
- runconfig/opts
- name: github.com/docker/engine-api - name: github.com/docker/engine-api
version: dea108d3aa0c67d7162a3fd8aa65f38a430019fd version: dea108d3aa0c67d7162a3fd8aa65f38a430019fd
subpackages: subpackages:
@ -196,7 +195,7 @@ imports:
- name: github.com/docker/go-units - name: github.com/docker/go-units
version: 0bbddae09c5a5419a8c6dcdd7ff90da3d450393b version: 0bbddae09c5a5419a8c6dcdd7ff90da3d450393b
- name: github.com/docker/libcompose - name: github.com/docker/libcompose
version: 4a647d664afbe05c41455c9d534d8239671eb46a version: 57bd716502dcbe1799f026148016022b0f3b989c
subpackages: subpackages:
- config - config
- logger - logger
@ -216,11 +215,11 @@ imports:
- name: github.com/evanphx/json-patch - name: github.com/evanphx/json-patch
version: 465937c80b3c07a7c7ad20cc934898646a91c1de version: 465937c80b3c07a7c7ad20cc934898646a91c1de
- name: github.com/fatih/structs - name: github.com/fatih/structs
version: 7e5a8eef611ee84dd359503f3969f80df4c50723 version: dc3312cb1a4513a366c4c9e622ad55c32df12ed3
- name: github.com/flynn/go-shlex - name: github.com/flynn/go-shlex
version: 3f9db97f856818214da2e1057f8ad84803971cff version: 3f9db97f856818214da2e1057f8ad84803971cff
- name: github.com/fsnotify/fsnotify - name: github.com/fsnotify/fsnotify
version: 4da3e2cfbabc9f751898f250b49f2439785783a1 version: 629574ca2a5df945712d3079857300b5e4da0236
- name: github.com/fsouza/go-dockerclient - name: github.com/fsouza/go-dockerclient
version: bf97c77db7c945cbcdbf09d56c6f87a66f54537b version: bf97c77db7c945cbcdbf09d56c6f87a66f54537b
subpackages: subpackages:
@ -338,7 +337,7 @@ imports:
- runtime/internal - runtime/internal
- utilities - utilities
- name: github.com/hashicorp/hcl - name: github.com/hashicorp/hcl
version: 42e33e2d55a0ff1d6263f738896ea8c13571a8d0 version: 37ab263305aaeb501a60eb16863e808d426e37f2
subpackages: subpackages:
- hcl/ast - hcl/ast
- hcl/parser - hcl/parser
@ -359,7 +358,7 @@ imports:
- name: github.com/juju/ratelimit - name: github.com/juju/ratelimit
version: 77ed1c8a01217656d2080ad51981f6e99adaa177 version: 77ed1c8a01217656d2080ad51981f6e99adaa177
- name: github.com/magiconair/properties - name: github.com/magiconair/properties
version: 8d7837e64d3c1ee4e54a880c5a920ab4316fc90a version: be5ece7dd465ab0765a9682137865547526d1dfb
- name: github.com/mattn/go-shellwords - name: github.com/mattn/go-shellwords
version: 95c860c1895b21b58903abdd1d9c591560b0601c version: 95c860c1895b21b58903abdd1d9c591560b0601c
- name: github.com/matttproud/golang_protobuf_extensions - name: github.com/matttproud/golang_protobuf_extensions
@ -367,7 +366,7 @@ imports:
subpackages: subpackages:
- pbutil - pbutil
- name: github.com/mitchellh/mapstructure - name: github.com/mitchellh/mapstructure
version: d0303fe809921458f417bcf828397a65db30a7e4 version: 5a0325d7fafaac12dda6e7fb8bd222ec1b69875e
- name: github.com/novln/docker-parser - name: github.com/novln/docker-parser
version: 6030251119d652af8ead44ac7907444227b64d56 version: 6030251119d652af8ead44ac7907444227b64d56
subpackages: subpackages:
@ -427,10 +426,12 @@ imports:
- pkg/version - pkg/version
- name: github.com/pborman/uuid - name: github.com/pborman/uuid
version: ca53cad383cad2479bbba7f7a1a05797ec1386e4 version: ca53cad383cad2479bbba7f7a1a05797ec1386e4
- name: github.com/pelletier/go-buffruneio
version: c37440a7cf42ac63b919c752ca73a85067e05992
- name: github.com/pelletier/go-toml - name: github.com/pelletier/go-toml
version: 2009e44b6f182e34d8ce081ac2767622937ea3d4 version: 13d49d4606eb801b8f01ae542b4afc4c6ee3d84a
- name: github.com/pkg/errors - name: github.com/pkg/errors
version: 2b3a18b5f0fb6b4f9190549597d3f962c02bc5eb version: 645ef00459ed84a119197bfb8d8205042c6df63d
- name: github.com/prometheus/client_golang - name: github.com/prometheus/client_golang
version: e51041b3fa41cece0dca035740ba6411905be473 version: e51041b3fa41cece0dca035740ba6411905be473
subpackages: subpackages:
@ -447,26 +448,26 @@ imports:
- model - model
- name: github.com/prometheus/procfs - name: github.com/prometheus/procfs
version: 454a56f35412459b5e684fd5ec0f9211b94f002a version: 454a56f35412459b5e684fd5ec0f9211b94f002a
- name: github.com/sirupsen/logrus
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
- name: github.com/Sirupsen/logrus - name: github.com/Sirupsen/logrus
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
repo: git@github.com:/sirupsen/logrus repo: git@github.com:/sirupsen/logrus
vcs: git vcs: git
- name: github.com/sirupsen/logrus
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
- name: github.com/spf13/afero - name: github.com/spf13/afero
version: e67d870304c4bca21331b02f414f970df13aa694 version: 2f30b2a92c0e5700bcfe4715891adb1f2a7a406d
subpackages: subpackages:
- mem - mem
- name: github.com/spf13/cast - name: github.com/spf13/cast
version: acbeb36b902d72a7a4c18e8f3241075e7ab763e4 version: 24b6558033ffe202bf42f0f3b870dcc798dd2ba8
- name: github.com/spf13/cobra - name: github.com/spf13/cobra
version: 4d6af280c76ff7d266434f2dba207c4b75dfc076 version: 9495bc009a56819bdb0ddbc1a373e29c140bc674
- name: github.com/spf13/jwalterweatherman - name: github.com/spf13/jwalterweatherman
version: 12bd96e66386c1960ab0f74ced1362f66f552f7b version: 33c24e77fb80341fe7130ee7c594256ff08ccc46
- name: github.com/spf13/pflag - name: github.com/spf13/pflag
version: a9789e855c7696159b7db0db7f440b449edf2b31 version: 5ccb023bc27df288a957c5e994cd44fd19619465
- name: github.com/spf13/viper - name: github.com/spf13/viper
version: d9cca5ef33035202efb1586825bdbb15ff9ec3ba version: 651d9d916abc3c3d6a91a12549495caba5edffd2
- name: github.com/ugorji/go - name: github.com/ugorji/go
version: f4485b318aadd133842532f841dc205a8e339d74 version: f4485b318aadd133842532f841dc205a8e339d74
subpackages: subpackages:
@ -478,7 +479,7 @@ imports:
- name: github.com/xeipuuv/gojsonschema - name: github.com/xeipuuv/gojsonschema
version: 93e72a773fade158921402d6a24c819b48aba29d version: 93e72a773fade158921402d6a24c819b48aba29d
- name: golang.org/x/crypto - name: golang.org/x/crypto
version: 1f22c0103821b9390939b6776727195525381532 version: 81e90905daefcd6fd217b62423c0908922eadb30
subpackages: subpackages:
- ssh/terminal - ssh/terminal
- name: golang.org/x/net - name: golang.org/x/net
@ -504,9 +505,10 @@ imports:
- jws - jws
- jwt - jwt
- name: golang.org/x/sys - name: golang.org/x/sys
version: ebfc5b4631820b793c9010c87fd8fef0f39eb082 version: 833a04a10549a95dc34458c195cbad61bbb6cb4d
subpackages: subpackages:
- unix - unix
- windows
- name: golang.org/x/text - name: golang.org/x/text
version: ceefd2213ed29504fff30155163c8f59827734f3 version: ceefd2213ed29504fff30155163c8f59827734f3
subpackages: subpackages:
@ -546,7 +548,7 @@ imports:
- name: gopkg.in/inf.v0 - name: gopkg.in/inf.v0
version: 3887ee99ecf07df5b447e9b00d9c0b2adaa9f3e4 version: 3887ee99ecf07df5b447e9b00d9c0b2adaa9f3e4
- name: gopkg.in/yaml.v2 - name: gopkg.in/yaml.v2
version: eb3733d160e74a9c7e442f435eb3bea458e1d19f version: a5b47d31c556af34a302ce5d659e6fea44d90de0
- name: k8s.io/client-go - name: k8s.io/client-go
version: d72c0e162789e1bbb33c33cfa26858a1375efe01 version: d72c0e162789e1bbb33c33cfa26858a1375efe01
subpackages: subpackages:

View File

@ -8,7 +8,6 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/Sirupsen/logrus"
"github.com/docker/cli/cli/compose/interpolation" "github.com/docker/cli/cli/compose/interpolation"
"github.com/docker/cli/cli/compose/schema" "github.com/docker/cli/cli/compose/schema"
"github.com/docker/cli/cli/compose/template" "github.com/docker/cli/cli/compose/template"
@ -19,6 +18,7 @@ import (
shellwords "github.com/mattn/go-shellwords" shellwords "github.com/mattn/go-shellwords"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
) )
@ -221,6 +221,7 @@ func createTransformHook() mapstructure.DecodeHookFuncType {
reflect.TypeOf(types.Labels{}): transformMappingOrListFunc("=", false), reflect.TypeOf(types.Labels{}): transformMappingOrListFunc("=", false),
reflect.TypeOf(types.MappingWithColon{}): transformMappingOrListFunc(":", false), reflect.TypeOf(types.MappingWithColon{}): transformMappingOrListFunc(":", false),
reflect.TypeOf(types.ServiceVolumeConfig{}): transformServiceVolumeConfig, reflect.TypeOf(types.ServiceVolumeConfig{}): transformServiceVolumeConfig,
reflect.TypeOf(types.BuildConfig{}): transformBuildConfig,
} }
return func(_ reflect.Type, target reflect.Type, data interface{}) (interface{}, error) { return func(_ reflect.Type, target reflect.Type, data interface{}) (interface{}, error) {
@ -563,6 +564,17 @@ func transformStringSourceMap(data interface{}) (interface{}, error) {
} }
} }
func transformBuildConfig(data interface{}) (interface{}, error) {
switch value := data.(type) {
case string:
return map[string]interface{}{"context": value}, nil
case map[string]interface{}:
return data, nil
default:
return data, errors.Errorf("invalid type %T for service build", value)
}
}
func transformServiceVolumeConfig(data interface{}) (interface{}, error) { func transformServiceVolumeConfig(data interface{}) (interface{}, error) {
switch value := data.(type) { switch value := data.(type) {
case string: case string:
@ -572,7 +584,6 @@ func transformServiceVolumeConfig(data interface{}) (interface{}, error) {
default: default:
return data, errors.Errorf("invalid type %T for service volume", value) return data, errors.Errorf("invalid type %T for service volume", value)
} }
} }
func transformServiceNetworkMap(value interface{}) (interface{}, error) { func transformServiceNetworkMap(value interface{}) (interface{}, error) {

View File

@ -112,6 +112,11 @@ func isFilePath(source string) bool {
return true return true
} }
// windows named pipes
if strings.HasPrefix(source, `\\`) {
return true
}
first, nextIndex := utf8.DecodeRuneInString(source) first, nextIndex := utf8.DecodeRuneInString(source)
return isWindowsDrive([]rune{first}, rune(source[nextIndex])) return isWindowsDrive([]rune{first}, rune(source[nextIndex]))
} }

File diff suppressed because one or more lines are too long

View File

@ -23,6 +23,7 @@ var UnsupportedProperties = []string{
"shm_size", "shm_size",
"sysctls", "sysctls",
"tmpfs", "tmpfs",
"ulimits",
"userns_mode", "userns_mode",
} }
@ -182,10 +183,10 @@ type DeployConfig struct {
// HealthCheckConfig the healthcheck configuration for a service // HealthCheckConfig the healthcheck configuration for a service
type HealthCheckConfig struct { type HealthCheckConfig struct {
Test HealthCheckTest Test HealthCheckTest
Timeout string Timeout *time.Duration
Interval string Interval *time.Duration
Retries *uint64 Retries *uint64
StartPeriod string StartPeriod *time.Duration `mapstructure:"start_period"`
Disable bool Disable bool
} }

View File

@ -95,12 +95,12 @@ func (n *NetworkOpt) String() string {
return "" return ""
} }
func parseDriverOpt(driverOpt string) (key string, value string, err error) { func parseDriverOpt(driverOpt string) (string, string, error) {
parts := strings.SplitN(driverOpt, "=", 2) parts := strings.SplitN(driverOpt, "=", 2)
if len(parts) != 2 { if len(parts) != 2 {
err = fmt.Errorf("invalid key value pair format in driver options") return "", "", fmt.Errorf("invalid key value pair format in driver options")
} }
key = strings.TrimSpace(strings.ToLower(parts[0])) key := strings.TrimSpace(strings.ToLower(parts[0]))
value = strings.TrimSpace(strings.ToLower(parts[1])) value := strings.TrimSpace(strings.ToLower(parts[1]))
return return key, value, nil
} }

View File

@ -1,81 +0,0 @@
package opts
import (
"bufio"
"bytes"
"fmt"
"os"
"strings"
"unicode"
"unicode/utf8"
)
// ParseEnvFile reads a file with environment variables enumerated by lines
//
// ``Environment variable names used by the utilities in the Shell and
// Utilities volume of IEEE Std 1003.1-2001 consist solely of uppercase
// letters, digits, and the '_' (underscore) from the characters defined in
// Portable Character Set and do not begin with a digit. *But*, other
// characters may be permitted by an implementation; applications shall
// tolerate the presence of such names.''
// -- http://pubs.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap08.html
//
// As of #16585, it's up to application inside docker to validate or not
// environment variables, that's why we just strip leading whitespace and
// nothing more.
func ParseEnvFile(filename string) ([]string, error) {
fh, err := os.Open(filename)
if err != nil {
return []string{}, err
}
defer fh.Close()
lines := []string{}
scanner := bufio.NewScanner(fh)
currentLine := 0
utf8bom := []byte{0xEF, 0xBB, 0xBF}
for scanner.Scan() {
scannedBytes := scanner.Bytes()
if !utf8.Valid(scannedBytes) {
return []string{}, fmt.Errorf("env file %s contains invalid utf8 bytes at line %d: %v", filename, currentLine+1, scannedBytes)
}
// We trim UTF8 BOM
if currentLine == 0 {
scannedBytes = bytes.TrimPrefix(scannedBytes, utf8bom)
}
// trim the line from all leading whitespace first
line := strings.TrimLeftFunc(string(scannedBytes), unicode.IsSpace)
currentLine++
// line is not empty, and not starting with '#'
if len(line) > 0 && !strings.HasPrefix(line, "#") {
data := strings.SplitN(line, "=", 2)
// trim the front of a variable, but nothing else
variable := strings.TrimLeft(data[0], whiteSpaces)
if strings.ContainsAny(variable, whiteSpaces) {
return []string{}, ErrBadEnvVariable{fmt.Sprintf("variable '%s' has white spaces", variable)}
}
if len(data) > 1 {
// pass the value through, no trimming
lines = append(lines, fmt.Sprintf("%s=%s", variable, data[1]))
} else {
// if only a pass-through variable is given, clean it up.
lines = append(lines, fmt.Sprintf("%s=%s", strings.TrimSpace(line), os.Getenv(line)))
}
}
}
return lines, scanner.Err()
}
var whiteSpaces = " \t"
// ErrBadEnvVariable typed error for bad environment variable
type ErrBadEnvVariable struct {
msg string
}
func (e ErrBadEnvVariable) Error() string {
return fmt.Sprintf("poorly formatted environment: %s", e.msg)
}

View File

@ -1,87 +0,0 @@
package opts
import (
"fmt"
"strconv"
"strings"
"github.com/docker/docker/api/types/container"
)
// ReadKVStrings reads a file of line terminated key=value pairs, and overrides any keys
// present in the file with additional pairs specified in the override parameter
func ReadKVStrings(files []string, override []string) ([]string, error) {
envVariables := []string{}
for _, ef := range files {
parsedVars, err := ParseEnvFile(ef)
if err != nil {
return nil, err
}
envVariables = append(envVariables, parsedVars...)
}
// parse the '-e' and '--env' after, to allow override
envVariables = append(envVariables, override...)
return envVariables, nil
}
// ConvertKVStringsToMap converts ["key=value"] to {"key":"value"}
func ConvertKVStringsToMap(values []string) map[string]string {
result := make(map[string]string, len(values))
for _, value := range values {
kv := strings.SplitN(value, "=", 2)
if len(kv) == 1 {
result[kv[0]] = ""
} else {
result[kv[0]] = kv[1]
}
}
return result
}
// ConvertKVStringsToMapWithNil converts ["key=value"] to {"key":"value"}
// but set unset keys to nil - meaning the ones with no "=" in them.
// We use this in cases where we need to distinguish between
// FOO= and FOO
// where the latter case just means FOO was mentioned but not given a value
func ConvertKVStringsToMapWithNil(values []string) map[string]*string {
result := make(map[string]*string, len(values))
for _, value := range values {
kv := strings.SplitN(value, "=", 2)
if len(kv) == 1 {
result[kv[0]] = nil
} else {
result[kv[0]] = &kv[1]
}
}
return result
}
// ParseRestartPolicy returns the parsed policy or an error indicating what is incorrect
func ParseRestartPolicy(policy string) (container.RestartPolicy, error) {
p := container.RestartPolicy{}
if policy == "" {
return p, nil
}
parts := strings.Split(policy, ":")
if len(parts) > 2 {
return p, fmt.Errorf("invalid restart policy format")
}
if len(parts) == 2 {
count, err := strconv.Atoi(parts[1])
if err != nil {
return p, fmt.Errorf("maximum retry count must be an integer")
}
p.MaximumRetryCount = count
}
p.Name = parts[0]
return p, nil
}

View File

@ -86,8 +86,12 @@ func GetServiceHash(name string, config *ServiceConfig) string {
for _, sliceKey := range sliceKeys { for _, sliceKey := range sliceKeys {
io.WriteString(hash, fmt.Sprintf("%s, ", sliceKey)) io.WriteString(hash, fmt.Sprintf("%s, ", sliceKey))
} }
case *yaml.Networks:
io.WriteString(hash, fmt.Sprintf("%s, ", s.HashString()))
case *yaml.Volumes:
io.WriteString(hash, fmt.Sprintf("%s, ", s.HashString()))
default: default:
io.WriteString(hash, fmt.Sprintf("%v", serviceValue)) io.WriteString(hash, fmt.Sprintf("%v, ", serviceValue))
} }
} }

View File

@ -5,13 +5,19 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/Sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var defaultValues = make(map[string]string)
func isNum(c uint8) bool { func isNum(c uint8) bool {
return c >= '0' && c <= '9' return c >= '0' && c <= '9'
} }
func validVariableDefault(c uint8, line string, pos int) bool {
return (c == ':' && line[pos+1] == '-') || (c == '-')
}
func validVariableNameChar(c uint8) bool { func validVariableNameChar(c uint8) bool {
return c == '_' || return c == '_' ||
c >= 'A' && c <= 'Z' || c >= 'A' && c <= 'Z' ||
@ -36,6 +42,30 @@ func parseVariable(line string, pos int, mapping func(string) string) (string, i
return mapping(buffer.String()), pos, true return mapping(buffer.String()), pos, true
} }
func parseDefaultValue(line string, pos int) (string, int, bool) {
var buffer bytes.Buffer
// only skip :, :- and - at the beginning
for ; pos < len(line); pos++ {
c := line[pos]
if c == ':' || c == '-' {
continue
}
break
}
for ; pos < len(line); pos++ {
c := line[pos]
if c == '}' {
return buffer.String(), pos - 1, true
}
err := buffer.WriteByte(c)
if err != nil {
return "", pos, false
}
}
return "", 0, false
}
func parseVariableWithBraces(line string, pos int, mapping func(string) string) (string, int, bool) { func parseVariableWithBraces(line string, pos int, mapping func(string) string) (string, int, bool) {
var buffer bytes.Buffer var buffer bytes.Buffer
@ -49,10 +79,13 @@ func parseVariableWithBraces(line string, pos int, mapping func(string) string)
if bufferString == "" { if bufferString == "" {
return "", 0, false return "", 0, false
} }
return mapping(buffer.String()), pos, true return mapping(buffer.String()), pos, true
case validVariableNameChar(c): case validVariableNameChar(c):
buffer.WriteByte(c) buffer.WriteByte(c)
case validVariableDefault(c, line, pos):
defaultValue := ""
defaultValue, pos, _ = parseDefaultValue(line, pos)
defaultValues[buffer.String()] = defaultValue
default: default:
return "", 0, false return "", 0, false
} }
@ -143,10 +176,19 @@ func Interpolate(key string, data *interface{}, environmentLookup EnvironmentLoo
values := environmentLookup.Lookup(s, nil) values := environmentLookup.Lookup(s, nil)
if len(values) == 0 { if len(values) == 0 {
if val, ok := defaultValues[s]; ok {
return val
}
logrus.Warnf("The %s variable is not set. Substituting a blank string.", s) logrus.Warnf("The %s variable is not set. Substituting a blank string.", s)
return "" return ""
} }
if strings.SplitN(values[0], "=", 2)[1] == "" {
if val, ok := defaultValues[s]; ok {
return val
}
}
// Use first result if many are given // Use first result if many are given
value := values[0] value := values[0]

View File

@ -6,11 +6,12 @@ import (
"fmt" "fmt"
"strings" "strings"
"reflect"
"github.com/docker/docker/pkg/urlutil" "github.com/docker/docker/pkg/urlutil"
"github.com/docker/libcompose/utils" "github.com/docker/libcompose/utils"
composeYaml "github.com/docker/libcompose/yaml" composeYaml "github.com/docker/libcompose/yaml"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"reflect"
) )
var ( var (
@ -229,19 +230,11 @@ func readEnvFile(resourceLookup ResourceLookup, inFile string, serviceData RawSe
serviceData["environment"] = vars serviceData["environment"] = vars
delete(serviceData, "env_file")
return serviceData, nil return serviceData, nil
} }
func mergeConfig(baseService, serviceData RawService) RawService { func mergeConfig(baseService, serviceData RawService) RawService {
for k, v := range serviceData { for k, v := range serviceData {
// Image and build are mutually exclusive in merge
if k == "image" {
delete(baseService, "build")
} else if k == "build" {
delete(baseService, "image")
}
existing, ok := baseService[k] existing, ok := baseService[k]
if ok { if ok {
baseService[k] = merge(existing, v) baseService[k] = merge(existing, v)

View File

@ -4,8 +4,8 @@ import (
"fmt" "fmt"
"path" "path"
"github.com/Sirupsen/logrus"
"github.com/docker/libcompose/utils" "github.com/docker/libcompose/utils"
"github.com/sirupsen/logrus"
) )
// MergeServicesV1 merges a v1 compose file into an existing set of service configs // MergeServicesV1 merges a v1 compose file into an existing set of service configs
@ -29,7 +29,7 @@ func MergeServicesV1(existingServices *ServiceConfigs, environmentLookup Environ
return nil, err return nil, err
} }
data = mergeConfig(rawExistingService, data) data = mergeConfigV1(rawExistingService, data)
} }
datas[name] = data datas[name] = data
@ -148,7 +148,7 @@ func parseV1(resourceLookup ResourceLookup, environmentLookup EnvironmentLookup,
} }
} }
baseService = mergeConfig(baseService, serviceData) baseService = mergeConfigV1(baseService, serviceData)
logrus.Debugf("Merged result %#v", baseService) logrus.Debugf("Merged result %#v", baseService)
@ -177,3 +177,22 @@ func resolveContextV1(inFile string, serviceData RawService) RawService {
return serviceData return serviceData
} }
func mergeConfigV1(baseService, serviceData RawService) RawService {
for k, v := range serviceData {
// Image and build are mutually exclusive in merge
if k == "image" {
delete(baseService, "build")
} else if k == "build" {
delete(baseService, "image")
}
existing, ok := baseService[k]
if ok {
baseService[k] = merge(existing, v)
} else {
baseService[k] = v
}
}
return baseService
}

View File

@ -5,8 +5,8 @@ import (
"path" "path"
"strings" "strings"
"github.com/Sirupsen/logrus"
"github.com/docker/libcompose/utils" "github.com/docker/libcompose/utils"
"github.com/sirupsen/logrus"
) )
// MergeServicesV2 merges a v2 compose file into an existing set of service configs // MergeServicesV2 merges a v2 compose file into an existing set of service configs

View File

@ -3,7 +3,7 @@ package lookup
import ( import (
"strings" "strings"
"github.com/docker/docker/runconfig/opts" "github.com/docker/cli/opts"
"github.com/docker/libcompose/config" "github.com/docker/libcompose/config"
) )

View File

@ -7,7 +7,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/Sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// relativePath returns the proper relative path for the given file path. If // relativePath returns the proper relative path for the given file path. If

View File

@ -9,9 +9,9 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/Sirupsen/logrus"
"github.com/docker/libcompose/config" "github.com/docker/libcompose/config"
"github.com/docker/libcompose/logger" "github.com/docker/libcompose/logger"
"github.com/sirupsen/logrus"
) )
var projectRegexp = regexp.MustCompile("[^a-zA-Z0-9_.-]") var projectRegexp = regexp.MustCompile("[^a-zA-Z0-9_.-]")

View File

@ -3,8 +3,8 @@ package project
import ( import (
"bytes" "bytes"
"github.com/Sirupsen/logrus"
"github.com/docker/libcompose/project/events" "github.com/docker/libcompose/project/events"
"github.com/sirupsen/logrus"
) )
var ( var (

View File

@ -30,7 +30,8 @@ type Create struct {
// Run holds options of compose run. // Run holds options of compose run.
type Run struct { type Run struct {
Detached bool Detached bool
DisableTty bool
} }
// Up holds options of compose up. // Up holds options of compose up.

View File

@ -10,13 +10,13 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
log "github.com/Sirupsen/logrus"
"github.com/docker/libcompose/config" "github.com/docker/libcompose/config"
"github.com/docker/libcompose/logger" "github.com/docker/libcompose/logger"
"github.com/docker/libcompose/lookup" "github.com/docker/libcompose/lookup"
"github.com/docker/libcompose/project/events" "github.com/docker/libcompose/project/events"
"github.com/docker/libcompose/utils" "github.com/docker/libcompose/utils"
"github.com/docker/libcompose/yaml" "github.com/docker/libcompose/yaml"
log "github.com/sirupsen/logrus"
) )
// ComposeVersion is name of docker-compose.yml file syntax supported version // ComposeVersion is name of docker-compose.yml file syntax supported version

View File

@ -2,6 +2,7 @@ package project
import ( import (
"fmt" "fmt"
"sync"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -12,6 +13,8 @@ import (
// the Filter struct. // the Filter struct.
func (p *Project) Containers(ctx context.Context, filter Filter, services ...string) ([]string, error) { func (p *Project) Containers(ctx context.Context, filter Filter, services ...string) ([]string, error) {
containers := []string{} containers := []string{}
var lock sync.Mutex
err := p.forEach(services, wrapperAction(func(wrapper *serviceWrapper, wrappers map[string]*serviceWrapper) { err := p.forEach(services, wrapperAction(func(wrapper *serviceWrapper, wrappers map[string]*serviceWrapper) {
wrapper.Do(nil, events.NoEvent, events.NoEvent, func(service Service) error { wrapper.Do(nil, events.NoEvent, events.NoEvent, func(service Service) error {
serviceContainers, innerErr := service.Containers(ctx) serviceContainers, innerErr := service.Containers(ctx)
@ -37,7 +40,9 @@ func (p *Project) Containers(ctx context.Context, filter Filter, services ...str
return fmt.Errorf("Invalid container filter: %s", filter.State) return fmt.Errorf("Invalid container filter: %s", filter.State)
} }
containerID := container.ID() containerID := container.ID()
lock.Lock()
containers = append(containers, containerID) containers = append(containers, containerID)
lock.Unlock()
} }
return nil return nil
}) })

View File

@ -5,7 +5,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Scale scales the specified services. // Scale scales the specified services.

View File

@ -3,8 +3,8 @@ package project
import ( import (
"sync" "sync"
log "github.com/Sirupsen/logrus"
"github.com/docker/libcompose/project/events" "github.com/docker/libcompose/project/events"
log "github.com/sirupsen/logrus"
) )
type serviceWrapper struct { type serviceWrapper struct {

View File

@ -5,7 +5,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/Sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )

View File

@ -3,6 +3,8 @@ package yaml
import ( import (
"errors" "errors"
"fmt" "fmt"
"sort"
"strings"
) )
// Networks represents a list of service networks in compose file. // Networks represents a list of service networks in compose file.
@ -20,6 +22,35 @@ type Network struct {
IPv6Address string `yaml:"ipv6_address,omitempty"` IPv6Address string `yaml:"ipv6_address,omitempty"`
} }
// Generate a hash string to detect service network config changes
func (n *Networks) HashString() string {
if n == nil {
return ""
}
result := []string{}
for _, net := range n.Networks {
result = append(result, net.HashString())
}
sort.Strings(result)
return strings.Join(result, ",")
}
// Generate a hash string to detect service network config changes
func (n *Network) HashString() string {
if n == nil {
return ""
}
result := []string{}
result = append(result, n.Name)
result = append(result, n.RealName)
sort.Strings(n.Aliases)
result = append(result, strings.Join(n.Aliases, ","))
result = append(result, n.IPv4Address)
result = append(result, n.IPv6Address)
sort.Strings(result)
return strings.Join(result, ",")
}
// MarshalYAML implements the Marshaller interface. // MarshalYAML implements the Marshaller interface.
func (n Networks) MarshalYAML() (interface{}, error) { func (n Networks) MarshalYAML() (interface{}, error) {
m := map[string]*Network{} m := map[string]*Network{}

View File

@ -3,6 +3,7 @@ package yaml
import ( import (
"errors" "errors"
"fmt" "fmt"
"sort"
"strings" "strings"
) )
@ -19,6 +20,19 @@ type Volume struct {
AccessMode string `yaml:"-"` AccessMode string `yaml:"-"`
} }
// Generate a hash string to detect service volume config changes
func (v *Volumes) HashString() string {
if v == nil {
return ""
}
result := []string{}
for _, vol := range v.Volumes {
result = append(result, vol.String())
}
sort.Strings(result)
return strings.Join(result, ",")
}
// String implements the Stringer interface. // String implements the Stringer interface.
func (v *Volume) String() string { func (v *Volume) String() string {
var paths []string var paths []string

View File

@ -530,22 +530,15 @@ func (s *Struct) nested(val reflect.Value) interface{} {
finalVal = m finalVal = m
} }
case reflect.Map: case reflect.Map:
// get the element type of the map v := val.Type().Elem()
mapElem := val.Type() if v.Kind() == reflect.Ptr {
switch val.Type().Kind() { v = v.Elem()
case reflect.Ptr, reflect.Array, reflect.Map,
reflect.Slice, reflect.Chan:
mapElem = val.Type().Elem()
if mapElem.Kind() == reflect.Ptr {
mapElem = mapElem.Elem()
}
} }
// only iterate over struct types, ie: map[string]StructType, // only iterate over struct types, ie: map[string]StructType,
// map[string][]StructType, // map[string][]StructType,
if mapElem.Kind() == reflect.Struct || if v.Kind() == reflect.Struct ||
(mapElem.Kind() == reflect.Slice && (v.Kind() == reflect.Slice && v.Elem().Kind() == reflect.Struct) {
mapElem.Elem().Kind() == reflect.Struct) {
m := make(map[string]interface{}, val.Len()) m := make(map[string]interface{}, val.Len())
for _, k := range val.MapKeys() { for _, k := range val.MapKeys() {
m[k.String()] = s.nested(val.MapIndex(k)) m[k.String()] = s.nested(val.MapIndex(k))

View File

@ -9,7 +9,6 @@ package fsnotify
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
) )
@ -61,6 +60,3 @@ func (op Op) String() string {
func (e Event) String() string { func (e Event) String() string {
return fmt.Sprintf("%q: %s", e.Name, e.Op.String()) return fmt.Sprintf("%q: %s", e.Name, e.Op.String())
} }
// Common errors that can be reported by a watcher
var ErrEventOverflow = errors.New("fsnotify queue overflow")

View File

@ -24,6 +24,7 @@ type Watcher struct {
Events chan Event Events chan Event
Errors chan error Errors chan error
mu sync.Mutex // Map access mu sync.Mutex // Map access
cv *sync.Cond // sync removing on rm_watch with IN_IGNORE
fd int fd int
poller *fdPoller poller *fdPoller
watches map[string]*watch // Map of inotify watches (key: path) watches map[string]*watch // Map of inotify watches (key: path)
@ -55,6 +56,7 @@ func NewWatcher() (*Watcher, error) {
done: make(chan struct{}), done: make(chan struct{}),
doneResp: make(chan struct{}), doneResp: make(chan struct{}),
} }
w.cv = sync.NewCond(&w.mu)
go w.readEvents() go w.readEvents()
return w, nil return w, nil
@ -101,23 +103,21 @@ func (w *Watcher) Add(name string) error {
var flags uint32 = agnosticEvents var flags uint32 = agnosticEvents
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() watchEntry, found := w.watches[name]
watchEntry := w.watches[name] w.mu.Unlock()
if watchEntry != nil { if found {
flags |= watchEntry.flags | unix.IN_MASK_ADD watchEntry.flags |= flags
flags |= unix.IN_MASK_ADD
} }
wd, errno := unix.InotifyAddWatch(w.fd, name, flags) wd, errno := unix.InotifyAddWatch(w.fd, name, flags)
if wd == -1 { if wd == -1 {
return errno return errno
} }
if watchEntry == nil { w.mu.Lock()
w.watches[name] = &watch{wd: uint32(wd), flags: flags} w.watches[name] = &watch{wd: uint32(wd), flags: flags}
w.paths[wd] = name w.paths[wd] = name
} else { w.mu.Unlock()
watchEntry.wd = uint32(wd)
watchEntry.flags = flags
}
return nil return nil
} }
@ -135,13 +135,6 @@ func (w *Watcher) Remove(name string) error {
if !ok { if !ok {
return fmt.Errorf("can't remove non-existent inotify watch for: %s", name) return fmt.Errorf("can't remove non-existent inotify watch for: %s", name)
} }
// We successfully removed the watch if InotifyRmWatch doesn't return an
// error, we need to clean up our internal state to ensure it matches
// inotify's kernel state.
delete(w.paths, int(watch.wd))
delete(w.watches, name)
// inotify_rm_watch will return EINVAL if the file has been deleted; // inotify_rm_watch will return EINVAL if the file has been deleted;
// the inotify will already have been removed. // the inotify will already have been removed.
// watches and pathes are deleted in ignoreLinux() implicitly and asynchronously // watches and pathes are deleted in ignoreLinux() implicitly and asynchronously
@ -159,6 +152,13 @@ func (w *Watcher) Remove(name string) error {
return errno return errno
} }
// wait until ignoreLinux() deleting maps
exists := true
for exists {
w.cv.Wait()
_, exists = w.watches[name]
}
return nil return nil
} }
@ -245,31 +245,13 @@ func (w *Watcher) readEvents() {
mask := uint32(raw.Mask) mask := uint32(raw.Mask)
nameLen := uint32(raw.Len) nameLen := uint32(raw.Len)
if mask&unix.IN_Q_OVERFLOW != 0 {
select {
case w.Errors <- ErrEventOverflow:
case <-w.done:
return
}
}
// If the event happened to the watched directory or the watched file, the kernel // If the event happened to the watched directory or the watched file, the kernel
// doesn't append the filename to the event, but we would like to always fill the // doesn't append the filename to the event, but we would like to always fill the
// the "Name" field with a valid filename. We retrieve the path of the watch from // the "Name" field with a valid filename. We retrieve the path of the watch from
// the "paths" map. // the "paths" map.
w.mu.Lock() w.mu.Lock()
name, ok := w.paths[int(raw.Wd)] name := w.paths[int(raw.Wd)]
// IN_DELETE_SELF occurs when the file/directory being watched is removed.
// This is a sign to clean up the maps, otherwise we are no longer in sync
// with the inotify kernel state which has already deleted the watch
// automatically.
if ok && mask&unix.IN_DELETE_SELF == unix.IN_DELETE_SELF {
delete(w.paths, int(raw.Wd))
delete(w.watches, name)
}
w.mu.Unlock() w.mu.Unlock()
if nameLen > 0 { if nameLen > 0 {
// Point "bytes" at the first byte of the filename // Point "bytes" at the first byte of the filename
bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[offset+unix.SizeofInotifyEvent])) bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[offset+unix.SizeofInotifyEvent]))
@ -280,7 +262,7 @@ func (w *Watcher) readEvents() {
event := newEvent(name, mask) event := newEvent(name, mask)
// Send the events that are not ignored on the events channel // Send the events that are not ignored on the events channel
if !event.ignoreLinux(mask) { if !event.ignoreLinux(w, raw.Wd, mask) {
select { select {
case w.Events <- event: case w.Events <- event:
case <-w.done: case <-w.done:
@ -297,9 +279,15 @@ func (w *Watcher) readEvents() {
// Certain types of events can be "ignored" and not sent over the Events // Certain types of events can be "ignored" and not sent over the Events
// channel. Such as events marked ignore by the kernel, or MODIFY events // channel. Such as events marked ignore by the kernel, or MODIFY events
// against files that do not exist. // against files that do not exist.
func (e *Event) ignoreLinux(mask uint32) bool { func (e *Event) ignoreLinux(w *Watcher, wd int32, mask uint32) bool {
// Ignore anything the inotify API says to ignore // Ignore anything the inotify API says to ignore
if mask&unix.IN_IGNORED == unix.IN_IGNORED { if mask&unix.IN_IGNORED == unix.IN_IGNORED {
w.mu.Lock()
defer w.mu.Unlock()
name := w.paths[int(wd)]
delete(w.paths, int(wd))
delete(w.watches, name)
w.cv.Broadcast()
return true return true
} }

View File

@ -89,7 +89,7 @@ func (d *decoder) decode(name string, node ast.Node, result reflect.Value) error
switch k.Kind() { switch k.Kind() {
case reflect.Bool: case reflect.Bool:
return d.decodeBool(name, node, result) return d.decodeBool(name, node, result)
case reflect.Float32, reflect.Float64: case reflect.Float64:
return d.decodeFloat(name, node, result) return d.decodeFloat(name, node, result)
case reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32, reflect.Int64:
return d.decodeInt(name, node, result) return d.decodeInt(name, node, result)
@ -137,13 +137,13 @@ func (d *decoder) decodeBool(name string, node ast.Node, result reflect.Value) e
func (d *decoder) decodeFloat(name string, node ast.Node, result reflect.Value) error { func (d *decoder) decodeFloat(name string, node ast.Node, result reflect.Value) error {
switch n := node.(type) { switch n := node.(type) {
case *ast.LiteralType: case *ast.LiteralType:
if n.Token.Type == token.FLOAT || n.Token.Type == token.NUMBER { if n.Token.Type == token.FLOAT {
v, err := strconv.ParseFloat(n.Token.Text, 64) v, err := strconv.ParseFloat(n.Token.Text, 64)
if err != nil { if err != nil {
return err return err
} }
result.Set(reflect.ValueOf(v).Convert(result.Type())) result.Set(reflect.ValueOf(v))
return nil return nil
} }
} }

View File

@ -3,7 +3,6 @@
package parser package parser
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -37,11 +36,6 @@ func newParser(src []byte) *Parser {
// Parse returns the fully parsed source and returns the abstract syntax tree. // Parse returns the fully parsed source and returns the abstract syntax tree.
func Parse(src []byte) (*ast.File, error) { func Parse(src []byte) (*ast.File, error) {
// normalize all line endings
// since the scanner and output only work with "\n" line endings, we may
// end up with dangling "\r" characters in the parsed data.
src = bytes.Replace(src, []byte("\r\n"), []byte("\n"), -1)
p := newParser(src) p := newParser(src)
return p.Parse() return p.Parse()
} }
@ -197,12 +191,9 @@ func (p *Parser) objectItem() (*ast.ObjectItem, error) {
keyStr = append(keyStr, k.Token.Text) keyStr = append(keyStr, k.Token.Text)
} }
return nil, &PosError{ return nil, fmt.Errorf(
Pos: p.tok.Pos, "key '%s' expected start of object ('{') or assignment ('=')",
Err: fmt.Errorf( strings.Join(keyStr, " "))
"key '%s' expected start of object ('{') or assignment ('=')",
strings.Join(keyStr, " ")),
}
} }
// do a look-ahead for line comment // do a look-ahead for line comment
@ -322,10 +313,7 @@ func (p *Parser) objectType() (*ast.ObjectType, error) {
// No error, scan and expect the ending to be a brace // No error, scan and expect the ending to be a brace
if tok := p.scan(); tok.Type != token.RBRACE { if tok := p.scan(); tok.Type != token.RBRACE {
return nil, &PosError{ return nil, fmt.Errorf("object expected closing RBRACE got: %s", tok.Type)
Pos: tok.Pos,
Err: fmt.Errorf("object expected closing RBRACE got: %s", tok.Type),
}
} }
o.List = l o.List = l
@ -358,7 +346,7 @@ func (p *Parser) listType() (*ast.ListType, error) {
} }
} }
switch tok.Type { switch tok.Type {
case token.BOOL, token.NUMBER, token.FLOAT, token.STRING, token.HEREDOC: case token.NUMBER, token.FLOAT, token.STRING, token.HEREDOC:
node, err := p.literalType() node, err := p.literalType()
if err != nil { if err != nil {
return nil, err return nil, err
@ -400,16 +388,12 @@ func (p *Parser) listType() (*ast.ListType, error) {
} }
l.Add(node) l.Add(node)
needComma = true needComma = true
case token.BOOL:
// TODO(arslan) should we support? not supported by HCL yet
case token.LBRACK: case token.LBRACK:
node, err := p.listType() // TODO(arslan) should we support nested lists? Even though it's
if err != nil { // written in README of HCL, it's not a part of the grammar
return nil, &PosError{ // (not defined in parse.y)
Pos: tok.Pos,
Err: fmt.Errorf(
"error while trying to parse list within list: %s", err),
}
}
l.Add(node)
case token.RBRACK: case token.RBRACK:
// finished // finished
l.Rbrack = p.tok.Pos l.Rbrack = p.tok.Pos

View File

@ -147,7 +147,7 @@ func (p *Parser) objectKey() ([]*ast.ObjectKey, error) {
// Done // Done
return keys, nil return keys, nil
case token.ILLEGAL: case token.ILLEGAL:
return nil, errors.New("illegal") fmt.Println("illegal")
default: default:
return nil, fmt.Errorf("expected: STRING got: %s", p.tok.Type) return nil, fmt.Errorf("expected: STRING got: %s", p.tok.Type)
} }

View File

@ -511,9 +511,6 @@ func (p *Properties) Set(key, value string) (prev string, ok bool, err error) {
if p.DisableExpansion { if p.DisableExpansion {
prev, ok = p.Get(key) prev, ok = p.Get(key)
p.m[key] = value p.m[key] = value
if !ok {
p.k = append(p.k, key)
}
return prev, ok, nil return prev, ok, nil
} }

View File

@ -38,6 +38,12 @@ func DecodeHookExec(
raw DecodeHookFunc, raw DecodeHookFunc,
from reflect.Type, to reflect.Type, from reflect.Type, to reflect.Type,
data interface{}) (interface{}, error) { data interface{}) (interface{}, error) {
// Build our arguments that reflect expects
argVals := make([]reflect.Value, 3)
argVals[0] = reflect.ValueOf(from)
argVals[1] = reflect.ValueOf(to)
argVals[2] = reflect.ValueOf(data)
switch f := typedDecodeHook(raw).(type) { switch f := typedDecodeHook(raw).(type) {
case DecodeHookFuncType: case DecodeHookFuncType:
return f(from, to, data) return f(from, to, data)
@ -115,11 +121,6 @@ func StringToTimeDurationHookFunc() DecodeHookFunc {
} }
} }
// WeaklyTypedHook is a DecodeHookFunc which adds support for weak typing to
// the decoder.
//
// Note that this is significantly different from the WeaklyTypedInput option
// of the DecoderConfig.
func WeaklyTypedHook( func WeaklyTypedHook(
f reflect.Kind, f reflect.Kind,
t reflect.Kind, t reflect.Kind,
@ -131,8 +132,9 @@ func WeaklyTypedHook(
case reflect.Bool: case reflect.Bool:
if dataVal.Bool() { if dataVal.Bool() {
return "1", nil return "1", nil
} else {
return "0", nil
} }
return "0", nil
case reflect.Float32: case reflect.Float32:
return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil return strconv.FormatFloat(dataVal.Float(), 'f', -1, 64), nil
case reflect.Int: case reflect.Int:

View File

@ -1,5 +1,5 @@
// Package mapstructure exposes functionality to convert an arbitrary // The mapstructure package exposes functionality to convert an
// map[string]interface{} into a native Go structure. // arbitrary map[string]interface{} into a native Go structure.
// //
// The Go structure can be arbitrarily complex, containing slices, // The Go structure can be arbitrarily complex, containing slices,
// other structs, etc. and the decoder will properly decode nested // other structs, etc. and the decoder will properly decode nested
@ -32,12 +32,7 @@ import (
// both. // both.
type DecodeHookFunc interface{} type DecodeHookFunc interface{}
// DecodeHookFuncType is a DecodeHookFunc which has complete information about
// the source and target types.
type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error) type DecodeHookFuncType func(reflect.Type, reflect.Type, interface{}) (interface{}, error)
// DecodeHookFuncKind is a DecodeHookFunc which knows only the Kinds of the
// source and target types.
type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error) type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, interface{}) (interface{}, error)
// DecoderConfig is the configuration that is used to create a new decoder // DecoderConfig is the configuration that is used to create a new decoder
@ -74,9 +69,6 @@ type DecoderConfig struct {
// - empty array = empty map and vice versa // - empty array = empty map and vice versa
// - negative numbers to overflowed uint values (base 10) // - negative numbers to overflowed uint values (base 10)
// - slice of maps to a merged map // - slice of maps to a merged map
// - single values are converted to slices if required. Each
// element is weakly decoded. For example: "4" can become []int{4}
// if the target type is an int slice.
// //
WeaklyTypedInput bool WeaklyTypedInput bool
@ -210,7 +202,7 @@ func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error
d.config.DecodeHook, d.config.DecodeHook,
dataVal.Type(), val.Type(), data) dataVal.Type(), val.Type(), data)
if err != nil { if err != nil {
return fmt.Errorf("error decoding '%s': %s", name, err) return err
} }
} }
@ -237,8 +229,6 @@ func (d *Decoder) decode(name string, data interface{}, val reflect.Value) error
err = d.decodePtr(name, data, val) err = d.decodePtr(name, data, val)
case reflect.Slice: case reflect.Slice:
err = d.decodeSlice(name, data, val) err = d.decodeSlice(name, data, val)
case reflect.Func:
err = d.decodeFunc(name, data, val)
default: default:
// If we reached this point then we weren't able to decode it // If we reached this point then we weren't able to decode it
return fmt.Errorf("%s: unsupported type: %s", name, dataKind) return fmt.Errorf("%s: unsupported type: %s", name, dataKind)
@ -441,7 +431,7 @@ func (d *Decoder) decodeFloat(name string, data interface{}, val reflect.Value)
case dataKind == reflect.Uint: case dataKind == reflect.Uint:
val.SetFloat(float64(dataVal.Uint())) val.SetFloat(float64(dataVal.Uint()))
case dataKind == reflect.Float32: case dataKind == reflect.Float32:
val.SetFloat(dataVal.Float()) val.SetFloat(float64(dataVal.Float()))
case dataKind == reflect.Bool && d.config.WeaklyTypedInput: case dataKind == reflect.Bool && d.config.WeaklyTypedInput:
if dataVal.Bool() { if dataVal.Bool() {
val.SetFloat(1) val.SetFloat(1)
@ -556,12 +546,7 @@ func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) er
// into that. Then set the value of the pointer to this type. // into that. Then set the value of the pointer to this type.
valType := val.Type() valType := val.Type()
valElemType := valType.Elem() valElemType := valType.Elem()
realVal := reflect.New(valElemType)
realVal := val
if realVal.IsNil() || d.config.ZeroFields {
realVal = reflect.New(valElemType)
}
if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil { if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil {
return err return err
} }
@ -570,19 +555,6 @@ func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) er
return nil return nil
} }
func (d *Decoder) decodeFunc(name string, data interface{}, val reflect.Value) error {
// Create an element of the concrete (non pointer) type and decode
// into that. Then set the value of the pointer to this type.
dataVal := reflect.Indirect(reflect.ValueOf(data))
if val.Type() != dataVal.Type() {
return fmt.Errorf(
"'%s' expected type '%s', got unconvertible type '%s'",
name, val.Type(), dataVal.Type())
}
val.Set(dataVal)
return nil
}
func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error { func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) error {
dataVal := reflect.Indirect(reflect.ValueOf(data)) dataVal := reflect.Indirect(reflect.ValueOf(data))
dataValKind := dataVal.Kind() dataValKind := dataVal.Kind()
@ -590,44 +562,26 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
valElemType := valType.Elem() valElemType := valType.Elem()
sliceType := reflect.SliceOf(valElemType) sliceType := reflect.SliceOf(valElemType)
valSlice := val // Check input type
if valSlice.IsNil() || d.config.ZeroFields { if dataValKind != reflect.Array && dataValKind != reflect.Slice {
// Check input type // Accept empty map instead of array/slice in weakly typed mode
if dataValKind != reflect.Array && dataValKind != reflect.Slice { if d.config.WeaklyTypedInput && dataVal.Kind() == reflect.Map && dataVal.Len() == 0 {
if d.config.WeaklyTypedInput { val.Set(reflect.MakeSlice(sliceType, 0, 0))
switch { return nil
// Empty maps turn into empty slices } else {
case dataValKind == reflect.Map:
if dataVal.Len() == 0 {
val.Set(reflect.MakeSlice(sliceType, 0, 0))
return nil
}
// All other types we try to convert to the slice type
// and "lift" it into it. i.e. a string becomes a string slice.
default:
// Just re-try this function with data as a slice.
return d.decodeSlice(name, []interface{}{data}, val)
}
}
return fmt.Errorf( return fmt.Errorf(
"'%s': source data must be an array or slice, got %s", name, dataValKind) "'%s': source data must be an array or slice, got %s", name, dataValKind)
} }
// Make a new slice to hold our result, same size as the original data.
valSlice = reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len())
} }
// Make a new slice to hold our result, same size as the original data.
valSlice := reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len())
// Accumulate any errors // Accumulate any errors
errors := make([]string, 0) errors := make([]string, 0)
for i := 0; i < dataVal.Len(); i++ { for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface() currentData := dataVal.Index(i).Interface()
for valSlice.Len() <= i {
valSlice = reflect.Append(valSlice, reflect.Zero(valElemType))
}
currentField := valSlice.Index(i) currentField := valSlice.Index(i)
fieldName := fmt.Sprintf("%s[%d]", name, i) fieldName := fmt.Sprintf("%s[%d]", name, i)

117
vendor/github.com/pelletier/go-buffruneio/buffruneio.go generated vendored Normal file
View File

@ -0,0 +1,117 @@
// Package buffruneio is a wrapper around bufio to provide buffered runes access with unlimited unreads.
package buffruneio
import (
"bufio"
"container/list"
"errors"
"io"
)
// Rune to indicate end of file.
const (
EOF = -(iota + 1)
)
// ErrNoRuneToUnread is returned by UnreadRune() when the read index is already at the beginning of the buffer.
var ErrNoRuneToUnread = errors.New("no rune to unwind")
// Reader implements runes buffering for an io.Reader object.
type Reader struct {
buffer *list.List
current *list.Element
input *bufio.Reader
}
// NewReader returns a new Reader.
func NewReader(rd io.Reader) *Reader {
return &Reader{
buffer: list.New(),
input: bufio.NewReader(rd),
}
}
type runeWithSize struct {
r rune
size int
}
func (rd *Reader) feedBuffer() error {
r, size, err := rd.input.ReadRune()
if err != nil {
if err != io.EOF {
return err
}
r = EOF
}
newRuneWithSize := runeWithSize{r, size}
rd.buffer.PushBack(newRuneWithSize)
if rd.current == nil {
rd.current = rd.buffer.Back()
}
return nil
}
// ReadRune reads the next rune from buffer, or from the underlying reader if needed.
func (rd *Reader) ReadRune() (rune, int, error) {
if rd.current == rd.buffer.Back() || rd.current == nil {
err := rd.feedBuffer()
if err != nil {
return EOF, 0, err
}
}
runeWithSize := rd.current.Value.(runeWithSize)
rd.current = rd.current.Next()
return runeWithSize.r, runeWithSize.size, nil
}
// UnreadRune pushes back the previously read rune in the buffer, extending it if needed.
func (rd *Reader) UnreadRune() error {
if rd.current == rd.buffer.Front() {
return ErrNoRuneToUnread
}
if rd.current == nil {
rd.current = rd.buffer.Back()
} else {
rd.current = rd.current.Prev()
}
return nil
}
// Forget removes runes stored before the current stream position index.
func (rd *Reader) Forget() {
if rd.current == nil {
rd.current = rd.buffer.Back()
}
for ; rd.current != rd.buffer.Front(); rd.buffer.Remove(rd.current.Prev()) {
}
}
// PeekRune returns at most the next n runes, reading from the uderlying source if
// needed. Does not move the current index. It includes EOF if reached.
func (rd *Reader) PeekRunes(n int) []rune {
res := make([]rune, 0, n)
cursor := rd.current
for i := 0; i < n; i++ {
if cursor == nil {
err := rd.feedBuffer()
if err != nil {
return res
}
cursor = rd.buffer.Back()
}
if cursor != nil {
r := cursor.Value.(runeWithSize).r
res = append(res, r)
if r == EOF {
return res
}
cursor = cursor.Next()
}
}
return res
}

View File

@ -1,23 +1,250 @@
// Package toml is a TOML parser and manipulation library. // Package toml is a TOML markup language parser.
// //
// This version supports the specification as described in // This version supports the specification as described in
// https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md // https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md
// //
// Marshaling // TOML Parsing
// //
// Go-toml can marshal and unmarshal TOML documents from and to data // TOML data may be parsed in two ways: by file, or by string.
// structures.
// //
// TOML document as a tree // // load TOML data by filename
// tree, err := toml.LoadFile("filename.toml")
// //
// Go-toml can operate on a TOML document as a tree. Use one of the Load* // // load TOML data stored in a string
// functions to parse TOML data and obtain a Tree instance, then one of its // tree, err := toml.Load(stringContainingTomlData)
// methods to manipulate the tree.
// //
// JSONPath-like queries // Either way, the result is a TomlTree object that can be used to navigate the
// structure and data within the original document.
// //
// The package github.com/pelletier/go-toml/query implements a system //
// similar to JSONPath to quickly retrive elements of a TOML document using a // Getting data from the TomlTree
// single expression. See the package documentation for more information. //
// After parsing TOML data with Load() or LoadFile(), use the Has() and Get()
// methods on the returned TomlTree, to find your way through the document data.
//
// if tree.Has("foo") {
// fmt.Println("foo is:", tree.Get("foo"))
// }
//
// Working with Paths
//
// Go-toml has support for basic dot-separated key paths on the Has(), Get(), Set()
// and GetDefault() methods. These are the same kind of key paths used within the
// TOML specification for struct tames.
//
// // looks for a key named 'baz', within struct 'bar', within struct 'foo'
// tree.Has("foo.bar.baz")
//
// // returns the key at this path, if it is there
// tree.Get("foo.bar.baz")
//
// TOML allows keys to contain '.', which can cause this syntax to be problematic
// for some documents. In such cases, use the GetPath(), HasPath(), and SetPath(),
// methods to explicitly define the path. This form is also faster, since
// it avoids having to parse the passed key for '.' delimiters.
//
// // looks for a key named 'baz', within struct 'bar', within struct 'foo'
// tree.HasPath([]string{"foo","bar","baz"})
//
// // returns the key at this path, if it is there
// tree.GetPath([]string{"foo","bar","baz"})
//
// Note that this is distinct from the heavyweight query syntax supported by
// TomlTree.Query() and the Query() struct (see below).
//
// Position Support
//
// Each element within the TomlTree is stored with position metadata, which is
// invaluable for providing semantic feedback to a user. This helps in
// situations where the TOML file parses correctly, but contains data that is
// not correct for the application. In such cases, an error message can be
// generated that indicates the problem line and column number in the source
// TOML document.
//
// // load TOML data
// tree, _ := toml.Load("filename.toml")
//
// // get an entry and report an error if it's the wrong type
// element := tree.Get("foo")
// if value, ok := element.(int64); !ok {
// return fmt.Errorf("%v: Element 'foo' must be an integer", tree.GetPosition("foo"))
// }
//
// // report an error if an expected element is missing
// if !tree.Has("bar") {
// return fmt.Errorf("%v: Expected 'bar' element", tree.GetPosition(""))
// }
//
// Query Support
//
// The TOML query path implementation is based loosely on the JSONPath specification:
// http://goessner.net/articles/JsonPath/
//
// The idea behind a query path is to allow quick access to any element, or set
// of elements within TOML document, with a single expression.
//
// result, err := tree.Query("$.foo.bar.baz")
//
// This is roughly equivalent to:
//
// next := tree.Get("foo")
// if next != nil {
// next = next.Get("bar")
// if next != nil {
// next = next.Get("baz")
// }
// }
// result := next
//
// err is nil if any parsing exception occurs.
//
// If no node in the tree matches the query, result will simply contain an empty list of
// items.
//
// As illustrated above, the query path is much more efficient, especially since
// the structure of the TOML file can vary. Rather than making assumptions about
// a document's structure, a query allows the programmer to make structured
// requests into the document, and get zero or more values as a result.
//
// The syntax of a query begins with a root token, followed by any number
// sub-expressions:
//
// $
// Root of the TOML tree. This must always come first.
// .name
// Selects child of this node, where 'name' is a TOML key
// name.
// ['name']
// Selects child of this node, where 'name' is a string
// containing a TOML key name.
// [index]
// Selcts child array element at 'index'.
// ..expr
// Recursively selects all children, filtered by an a union,
// index, or slice expression.
// ..*
// Recursive selection of all nodes at this point in the
// tree.
// .*
// Selects all children of the current node.
// [expr,expr]
// Union operator - a logical 'or' grouping of two or more
// sub-expressions: index, key name, or filter.
// [start:end:step]
// Slice operator - selects array elements from start to
// end-1, at the given step. All three arguments are
// optional.
// [?(filter)]
// Named filter expression - the function 'filter' is
// used to filter children at this node.
//
// Query Indexes And Slices
//
// Index expressions perform no bounds checking, and will contribute no
// values to the result set if the provided index or index range is invalid.
// Negative indexes represent values from the end of the array, counting backwards.
//
// // select the last index of the array named 'foo'
// tree.Query("$.foo[-1]")
//
// Slice expressions are supported, by using ':' to separate a start/end index pair.
//
// // select up to the first five elements in the array
// tree.Query("$.foo[0:5]")
//
// Slice expressions also allow negative indexes for the start and stop
// arguments.
//
// // select all array elements.
// tree.Query("$.foo[0:-1]")
//
// Slice expressions may have an optional stride/step parameter:
//
// // select every other element
// tree.Query("$.foo[0:-1:2]")
//
// Slice start and end parameters are also optional:
//
// // these are all equivalent and select all the values in the array
// tree.Query("$.foo[:]")
// tree.Query("$.foo[0:]")
// tree.Query("$.foo[:-1]")
// tree.Query("$.foo[0:-1:]")
// tree.Query("$.foo[::1]")
// tree.Query("$.foo[0::1]")
// tree.Query("$.foo[:-1:1]")
// tree.Query("$.foo[0:-1:1]")
//
// Query Filters
//
// Query filters are used within a Union [,] or single Filter [] expression.
// A filter only allows nodes that qualify through to the next expression,
// and/or into the result set.
//
// // returns children of foo that are permitted by the 'bar' filter.
// tree.Query("$.foo[?(bar)]")
//
// There are several filters provided with the library:
//
// tree
// Allows nodes of type TomlTree.
// int
// Allows nodes of type int64.
// float
// Allows nodes of type float64.
// string
// Allows nodes of type string.
// time
// Allows nodes of type time.Time.
// bool
// Allows nodes of type bool.
//
// Query Results
//
// An executed query returns a QueryResult object. This contains the nodes
// in the TOML tree that qualify the query expression. Position information
// is also available for each value in the set.
//
// // display the results of a query
// results := tree.Query("$.foo.bar.baz")
// for idx, value := results.Values() {
// fmt.Println("%v: %v", results.Positions()[idx], value)
// }
//
// Compiled Queries
//
// Queries may be executed directly on a TomlTree object, or compiled ahead
// of time and executed discretely. The former is more convienent, but has the
// penalty of having to recompile the query expression each time.
//
// // basic query
// results := tree.Query("$.foo.bar.baz")
//
// // compiled query
// query := toml.CompileQuery("$.foo.bar.baz")
// results := query.Execute(tree)
//
// // run the compiled query again on a different tree
// moreResults := query.Execute(anotherTree)
//
// User Defined Query Filters
//
// Filter expressions may also be user defined by using the SetFilter()
// function on the Query object. The function must return true/false, which
// signifies if the passed node is kept or discarded, respectively.
//
// // create a query that references a user-defined filter
// query, _ := CompileQuery("$[?(bazOnly)]")
//
// // define the filter, and assign it to the query
// query.SetFilter("bazOnly", func(node interface{}) bool{
// if tree, ok := node.(*TomlTree); ok {
// return tree.Has("baz")
// }
// return false // reject all other node types
// })
//
// // run the query
// query.Execute(tree)
// //
package toml package toml

View File

@ -6,12 +6,14 @@
package toml package toml
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/pelletier/go-buffruneio"
) )
var dateRegexp *regexp.Regexp var dateRegexp *regexp.Regexp
@ -21,29 +23,29 @@ type tomlLexStateFn func() tomlLexStateFn
// Define lexer // Define lexer
type tomlLexer struct { type tomlLexer struct {
inputIdx int input *buffruneio.Reader // Textual source
input []rune // Textual source buffer []rune // Runes composing the current token
currentTokenStart int tokens chan token
currentTokenStop int depth int
tokens []token line int
depth int col int
line int endbufferLine int
col int endbufferCol int
endbufferLine int
endbufferCol int
} }
// Basic read operations on input // Basic read operations on input
func (l *tomlLexer) read() rune { func (l *tomlLexer) read() rune {
r := l.peek() r, _, err := l.input.ReadRune()
if err != nil {
panic(err)
}
if r == '\n' { if r == '\n' {
l.endbufferLine++ l.endbufferLine++
l.endbufferCol = 1 l.endbufferCol = 1
} else { } else {
l.endbufferCol++ l.endbufferCol++
} }
l.inputIdx++
return r return r
} }
@ -51,13 +53,13 @@ func (l *tomlLexer) next() rune {
r := l.read() r := l.read()
if r != eof { if r != eof {
l.currentTokenStop++ l.buffer = append(l.buffer, r)
} }
return r return r
} }
func (l *tomlLexer) ignore() { func (l *tomlLexer) ignore() {
l.currentTokenStart = l.currentTokenStop l.buffer = make([]rune, 0)
l.line = l.endbufferLine l.line = l.endbufferLine
l.col = l.endbufferCol l.col = l.endbufferCol
} }
@ -74,46 +76,49 @@ func (l *tomlLexer) fastForward(n int) {
} }
func (l *tomlLexer) emitWithValue(t tokenType, value string) { func (l *tomlLexer) emitWithValue(t tokenType, value string) {
l.tokens = append(l.tokens, token{ l.tokens <- token{
Position: Position{l.line, l.col}, Position: Position{l.line, l.col},
typ: t, typ: t,
val: value, val: value,
}) }
l.ignore() l.ignore()
} }
func (l *tomlLexer) emit(t tokenType) { func (l *tomlLexer) emit(t tokenType) {
l.emitWithValue(t, string(l.input[l.currentTokenStart:l.currentTokenStop])) l.emitWithValue(t, string(l.buffer))
} }
func (l *tomlLexer) peek() rune { func (l *tomlLexer) peek() rune {
if l.inputIdx >= len(l.input) { r, _, err := l.input.ReadRune()
return eof if err != nil {
panic(err)
} }
return l.input[l.inputIdx] l.input.UnreadRune()
} return r
func (l *tomlLexer) peekString(size int) string {
maxIdx := len(l.input)
upperIdx := l.inputIdx + size // FIXME: potential overflow
if upperIdx > maxIdx {
upperIdx = maxIdx
}
return string(l.input[l.inputIdx:upperIdx])
} }
func (l *tomlLexer) follow(next string) bool { func (l *tomlLexer) follow(next string) bool {
return next == l.peekString(len(next)) for _, expectedRune := range next {
r, _, err := l.input.ReadRune()
defer l.input.UnreadRune()
if err != nil {
panic(err)
}
if expectedRune != r {
return false
}
}
return true
} }
// Error management // Error management
func (l *tomlLexer) errorf(format string, args ...interface{}) tomlLexStateFn { func (l *tomlLexer) errorf(format string, args ...interface{}) tomlLexStateFn {
l.tokens = append(l.tokens, token{ l.tokens <- token{
Position: Position{l.line, l.col}, Position: Position{l.line, l.col},
typ: tokenError, typ: tokenError,
val: fmt.Sprintf(format, args...), val: fmt.Sprintf(format, args...),
}) }
return nil return nil
} }
@ -214,7 +219,7 @@ func (l *tomlLexer) lexRvalue() tomlLexStateFn {
break break
} }
possibleDate := l.peekString(35) possibleDate := string(l.input.PeekRunes(35))
dateMatch := dateRegexp.FindString(possibleDate) dateMatch := dateRegexp.FindString(possibleDate)
if dateMatch != "" { if dateMatch != "" {
l.fastForward(len(dateMatch)) l.fastForward(len(dateMatch))
@ -531,7 +536,7 @@ func (l *tomlLexer) lexInsideTableArrayKey() tomlLexStateFn {
for r := l.peek(); r != eof; r = l.peek() { for r := l.peek(); r != eof; r = l.peek() {
switch r { switch r {
case ']': case ']':
if l.currentTokenStop > l.currentTokenStart { if len(l.buffer) > 0 {
l.emit(tokenKeyGroupArray) l.emit(tokenKeyGroupArray)
} }
l.next() l.next()
@ -554,7 +559,7 @@ func (l *tomlLexer) lexInsideTableKey() tomlLexStateFn {
for r := l.peek(); r != eof; r = l.peek() { for r := l.peek(); r != eof; r = l.peek() {
switch r { switch r {
case ']': case ']':
if l.currentTokenStop > l.currentTokenStart { if len(l.buffer) > 0 {
l.emit(tokenKeyGroup) l.emit(tokenKeyGroup)
} }
l.next() l.next()
@ -629,6 +634,7 @@ func (l *tomlLexer) run() {
for state := l.lexVoid; state != nil; { for state := l.lexVoid; state != nil; {
state = state() state = state()
} }
close(l.tokens)
} }
func init() { func init() {
@ -636,16 +642,16 @@ func init() {
} }
// Entry point // Entry point
func lexToml(inputBytes []byte) []token { func lexToml(input io.Reader) chan token {
runes := bytes.Runes(inputBytes) bufferedInput := buffruneio.NewReader(input)
l := &tomlLexer{ l := &tomlLexer{
input: runes, input: bufferedInput,
tokens: make([]token, 0, 256), tokens: make(chan token),
line: 1, line: 1,
col: 1, col: 1,
endbufferLine: 1, endbufferLine: 1,
endbufferCol: 1, endbufferCol: 1,
} }
l.run() go l.run()
return l.tokens return l.tokens
} }

View File

@ -1,508 +0,0 @@
package toml
import (
"bytes"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
type tomlOpts struct {
name string
comment string
commented bool
include bool
omitempty bool
}
var timeType = reflect.TypeOf(time.Time{})
var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
// Check if the given marshall type maps to a Tree primitive
func isPrimitive(mtype reflect.Type) bool {
switch mtype.Kind() {
case reflect.Ptr:
return isPrimitive(mtype.Elem())
case reflect.Bool:
return true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
case reflect.String:
return true
case reflect.Struct:
return mtype == timeType || isCustomMarshaler(mtype)
default:
return false
}
}
// Check if the given marshall type maps to a Tree slice
func isTreeSlice(mtype reflect.Type) bool {
switch mtype.Kind() {
case reflect.Slice:
return !isOtherSlice(mtype)
default:
return false
}
}
// Check if the given marshall type maps to a non-Tree slice
func isOtherSlice(mtype reflect.Type) bool {
switch mtype.Kind() {
case reflect.Ptr:
return isOtherSlice(mtype.Elem())
case reflect.Slice:
return isPrimitive(mtype.Elem()) || isOtherSlice(mtype.Elem())
default:
return false
}
}
// Check if the given marshall type maps to a Tree
func isTree(mtype reflect.Type) bool {
switch mtype.Kind() {
case reflect.Map:
return true
case reflect.Struct:
return !isPrimitive(mtype)
default:
return false
}
}
func isCustomMarshaler(mtype reflect.Type) bool {
return mtype.Implements(marshalerType)
}
func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
return mval.Interface().(Marshaler).MarshalTOML()
}
// Marshaler is the interface implemented by types that
// can marshal themselves into valid TOML.
type Marshaler interface {
MarshalTOML() ([]byte, error)
}
/*
Marshal returns the TOML encoding of v. Behavior is similar to the Go json
encoder, except that there is no concept of a Marshaler interface or MarshalTOML
function for sub-structs, and currently only definite types can be marshaled
(i.e. no `interface{}`).
The following struct annotations are supported:
toml:"Field" Overrides the field's name to output.
omitempty When set, empty values and groups are not emitted.
comment:"comment" Emits a # comment on the same line. This supports new lines.
commented:"true" Emits the value as commented.
Note that pointers are automatically assigned the "omitempty" option, as TOML
explicity does not handle null values (saying instead the label should be
dropped).
Tree structural types and corresponding marshal types:
*Tree (*)struct, (*)map[string]interface{}
[]*Tree (*)[](*)struct, (*)[](*)map[string]interface{}
[]interface{} (as interface{}) (*)[]primitive, (*)[]([]interface{})
interface{} (*)primitive
Tree primitive types and corresponding marshal types:
uint64 uint, uint8-uint64, pointers to same
int64 int, int8-uint64, pointers to same
float64 float32, float64, pointers to same
string string, pointers to same
bool bool, pointers to same
time.Time time.Time{}, pointers to same
*/
func Marshal(v interface{}) ([]byte, error) {
mtype := reflect.TypeOf(v)
if mtype.Kind() != reflect.Struct {
return []byte{}, errors.New("Only a struct can be marshaled to TOML")
}
sval := reflect.ValueOf(v)
if isCustomMarshaler(mtype) {
return callCustomMarshaler(sval)
}
t, err := valueToTree(mtype, sval)
if err != nil {
return []byte{}, err
}
s, err := t.ToTomlString()
return []byte(s), err
}
// Convert given marshal struct or map value to toml tree
func valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) {
if mtype.Kind() == reflect.Ptr {
return valueToTree(mtype.Elem(), mval.Elem())
}
tval := newTree()
switch mtype.Kind() {
case reflect.Struct:
for i := 0; i < mtype.NumField(); i++ {
mtypef, mvalf := mtype.Field(i), mval.Field(i)
opts := tomlOptions(mtypef)
if opts.include && (!opts.omitempty || !isZero(mvalf)) {
val, err := valueToToml(mtypef.Type, mvalf)
if err != nil {
return nil, err
}
tval.Set(opts.name, opts.comment, opts.commented, val)
}
}
case reflect.Map:
for _, key := range mval.MapKeys() {
mvalf := mval.MapIndex(key)
val, err := valueToToml(mtype.Elem(), mvalf)
if err != nil {
return nil, err
}
tval.Set(key.String(), "", false, val)
}
}
return tval, nil
}
// Convert given marshal slice to slice of Toml trees
func valueToTreeSlice(mtype reflect.Type, mval reflect.Value) ([]*Tree, error) {
tval := make([]*Tree, mval.Len(), mval.Len())
for i := 0; i < mval.Len(); i++ {
val, err := valueToTree(mtype.Elem(), mval.Index(i))
if err != nil {
return nil, err
}
tval[i] = val
}
return tval, nil
}
// Convert given marshal slice to slice of toml values
func valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
tval := make([]interface{}, mval.Len(), mval.Len())
for i := 0; i < mval.Len(); i++ {
val, err := valueToToml(mtype.Elem(), mval.Index(i))
if err != nil {
return nil, err
}
tval[i] = val
}
return tval, nil
}
// Convert given marshal value to toml value
func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
if mtype.Kind() == reflect.Ptr {
return valueToToml(mtype.Elem(), mval.Elem())
}
switch {
case isCustomMarshaler(mtype):
return callCustomMarshaler(mval)
case isTree(mtype):
return valueToTree(mtype, mval)
case isTreeSlice(mtype):
return valueToTreeSlice(mtype, mval)
case isOtherSlice(mtype):
return valueToOtherSlice(mtype, mval)
default:
switch mtype.Kind() {
case reflect.Bool:
return mval.Bool(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return mval.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return mval.Uint(), nil
case reflect.Float32, reflect.Float64:
return mval.Float(), nil
case reflect.String:
return mval.String(), nil
case reflect.Struct:
return mval.Interface().(time.Time), nil
default:
return nil, fmt.Errorf("Marshal can't handle %v(%v)", mtype, mtype.Kind())
}
}
}
// Unmarshal attempts to unmarshal the Tree into a Go struct pointed by v.
// Neither Unmarshaler interfaces nor UnmarshalTOML functions are supported for
// sub-structs, and only definite types can be unmarshaled.
func (t *Tree) Unmarshal(v interface{}) error {
mtype := reflect.TypeOf(v)
if mtype.Kind() != reflect.Ptr || mtype.Elem().Kind() != reflect.Struct {
return errors.New("Only a pointer to struct can be unmarshaled from TOML")
}
sval, err := valueFromTree(mtype.Elem(), t)
if err != nil {
return err
}
reflect.ValueOf(v).Elem().Set(sval)
return nil
}
// Unmarshal parses the TOML-encoded data and stores the result in the value
// pointed to by v. Behavior is similar to the Go json encoder, except that there
// is no concept of an Unmarshaler interface or UnmarshalTOML function for
// sub-structs, and currently only definite types can be unmarshaled to (i.e. no
// `interface{}`).
//
// The following struct annotations are supported:
//
// toml:"Field" Overrides the field's name to map to.
//
// See Marshal() documentation for types mapping table.
func Unmarshal(data []byte, v interface{}) error {
t, err := LoadReader(bytes.NewReader(data))
if err != nil {
return err
}
return t.Unmarshal(v)
}
// Convert toml tree to marshal struct or map, using marshal type
func valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, error) {
if mtype.Kind() == reflect.Ptr {
return unwrapPointer(mtype, tval)
}
var mval reflect.Value
switch mtype.Kind() {
case reflect.Struct:
mval = reflect.New(mtype).Elem()
for i := 0; i < mtype.NumField(); i++ {
mtypef := mtype.Field(i)
opts := tomlOptions(mtypef)
if opts.include {
baseKey := opts.name
keysToTry := []string{baseKey, strings.ToLower(baseKey), strings.ToTitle(baseKey)}
for _, key := range keysToTry {
exists := tval.Has(key)
if !exists {
continue
}
val := tval.Get(key)
mvalf, err := valueFromToml(mtypef.Type, val)
if err != nil {
return mval, formatError(err, tval.GetPosition(key))
}
mval.Field(i).Set(mvalf)
break
}
}
}
case reflect.Map:
mval = reflect.MakeMap(mtype)
for _, key := range tval.Keys() {
val := tval.Get(key)
mvalf, err := valueFromToml(mtype.Elem(), val)
if err != nil {
return mval, formatError(err, tval.GetPosition(key))
}
mval.SetMapIndex(reflect.ValueOf(key), mvalf)
}
}
return mval, nil
}
// Convert toml value to marshal struct/map slice, using marshal type
func valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
mval := reflect.MakeSlice(mtype, len(tval), len(tval))
for i := 0; i < len(tval); i++ {
val, err := valueFromTree(mtype.Elem(), tval[i])
if err != nil {
return mval, err
}
mval.Index(i).Set(val)
}
return mval, nil
}
// Convert toml value to marshal primitive slice, using marshal type
func valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) {
mval := reflect.MakeSlice(mtype, len(tval), len(tval))
for i := 0; i < len(tval); i++ {
val, err := valueFromToml(mtype.Elem(), tval[i])
if err != nil {
return mval, err
}
mval.Index(i).Set(val)
}
return mval, nil
}
// Convert toml value to marshal value, using marshal type
func valueFromToml(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
if mtype.Kind() == reflect.Ptr {
return unwrapPointer(mtype, tval)
}
switch {
case isTree(mtype):
return valueFromTree(mtype, tval.(*Tree))
case isTreeSlice(mtype):
return valueFromTreeSlice(mtype, tval.([]*Tree))
case isOtherSlice(mtype):
return valueFromOtherSlice(mtype, tval.([]interface{}))
default:
switch mtype.Kind() {
case reflect.Bool:
val, ok := tval.(bool)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to bool", tval, tval)
}
return reflect.ValueOf(val), nil
case reflect.Int:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
}
return reflect.ValueOf(int(val)), nil
case reflect.Int8:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
}
return reflect.ValueOf(int8(val)), nil
case reflect.Int16:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
}
return reflect.ValueOf(int16(val)), nil
case reflect.Int32:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
}
return reflect.ValueOf(int32(val)), nil
case reflect.Int64:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
}
return reflect.ValueOf(val), nil
case reflect.Uint:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
}
return reflect.ValueOf(uint(val)), nil
case reflect.Uint8:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
}
return reflect.ValueOf(uint8(val)), nil
case reflect.Uint16:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
}
return reflect.ValueOf(uint16(val)), nil
case reflect.Uint32:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
}
return reflect.ValueOf(uint32(val)), nil
case reflect.Uint64:
val, ok := tval.(int64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
}
return reflect.ValueOf(uint64(val)), nil
case reflect.Float32:
val, ok := tval.(float64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval)
}
return reflect.ValueOf(float32(val)), nil
case reflect.Float64:
val, ok := tval.(float64)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval)
}
return reflect.ValueOf(val), nil
case reflect.String:
val, ok := tval.(string)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to string", tval, tval)
}
return reflect.ValueOf(val), nil
case reflect.Struct:
val, ok := tval.(time.Time)
if !ok {
return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to time", tval, tval)
}
return reflect.ValueOf(val), nil
default:
return reflect.ValueOf(nil), fmt.Errorf("Unmarshal can't handle %v(%v)", mtype, mtype.Kind())
}
}
}
func unwrapPointer(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
val, err := valueFromToml(mtype.Elem(), tval)
if err != nil {
return reflect.ValueOf(nil), err
}
mval := reflect.New(mtype.Elem())
mval.Elem().Set(val)
return mval, nil
}
func tomlOptions(vf reflect.StructField) tomlOpts {
tag := vf.Tag.Get("toml")
parse := strings.Split(tag, ",")
var comment string
if c := vf.Tag.Get("comment"); c != "" {
comment = c
}
commented, _ := strconv.ParseBool(vf.Tag.Get("commented"))
result := tomlOpts{name: vf.Name, comment: comment, commented: commented, include: true, omitempty: false}
if parse[0] != "" {
if parse[0] == "-" && len(parse) == 1 {
result.include = false
} else {
result.name = strings.Trim(parse[0], " ")
}
}
if vf.PkgPath != "" {
result.include = false
}
if len(parse) > 1 && strings.Trim(parse[1], " ") == "omitempty" {
result.omitempty = true
}
if vf.Type.Kind() == reflect.Ptr {
result.omitempty = true
}
return result
}
func isZero(val reflect.Value) bool {
switch val.Type().Kind() {
case reflect.Map:
fallthrough
case reflect.Array:
fallthrough
case reflect.Slice:
return val.Len() == 0
default:
return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface())
}
}
func formatError(err error, pos Position) error {
if err.Error()[0] == '(' { // Error already contains position information
return err
}
return fmt.Errorf("%s: %s", pos, err)
}

234
vendor/github.com/pelletier/go-toml/match.go generated vendored Normal file
View File

@ -0,0 +1,234 @@
package toml
import (
"fmt"
)
// support function to set positions for tomlValues
// NOTE: this is done to allow ctx.lastPosition to indicate the start of any
// values returned by the query engines
func tomlValueCheck(node interface{}, ctx *queryContext) interface{} {
switch castNode := node.(type) {
case *tomlValue:
ctx.lastPosition = castNode.position
return castNode.value
case []*TomlTree:
if len(castNode) > 0 {
ctx.lastPosition = castNode[0].position
}
return node
default:
return node
}
}
// base match
type matchBase struct {
next pathFn
}
func (f *matchBase) setNext(next pathFn) {
f.next = next
}
// terminating functor - gathers results
type terminatingFn struct {
// empty
}
func newTerminatingFn() *terminatingFn {
return &terminatingFn{}
}
func (f *terminatingFn) setNext(next pathFn) {
// do nothing
}
func (f *terminatingFn) call(node interface{}, ctx *queryContext) {
switch castNode := node.(type) {
case *TomlTree:
ctx.result.appendResult(node, castNode.position)
case *tomlValue:
ctx.result.appendResult(node, castNode.position)
default:
// use last position for scalars
ctx.result.appendResult(node, ctx.lastPosition)
}
}
// match single key
type matchKeyFn struct {
matchBase
Name string
}
func newMatchKeyFn(name string) *matchKeyFn {
return &matchKeyFn{Name: name}
}
func (f *matchKeyFn) call(node interface{}, ctx *queryContext) {
if array, ok := node.([]*TomlTree); ok {
for _, tree := range array {
item := tree.values[f.Name]
if item != nil {
f.next.call(item, ctx)
}
}
} else if tree, ok := node.(*TomlTree); ok {
item := tree.values[f.Name]
if item != nil {
f.next.call(item, ctx)
}
}
}
// match single index
type matchIndexFn struct {
matchBase
Idx int
}
func newMatchIndexFn(idx int) *matchIndexFn {
return &matchIndexFn{Idx: idx}
}
func (f *matchIndexFn) call(node interface{}, ctx *queryContext) {
if arr, ok := tomlValueCheck(node, ctx).([]interface{}); ok {
if f.Idx < len(arr) && f.Idx >= 0 {
f.next.call(arr[f.Idx], ctx)
}
}
}
// filter by slicing
type matchSliceFn struct {
matchBase
Start, End, Step int
}
func newMatchSliceFn(start, end, step int) *matchSliceFn {
return &matchSliceFn{Start: start, End: end, Step: step}
}
func (f *matchSliceFn) call(node interface{}, ctx *queryContext) {
if arr, ok := tomlValueCheck(node, ctx).([]interface{}); ok {
// adjust indexes for negative values, reverse ordering
realStart, realEnd := f.Start, f.End
if realStart < 0 {
realStart = len(arr) + realStart
}
if realEnd < 0 {
realEnd = len(arr) + realEnd
}
if realEnd < realStart {
realEnd, realStart = realStart, realEnd // swap
}
// loop and gather
for idx := realStart; idx < realEnd; idx += f.Step {
f.next.call(arr[idx], ctx)
}
}
}
// match anything
type matchAnyFn struct {
matchBase
}
func newMatchAnyFn() *matchAnyFn {
return &matchAnyFn{}
}
func (f *matchAnyFn) call(node interface{}, ctx *queryContext) {
if tree, ok := node.(*TomlTree); ok {
for _, v := range tree.values {
f.next.call(v, ctx)
}
}
}
// filter through union
type matchUnionFn struct {
Union []pathFn
}
func (f *matchUnionFn) setNext(next pathFn) {
for _, fn := range f.Union {
fn.setNext(next)
}
}
func (f *matchUnionFn) call(node interface{}, ctx *queryContext) {
for _, fn := range f.Union {
fn.call(node, ctx)
}
}
// match every single last node in the tree
type matchRecursiveFn struct {
matchBase
}
func newMatchRecursiveFn() *matchRecursiveFn {
return &matchRecursiveFn{}
}
func (f *matchRecursiveFn) call(node interface{}, ctx *queryContext) {
if tree, ok := node.(*TomlTree); ok {
var visit func(tree *TomlTree)
visit = func(tree *TomlTree) {
for _, v := range tree.values {
f.next.call(v, ctx)
switch node := v.(type) {
case *TomlTree:
visit(node)
case []*TomlTree:
for _, subtree := range node {
visit(subtree)
}
}
}
}
f.next.call(tree, ctx)
visit(tree)
}
}
// match based on an externally provided functional filter
type matchFilterFn struct {
matchBase
Pos Position
Name string
}
func newMatchFilterFn(name string, pos Position) *matchFilterFn {
return &matchFilterFn{Name: name, Pos: pos}
}
func (f *matchFilterFn) call(node interface{}, ctx *queryContext) {
fn, ok := (*ctx.filters)[f.Name]
if !ok {
panic(fmt.Sprintf("%s: query context does not have filter '%s'",
f.Pos.String(), f.Name))
}
switch castNode := tomlValueCheck(node, ctx).(type) {
case *TomlTree:
for _, v := range castNode.values {
if tv, ok := v.(*tomlValue); ok {
if fn(tv.value) {
f.next.call(v, ctx)
}
} else {
if fn(v) {
f.next.call(v, ctx)
}
}
}
case []interface{}:
for _, v := range castNode {
if fn(v) {
f.next.call(v, ctx)
}
}
}
}

View File

@ -13,9 +13,9 @@ import (
) )
type tomlParser struct { type tomlParser struct {
flowIdx int flow chan token
flow []token tree *TomlTree
tree *Tree tokensBuffer []token
currentTable []string currentTable []string
seenTableKeys []string seenTableKeys []string
} }
@ -34,10 +34,16 @@ func (p *tomlParser) run() {
} }
func (p *tomlParser) peek() *token { func (p *tomlParser) peek() *token {
if p.flowIdx >= len(p.flow) { if len(p.tokensBuffer) != 0 {
return &(p.tokensBuffer[0])
}
tok, ok := <-p.flow
if !ok {
return nil return nil
} }
return &p.flow[p.flowIdx] p.tokensBuffer = append(p.tokensBuffer, tok)
return &tok
} }
func (p *tomlParser) assume(typ tokenType) { func (p *tomlParser) assume(typ tokenType) {
@ -51,12 +57,16 @@ func (p *tomlParser) assume(typ tokenType) {
} }
func (p *tomlParser) getToken() *token { func (p *tomlParser) getToken() *token {
tok := p.peek() if len(p.tokensBuffer) != 0 {
if tok == nil { tok := p.tokensBuffer[0]
p.tokensBuffer = p.tokensBuffer[1:]
return &tok
}
tok, ok := <-p.flow
if !ok {
return nil return nil
} }
p.flowIdx++ return &tok
return tok
} }
func (p *tomlParser) parseStart() tomlParserStateFn { func (p *tomlParser) parseStart() tomlParserStateFn {
@ -96,21 +106,21 @@ func (p *tomlParser) parseGroupArray() tomlParserStateFn {
} }
p.tree.createSubTree(keys[:len(keys)-1], startToken.Position) // create parent entries p.tree.createSubTree(keys[:len(keys)-1], startToken.Position) // create parent entries
destTree := p.tree.GetPath(keys) destTree := p.tree.GetPath(keys)
var array []*Tree var array []*TomlTree
if destTree == nil { if destTree == nil {
array = make([]*Tree, 0) array = make([]*TomlTree, 0)
} else if target, ok := destTree.([]*Tree); ok && target != nil { } else if target, ok := destTree.([]*TomlTree); ok && target != nil {
array = destTree.([]*Tree) array = destTree.([]*TomlTree)
} else { } else {
p.raiseError(key, "key %s is already assigned and not of type table array", key) p.raiseError(key, "key %s is already assigned and not of type table array", key)
} }
p.currentTable = keys p.currentTable = keys
// add a new tree to the end of the table array // add a new tree to the end of the table array
newTree := newTree() newTree := newTomlTree()
newTree.position = startToken.Position newTree.position = startToken.Position
array = append(array, newTree) array = append(array, newTree)
p.tree.SetPath(p.currentTable, "", false, array) p.tree.SetPath(p.currentTable, array)
// remove all keys that were children of this table array // remove all keys that were children of this table array
prefix := key.val + "." prefix := key.val + "."
@ -173,11 +183,11 @@ func (p *tomlParser) parseAssign() tomlParserStateFn {
} }
// find the table to assign, looking out for arrays of tables // find the table to assign, looking out for arrays of tables
var targetNode *Tree var targetNode *TomlTree
switch node := p.tree.GetPath(tableKey).(type) { switch node := p.tree.GetPath(tableKey).(type) {
case []*Tree: case []*TomlTree:
targetNode = node[len(node)-1] targetNode = node[len(node)-1]
case *Tree: case *TomlTree:
targetNode = node targetNode = node
default: default:
p.raiseError(key, "Unknown table type for path: %s", p.raiseError(key, "Unknown table type for path: %s",
@ -202,10 +212,10 @@ func (p *tomlParser) parseAssign() tomlParserStateFn {
var toInsert interface{} var toInsert interface{}
switch value.(type) { switch value.(type) {
case *Tree, []*Tree: case *TomlTree, []*TomlTree:
toInsert = value toInsert = value
default: default:
toInsert = &tomlValue{value: value, position: key.Position} toInsert = &tomlValue{value, key.Position}
} }
targetNode.values[keyVal] = toInsert targetNode.values[keyVal] = toInsert
return p.parseStart return p.parseStart
@ -279,8 +289,8 @@ func tokenIsComma(t *token) bool {
return t != nil && t.typ == tokenComma return t != nil && t.typ == tokenComma
} }
func (p *tomlParser) parseInlineTable() *Tree { func (p *tomlParser) parseInlineTable() *TomlTree {
tree := newTree() tree := newTomlTree()
var previous *token var previous *token
Loop: Loop:
for { for {
@ -299,7 +309,7 @@ Loop:
key := p.getToken() key := p.getToken()
p.assume(tokenEqual) p.assume(tokenEqual)
value := p.parseRvalue() value := p.parseRvalue()
tree.Set(key.val, "", false, value) tree.Set(key.val, value)
case tokenComma: case tokenComma:
if previous == nil { if previous == nil {
p.raiseError(follow, "inline table cannot start with a comma") p.raiseError(follow, "inline table cannot start with a comma")
@ -350,27 +360,27 @@ func (p *tomlParser) parseArray() interface{} {
p.getToken() p.getToken()
} }
} }
// An array of Trees is actually an array of inline // An array of TomlTrees is actually an array of inline
// tables, which is a shorthand for a table array. If the // tables, which is a shorthand for a table array. If the
// array was not converted from []interface{} to []*Tree, // array was not converted from []interface{} to []*TomlTree,
// the two notations would not be equivalent. // the two notations would not be equivalent.
if arrayType == reflect.TypeOf(newTree()) { if arrayType == reflect.TypeOf(newTomlTree()) {
tomlArray := make([]*Tree, len(array)) tomlArray := make([]*TomlTree, len(array))
for i, v := range array { for i, v := range array {
tomlArray[i] = v.(*Tree) tomlArray[i] = v.(*TomlTree)
} }
return tomlArray return tomlArray
} }
return array return array
} }
func parseToml(flow []token) *Tree { func parseToml(flow chan token) *TomlTree {
result := newTree() result := newTomlTree()
result.position = Position{1, 1} result.position = Position{1, 1}
parser := &tomlParser{ parser := &tomlParser{
flowIdx: 0,
flow: flow, flow: flow,
tree: result, tree: result,
tokensBuffer: make([]token, 0),
currentTable: make([]string, 0), currentTable: make([]string, 0),
seenTableKeys: make([]string, 0), seenTableKeys: make([]string, 0),
} }

153
vendor/github.com/pelletier/go-toml/query.go generated vendored Normal file
View File

@ -0,0 +1,153 @@
package toml
import (
"time"
)
// NodeFilterFn represents a user-defined filter function, for use with
// Query.SetFilter().
//
// The return value of the function must indicate if 'node' is to be included
// at this stage of the TOML path. Returning true will include the node, and
// returning false will exclude it.
//
// NOTE: Care should be taken to write script callbacks such that they are safe
// to use from multiple goroutines.
type NodeFilterFn func(node interface{}) bool
// QueryResult is the result of Executing a Query.
type QueryResult struct {
items []interface{}
positions []Position
}
// appends a value/position pair to the result set.
func (r *QueryResult) appendResult(node interface{}, pos Position) {
r.items = append(r.items, node)
r.positions = append(r.positions, pos)
}
// Values is a set of values within a QueryResult. The order of values is not
// guaranteed to be in document order, and may be different each time a query is
// executed.
func (r QueryResult) Values() []interface{} {
values := make([]interface{}, len(r.items))
for i, v := range r.items {
o, ok := v.(*tomlValue)
if ok {
values[i] = o.value
} else {
values[i] = v
}
}
return values
}
// Positions is a set of positions for values within a QueryResult. Each index
// in Positions() corresponds to the entry in Value() of the same index.
func (r QueryResult) Positions() []Position {
return r.positions
}
// runtime context for executing query paths
type queryContext struct {
result *QueryResult
filters *map[string]NodeFilterFn
lastPosition Position
}
// generic path functor interface
type pathFn interface {
setNext(next pathFn)
call(node interface{}, ctx *queryContext)
}
// A Query is the representation of a compiled TOML path. A Query is safe
// for concurrent use by multiple goroutines.
type Query struct {
root pathFn
tail pathFn
filters *map[string]NodeFilterFn
}
func newQuery() *Query {
return &Query{
root: nil,
tail: nil,
filters: &defaultFilterFunctions,
}
}
func (q *Query) appendPath(next pathFn) {
if q.root == nil {
q.root = next
} else {
q.tail.setNext(next)
}
q.tail = next
next.setNext(newTerminatingFn()) // init the next functor
}
// CompileQuery compiles a TOML path expression. The returned Query can be used
// to match elements within a TomlTree and its descendants.
func CompileQuery(path string) (*Query, error) {
return parseQuery(lexQuery(path))
}
// Execute executes a query against a TomlTree, and returns the result of the query.
func (q *Query) Execute(tree *TomlTree) *QueryResult {
result := &QueryResult{
items: []interface{}{},
positions: []Position{},
}
if q.root == nil {
result.appendResult(tree, tree.GetPosition(""))
} else {
ctx := &queryContext{
result: result,
filters: q.filters,
}
q.root.call(tree, ctx)
}
return result
}
// SetFilter sets a user-defined filter function. These may be used inside
// "?(..)" query expressions to filter TOML document elements within a query.
func (q *Query) SetFilter(name string, fn NodeFilterFn) {
if q.filters == &defaultFilterFunctions {
// clone the static table
q.filters = &map[string]NodeFilterFn{}
for k, v := range defaultFilterFunctions {
(*q.filters)[k] = v
}
}
(*q.filters)[name] = fn
}
var defaultFilterFunctions = map[string]NodeFilterFn{
"tree": func(node interface{}) bool {
_, ok := node.(*TomlTree)
return ok
},
"int": func(node interface{}) bool {
_, ok := node.(int64)
return ok
},
"float": func(node interface{}) bool {
_, ok := node.(float64)
return ok
},
"string": func(node interface{}) bool {
_, ok := node.(string)
return ok
},
"time": func(node interface{}) bool {
_, ok := node.(time.Time)
return ok
},
"bool": func(node interface{}) bool {
_, ok := node.(bool)
return ok
},
}

356
vendor/github.com/pelletier/go-toml/querylexer.go generated vendored Normal file
View File

@ -0,0 +1,356 @@
// TOML JSONPath lexer.
//
// Written using the principles developed by Rob Pike in
// http://www.youtube.com/watch?v=HxaD_trXwRE
package toml
import (
"fmt"
"strconv"
"strings"
"unicode/utf8"
)
// Lexer state function
type queryLexStateFn func() queryLexStateFn
// Lexer definition
type queryLexer struct {
input string
start int
pos int
width int
tokens chan token
depth int
line int
col int
stringTerm string
}
func (l *queryLexer) run() {
for state := l.lexVoid; state != nil; {
state = state()
}
close(l.tokens)
}
func (l *queryLexer) nextStart() {
// iterate by runes (utf8 characters)
// search for newlines and advance line/col counts
for i := l.start; i < l.pos; {
r, width := utf8.DecodeRuneInString(l.input[i:])
if r == '\n' {
l.line++
l.col = 1
} else {
l.col++
}
i += width
}
// advance start position to next token
l.start = l.pos
}
func (l *queryLexer) emit(t tokenType) {
l.tokens <- token{
Position: Position{l.line, l.col},
typ: t,
val: l.input[l.start:l.pos],
}
l.nextStart()
}
func (l *queryLexer) emitWithValue(t tokenType, value string) {
l.tokens <- token{
Position: Position{l.line, l.col},
typ: t,
val: value,
}
l.nextStart()
}
func (l *queryLexer) next() rune {
if l.pos >= len(l.input) {
l.width = 0
return eof
}
var r rune
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
l.pos += l.width
return r
}
func (l *queryLexer) ignore() {
l.nextStart()
}
func (l *queryLexer) backup() {
l.pos -= l.width
}
func (l *queryLexer) errorf(format string, args ...interface{}) queryLexStateFn {
l.tokens <- token{
Position: Position{l.line, l.col},
typ: tokenError,
val: fmt.Sprintf(format, args...),
}
return nil
}
func (l *queryLexer) peek() rune {
r := l.next()
l.backup()
return r
}
func (l *queryLexer) accept(valid string) bool {
if strings.ContainsRune(valid, l.next()) {
return true
}
l.backup()
return false
}
func (l *queryLexer) follow(next string) bool {
return strings.HasPrefix(l.input[l.pos:], next)
}
func (l *queryLexer) lexVoid() queryLexStateFn {
for {
next := l.peek()
switch next {
case '$':
l.pos++
l.emit(tokenDollar)
continue
case '.':
if l.follow("..") {
l.pos += 2
l.emit(tokenDotDot)
} else {
l.pos++
l.emit(tokenDot)
}
continue
case '[':
l.pos++
l.emit(tokenLeftBracket)
continue
case ']':
l.pos++
l.emit(tokenRightBracket)
continue
case ',':
l.pos++
l.emit(tokenComma)
continue
case '*':
l.pos++
l.emit(tokenStar)
continue
case '(':
l.pos++
l.emit(tokenLeftParen)
continue
case ')':
l.pos++
l.emit(tokenRightParen)
continue
case '?':
l.pos++
l.emit(tokenQuestion)
continue
case ':':
l.pos++
l.emit(tokenColon)
continue
case '\'':
l.ignore()
l.stringTerm = string(next)
return l.lexString
case '"':
l.ignore()
l.stringTerm = string(next)
return l.lexString
}
if isSpace(next) {
l.next()
l.ignore()
continue
}
if isAlphanumeric(next) {
return l.lexKey
}
if next == '+' || next == '-' || isDigit(next) {
return l.lexNumber
}
if l.next() == eof {
break
}
return l.errorf("unexpected char: '%v'", next)
}
l.emit(tokenEOF)
return nil
}
func (l *queryLexer) lexKey() queryLexStateFn {
for {
next := l.peek()
if !isAlphanumeric(next) {
l.emit(tokenKey)
return l.lexVoid
}
if l.next() == eof {
break
}
}
l.emit(tokenEOF)
return nil
}
func (l *queryLexer) lexString() queryLexStateFn {
l.pos++
l.ignore()
growingString := ""
for {
if l.follow(l.stringTerm) {
l.emitWithValue(tokenString, growingString)
l.pos++
l.ignore()
return l.lexVoid
}
if l.follow("\\\"") {
l.pos++
growingString += "\""
} else if l.follow("\\'") {
l.pos++
growingString += "'"
} else if l.follow("\\n") {
l.pos++
growingString += "\n"
} else if l.follow("\\b") {
l.pos++
growingString += "\b"
} else if l.follow("\\f") {
l.pos++
growingString += "\f"
} else if l.follow("\\/") {
l.pos++
growingString += "/"
} else if l.follow("\\t") {
l.pos++
growingString += "\t"
} else if l.follow("\\r") {
l.pos++
growingString += "\r"
} else if l.follow("\\\\") {
l.pos++
growingString += "\\"
} else if l.follow("\\u") {
l.pos += 2
code := ""
for i := 0; i < 4; i++ {
c := l.peek()
l.pos++
if !isHexDigit(c) {
return l.errorf("unfinished unicode escape")
}
code = code + string(c)
}
l.pos--
intcode, err := strconv.ParseInt(code, 16, 32)
if err != nil {
return l.errorf("invalid unicode escape: \\u" + code)
}
growingString += string(rune(intcode))
} else if l.follow("\\U") {
l.pos += 2
code := ""
for i := 0; i < 8; i++ {
c := l.peek()
l.pos++
if !isHexDigit(c) {
return l.errorf("unfinished unicode escape")
}
code = code + string(c)
}
l.pos--
intcode, err := strconv.ParseInt(code, 16, 32)
if err != nil {
return l.errorf("invalid unicode escape: \\u" + code)
}
growingString += string(rune(intcode))
} else if l.follow("\\") {
l.pos++
return l.errorf("invalid escape sequence: \\" + string(l.peek()))
} else {
growingString += string(l.peek())
}
if l.next() == eof {
break
}
}
return l.errorf("unclosed string")
}
func (l *queryLexer) lexNumber() queryLexStateFn {
l.ignore()
if !l.accept("+") {
l.accept("-")
}
pointSeen := false
digitSeen := false
for {
next := l.next()
if next == '.' {
if pointSeen {
return l.errorf("cannot have two dots in one float")
}
if !isDigit(l.peek()) {
return l.errorf("float cannot end with a dot")
}
pointSeen = true
} else if isDigit(next) {
digitSeen = true
} else {
l.backup()
break
}
if pointSeen && !digitSeen {
return l.errorf("cannot start float with a dot")
}
}
if !digitSeen {
return l.errorf("no digit in that number")
}
if pointSeen {
l.emit(tokenFloat)
} else {
l.emit(tokenInteger)
}
return l.lexVoid
}
// Entry point
func lexQuery(input string) chan token {
l := &queryLexer{
input: input,
tokens: make(chan token),
line: 1,
col: 1,
}
go l.run()
return l.tokens
}

275
vendor/github.com/pelletier/go-toml/queryparser.go generated vendored Normal file
View File

@ -0,0 +1,275 @@
/*
Based on the "jsonpath" spec/concept.
http://goessner.net/articles/JsonPath/
https://code.google.com/p/json-path/
*/
package toml
import (
"fmt"
)
const maxInt = int(^uint(0) >> 1)
type queryParser struct {
flow chan token
tokensBuffer []token
query *Query
union []pathFn
err error
}
type queryParserStateFn func() queryParserStateFn
// Formats and panics an error message based on a token
func (p *queryParser) parseError(tok *token, msg string, args ...interface{}) queryParserStateFn {
p.err = fmt.Errorf(tok.Position.String()+": "+msg, args...)
return nil // trigger parse to end
}
func (p *queryParser) run() {
for state := p.parseStart; state != nil; {
state = state()
}
}
func (p *queryParser) backup(tok *token) {
p.tokensBuffer = append(p.tokensBuffer, *tok)
}
func (p *queryParser) peek() *token {
if len(p.tokensBuffer) != 0 {
return &(p.tokensBuffer[0])
}
tok, ok := <-p.flow
if !ok {
return nil
}
p.backup(&tok)
return &tok
}
func (p *queryParser) lookahead(types ...tokenType) bool {
result := true
buffer := []token{}
for _, typ := range types {
tok := p.getToken()
if tok == nil {
result = false
break
}
buffer = append(buffer, *tok)
if tok.typ != typ {
result = false
break
}
}
// add the tokens back to the buffer, and return
p.tokensBuffer = append(p.tokensBuffer, buffer...)
return result
}
func (p *queryParser) getToken() *token {
if len(p.tokensBuffer) != 0 {
tok := p.tokensBuffer[0]
p.tokensBuffer = p.tokensBuffer[1:]
return &tok
}
tok, ok := <-p.flow
if !ok {
return nil
}
return &tok
}
func (p *queryParser) parseStart() queryParserStateFn {
tok := p.getToken()
if tok == nil || tok.typ == tokenEOF {
return nil
}
if tok.typ != tokenDollar {
return p.parseError(tok, "Expected '$' at start of expression")
}
return p.parseMatchExpr
}
// handle '.' prefix, '[]', and '..'
func (p *queryParser) parseMatchExpr() queryParserStateFn {
tok := p.getToken()
switch tok.typ {
case tokenDotDot:
p.query.appendPath(&matchRecursiveFn{})
// nested parse for '..'
tok := p.getToken()
switch tok.typ {
case tokenKey:
p.query.appendPath(newMatchKeyFn(tok.val))
return p.parseMatchExpr
case tokenLeftBracket:
return p.parseBracketExpr
case tokenStar:
// do nothing - the recursive predicate is enough
return p.parseMatchExpr
}
case tokenDot:
// nested parse for '.'
tok := p.getToken()
switch tok.typ {
case tokenKey:
p.query.appendPath(newMatchKeyFn(tok.val))
return p.parseMatchExpr
case tokenStar:
p.query.appendPath(&matchAnyFn{})
return p.parseMatchExpr
}
case tokenLeftBracket:
return p.parseBracketExpr
case tokenEOF:
return nil // allow EOF at this stage
}
return p.parseError(tok, "expected match expression")
}
func (p *queryParser) parseBracketExpr() queryParserStateFn {
if p.lookahead(tokenInteger, tokenColon) {
return p.parseSliceExpr
}
if p.peek().typ == tokenColon {
return p.parseSliceExpr
}
return p.parseUnionExpr
}
func (p *queryParser) parseUnionExpr() queryParserStateFn {
var tok *token
// this state can be traversed after some sub-expressions
// so be careful when setting up state in the parser
if p.union == nil {
p.union = []pathFn{}
}
loop: // labeled loop for easy breaking
for {
if len(p.union) > 0 {
// parse delimiter or terminator
tok = p.getToken()
switch tok.typ {
case tokenComma:
// do nothing
case tokenRightBracket:
break loop
default:
return p.parseError(tok, "expected ',' or ']', not '%s'", tok.val)
}
}
// parse sub expression
tok = p.getToken()
switch tok.typ {
case tokenInteger:
p.union = append(p.union, newMatchIndexFn(tok.Int()))
case tokenKey:
p.union = append(p.union, newMatchKeyFn(tok.val))
case tokenString:
p.union = append(p.union, newMatchKeyFn(tok.val))
case tokenQuestion:
return p.parseFilterExpr
default:
return p.parseError(tok, "expected union sub expression, not '%s', %d", tok.val, len(p.union))
}
}
// if there is only one sub-expression, use that instead
if len(p.union) == 1 {
p.query.appendPath(p.union[0])
} else {
p.query.appendPath(&matchUnionFn{p.union})
}
p.union = nil // clear out state
return p.parseMatchExpr
}
func (p *queryParser) parseSliceExpr() queryParserStateFn {
// init slice to grab all elements
start, end, step := 0, maxInt, 1
// parse optional start
tok := p.getToken()
if tok.typ == tokenInteger {
start = tok.Int()
tok = p.getToken()
}
if tok.typ != tokenColon {
return p.parseError(tok, "expected ':'")
}
// parse optional end
tok = p.getToken()
if tok.typ == tokenInteger {
end = tok.Int()
tok = p.getToken()
}
if tok.typ == tokenRightBracket {
p.query.appendPath(newMatchSliceFn(start, end, step))
return p.parseMatchExpr
}
if tok.typ != tokenColon {
return p.parseError(tok, "expected ']' or ':'")
}
// parse optional step
tok = p.getToken()
if tok.typ == tokenInteger {
step = tok.Int()
if step < 0 {
return p.parseError(tok, "step must be a positive value")
}
tok = p.getToken()
}
if tok.typ != tokenRightBracket {
return p.parseError(tok, "expected ']'")
}
p.query.appendPath(newMatchSliceFn(start, end, step))
return p.parseMatchExpr
}
func (p *queryParser) parseFilterExpr() queryParserStateFn {
tok := p.getToken()
if tok.typ != tokenLeftParen {
return p.parseError(tok, "expected left-parenthesis for filter expression")
}
tok = p.getToken()
if tok.typ != tokenKey && tok.typ != tokenString {
return p.parseError(tok, "expected key or string for filter funciton name")
}
name := tok.val
tok = p.getToken()
if tok.typ != tokenRightParen {
return p.parseError(tok, "expected right-parenthesis for filter expression")
}
p.union = append(p.union, newMatchFilterFn(name, tok.Position))
return p.parseUnionExpr
}
func parseQuery(flow chan token) (*Query, error) {
parser := &queryParser{
flow: flow,
tokensBuffer: []token{},
query: newQuery(),
}
parser.run()
return parser.query, parser.err
}

View File

@ -135,6 +135,5 @@ func isDigit(r rune) bool {
func isHexDigit(r rune) bool { func isHexDigit(r rune) bool {
return isDigit(r) || return isDigit(r) ||
(r >= 'a' && r <= 'f') || r == 'A' || r == 'B' || r == 'C' || r == 'D' || r == 'E' || r == 'F'
(r >= 'A' && r <= 'F')
} }

View File

@ -4,50 +4,38 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"runtime" "runtime"
"strings" "strings"
) )
type tomlValue struct { type tomlValue struct {
value interface{} // string, int64, uint64, float64, bool, time.Time, [] of any of this list value interface{} // string, int64, uint64, float64, bool, time.Time, [] of any of this list
comment string position Position
commented bool
position Position
} }
// Tree is the result of the parsing of a TOML file. // TomlTree is the result of the parsing of a TOML file.
type Tree struct { type TomlTree struct {
values map[string]interface{} // string -> *tomlValue, *Tree, []*Tree values map[string]interface{} // string -> *tomlValue, *TomlTree, []*TomlTree
comment string position Position
commented bool
position Position
} }
func newTree() *Tree { func newTomlTree() *TomlTree {
return &Tree{ return &TomlTree{
values: make(map[string]interface{}), values: make(map[string]interface{}),
position: Position{}, position: Position{},
} }
} }
// TreeFromMap initializes a new Tree object using the given map. // TreeFromMap initializes a new TomlTree object using the given map.
func TreeFromMap(m map[string]interface{}) (*Tree, error) { func TreeFromMap(m map[string]interface{}) *TomlTree {
result, err := toTree(m) return &TomlTree{
if err != nil { values: m,
return nil, err
} }
return result.(*Tree), nil
}
// Position returns the position of the tree.
func (t *Tree) Position() Position {
return t.position
} }
// Has returns a boolean indicating if the given key exists. // Has returns a boolean indicating if the given key exists.
func (t *Tree) Has(key string) bool { func (t *TomlTree) Has(key string) bool {
if key == "" { if key == "" {
return false return false
} }
@ -55,26 +43,25 @@ func (t *Tree) Has(key string) bool {
} }
// HasPath returns true if the given path of keys exists, false otherwise. // HasPath returns true if the given path of keys exists, false otherwise.
func (t *Tree) HasPath(keys []string) bool { func (t *TomlTree) HasPath(keys []string) bool {
return t.GetPath(keys) != nil return t.GetPath(keys) != nil
} }
// Keys returns the keys of the toplevel tree (does not recurse). // Keys returns the keys of the toplevel tree.
func (t *Tree) Keys() []string { // Warning: this is a costly operation.
keys := make([]string, len(t.values)) func (t *TomlTree) Keys() []string {
i := 0 var keys []string
for k := range t.values { for k := range t.values {
keys[i] = k keys = append(keys, k)
i++
} }
return keys return keys
} }
// Get the value at key in the Tree. // Get the value at key in the TomlTree.
// Key is a dot-separated path (e.g. a.b.c). // Key is a dot-separated path (e.g. a.b.c).
// Returns nil if the path does not exist in the tree. // Returns nil if the path does not exist in the tree.
// If keys is of length zero, the current tree is returned. // If keys is of length zero, the current tree is returned.
func (t *Tree) Get(key string) interface{} { func (t *TomlTree) Get(key string) interface{} {
if key == "" { if key == "" {
return t return t
} }
@ -87,7 +74,7 @@ func (t *Tree) Get(key string) interface{} {
// GetPath returns the element in the tree indicated by 'keys'. // GetPath returns the element in the tree indicated by 'keys'.
// If keys is of length zero, the current tree is returned. // If keys is of length zero, the current tree is returned.
func (t *Tree) GetPath(keys []string) interface{} { func (t *TomlTree) GetPath(keys []string) interface{} {
if len(keys) == 0 { if len(keys) == 0 {
return t return t
} }
@ -98,9 +85,9 @@ func (t *Tree) GetPath(keys []string) interface{} {
return nil return nil
} }
switch node := value.(type) { switch node := value.(type) {
case *Tree: case *TomlTree:
subtree = node subtree = node
case []*Tree: case []*TomlTree:
// go to most recent element // go to most recent element
if len(node) == 0 { if len(node) == 0 {
return nil return nil
@ -120,7 +107,7 @@ func (t *Tree) GetPath(keys []string) interface{} {
} }
// GetPosition returns the position of the given key. // GetPosition returns the position of the given key.
func (t *Tree) GetPosition(key string) Position { func (t *TomlTree) GetPosition(key string) Position {
if key == "" { if key == "" {
return t.position return t.position
} }
@ -129,7 +116,7 @@ func (t *Tree) GetPosition(key string) Position {
// GetPositionPath returns the element in the tree indicated by 'keys'. // GetPositionPath returns the element in the tree indicated by 'keys'.
// If keys is of length zero, the current tree is returned. // If keys is of length zero, the current tree is returned.
func (t *Tree) GetPositionPath(keys []string) Position { func (t *TomlTree) GetPositionPath(keys []string) Position {
if len(keys) == 0 { if len(keys) == 0 {
return t.position return t.position
} }
@ -140,9 +127,9 @@ func (t *Tree) GetPositionPath(keys []string) Position {
return Position{0, 0} return Position{0, 0}
} }
switch node := value.(type) { switch node := value.(type) {
case *Tree: case *TomlTree:
subtree = node subtree = node
case []*Tree: case []*TomlTree:
// go to most recent element // go to most recent element
if len(node) == 0 { if len(node) == 0 {
return Position{0, 0} return Position{0, 0}
@ -156,9 +143,9 @@ func (t *Tree) GetPositionPath(keys []string) Position {
switch node := subtree.values[keys[len(keys)-1]].(type) { switch node := subtree.values[keys[len(keys)-1]].(type) {
case *tomlValue: case *tomlValue:
return node.position return node.position
case *Tree: case *TomlTree:
return node.position return node.position
case []*Tree: case []*TomlTree:
// go to most recent element // go to most recent element
if len(node) == 0 { if len(node) == 0 {
return Position{0, 0} return Position{0, 0}
@ -170,7 +157,7 @@ func (t *Tree) GetPositionPath(keys []string) Position {
} }
// GetDefault works like Get but with a default value // GetDefault works like Get but with a default value
func (t *Tree) GetDefault(key string, def interface{}) interface{} { func (t *TomlTree) GetDefault(key string, def interface{}) interface{} {
val := t.Get(key) val := t.Get(key)
if val == nil { if val == nil {
return def return def
@ -180,30 +167,30 @@ func (t *Tree) GetDefault(key string, def interface{}) interface{} {
// Set an element in the tree. // Set an element in the tree.
// Key is a dot-separated path (e.g. a.b.c). // Key is a dot-separated path (e.g. a.b.c).
// Creates all necessary intermediate trees, if needed. // Creates all necessary intermediates trees, if needed.
func (t *Tree) Set(key string, comment string, commented bool, value interface{}) { func (t *TomlTree) Set(key string, value interface{}) {
t.SetPath(strings.Split(key, "."), comment, commented, value) t.SetPath(strings.Split(key, "."), value)
} }
// SetPath sets an element in the tree. // SetPath sets an element in the tree.
// Keys is an array of path elements (e.g. {"a","b","c"}). // Keys is an array of path elements (e.g. {"a","b","c"}).
// Creates all necessary intermediate trees, if needed. // Creates all necessary intermediates trees, if needed.
func (t *Tree) SetPath(keys []string, comment string, commented bool, value interface{}) { func (t *TomlTree) SetPath(keys []string, value interface{}) {
subtree := t subtree := t
for _, intermediateKey := range keys[:len(keys)-1] { for _, intermediateKey := range keys[:len(keys)-1] {
nextTree, exists := subtree.values[intermediateKey] nextTree, exists := subtree.values[intermediateKey]
if !exists { if !exists {
nextTree = newTree() nextTree = newTomlTree()
subtree.values[intermediateKey] = nextTree // add new element here subtree.values[intermediateKey] = nextTree // add new element here
} }
switch node := nextTree.(type) { switch node := nextTree.(type) {
case *Tree: case *TomlTree:
subtree = node subtree = node
case []*Tree: case []*TomlTree:
// go to most recent element // go to most recent element
if len(node) == 0 { if len(node) == 0 {
// create element if it does not exist // create element if it does not exist
subtree.values[intermediateKey] = append(node, newTree()) subtree.values[intermediateKey] = append(node, newTomlTree())
} }
subtree = node[len(node)-1] subtree = node[len(node)-1]
} }
@ -212,18 +199,14 @@ func (t *Tree) SetPath(keys []string, comment string, commented bool, value inte
var toInsert interface{} var toInsert interface{}
switch value.(type) { switch value.(type) {
case *Tree: case *TomlTree:
tt := value.(*Tree)
tt.comment = comment
toInsert = value toInsert = value
case []*Tree: case []*TomlTree:
toInsert = value toInsert = value
case *tomlValue: case *tomlValue:
tt := value.(*tomlValue) toInsert = value
tt.comment = comment
toInsert = tt
default: default:
toInsert = &tomlValue{value: value, comment: comment, commented: commented} toInsert = &tomlValue{value: value}
} }
subtree.values[keys[len(keys)-1]] = toInsert subtree.values[keys[len(keys)-1]] = toInsert
@ -236,21 +219,21 @@ func (t *Tree) SetPath(keys []string, comment string, commented bool, value inte
// and tree[a][b][c] // and tree[a][b][c]
// //
// Returns nil on success, error object on failure // Returns nil on success, error object on failure
func (t *Tree) createSubTree(keys []string, pos Position) error { func (t *TomlTree) createSubTree(keys []string, pos Position) error {
subtree := t subtree := t
for _, intermediateKey := range keys { for _, intermediateKey := range keys {
nextTree, exists := subtree.values[intermediateKey] nextTree, exists := subtree.values[intermediateKey]
if !exists { if !exists {
tree := newTree() tree := newTomlTree()
tree.position = pos tree.position = pos
subtree.values[intermediateKey] = tree subtree.values[intermediateKey] = tree
nextTree = tree nextTree = tree
} }
switch node := nextTree.(type) { switch node := nextTree.(type) {
case []*Tree: case []*TomlTree:
subtree = node[len(node)-1] subtree = node[len(node)-1]
case *Tree: case *TomlTree:
subtree = node subtree = node
default: default:
return fmt.Errorf("unknown type for path %s (%s): %T (%#v)", return fmt.Errorf("unknown type for path %s (%s): %T (%#v)",
@ -260,8 +243,17 @@ func (t *Tree) createSubTree(keys []string, pos Position) error {
return nil return nil
} }
// LoadBytes creates a Tree from a []byte. // Query compiles and executes a query on a tree and returns the query result.
func LoadBytes(b []byte) (tree *Tree, err error) { func (t *TomlTree) Query(query string) (*QueryResult, error) {
q, err := CompileQuery(query)
if err != nil {
return nil, err
}
return q.Execute(t), nil
}
// LoadReader creates a TomlTree from any io.Reader.
func LoadReader(reader io.Reader) (tree *TomlTree, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok { if _, ok := r.(runtime.Error); ok {
@ -270,27 +262,17 @@ func LoadBytes(b []byte) (tree *Tree, err error) {
err = errors.New(r.(string)) err = errors.New(r.(string))
} }
}() }()
tree = parseToml(lexToml(b)) tree = parseToml(lexToml(reader))
return return
} }
// LoadReader creates a Tree from any io.Reader. // Load creates a TomlTree from a string.
func LoadReader(reader io.Reader) (tree *Tree, err error) { func Load(content string) (tree *TomlTree, err error) {
inputBytes, err := ioutil.ReadAll(reader) return LoadReader(strings.NewReader(content))
if err != nil {
return
}
tree, err = LoadBytes(inputBytes)
return
} }
// Load creates a Tree from a string. // LoadFile creates a TomlTree from a file.
func Load(content string) (tree *Tree, err error) { func LoadFile(path string) (tree *TomlTree, err error) {
return LoadBytes([]byte(content))
}
// LoadFile creates a Tree from a file.
func LoadFile(path string) (tree *Tree, err error) {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -1,142 +0,0 @@
package toml
import (
"fmt"
"reflect"
"time"
)
var kindToType = [reflect.String + 1]reflect.Type{
reflect.Bool: reflect.TypeOf(true),
reflect.String: reflect.TypeOf(""),
reflect.Float32: reflect.TypeOf(float64(1)),
reflect.Float64: reflect.TypeOf(float64(1)),
reflect.Int: reflect.TypeOf(int64(1)),
reflect.Int8: reflect.TypeOf(int64(1)),
reflect.Int16: reflect.TypeOf(int64(1)),
reflect.Int32: reflect.TypeOf(int64(1)),
reflect.Int64: reflect.TypeOf(int64(1)),
reflect.Uint: reflect.TypeOf(uint64(1)),
reflect.Uint8: reflect.TypeOf(uint64(1)),
reflect.Uint16: reflect.TypeOf(uint64(1)),
reflect.Uint32: reflect.TypeOf(uint64(1)),
reflect.Uint64: reflect.TypeOf(uint64(1)),
}
// typeFor returns a reflect.Type for a reflect.Kind, or nil if none is found.
// supported values:
// string, bool, int64, uint64, float64, time.Time, int, int8, int16, int32, uint, uint8, uint16, uint32, float32
func typeFor(k reflect.Kind) reflect.Type {
if k > 0 && int(k) < len(kindToType) {
return kindToType[k]
}
return nil
}
func simpleValueCoercion(object interface{}) (interface{}, error) {
switch original := object.(type) {
case string, bool, int64, uint64, float64, time.Time:
return original, nil
case int:
return int64(original), nil
case int8:
return int64(original), nil
case int16:
return int64(original), nil
case int32:
return int64(original), nil
case uint:
return uint64(original), nil
case uint8:
return uint64(original), nil
case uint16:
return uint64(original), nil
case uint32:
return uint64(original), nil
case float32:
return float64(original), nil
case fmt.Stringer:
return original.String(), nil
default:
return nil, fmt.Errorf("cannot convert type %T to Tree", object)
}
}
func sliceToTree(object interface{}) (interface{}, error) {
// arrays are a bit tricky, since they can represent either a
// collection of simple values, which is represented by one
// *tomlValue, or an array of tables, which is represented by an
// array of *Tree.
// holding the assumption that this function is called from toTree only when value.Kind() is Array or Slice
value := reflect.ValueOf(object)
insideType := value.Type().Elem()
length := value.Len()
if length > 0 {
insideType = reflect.ValueOf(value.Index(0).Interface()).Type()
}
if insideType.Kind() == reflect.Map {
// this is considered as an array of tables
tablesArray := make([]*Tree, 0, length)
for i := 0; i < length; i++ {
table := value.Index(i)
tree, err := toTree(table.Interface())
if err != nil {
return nil, err
}
tablesArray = append(tablesArray, tree.(*Tree))
}
return tablesArray, nil
}
sliceType := typeFor(insideType.Kind())
if sliceType == nil {
sliceType = insideType
}
arrayValue := reflect.MakeSlice(reflect.SliceOf(sliceType), 0, length)
for i := 0; i < length; i++ {
val := value.Index(i).Interface()
simpleValue, err := simpleValueCoercion(val)
if err != nil {
return nil, err
}
arrayValue = reflect.Append(arrayValue, reflect.ValueOf(simpleValue))
}
return &tomlValue{value: arrayValue.Interface(), position: Position{}}, nil
}
func toTree(object interface{}) (interface{}, error) {
value := reflect.ValueOf(object)
if value.Kind() == reflect.Map {
values := map[string]interface{}{}
keys := value.MapKeys()
for _, key := range keys {
if key.Kind() != reflect.String {
if _, ok := key.Interface().(string); !ok {
return nil, fmt.Errorf("map key needs to be a string, not %T (%v)", key.Interface(), key.Kind())
}
}
v := value.MapIndex(key)
newValue, err := toTree(v.Interface())
if err != nil {
return nil, err
}
values[key.String()] = newValue
}
return &Tree{values: values, position: Position{}}, nil
}
if value.Kind() == reflect.Array || value.Kind() == reflect.Slice {
return sliceToTree(object)
}
simpleValue, err := simpleValueCoercion(object)
if err != nil {
return nil, err
}
return &tomlValue{value: simpleValue, position: Position{}}, nil
}

View File

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"math"
"reflect"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -14,34 +12,33 @@ import (
// encodes a string to a TOML-compliant string value // encodes a string to a TOML-compliant string value
func encodeTomlString(value string) string { func encodeTomlString(value string) string {
var b bytes.Buffer result := ""
for _, rr := range value { for _, rr := range value {
switch rr { switch rr {
case '\b': case '\b':
b.WriteString(`\b`) result += "\\b"
case '\t': case '\t':
b.WriteString(`\t`) result += "\\t"
case '\n': case '\n':
b.WriteString(`\n`) result += "\\n"
case '\f': case '\f':
b.WriteString(`\f`) result += "\\f"
case '\r': case '\r':
b.WriteString(`\r`) result += "\\r"
case '"': case '"':
b.WriteString(`\"`) result += "\\\""
case '\\': case '\\':
b.WriteString(`\\`) result += "\\\\"
default: default:
intRr := uint16(rr) intRr := uint16(rr)
if intRr < 0x001F { if intRr < 0x001F {
b.WriteString(fmt.Sprintf("\\u%0.4X", intRr)) result += fmt.Sprintf("\\u%0.4X", intRr)
} else { } else {
b.WriteRune(rr) result += string(rr)
} }
} }
} }
return b.String() return result
} }
func tomlValueStringRepresentation(v interface{}) (string, error) { func tomlValueStringRepresentation(v interface{}) (string, error) {
@ -51,17 +48,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
case int64: case int64:
return strconv.FormatInt(value, 10), nil return strconv.FormatInt(value, 10), nil
case float64: case float64:
// Ensure a round float does contain a decimal point. Otherwise feeding
// the output back to the parser would convert to an integer.
if math.Trunc(value) == value {
return strconv.FormatFloat(value, 'f', 1, 32), nil
}
return strconv.FormatFloat(value, 'f', -1, 32), nil return strconv.FormatFloat(value, 'f', -1, 32), nil
case string: case string:
return "\"" + encodeTomlString(value) + "\"", nil return "\"" + encodeTomlString(value) + "\"", nil
case []byte:
b, _ := v.([]byte)
return tomlValueStringRepresentation(string(b))
case bool: case bool:
if value { if value {
return "true", nil return "true", nil
@ -71,14 +60,9 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
return value.Format(time.RFC3339), nil return value.Format(time.RFC3339), nil
case nil: case nil:
return "", nil return "", nil
} case []interface{}:
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Slice {
values := []string{} values := []string{}
for i := 0; i < rv.Len(); i++ { for _, item := range value {
item := rv.Index(i).Interface()
itemRepr, err := tomlValueStringRepresentation(item) itemRepr, err := tomlValueStringRepresentation(item)
if err != nil { if err != nil {
return "", err return "", err
@ -86,18 +70,19 @@ func tomlValueStringRepresentation(v interface{}) (string, error) {
values = append(values, itemRepr) values = append(values, itemRepr)
} }
return "[" + strings.Join(values, ",") + "]", nil return "[" + strings.Join(values, ",") + "]", nil
default:
return "", fmt.Errorf("unsupported value type %T: %v", value, value)
} }
return "", fmt.Errorf("unsupported value type %T: %v", v, v)
} }
func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (int64, error) { func (t *TomlTree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (int64, error) {
simpleValuesKeys := make([]string, 0) simpleValuesKeys := make([]string, 0)
complexValuesKeys := make([]string, 0) complexValuesKeys := make([]string, 0)
for k := range t.values { for k := range t.values {
v := t.values[k] v := t.values[k]
switch v.(type) { switch v.(type) {
case *Tree, []*Tree: case *TomlTree, []*TomlTree:
complexValuesKeys = append(complexValuesKeys, k) complexValuesKeys = append(complexValuesKeys, k)
default: default:
simpleValuesKeys = append(simpleValuesKeys, k) simpleValuesKeys = append(simpleValuesKeys, k)
@ -110,7 +95,7 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (
for _, k := range simpleValuesKeys { for _, k := range simpleValuesKeys {
v, ok := t.values[k].(*tomlValue) v, ok := t.values[k].(*tomlValue)
if !ok { if !ok {
return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k]) return bytesCount, fmt.Errorf("invalid key type at %s: %T", k, t.values[k])
} }
repr, err := tomlValueStringRepresentation(v.value) repr, err := tomlValueStringRepresentation(v.value)
@ -118,24 +103,8 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (
return bytesCount, err return bytesCount, err
} }
if v.comment != "" { kvRepr := fmt.Sprintf("%s%s = %s\n", indent, k, repr)
comment := strings.Replace(v.comment, "\n", "\n"+indent+"#", -1) writtenBytesCount, err := w.Write([]byte(kvRepr))
start := "# "
if strings.HasPrefix(comment, "#") {
start = ""
}
writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment, "\n")
bytesCount += int64(writtenBytesCountComment)
if errc != nil {
return bytesCount, errc
}
}
var commented string
if v.commented {
commented = "# "
}
writtenBytesCount, err := writeStrings(w, indent, commented, k, " = ", repr, "\n")
bytesCount += int64(writtenBytesCount) bytesCount += int64(writtenBytesCount)
if err != nil { if err != nil {
return bytesCount, err return bytesCount, err
@ -149,31 +118,12 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (
if keyspace != "" { if keyspace != "" {
combinedKey = keyspace + "." + combinedKey combinedKey = keyspace + "." + combinedKey
} }
var commented string
if t.commented {
commented = "# "
}
switch node := v.(type) { switch node := v.(type) {
// node has to be of those two types given how keys are sorted above // node has to be of those two types given how keys are sorted above
case *Tree: case *TomlTree:
tv, ok := t.values[k].(*Tree) tableName := fmt.Sprintf("\n%s[%s]\n", indent, combinedKey)
if !ok { writtenBytesCount, err := w.Write([]byte(tableName))
return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k])
}
if tv.comment != "" {
comment := strings.Replace(tv.comment, "\n", "\n"+indent+"#", -1)
start := "# "
if strings.HasPrefix(comment, "#") {
start = ""
}
writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment)
bytesCount += int64(writtenBytesCountComment)
if errc != nil {
return bytesCount, errc
}
}
writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[", combinedKey, "]\n")
bytesCount += int64(writtenBytesCount) bytesCount += int64(writtenBytesCount)
if err != nil { if err != nil {
return bytesCount, err return bytesCount, err
@ -182,17 +132,20 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (
if err != nil { if err != nil {
return bytesCount, err return bytesCount, err
} }
case []*Tree: case []*TomlTree:
for _, subTree := range node { for _, subTree := range node {
writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[[", combinedKey, "]]\n") if len(subTree.values) > 0 {
bytesCount += int64(writtenBytesCount) tableArrayName := fmt.Sprintf("\n%s[[%s]]\n", indent, combinedKey)
if err != nil { writtenBytesCount, err := w.Write([]byte(tableArrayName))
return bytesCount, err bytesCount += int64(writtenBytesCount)
} if err != nil {
return bytesCount, err
}
bytesCount, err = subTree.writeTo(w, indent+" ", combinedKey, bytesCount) bytesCount, err = subTree.writeTo(w, indent+" ", combinedKey, bytesCount)
if err != nil { if err != nil {
return bytesCount, err return bytesCount, err
}
} }
} }
} }
@ -201,28 +154,16 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) (
return bytesCount, nil return bytesCount, nil
} }
func writeStrings(w io.Writer, s ...string) (int, error) { // WriteTo encode the TomlTree as Toml and writes it to the writer w.
var n int
for i := range s {
b, err := io.WriteString(w, s[i])
n += b
if err != nil {
return n, err
}
}
return n, nil
}
// WriteTo encode the Tree as Toml and writes it to the writer w.
// Returns the number of bytes written in case of success, or an error if anything happened. // Returns the number of bytes written in case of success, or an error if anything happened.
func (t *Tree) WriteTo(w io.Writer) (int64, error) { func (t *TomlTree) WriteTo(w io.Writer) (int64, error) {
return t.writeTo(w, "", "", 0) return t.writeTo(w, "", "", 0)
} }
// ToTomlString generates a human-readable representation of the current tree. // ToTomlString generates a human-readable representation of the current tree.
// Output spans multiple lines, and is suitable for ingest by a TOML parser. // Output spans multiple lines, and is suitable for ingest by a TOML parser.
// If the conversion cannot be performed, ToString returns a non-nil error. // If the conversion cannot be performed, ToString returns a non-nil error.
func (t *Tree) ToTomlString() (string, error) { func (t *TomlTree) ToTomlString() (string, error) {
var buf bytes.Buffer var buf bytes.Buffer
_, err := t.WriteTo(&buf) _, err := t.WriteTo(&buf)
if err != nil { if err != nil {
@ -233,35 +174,36 @@ func (t *Tree) ToTomlString() (string, error) {
// String generates a human-readable representation of the current tree. // String generates a human-readable representation of the current tree.
// Alias of ToString. Present to implement the fmt.Stringer interface. // Alias of ToString. Present to implement the fmt.Stringer interface.
func (t *Tree) String() string { func (t *TomlTree) String() string {
result, _ := t.ToTomlString() result, _ := t.ToTomlString()
return result return result
} }
// ToMap recursively generates a representation of the tree using Go built-in structures. // ToMap recursively generates a representation of the tree using Go built-in structures.
// The following types are used: // The following types are used:
// // * uint64
// * bool // * int64
// * float64 // * bool
// * int64 // * string
// * string // * time.Time
// * uint64 // * map[string]interface{} (where interface{} is any of this list)
// * time.Time // * []interface{} (where interface{} is any of this list)
// * map[string]interface{} (where interface{} is any of this list) func (t *TomlTree) ToMap() map[string]interface{} {
// * []interface{} (where interface{} is any of this list)
func (t *Tree) ToMap() map[string]interface{} {
result := map[string]interface{}{} result := map[string]interface{}{}
for k, v := range t.values { for k, v := range t.values {
switch node := v.(type) { switch node := v.(type) {
case []*Tree: case []*TomlTree:
var array []interface{} var array []interface{}
for _, item := range node { for _, item := range node {
array = append(array, item.ToMap()) array = append(array, item.ToMap())
} }
result[k] = array result[k] = array
case *Tree: case *TomlTree:
result[k] = node.ToMap() result[k] = node.ToMap()
case map[string]interface{}:
sub := TreeFromMap(node)
result[k] = sub.ToMap()
case *tomlValue: case *tomlValue:
result[k] = node.value result[k] = node.value
} }

View File

@ -79,14 +79,6 @@ func (f Frame) Format(s fmt.State, verb rune) {
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). // StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) { func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb { switch verb {
case 'v': case 'v':

View File

@ -52,7 +52,7 @@ func validateBasePathName(name string) error {
// On Windows a common mistake would be to provide an absolute OS path // On Windows a common mistake would be to provide an absolute OS path
// We could strip out the base part, but that would not be very portable. // We could strip out the base part, but that would not be very portable.
if filepath.IsAbs(name) { if filepath.IsAbs(name) {
return &os.PathError{Op: "realPath", Path: name, Err: errors.New("got a real OS path instead of a virtual")} return &os.PathError{"realPath", name, errors.New("got a real OS path instead of a virtual")}
} }
return nil return nil
@ -60,14 +60,14 @@ func validateBasePathName(name string) error {
func (b *BasePathFs) Chtimes(name string, atime, mtime time.Time) (err error) { func (b *BasePathFs) Chtimes(name string, atime, mtime time.Time) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "chtimes", Path: name, Err: err} return &os.PathError{"chtimes", name, err}
} }
return b.source.Chtimes(name, atime, mtime) return b.source.Chtimes(name, atime, mtime)
} }
func (b *BasePathFs) Chmod(name string, mode os.FileMode) (err error) { func (b *BasePathFs) Chmod(name string, mode os.FileMode) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "chmod", Path: name, Err: err} return &os.PathError{"chmod", name, err}
} }
return b.source.Chmod(name, mode) return b.source.Chmod(name, mode)
} }
@ -78,66 +78,66 @@ func (b *BasePathFs) Name() string {
func (b *BasePathFs) Stat(name string) (fi os.FileInfo, err error) { func (b *BasePathFs) Stat(name string) (fi os.FileInfo, err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return nil, &os.PathError{Op: "stat", Path: name, Err: err} return nil, &os.PathError{"stat", name, err}
} }
return b.source.Stat(name) return b.source.Stat(name)
} }
func (b *BasePathFs) Rename(oldname, newname string) (err error) { func (b *BasePathFs) Rename(oldname, newname string) (err error) {
if oldname, err = b.RealPath(oldname); err != nil { if oldname, err = b.RealPath(oldname); err != nil {
return &os.PathError{Op: "rename", Path: oldname, Err: err} return &os.PathError{"rename", oldname, err}
} }
if newname, err = b.RealPath(newname); err != nil { if newname, err = b.RealPath(newname); err != nil {
return &os.PathError{Op: "rename", Path: newname, Err: err} return &os.PathError{"rename", newname, err}
} }
return b.source.Rename(oldname, newname) return b.source.Rename(oldname, newname)
} }
func (b *BasePathFs) RemoveAll(name string) (err error) { func (b *BasePathFs) RemoveAll(name string) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "remove_all", Path: name, Err: err} return &os.PathError{"remove_all", name, err}
} }
return b.source.RemoveAll(name) return b.source.RemoveAll(name)
} }
func (b *BasePathFs) Remove(name string) (err error) { func (b *BasePathFs) Remove(name string) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "remove", Path: name, Err: err} return &os.PathError{"remove", name, err}
} }
return b.source.Remove(name) return b.source.Remove(name)
} }
func (b *BasePathFs) OpenFile(name string, flag int, mode os.FileMode) (f File, err error) { func (b *BasePathFs) OpenFile(name string, flag int, mode os.FileMode) (f File, err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return nil, &os.PathError{Op: "openfile", Path: name, Err: err} return nil, &os.PathError{"openfile", name, err}
} }
return b.source.OpenFile(name, flag, mode) return b.source.OpenFile(name, flag, mode)
} }
func (b *BasePathFs) Open(name string) (f File, err error) { func (b *BasePathFs) Open(name string) (f File, err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: err} return nil, &os.PathError{"open", name, err}
} }
return b.source.Open(name) return b.source.Open(name)
} }
func (b *BasePathFs) Mkdir(name string, mode os.FileMode) (err error) { func (b *BasePathFs) Mkdir(name string, mode os.FileMode) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "mkdir", Path: name, Err: err} return &os.PathError{"mkdir", name, err}
} }
return b.source.Mkdir(name, mode) return b.source.Mkdir(name, mode)
} }
func (b *BasePathFs) MkdirAll(name string, mode os.FileMode) (err error) { func (b *BasePathFs) MkdirAll(name string, mode os.FileMode) (err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return &os.PathError{Op: "mkdir", Path: name, Err: err} return &os.PathError{"mkdir", name, err}
} }
return b.source.MkdirAll(name, mode) return b.source.MkdirAll(name, mode)
} }
func (b *BasePathFs) Create(name string) (f File, err error) { func (b *BasePathFs) Create(name string) (f File, err error) {
if name, err = b.RealPath(name); err != nil { if name, err = b.RealPath(name); err != nil {
return nil, &os.PathError{Op: "create", Path: name, Err: err} return nil, &os.PathError{"create", name, err}
} }
return b.source.Create(name) return b.source.Create(name)
} }

View File

@ -64,10 +64,15 @@ func (u *CacheOnReadFs) cacheStatus(name string) (state cacheState, fi os.FileIn
return cacheHit, lfi, nil return cacheHit, lfi, nil
} }
if err == syscall.ENOENT || os.IsNotExist(err) { if err == syscall.ENOENT {
return cacheMiss, nil, nil return cacheMiss, nil, nil
} }
var ok bool
if err, ok = err.(*os.PathError); ok {
if err == os.ErrNotExist {
return cacheMiss, nil, nil
}
}
return cacheMiss, nil, err return cacheMiss, nil, err
} }

View File

@ -1,110 +0,0 @@
// Copyright © 2014 Steve Francia <spf@spf13.com>.
// Copyright 2009 The Go Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package afero
import (
"path/filepath"
"sort"
"strings"
)
// Glob returns the names of all files matching pattern or nil
// if there is no matching file. The syntax of patterns is the same
// as in Match. The pattern may describe hierarchical names such as
// /usr/*/bin/ed (assuming the Separator is '/').
//
// Glob ignores file system errors such as I/O errors reading directories.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
// This was adapted from (http://golang.org/pkg/path/filepath) and uses several
// built-ins from that package.
func Glob(fs Fs, pattern string) (matches []string, err error) {
if !hasMeta(pattern) {
// afero does not support Lstat directly.
if _, err = lstatIfOs(fs, pattern); err != nil {
return nil, nil
}
return []string{pattern}, nil
}
dir, file := filepath.Split(pattern)
switch dir {
case "":
dir = "."
case string(filepath.Separator):
// nothing
default:
dir = dir[0 : len(dir)-1] // chop off trailing separator
}
if !hasMeta(dir) {
return glob(fs, dir, file, nil)
}
var m []string
m, err = Glob(fs, dir)
if err != nil {
return
}
for _, d := range m {
matches, err = glob(fs, d, file, matches)
if err != nil {
return
}
}
return
}
// glob searches for files matching pattern in the directory dir
// and appends them to matches. If the directory cannot be
// opened, it returns the existing matches. New matches are
// added in lexicographical order.
func glob(fs Fs, dir, pattern string, matches []string) (m []string, e error) {
m = matches
fi, err := fs.Stat(dir)
if err != nil {
return
}
if !fi.IsDir() {
return
}
d, err := fs.Open(dir)
if err != nil {
return
}
defer d.Close()
names, _ := d.Readdirnames(-1)
sort.Strings(names)
for _, n := range names {
matched, err := filepath.Match(pattern, n)
if err != nil {
return m, err
}
if matched {
m = append(m, filepath.Join(dir, n))
}
}
return
}
// hasMeta reports whether path contains any of the magic characters
// recognized by Match.
func hasMeta(path string) bool {
// TODO(niemeyer): Should other magic characters be added here?
return strings.IndexAny(path, "*?[") >= 0
}

View File

@ -74,24 +74,14 @@ func CreateDir(name string) *FileData {
} }
func ChangeFileName(f *FileData, newname string) { func ChangeFileName(f *FileData, newname string) {
f.Lock()
f.name = newname f.name = newname
f.Unlock()
} }
func SetMode(f *FileData, mode os.FileMode) { func SetMode(f *FileData, mode os.FileMode) {
f.Lock()
f.mode = mode f.mode = mode
f.Unlock()
} }
func SetModTime(f *FileData, mtime time.Time) { func SetModTime(f *FileData, mtime time.Time) {
f.Lock()
setModTime(f, mtime)
f.Unlock()
}
func setModTime(f *FileData, mtime time.Time) {
f.modtime = mtime f.modtime = mtime
} }
@ -112,7 +102,7 @@ func (f *File) Close() error {
f.fileData.Lock() f.fileData.Lock()
f.closed = true f.closed = true
if !f.readOnly { if !f.readOnly {
setModTime(f.fileData, time.Now()) SetModTime(f.fileData, time.Now())
} }
f.fileData.Unlock() f.fileData.Unlock()
return nil return nil
@ -196,7 +186,7 @@ func (f *File) Truncate(size int64) error {
return ErrFileClosed return ErrFileClosed
} }
if f.readOnly { if f.readOnly {
return &os.PathError{Op: "truncate", Path: f.fileData.name, Err: errors.New("file handle is read only")} return &os.PathError{"truncate", f.fileData.name, errors.New("file handle is read only")}
} }
if size < 0 { if size < 0 {
return ErrOutOfRange return ErrOutOfRange
@ -207,7 +197,7 @@ func (f *File) Truncate(size int64) error {
} else { } else {
f.fileData.data = f.fileData.data[0:size] f.fileData.data = f.fileData.data[0:size]
} }
setModTime(f.fileData, time.Now()) SetModTime(f.fileData, time.Now())
return nil return nil
} }
@ -228,7 +218,7 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
func (f *File) Write(b []byte) (n int, err error) { func (f *File) Write(b []byte) (n int, err error) {
if f.readOnly { if f.readOnly {
return 0, &os.PathError{Op: "write", Path: f.fileData.name, Err: errors.New("file handle is read only")} return 0, &os.PathError{"write", f.fileData.name, errors.New("file handle is read only")}
} }
n = len(b) n = len(b)
cur := atomic.LoadInt64(&f.at) cur := atomic.LoadInt64(&f.at)
@ -246,7 +236,7 @@ func (f *File) Write(b []byte) (n int, err error) {
f.fileData.data = append(f.fileData.data[:cur], b...) f.fileData.data = append(f.fileData.data[:cur], b...)
f.fileData.data = append(f.fileData.data, tail...) f.fileData.data = append(f.fileData.data, tail...)
} }
setModTime(f.fileData, time.Now()) SetModTime(f.fileData, time.Now())
atomic.StoreInt64(&f.at, int64(len(f.fileData.data))) atomic.StoreInt64(&f.at, int64(len(f.fileData.data)))
return return
@ -271,33 +261,17 @@ type FileInfo struct {
// Implements os.FileInfo // Implements os.FileInfo
func (s *FileInfo) Name() string { func (s *FileInfo) Name() string {
s.Lock()
_, name := filepath.Split(s.name) _, name := filepath.Split(s.name)
s.Unlock()
return name return name
} }
func (s *FileInfo) Mode() os.FileMode { func (s *FileInfo) Mode() os.FileMode { return s.mode }
s.Lock() func (s *FileInfo) ModTime() time.Time { return s.modtime }
defer s.Unlock() func (s *FileInfo) IsDir() bool { return s.dir }
return s.mode func (s *FileInfo) Sys() interface{} { return nil }
}
func (s *FileInfo) ModTime() time.Time {
s.Lock()
defer s.Unlock()
return s.modtime
}
func (s *FileInfo) IsDir() bool {
s.Lock()
defer s.Unlock()
return s.dir
}
func (s *FileInfo) Sys() interface{} { return nil }
func (s *FileInfo) Size() int64 { func (s *FileInfo) Size() int64 {
if s.IsDir() { if s.IsDir() {
return int64(42) return int64(42)
} }
s.Lock()
defer s.Unlock()
return int64(len(s.data)) return int64(len(s.data))
} }

View File

@ -45,7 +45,7 @@ func (m *MemMapFs) getData() map[string]*mem.FileData {
return m.data return m.data
} }
func (*MemMapFs) Name() string { return "MemMapFS" } func (MemMapFs) Name() string { return "MemMapFS" }
func (m *MemMapFs) Create(name string) (File, error) { func (m *MemMapFs) Create(name string) (File, error) {
name = normalizePath(name) name = normalizePath(name)
@ -66,10 +66,7 @@ func (m *MemMapFs) unRegisterWithParent(fileName string) error {
if parent == nil { if parent == nil {
log.Panic("parent of ", f.Name(), " is nil") log.Panic("parent of ", f.Name(), " is nil")
} }
parent.Lock()
mem.RemoveFromMemDir(parent, f) mem.RemoveFromMemDir(parent, f)
parent.Unlock()
return nil return nil
} }
@ -102,10 +99,8 @@ func (m *MemMapFs) registerWithParent(f *mem.FileData) {
} }
} }
parent.Lock()
mem.InitializeDir(parent) mem.InitializeDir(parent)
mem.AddToMemDir(parent, f) mem.AddToMemDir(parent, f)
parent.Unlock()
} }
func (m *MemMapFs) lockfreeMkdir(name string, perm os.FileMode) error { func (m *MemMapFs) lockfreeMkdir(name string, perm os.FileMode) error {
@ -113,7 +108,7 @@ func (m *MemMapFs) lockfreeMkdir(name string, perm os.FileMode) error {
x, ok := m.getData()[name] x, ok := m.getData()[name]
if ok { if ok {
// Only return ErrFileExists if it's a file, not a directory. // Only return ErrFileExists if it's a file, not a directory.
i := mem.FileInfo{FileData: x} i := mem.FileInfo{x}
if !i.IsDir() { if !i.IsDir() {
return ErrFileExists return ErrFileExists
} }
@ -132,17 +127,15 @@ func (m *MemMapFs) Mkdir(name string, perm os.FileMode) error {
_, ok := m.getData()[name] _, ok := m.getData()[name]
m.mu.RUnlock() m.mu.RUnlock()
if ok { if ok {
return &os.PathError{Op: "mkdir", Path: name, Err: ErrFileExists} return &os.PathError{"mkdir", name, ErrFileExists}
} else {
m.mu.Lock()
item := mem.CreateDir(name)
m.getData()[name] = item
m.registerWithParent(item)
m.mu.Unlock()
m.Chmod(name, perm)
} }
m.mu.Lock()
item := mem.CreateDir(name)
m.getData()[name] = item
m.registerWithParent(item)
m.mu.Unlock()
m.Chmod(name, perm|os.ModeDir)
return nil return nil
} }
@ -151,8 +144,9 @@ func (m *MemMapFs) MkdirAll(path string, perm os.FileMode) error {
if err != nil { if err != nil {
if err.(*os.PathError).Err == ErrFileExists { if err.(*os.PathError).Err == ErrFileExists {
return nil return nil
} else {
return err
} }
return err
} }
return nil return nil
} }
@ -194,7 +188,7 @@ func (m *MemMapFs) open(name string) (*mem.FileData, error) {
f, ok := m.getData()[name] f, ok := m.getData()[name]
m.mu.RUnlock() m.mu.RUnlock()
if !ok { if !ok {
return nil, &os.PathError{Op: "open", Path: name, Err: ErrFileNotFound} return nil, &os.PathError{"open", name, ErrFileNotFound}
} }
return f, nil return f, nil
} }
@ -251,11 +245,11 @@ func (m *MemMapFs) Remove(name string) error {
if _, ok := m.getData()[name]; ok { if _, ok := m.getData()[name]; ok {
err := m.unRegisterWithParent(name) err := m.unRegisterWithParent(name)
if err != nil { if err != nil {
return &os.PathError{Op: "remove", Path: name, Err: err} return &os.PathError{"remove", name, err}
} }
delete(m.getData(), name) delete(m.getData(), name)
} else { } else {
return &os.PathError{Op: "remove", Path: name, Err: os.ErrNotExist} return &os.PathError{"remove", name, os.ErrNotExist}
} }
return nil return nil
} }
@ -303,7 +297,7 @@ func (m *MemMapFs) Rename(oldname, newname string) error {
m.mu.Unlock() m.mu.Unlock()
m.mu.RLock() m.mu.RLock()
} else { } else {
return &os.PathError{Op: "rename", Path: oldname, Err: ErrFileNotFound} return &os.PathError{"rename", oldname, ErrFileNotFound}
} }
return nil return nil
} }
@ -319,12 +313,9 @@ func (m *MemMapFs) Stat(name string) (os.FileInfo, error) {
func (m *MemMapFs) Chmod(name string, mode os.FileMode) error { func (m *MemMapFs) Chmod(name string, mode os.FileMode) error {
name = normalizePath(name) name = normalizePath(name)
m.mu.RLock()
f, ok := m.getData()[name] f, ok := m.getData()[name]
m.mu.RUnlock()
if !ok { if !ok {
return &os.PathError{Op: "chmod", Path: name, Err: ErrFileNotFound} return &os.PathError{"chmod", name, ErrFileNotFound}
} }
m.mu.Lock() m.mu.Lock()
@ -336,12 +327,9 @@ func (m *MemMapFs) Chmod(name string, mode os.FileMode) error {
func (m *MemMapFs) Chtimes(name string, atime time.Time, mtime time.Time) error { func (m *MemMapFs) Chtimes(name string, atime time.Time, mtime time.Time) error {
name = normalizePath(name) name = normalizePath(name)
m.mu.RLock()
f, ok := m.getData()[name] f, ok := m.getData()[name]
m.mu.RUnlock()
if !ok { if !ok {
return &os.PathError{Op: "chtimes", Path: name, Err: ErrFileNotFound} return &os.PathError{"chtimes", name, ErrFileNotFound}
} }
m.mu.Lock() m.mu.Lock()
@ -353,7 +341,7 @@ func (m *MemMapFs) Chtimes(name string, atime time.Time, mtime time.Time) error
func (m *MemMapFs) List() { func (m *MemMapFs) List() {
for _, x := range m.data { for _, x := range m.data {
y := mem.FileInfo{FileData: x} y := mem.FileInfo{x}
fmt.Println(x.Name(), y.Size()) fmt.Println(x.Name(), y.Size())
} }
} }

14
vendor/github.com/spf13/afero/memradix.go generated vendored Normal file
View File

@ -0,0 +1,14 @@
// Copyright © 2014 Steve Francia <spf@spf13.com>.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package afero

76
vendor/github.com/spf13/cast/cast.go generated vendored
View File

@ -3,157 +3,81 @@
// Use of this source code is governed by an MIT-style // Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package cast provides easy and safe casting in Go.
package cast package cast
import "time" import "time"
// ToBool casts an interface to a bool type.
func ToBool(i interface{}) bool { func ToBool(i interface{}) bool {
v, _ := ToBoolE(i) v, _ := ToBoolE(i)
return v return v
} }
// ToTime casts an interface to a time.Time type.
func ToTime(i interface{}) time.Time { func ToTime(i interface{}) time.Time {
v, _ := ToTimeE(i) v, _ := ToTimeE(i)
return v return v
} }
// ToDuration casts an interface to a time.Duration type.
func ToDuration(i interface{}) time.Duration { func ToDuration(i interface{}) time.Duration {
v, _ := ToDurationE(i) v, _ := ToDurationE(i)
return v return v
} }
// ToFloat64 casts an interface to a float64 type.
func ToFloat64(i interface{}) float64 { func ToFloat64(i interface{}) float64 {
v, _ := ToFloat64E(i) v, _ := ToFloat64E(i)
return v return v
} }
// ToFloat32 casts an interface to a float32 type.
func ToFloat32(i interface{}) float32 {
v, _ := ToFloat32E(i)
return v
}
// ToInt64 casts an interface to an int64 type.
func ToInt64(i interface{}) int64 { func ToInt64(i interface{}) int64 {
v, _ := ToInt64E(i) v, _ := ToInt64E(i)
return v return v
} }
// ToInt32 casts an interface to an int32 type.
func ToInt32(i interface{}) int32 {
v, _ := ToInt32E(i)
return v
}
// ToInt16 casts an interface to an int16 type.
func ToInt16(i interface{}) int16 {
v, _ := ToInt16E(i)
return v
}
// ToInt8 casts an interface to an int8 type.
func ToInt8(i interface{}) int8 {
v, _ := ToInt8E(i)
return v
}
// ToInt casts an interface to an int type.
func ToInt(i interface{}) int { func ToInt(i interface{}) int {
v, _ := ToIntE(i) v, _ := ToIntE(i)
return v return v
} }
// ToUint casts an interface to a uint type.
func ToUint(i interface{}) uint {
v, _ := ToUintE(i)
return v
}
// ToUint64 casts an interface to a uint64 type.
func ToUint64(i interface{}) uint64 {
v, _ := ToUint64E(i)
return v
}
// ToUint32 casts an interface to a uint32 type.
func ToUint32(i interface{}) uint32 {
v, _ := ToUint32E(i)
return v
}
// ToUint16 casts an interface to a uint16 type.
func ToUint16(i interface{}) uint16 {
v, _ := ToUint16E(i)
return v
}
// ToUint8 casts an interface to a uint8 type.
func ToUint8(i interface{}) uint8 {
v, _ := ToUint8E(i)
return v
}
// ToString casts an interface to a string type.
func ToString(i interface{}) string { func ToString(i interface{}) string {
v, _ := ToStringE(i) v, _ := ToStringE(i)
return v return v
} }
// ToStringMapString casts an interface to a map[string]string type.
func ToStringMapString(i interface{}) map[string]string { func ToStringMapString(i interface{}) map[string]string {
v, _ := ToStringMapStringE(i) v, _ := ToStringMapStringE(i)
return v return v
} }
// ToStringMapStringSlice casts an interface to a map[string][]string type.
func ToStringMapStringSlice(i interface{}) map[string][]string { func ToStringMapStringSlice(i interface{}) map[string][]string {
v, _ := ToStringMapStringSliceE(i) v, _ := ToStringMapStringSliceE(i)
return v return v
} }
// ToStringMapBool casts an interface to a map[string]bool type.
func ToStringMapBool(i interface{}) map[string]bool { func ToStringMapBool(i interface{}) map[string]bool {
v, _ := ToStringMapBoolE(i) v, _ := ToStringMapBoolE(i)
return v return v
} }
// ToStringMap casts an interface to a map[string]interface{} type.
func ToStringMap(i interface{}) map[string]interface{} { func ToStringMap(i interface{}) map[string]interface{} {
v, _ := ToStringMapE(i) v, _ := ToStringMapE(i)
return v return v
} }
// ToSlice casts an interface to a []interface{} type.
func ToSlice(i interface{}) []interface{} { func ToSlice(i interface{}) []interface{} {
v, _ := ToSliceE(i) v, _ := ToSliceE(i)
return v return v
} }
// ToBoolSlice casts an interface to a []bool type.
func ToBoolSlice(i interface{}) []bool { func ToBoolSlice(i interface{}) []bool {
v, _ := ToBoolSliceE(i) v, _ := ToBoolSliceE(i)
return v return v
} }
// ToStringSlice casts an interface to a []string type.
func ToStringSlice(i interface{}) []string { func ToStringSlice(i interface{}) []string {
v, _ := ToStringSliceE(i) v, _ := ToStringSliceE(i)
return v return v
} }
// ToIntSlice casts an interface to a []int type.
func ToIntSlice(i interface{}) []int { func ToIntSlice(i interface{}) []int {
v, _ := ToIntSliceE(i) v, _ := ToIntSliceE(i)
return v return v
} }
// ToDurationSlice casts an interface to a []time.Duration type.
func ToDurationSlice(i interface{}) []time.Duration {
v, _ := ToDurationSliceE(i)
return v
}

779
vendor/github.com/spf13/cast/caste.go generated vendored

File diff suppressed because it is too large Load Diff

View File

@ -1,98 +0,0 @@
package cobra
import (
"fmt"
)
type PositionalArgs func(cmd *Command, args []string) error
// Legacy arg validation has the following behaviour:
// - root commands with no subcommands can take arbitrary arguments
// - root commands with subcommands will do subcommand validity checking
// - subcommands will always accept arbitrary arguments
func legacyArgs(cmd *Command, args []string) error {
// no subcommand, always take args
if !cmd.HasSubCommands() {
return nil
}
// root command with subcommands, do subcommand checking
if !cmd.HasParent() && len(args) > 0 {
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
return nil
}
// NoArgs returns an error if any args are included
func NoArgs(cmd *Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath())
}
return nil
}
// OnlyValidArgs returns an error if any args are not in the list of ValidArgs
func OnlyValidArgs(cmd *Command, args []string) error {
if len(cmd.ValidArgs) > 0 {
for _, v := range args {
if !stringInSlice(v, cmd.ValidArgs) {
return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
}
}
return nil
}
func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// ArbitraryArgs never returns an error
func ArbitraryArgs(cmd *Command, args []string) error {
return nil
}
// MinimumNArgs returns an error if there is not at least N args
func MinimumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < n {
return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args))
}
return nil
}
}
// MaximumNArgs returns an error if there are more than N args
func MaximumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) > n {
return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args))
}
return nil
}
}
// ExactArgs returns an error if there are not exactly n args
func ExactArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) != n {
return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
}
return nil
}
}
// RangeArgs returns an error if the number of args is not within the expected range
func RangeArgs(min int, max int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < min || len(args) > max {
return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args))
}
return nil
}
}

View File

@ -1,7 +1,6 @@
package cobra package cobra
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -11,7 +10,6 @@ import (
"github.com/spf13/pflag" "github.com/spf13/pflag"
) )
// Annotations for Bash completion.
const ( const (
BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extensions" BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extensions"
BashCompCustom = "cobra_annotation_bash_completion_custom" BashCompCustom = "cobra_annotation_bash_completion_custom"
@ -19,9 +17,12 @@ const (
BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
) )
func writePreamble(buf *bytes.Buffer, name string) { func preamble(out io.Writer, name string) error {
buf.WriteString(fmt.Sprintf("# bash completion for %-36s -*- shell-script -*-\n", name)) _, err := fmt.Fprintf(out, "# bash completion for %-36s -*- shell-script -*-\n", name)
buf.WriteString(` if err != nil {
return err
}
_, err = fmt.Fprint(out, `
__debug() __debug()
{ {
if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then
@ -86,8 +87,8 @@ __handle_reply()
local index flag local index flag
flag="${cur%%=*}" flag="${cur%%=*}"
__index_of_word "${flag}" "${flags_with_completion[@]}" __index_of_word "${flag}" "${flags_with_completion[@]}"
COMPREPLY=()
if [[ ${index} -ge 0 ]]; then if [[ ${index} -ge 0 ]]; then
COMPREPLY=()
PREFIX="" PREFIX=""
cur="${cur#*=}" cur="${cur#*=}"
${flags_completion[${index}]} ${flags_completion[${index}]}
@ -132,10 +133,7 @@ __handle_reply()
declare -F __custom_func >/dev/null && __custom_func declare -F __custom_func >/dev/null && __custom_func
fi fi
# available in bash-completion >= 2, not always present on macOS __ltrim_colon_completions "$cur"
if declare -F __ltrim_colon_completions >/dev/null; then
__ltrim_colon_completions "$cur"
fi
} }
# The arguments should be in the form "ext1|ext2|extn" # The arguments should be in the form "ext1|ext2|extn"
@ -226,7 +224,7 @@ __handle_command()
fi fi
c=$((c+1)) c=$((c+1))
__debug "${FUNCNAME[0]}: looking for ${next_command}" __debug "${FUNCNAME[0]}: looking for ${next_command}"
declare -F "$next_command" >/dev/null && $next_command declare -F $next_command >/dev/null && $next_command
} }
__handle_word() __handle_word()
@ -249,12 +247,16 @@ __handle_word()
} }
`) `)
return err
} }
func writePostscript(buf *bytes.Buffer, name string) { func postscript(w io.Writer, name string) error {
name = strings.Replace(name, ":", "__", -1) name = strings.Replace(name, ":", "__", -1)
buf.WriteString(fmt.Sprintf("__start_%s()\n", name)) _, err := fmt.Fprintf(w, "__start_%s()\n", name)
buf.WriteString(fmt.Sprintf(`{ if err != nil {
return err
}
_, err = fmt.Fprintf(w, `{
local cur prev words cword local cur prev words cword
declare -A flaghash 2>/dev/null || : declare -A flaghash 2>/dev/null || :
if declare -F _init_completion >/dev/null 2>&1; then if declare -F _init_completion >/dev/null 2>&1; then
@ -278,132 +280,197 @@ func writePostscript(buf *bytes.Buffer, name string) {
__handle_word __handle_word
} }
`, name)) `, name)
buf.WriteString(fmt.Sprintf(`if [[ $(type -t compopt) = "builtin" ]]; then if err != nil {
return err
}
_, err = fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then
complete -o default -F __start_%s %s complete -o default -F __start_%s %s
else else
complete -o default -o nospace -F __start_%s %s complete -o default -o nospace -F __start_%s %s
fi fi
`, name, name, name, name)) `, name, name, name, name)
buf.WriteString("# ex: ts=4 sw=4 et filetype=sh\n") if err != nil {
return err
}
_, err = fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n")
return err
} }
func writeCommands(buf *bytes.Buffer, cmd *Command) { func writeCommands(cmd *Command, w io.Writer) error {
buf.WriteString(" commands=()\n") if _, err := fmt.Fprintf(w, " commands=()\n"); err != nil {
return err
}
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
buf.WriteString(fmt.Sprintf(" commands+=(%q)\n", c.Name())) if _, err := fmt.Fprintf(w, " commands+=(%q)\n", c.Name()); err != nil {
return err
}
} }
buf.WriteString("\n") _, err := fmt.Fprintf(w, "\n")
return err
} }
func writeFlagHandler(buf *bytes.Buffer, name string, annotations map[string][]string) { func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) error {
for key, value := range annotations { for key, value := range annotations {
switch key { switch key {
case BashCompFilenameExt: case BashCompFilenameExt:
buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name)
if err != nil {
var ext string return err
if len(value) > 0 { }
ext = "__handle_filename_extension_flag " + strings.Join(value, "|")
} else { if len(value) > 0 {
ext = "_filedir" ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else {
ext := "_filedir"
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
}
if err != nil {
return err
} }
buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
case BashCompCustom: case BashCompCustom:
buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name)
if err != nil {
return err
}
if len(value) > 0 { if len(value) > 0 {
handlers := strings.Join(value, "; ") handlers := strings.Join(value, "; ")
buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", handlers)) _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", handlers)
} else { } else {
buf.WriteString(" flags_completion+=(:)\n") _, err = fmt.Fprintf(w, " flags_completion+=(:)\n")
}
if err != nil {
return err
} }
case BashCompSubdirsInDir: case BashCompSubdirsInDir:
buf.WriteString(fmt.Sprintf(" flags_with_completion+=(%q)\n", name)) _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name)
var ext string
if len(value) == 1 { if len(value) == 1 {
ext = "__handle_subdirs_in_dir_flag " + value[0] ext := "__handle_subdirs_in_dir_flag " + value[0]
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else { } else {
ext = "_filedir -d" ext := "_filedir -d"
_, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
}
if err != nil {
return err
} }
buf.WriteString(fmt.Sprintf(" flags_completion+=(%q)\n", ext))
} }
} }
return nil
} }
func writeShortFlag(buf *bytes.Buffer, flag *pflag.Flag) { func writeShortFlag(flag *pflag.Flag, w io.Writer) error {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Shorthand name := flag.Shorthand
format := " " format := " "
if len(flag.NoOptDefVal) == 0 { if !b {
format += "two_word_" format += "two_word_"
} }
format += "flags+=(\"-%s\")\n" format += "flags+=(\"-%s\")\n"
buf.WriteString(fmt.Sprintf(format, name)) if _, err := fmt.Fprintf(w, format, name); err != nil {
writeFlagHandler(buf, "-"+name, flag.Annotations) return err
}
return writeFlagHandler("-"+name, flag.Annotations, w)
} }
func writeFlag(buf *bytes.Buffer, flag *pflag.Flag) { func writeFlag(flag *pflag.Flag, w io.Writer) error {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Name name := flag.Name
format := " flags+=(\"--%s" format := " flags+=(\"--%s"
if len(flag.NoOptDefVal) == 0 { if !b {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
buf.WriteString(fmt.Sprintf(format, name)) if _, err := fmt.Fprintf(w, format, name); err != nil {
writeFlagHandler(buf, "--"+name, flag.Annotations) return err
}
return writeFlagHandler("--"+name, flag.Annotations, w)
} }
func writeLocalNonPersistentFlag(buf *bytes.Buffer, flag *pflag.Flag) { func writeLocalNonPersistentFlag(flag *pflag.Flag, w io.Writer) error {
b := (len(flag.NoOptDefVal) > 0)
name := flag.Name name := flag.Name
format := " local_nonpersistent_flags+=(\"--%s" format := " local_nonpersistent_flags+=(\"--%s"
if len(flag.NoOptDefVal) == 0 { if !b {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
buf.WriteString(fmt.Sprintf(format, name)) _, err := fmt.Fprintf(w, format, name)
return err
} }
func writeFlags(buf *bytes.Buffer, cmd *Command) { func writeFlags(cmd *Command, w io.Writer) error {
buf.WriteString(` flags=() _, err := fmt.Fprintf(w, ` flags=()
two_word_flags=() two_word_flags=()
local_nonpersistent_flags=() local_nonpersistent_flags=()
flags_with_completion=() flags_with_completion=()
flags_completion=() flags_completion=()
`) `)
if err != nil {
return err
}
localNonPersistentFlags := cmd.LocalNonPersistentFlags() localNonPersistentFlags := cmd.LocalNonPersistentFlags()
var visitErr error
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
} }
writeFlag(buf, flag) if err := writeFlag(flag, w); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
writeShortFlag(buf, flag) if err := writeShortFlag(flag, w); err != nil {
visitErr = err
return
}
} }
if localNonPersistentFlags.Lookup(flag.Name) != nil { if localNonPersistentFlags.Lookup(flag.Name) != nil {
writeLocalNonPersistentFlag(buf, flag) if err := writeLocalNonPersistentFlag(flag, w); err != nil {
visitErr = err
return
}
} }
}) })
if visitErr != nil {
return visitErr
}
cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
} }
writeFlag(buf, flag) if err := writeFlag(flag, w); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
writeShortFlag(buf, flag) if err := writeShortFlag(flag, w); err != nil {
visitErr = err
return
}
} }
}) })
if visitErr != nil {
return visitErr
}
buf.WriteString("\n") _, err = fmt.Fprintf(w, "\n")
return err
} }
func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) { func writeRequiredFlag(cmd *Command, w io.Writer) error {
buf.WriteString(" must_have_one_flag=()\n") if _, err := fmt.Fprintf(w, " must_have_one_flag=()\n"); err != nil {
return err
}
flags := cmd.NonInheritedFlags() flags := cmd.NonInheritedFlags()
var visitErr error
flags.VisitAll(func(flag *pflag.Flag) { flags.VisitAll(func(flag *pflag.Flag) {
if nonCompletableFlag(flag) { if nonCompletableFlag(flag) {
return return
@ -412,93 +479,130 @@ func writeRequiredFlag(buf *bytes.Buffer, cmd *Command) {
switch key { switch key {
case BashCompOneRequiredFlag: case BashCompOneRequiredFlag:
format := " must_have_one_flag+=(\"--%s" format := " must_have_one_flag+=(\"--%s"
if flag.Value.Type() != "bool" { b := (flag.Value.Type() == "bool")
if !b {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
buf.WriteString(fmt.Sprintf(format, flag.Name)) if _, err := fmt.Fprintf(w, format, flag.Name); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
buf.WriteString(fmt.Sprintf(" must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)) if _, err := fmt.Fprintf(w, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand); err != nil {
visitErr = err
return
}
} }
} }
} }
}) })
return visitErr
} }
func writeRequiredNouns(buf *bytes.Buffer, cmd *Command) { func writeRequiredNouns(cmd *Command, w io.Writer) error {
buf.WriteString(" must_have_one_noun=()\n") if _, err := fmt.Fprintf(w, " must_have_one_noun=()\n"); err != nil {
return err
}
sort.Sort(sort.StringSlice(cmd.ValidArgs)) sort.Sort(sort.StringSlice(cmd.ValidArgs))
for _, value := range cmd.ValidArgs { for _, value := range cmd.ValidArgs {
buf.WriteString(fmt.Sprintf(" must_have_one_noun+=(%q)\n", value)) if _, err := fmt.Fprintf(w, " must_have_one_noun+=(%q)\n", value); err != nil {
return err
}
} }
return nil
} }
func writeArgAliases(buf *bytes.Buffer, cmd *Command) { func writeArgAliases(cmd *Command, w io.Writer) error {
buf.WriteString(" noun_aliases=()\n") if _, err := fmt.Fprintf(w, " noun_aliases=()\n"); err != nil {
return err
}
sort.Sort(sort.StringSlice(cmd.ArgAliases)) sort.Sort(sort.StringSlice(cmd.ArgAliases))
for _, value := range cmd.ArgAliases { for _, value := range cmd.ArgAliases {
buf.WriteString(fmt.Sprintf(" noun_aliases+=(%q)\n", value)) if _, err := fmt.Fprintf(w, " noun_aliases+=(%q)\n", value); err != nil {
return err
}
} }
return nil
} }
func gen(buf *bytes.Buffer, cmd *Command) { func gen(cmd *Command, w io.Writer) error {
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
gen(buf, c) if err := gen(c, w); err != nil {
return err
}
} }
commandName := cmd.CommandPath() commandName := cmd.CommandPath()
commandName = strings.Replace(commandName, " ", "_", -1) commandName = strings.Replace(commandName, " ", "_", -1)
commandName = strings.Replace(commandName, ":", "__", -1) commandName = strings.Replace(commandName, ":", "__", -1)
buf.WriteString(fmt.Sprintf("_%s()\n{\n", commandName)) if _, err := fmt.Fprintf(w, "_%s()\n{\n", commandName); err != nil {
buf.WriteString(fmt.Sprintf(" last_command=%q\n", commandName)) return err
writeCommands(buf, cmd) }
writeFlags(buf, cmd) if _, err := fmt.Fprintf(w, " last_command=%q\n", commandName); err != nil {
writeRequiredFlag(buf, cmd) return err
writeRequiredNouns(buf, cmd) }
writeArgAliases(buf, cmd) if err := writeCommands(cmd, w); err != nil {
buf.WriteString("}\n\n") return err
}
if err := writeFlags(cmd, w); err != nil {
return err
}
if err := writeRequiredFlag(cmd, w); err != nil {
return err
}
if err := writeRequiredNouns(cmd, w); err != nil {
return err
}
if err := writeArgAliases(cmd, w); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "}\n\n"); err != nil {
return err
}
return nil
} }
// GenBashCompletion generates bash completion file and writes to the passed writer. func (cmd *Command) GenBashCompletion(w io.Writer) error {
func (c *Command) GenBashCompletion(w io.Writer) error { if err := preamble(w, cmd.Name()); err != nil {
buf := new(bytes.Buffer) return err
writePreamble(buf, c.Name())
if len(c.BashCompletionFunction) > 0 {
buf.WriteString(c.BashCompletionFunction + "\n")
} }
gen(buf, c) if len(cmd.BashCompletionFunction) > 0 {
writePostscript(buf, c.Name()) if _, err := fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction); err != nil {
return err
_, err := buf.WriteTo(w) }
return err }
if err := gen(cmd, w); err != nil {
return err
}
return postscript(w, cmd.Name())
} }
func nonCompletableFlag(flag *pflag.Flag) bool { func nonCompletableFlag(flag *pflag.Flag) bool {
return flag.Hidden || len(flag.Deprecated) > 0 return flag.Hidden || len(flag.Deprecated) > 0
} }
// GenBashCompletionFile generates bash completion file. func (cmd *Command) GenBashCompletionFile(filename string) error {
func (c *Command) GenBashCompletionFile(filename string) error {
outFile, err := os.Create(filename) outFile, err := os.Create(filename)
if err != nil { if err != nil {
return err return err
} }
defer outFile.Close() defer outFile.Close()
return c.GenBashCompletion(outFile) return cmd.GenBashCompletion(outFile)
} }
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists.
func (c *Command) MarkFlagRequired(name string) error { func (cmd *Command) MarkFlagRequired(name string) error {
return MarkFlagRequired(c.Flags(), name) return MarkFlagRequired(cmd.Flags(), name)
} }
// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists. // MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists.
func (c *Command) MarkPersistentFlagRequired(name string) error { func (cmd *Command) MarkPersistentFlagRequired(name string) error {
return MarkFlagRequired(c.PersistentFlags(), name) return MarkFlagRequired(cmd.PersistentFlags(), name)
} }
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists.
@ -508,20 +612,20 @@ func MarkFlagRequired(flags *pflag.FlagSet, name string) error {
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkFlagFilename(name string, extensions ...string) error { func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.Flags(), name, extensions...) return MarkFlagFilename(cmd.Flags(), name, extensions...)
} }
// MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists. // MarkFlagCustom adds the BashCompCustom annotation to the named flag, if it exists.
// Generated bash autocompletion will call the bash function f for the flag. // Generated bash autocompletion will call the bash function f for the flag.
func (c *Command) MarkFlagCustom(name string, f string) error { func (cmd *Command) MarkFlagCustom(name string, f string) error {
return MarkFlagCustom(c.Flags(), name, f) return MarkFlagCustom(cmd.Flags(), name, f)
} }
// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists. // MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists.
// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. // Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided.
func (c *Command) MarkPersistentFlagFilename(name string, extensions ...string) error { func (cmd *Command) MarkPersistentFlagFilename(name string, extensions ...string) error {
return MarkFlagFilename(c.PersistentFlags(), name, extensions...) return MarkFlagFilename(cmd.PersistentFlags(), name, extensions...)
} }
// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists. // MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists.

View File

@ -27,19 +27,17 @@ import (
) )
var templateFuncs = template.FuncMap{ var templateFuncs = template.FuncMap{
"trim": strings.TrimSpace, "trim": strings.TrimSpace,
"trimRightSpace": trimRightSpace, "trimRightSpace": trimRightSpace,
"trimTrailingWhitespaces": trimRightSpace, "appendIfNotPresent": appendIfNotPresent,
"appendIfNotPresent": appendIfNotPresent, "rpad": rpad,
"rpad": rpad, "gt": Gt,
"gt": Gt, "eq": Eq,
"eq": Eq,
} }
var initializers []func() var initializers []func()
// EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing // Automatic prefix matching can be a dangerous thing to automatically enable in CLI tools.
// to automatically enable in CLI tools.
// Set this to true to enable it. // Set this to true to enable it.
var EnablePrefixMatching = false var EnablePrefixMatching = false
@ -47,22 +45,13 @@ var EnablePrefixMatching = false
// To disable sorting, set it to false. // To disable sorting, set it to false.
var EnableCommandSorting = true var EnableCommandSorting = true
// MousetrapHelpText enables an information splash screen on Windows
// if the CLI is started from explorer.exe.
// To disable the mousetrap, just set this variable to blank string ("").
// Works only on Microsoft Windows.
var MousetrapHelpText string = `This is a command line tool.
You need to open cmd.exe and run it from there.
`
// AddTemplateFunc adds a template function that's available to Usage and Help // AddTemplateFunc adds a template function that's available to Usage and Help
// template generation. // template generation.
func AddTemplateFunc(name string, tmplFunc interface{}) { func AddTemplateFunc(name string, tmplFunc interface{}) {
templateFuncs[name] = tmplFunc templateFuncs[name] = tmplFunc
} }
// AddTemplateFuncs adds multiple template functions that are available to Usage and // AddTemplateFuncs adds multiple template functions availalble to Usage and
// Help template generation. // Help template generation.
func AddTemplateFuncs(tmplFuncs template.FuncMap) { func AddTemplateFuncs(tmplFuncs template.FuncMap) {
for k, v := range tmplFuncs { for k, v := range tmplFuncs {
@ -75,8 +64,6 @@ func OnInitialize(y ...func()) {
initializers = append(initializers, y...) initializers = append(initializers, y...)
} }
// FIXME Gt is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra.
// Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans, // Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans,
// Maps and Slices, Gt will compare their lengths. Ints are compared directly while strings are first parsed as // Maps and Slices, Gt will compare their lengths. Ints are compared directly while strings are first parsed as
// ints and then compared. // ints and then compared.
@ -107,8 +94,6 @@ func Gt(a interface{}, b interface{}) bool {
return left > right return left > right
} }
// FIXME Eq is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra.
// Eq takes two types and checks whether they are equal. Supported types are int and string. Unsupported types will panic. // Eq takes two types and checks whether they are equal. Supported types are int and string. Unsupported types will panic.
func Eq(a interface{}, b interface{}) bool { func Eq(a interface{}, b interface{}) bool {
av := reflect.ValueOf(a) av := reflect.ValueOf(a)
@ -129,8 +114,6 @@ func trimRightSpace(s string) string {
return strings.TrimRightFunc(s, unicode.IsSpace) return strings.TrimRightFunc(s, unicode.IsSpace)
} }
// FIXME appendIfNotPresent is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra.
// appendIfNotPresent will append stringToAppend to the end of s, but only if it's not yet present in s. // appendIfNotPresent will append stringToAppend to the end of s, but only if it's not yet present in s.
func appendIfNotPresent(s, stringToAppend string) string { func appendIfNotPresent(s, stringToAppend string) string {
if strings.Contains(s, stringToAppend) { if strings.Contains(s, stringToAppend) {

File diff suppressed because it is too large Load Diff

View File

@ -11,8 +11,14 @@ import (
var preExecHookFn = preExecHook var preExecHookFn = preExecHook
// enables an information splash screen on Windows if the CLI is started from explorer.exe.
var MousetrapHelpText string = `This is a command line tool
You need to open cmd.exe and run it from there.
`
func preExecHook(c *Command) { func preExecHook(c *Command) {
if MousetrapHelpText != "" && mousetrap.StartedByExplorer() { if mousetrap.StartedByExplorer() {
c.Print(MousetrapHelpText) c.Print(MousetrapHelpText)
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
os.Exit(1) os.Exit(1)

View File

@ -1,126 +0,0 @@
package cobra
import (
"bytes"
"fmt"
"io"
"os"
"strings"
)
// GenZshCompletionFile generates zsh completion file.
func (c *Command) GenZshCompletionFile(filename string) error {
outFile, err := os.Create(filename)
if err != nil {
return err
}
defer outFile.Close()
return c.GenZshCompletion(outFile)
}
// GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (c *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer)
writeHeader(buf, c)
maxDepth := maxDepth(c)
writeLevelMapping(buf, maxDepth)
writeLevelCases(buf, maxDepth, c)
_, err := buf.WriteTo(w)
return err
}
func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}
func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
}
maxDepthSub := 0
for _, s := range c.Commands() {
subDepth := maxDepth(s)
if subDepth > maxDepthSub {
maxDepthSub = subDepth
}
}
return 1 + maxDepthSub
}
func writeLevelMapping(w io.Writer, numLevels int) {
fmt.Fprintln(w, `_arguments \`)
for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
fmt.Fprintln(w)
}
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
fmt.Fprintln(w)
}
func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")
for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}
func writeLevel(w io.Writer, root *Command, i int) {
fmt.Fprintf(w, " case $words[%d] in\n", i)
defer fmt.Fprintln(w, " esac")
commands := filterByLevel(root, i)
byParent := groupByParent(commands)
for p, c := range byParent {
names := names(c)
fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}
func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0)
if l == 0 {
cs = append(cs, c)
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}
func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands {
parent := c.Parent()
if parent == nil {
continue
}
m[parent.Name()] = append(m[parent.Name()], c)
}
return m
}
func names(commands []*Command) []string {
ns := make([]string, len(commands))
for i, c := range commands {
ns[i] = c.Name()
}
return ns
}

View File

@ -1,113 +0,0 @@
// Copyright © 2016 Steve Francia <spf@spf13.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package jwalterweatherman
import (
"io"
"io/ioutil"
"log"
"os"
)
var (
TRACE *log.Logger
DEBUG *log.Logger
INFO *log.Logger
WARN *log.Logger
ERROR *log.Logger
CRITICAL *log.Logger
FATAL *log.Logger
LOG *log.Logger
FEEDBACK *Feedback
defaultNotepad *Notepad
)
func reloadDefaultNotepad() {
TRACE = defaultNotepad.TRACE
DEBUG = defaultNotepad.DEBUG
INFO = defaultNotepad.INFO
WARN = defaultNotepad.WARN
ERROR = defaultNotepad.ERROR
CRITICAL = defaultNotepad.CRITICAL
FATAL = defaultNotepad.FATAL
LOG = defaultNotepad.LOG
FEEDBACK = defaultNotepad.FEEDBACK
}
func init() {
defaultNotepad = NewNotepad(LevelError, LevelWarn, os.Stdout, ioutil.Discard, "", log.Ldate|log.Ltime)
reloadDefaultNotepad()
}
// SetLogThreshold set the log threshold for the default notepad. Trace by default.
func SetLogThreshold(threshold Threshold) {
defaultNotepad.SetLogThreshold(threshold)
reloadDefaultNotepad()
}
// SetLogOutput set the log output for the default notepad. Discarded by default.
func SetLogOutput(handle io.Writer) {
defaultNotepad.SetLogOutput(handle)
reloadDefaultNotepad()
}
// SetStdoutThreshold set the standard output threshold for the default notepad.
// Info by default.
func SetStdoutThreshold(threshold Threshold) {
defaultNotepad.SetStdoutThreshold(threshold)
reloadDefaultNotepad()
}
// SetPrefix set the prefix for the default logger. Empty by default.
func SetPrefix(prefix string) {
defaultNotepad.SetPrefix(prefix)
reloadDefaultNotepad()
}
// SetFlags set the flags for the default logger. "log.Ldate | log.Ltime" by default.
func SetFlags(flags int) {
defaultNotepad.SetFlags(flags)
reloadDefaultNotepad()
}
// Level returns the current global log threshold.
func LogThreshold() Threshold {
return defaultNotepad.logThreshold
}
// Level returns the current global output threshold.
func StdoutThreshold() Threshold {
return defaultNotepad.stdoutThreshold
}
// GetStdoutThreshold returns the defined Treshold for the log logger.
func GetLogThreshold() Threshold {
return defaultNotepad.GetLogThreshold()
}
// GetStdoutThreshold returns the Treshold for the stdout logger.
func GetStdoutThreshold() Threshold {
return defaultNotepad.GetStdoutThreshold()
}
// LogCountForLevel returns the number of log invocations for a given threshold.
func LogCountForLevel(l Threshold) uint64 {
return defaultNotepad.LogCountForLevel(l)
}
// LogCountForLevelsGreaterThanorEqualTo returns the number of log invocations
// greater than or equal to a given threshold.
func LogCountForLevelsGreaterThanorEqualTo(threshold Threshold) uint64 {
return defaultNotepad.LogCountForLevelsGreaterThanorEqualTo(threshold)
}
// ResetLogCounters resets the invocation counters for all levels.
func ResetLogCounters() {
defaultNotepad.ResetLogCounters()
}

View File

@ -1,55 +0,0 @@
// Copyright © 2016 Steve Francia <spf@spf13.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package jwalterweatherman
import (
"sync/atomic"
)
type logCounter struct {
counter uint64
}
func (c *logCounter) incr() {
atomic.AddUint64(&c.counter, 1)
}
func (c *logCounter) resetCounter() {
atomic.StoreUint64(&c.counter, 0)
}
func (c *logCounter) getCount() uint64 {
return atomic.LoadUint64(&c.counter)
}
func (c *logCounter) Write(p []byte) (n int, err error) {
c.incr()
return len(p), nil
}
// LogCountForLevel returns the number of log invocations for a given threshold.
func (n *Notepad) LogCountForLevel(l Threshold) uint64 {
return n.logCounters[l].getCount()
}
// LogCountForLevelsGreaterThanorEqualTo returns the number of log invocations
// greater than or equal to a given threshold.
func (n *Notepad) LogCountForLevelsGreaterThanorEqualTo(threshold Threshold) uint64 {
var cnt uint64
for i := int(threshold); i < len(n.logCounters); i++ {
cnt += n.LogCountForLevel(Threshold(i))
}
return cnt
}
// ResetLogCounters resets the invocation counters for all levels.
func (n *Notepad) ResetLogCounters() {
for _, np := range n.logCounters {
np.resetCounter()
}
}

View File

@ -1,194 +0,0 @@
// Copyright © 2016 Steve Francia <spf@spf13.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package jwalterweatherman
import (
"fmt"
"io"
"log"
)
type Threshold int
func (t Threshold) String() string {
return prefixes[t]
}
const (
LevelTrace Threshold = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelCritical
LevelFatal
)
var prefixes map[Threshold]string = map[Threshold]string{
LevelTrace: "TRACE",
LevelDebug: "DEBUG",
LevelInfo: "INFO",
LevelWarn: "WARN",
LevelError: "ERROR",
LevelCritical: "CRITICAL",
LevelFatal: "FATAL",
}
// Notepad is where you leave a note!
type Notepad struct {
TRACE *log.Logger
DEBUG *log.Logger
INFO *log.Logger
WARN *log.Logger
ERROR *log.Logger
CRITICAL *log.Logger
FATAL *log.Logger
LOG *log.Logger
FEEDBACK *Feedback
loggers [7]**log.Logger
logHandle io.Writer
outHandle io.Writer
logThreshold Threshold
stdoutThreshold Threshold
prefix string
flags int
// One per Threshold
logCounters [7]*logCounter
}
// NewNotepad create a new notepad.
func NewNotepad(outThreshold Threshold, logThreshold Threshold, outHandle, logHandle io.Writer, prefix string, flags int) *Notepad {
n := &Notepad{}
n.loggers = [7]**log.Logger{&n.TRACE, &n.DEBUG, &n.INFO, &n.WARN, &n.ERROR, &n.CRITICAL, &n.FATAL}
n.outHandle = outHandle
n.logHandle = logHandle
n.stdoutThreshold = outThreshold
n.logThreshold = logThreshold
if len(prefix) != 0 {
n.prefix = "[" + prefix + "] "
} else {
n.prefix = ""
}
n.flags = flags
n.LOG = log.New(n.logHandle,
"LOG: ",
n.flags)
n.FEEDBACK = &Feedback{out: log.New(outHandle, "", 0), log: n.LOG}
n.init()
return n
}
// init creates the loggers for each level depending on the notepad thresholds.
func (n *Notepad) init() {
logAndOut := io.MultiWriter(n.outHandle, n.logHandle)
for t, logger := range n.loggers {
threshold := Threshold(t)
counter := &logCounter{}
n.logCounters[t] = counter
prefix := n.prefix + threshold.String() + " "
switch {
case threshold >= n.logThreshold && threshold >= n.stdoutThreshold:
*logger = log.New(io.MultiWriter(counter, logAndOut), prefix, n.flags)
case threshold >= n.logThreshold:
*logger = log.New(io.MultiWriter(counter, n.logHandle), prefix, n.flags)
case threshold >= n.stdoutThreshold:
*logger = log.New(io.MultiWriter(counter, n.outHandle), prefix, n.flags)
default:
// counter doesn't care about prefix and flags, so don't use them
// for performance.
*logger = log.New(counter, "", 0)
}
}
}
// SetLogThreshold changes the threshold above which messages are written to the
// log file.
func (n *Notepad) SetLogThreshold(threshold Threshold) {
n.logThreshold = threshold
n.init()
}
// SetLogOutput changes the file where log messages are written.
func (n *Notepad) SetLogOutput(handle io.Writer) {
n.logHandle = handle
n.init()
}
// GetStdoutThreshold returns the defined Treshold for the log logger.
func (n *Notepad) GetLogThreshold() Threshold {
return n.logThreshold
}
// SetStdoutThreshold changes the threshold above which messages are written to the
// standard output.
func (n *Notepad) SetStdoutThreshold(threshold Threshold) {
n.stdoutThreshold = threshold
n.init()
}
// GetStdoutThreshold returns the Treshold for the stdout logger.
func (n *Notepad) GetStdoutThreshold() Threshold {
return n.stdoutThreshold
}
// SetPrefix changes the prefix used by the notepad. Prefixes are displayed between
// brackets at the beginning of the line. An empty prefix won't be displayed at all.
func (n *Notepad) SetPrefix(prefix string) {
if len(prefix) != 0 {
n.prefix = "[" + prefix + "] "
} else {
n.prefix = ""
}
n.init()
}
// SetFlags choose which flags the logger will display (after prefix and message
// level). See the package log for more informations on this.
func (n *Notepad) SetFlags(flags int) {
n.flags = flags
n.init()
}
// Feedback writes plainly to the outHandle while
// logging with the standard extra information (date, file, etc).
type Feedback struct {
out *log.Logger
log *log.Logger
}
func (fb *Feedback) Println(v ...interface{}) {
fb.output(fmt.Sprintln(v...))
}
func (fb *Feedback) Printf(format string, v ...interface{}) {
fb.output(fmt.Sprintf(format, v...))
}
func (fb *Feedback) Print(v ...interface{}) {
fb.output(fmt.Sprint(v...))
}
func (fb *Feedback) output(s string) {
if fb.out != nil {
fb.out.Output(2, s)
}
if fb.log != nil {
fb.log.Output(2, s)
}
}

View File

@ -0,0 +1,256 @@
// Copyright © 2016 Steve Francia <spf@spf13.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package jwalterweatherman
import (
"fmt"
"io"
"io/ioutil"
"log"
"os"
"sync/atomic"
)
// Level describes the chosen log level between
// debug and critical.
type Level int
type NotePad struct {
Handle io.Writer
Level Level
Prefix string
Logger **log.Logger
counter uint64
}
func (n *NotePad) incr() {
atomic.AddUint64(&n.counter, 1)
}
func (n *NotePad) resetCounter() {
atomic.StoreUint64(&n.counter, 0)
}
func (n *NotePad) getCount() uint64 {
return atomic.LoadUint64(&n.counter)
}
type countingWriter struct {
incrFunc func()
}
func (cw *countingWriter) Write(p []byte) (n int, err error) {
cw.incrFunc()
return 0, nil
}
// Feedback is special. It writes plainly to the output while
// logging with the standard extra information (date, file, etc)
// Only Println and Printf are currently provided for this
type Feedback struct{}
const (
LevelTrace Level = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelCritical
LevelFatal
DefaultLogThreshold = LevelWarn
DefaultStdoutThreshold = LevelError
)
var (
TRACE *log.Logger
DEBUG *log.Logger
INFO *log.Logger
WARN *log.Logger
ERROR *log.Logger
CRITICAL *log.Logger
FATAL *log.Logger
LOG *log.Logger
FEEDBACK Feedback
LogHandle io.Writer = ioutil.Discard
OutHandle io.Writer = os.Stdout
BothHandle io.Writer = io.MultiWriter(LogHandle, OutHandle)
NotePads []*NotePad = []*NotePad{trace, debug, info, warn, err, critical, fatal}
trace *NotePad = &NotePad{Level: LevelTrace, Handle: os.Stdout, Logger: &TRACE, Prefix: "TRACE: "}
debug *NotePad = &NotePad{Level: LevelDebug, Handle: os.Stdout, Logger: &DEBUG, Prefix: "DEBUG: "}
info *NotePad = &NotePad{Level: LevelInfo, Handle: os.Stdout, Logger: &INFO, Prefix: "INFO: "}
warn *NotePad = &NotePad{Level: LevelWarn, Handle: os.Stdout, Logger: &WARN, Prefix: "WARN: "}
err *NotePad = &NotePad{Level: LevelError, Handle: os.Stdout, Logger: &ERROR, Prefix: "ERROR: "}
critical *NotePad = &NotePad{Level: LevelCritical, Handle: os.Stdout, Logger: &CRITICAL, Prefix: "CRITICAL: "}
fatal *NotePad = &NotePad{Level: LevelFatal, Handle: os.Stdout, Logger: &FATAL, Prefix: "FATAL: "}
logThreshold Level = DefaultLogThreshold
outputThreshold Level = DefaultStdoutThreshold
)
const (
DATE = log.Ldate
TIME = log.Ltime
SFILE = log.Lshortfile
LFILE = log.Llongfile
MSEC = log.Lmicroseconds
)
var logFlags = DATE | TIME | SFILE
func init() {
SetStdoutThreshold(DefaultStdoutThreshold)
}
// initialize will setup the jWalterWeatherman standard approach of providing the user
// some feedback and logging a potentially different amount based on independent log and output thresholds.
// By default the output has a lower threshold than logged
// Don't use if you have manually set the Handles of the different levels as it will overwrite them.
func initialize() {
BothHandle = io.MultiWriter(LogHandle, OutHandle)
for _, n := range NotePads {
if n.Level < outputThreshold && n.Level < logThreshold {
n.Handle = ioutil.Discard
} else if n.Level >= outputThreshold && n.Level >= logThreshold {
n.Handle = BothHandle
} else if n.Level >= outputThreshold && n.Level < logThreshold {
n.Handle = OutHandle
} else {
n.Handle = LogHandle
}
}
for _, n := range NotePads {
n.Handle = io.MultiWriter(n.Handle, &countingWriter{n.incr})
*n.Logger = log.New(n.Handle, n.Prefix, logFlags)
}
LOG = log.New(LogHandle,
"LOG: ",
logFlags)
}
// Set the log Flags (Available flag: DATE, TIME, SFILE, LFILE and MSEC)
func SetLogFlag(flags int) {
logFlags = flags
initialize()
}
// Level returns the current global log threshold.
func LogThreshold() Level {
return logThreshold
}
// Level returns the current global output threshold.
func StdoutThreshold() Level {
return outputThreshold
}
// Ensures that the level provided is within the bounds of available levels
func levelCheck(level Level) Level {
switch {
case level <= LevelTrace:
return LevelTrace
case level >= LevelFatal:
return LevelFatal
default:
return level
}
}
// Establishes a threshold where anything matching or above will be logged
func SetLogThreshold(level Level) {
logThreshold = levelCheck(level)
initialize()
}
// Establishes a threshold where anything matching or above will be output
func SetStdoutThreshold(level Level) {
outputThreshold = levelCheck(level)
initialize()
}
// Conveniently Sets the Log Handle to a io.writer created for the file behind the given filepath
// Will only append to this file
func SetLogFile(path string) {
file, err := os.OpenFile(path, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0666)
if err != nil {
CRITICAL.Println("Failed to open log file:", path, err)
os.Exit(-1)
}
INFO.Println("Logging to", file.Name())
LogHandle = file
initialize()
}
// Conveniently Creates a temporary file and sets the Log Handle to a io.writer created for it
func UseTempLogFile(prefix string) {
file, err := ioutil.TempFile(os.TempDir(), prefix)
if err != nil {
CRITICAL.Println(err)
}
INFO.Println("Logging to", file.Name())
LogHandle = file
initialize()
}
// LogCountForLevel returns the number of log invocations for a given level.
func LogCountForLevel(l Level) uint64 {
for _, np := range NotePads {
if np.Level == l {
return np.getCount()
}
}
return 0
}
// LogCountForLevelsGreaterThanorEqualTo returns the number of log invocations
// greater than or equal to a given level threshold.
func LogCountForLevelsGreaterThanorEqualTo(threshold Level) uint64 {
var cnt uint64
for _, np := range NotePads {
if np.Level >= threshold {
cnt += np.getCount()
}
}
return cnt
}
// ResetLogCounters resets the invocation counters for all levels.
func ResetLogCounters() {
for _, np := range NotePads {
np.resetCounter()
}
}
// Disables logging for the entire JWW system
func DiscardLogging() {
LogHandle = ioutil.Discard
initialize()
}
// Feedback is special. It writes plainly to the output while
// logging with the standard extra information (date, file, etc)
// Only Println and Printf are currently provided for this
func (fb *Feedback) Println(v ...interface{}) {
s := fmt.Sprintln(v...)
fmt.Print(s)
LOG.Output(2, s)
}
// Feedback is special. It writes plainly to the output while
// logging with the standard extra information (date, file, etc)
// Only Println and Printf are currently provided for this
func (fb *Feedback) Printf(format string, v ...interface{}) {
s := fmt.Sprintf(format, v...)
fmt.Print(s)
LOG.Output(2, s)
}

View File

@ -1,147 +0,0 @@
package pflag
import (
"io"
"strconv"
"strings"
)
// -- boolSlice Value
type boolSliceValue struct {
value *[]bool
changed bool
}
func newBoolSliceValue(val []bool, p *[]bool) *boolSliceValue {
bsv := new(boolSliceValue)
bsv.value = p
*bsv.value = val
return bsv
}
// Set converts, and assigns, the comma-separated boolean argument string representation as the []bool value of this flag.
// If Set is called on a flag that already has a []bool assigned, the newly converted values will be appended.
func (s *boolSliceValue) Set(val string) error {
// remove all quote characters
rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "")
// read flag arguments with CSV parser
boolStrSlice, err := readAsCSV(rmQuote.Replace(val))
if err != nil && err != io.EOF {
return err
}
// parse boolean values into slice
out := make([]bool, 0, len(boolStrSlice))
for _, boolStr := range boolStrSlice {
b, err := strconv.ParseBool(strings.TrimSpace(boolStr))
if err != nil {
return err
}
out = append(out, b)
}
if !s.changed {
*s.value = out
} else {
*s.value = append(*s.value, out...)
}
s.changed = true
return nil
}
// Type returns a string that uniquely represents this flag's type.
func (s *boolSliceValue) Type() string {
return "boolSlice"
}
// String defines a "native" format for this boolean slice flag value.
func (s *boolSliceValue) String() string {
boolStrSlice := make([]string, len(*s.value))
for i, b := range *s.value {
boolStrSlice[i] = strconv.FormatBool(b)
}
out, _ := writeAsCSV(boolStrSlice)
return "[" + out + "]"
}
func boolSliceConv(val string) (interface{}, error) {
val = strings.Trim(val, "[]")
// Empty string would cause a slice with one (empty) entry
if len(val) == 0 {
return []bool{}, nil
}
ss := strings.Split(val, ",")
out := make([]bool, len(ss))
for i, t := range ss {
var err error
out[i], err = strconv.ParseBool(t)
if err != nil {
return nil, err
}
}
return out, nil
}
// GetBoolSlice returns the []bool value of a flag with the given name.
func (f *FlagSet) GetBoolSlice(name string) ([]bool, error) {
val, err := f.getFlagType(name, "boolSlice", boolSliceConv)
if err != nil {
return []bool{}, err
}
return val.([]bool), nil
}
// BoolSliceVar defines a boolSlice flag with specified name, default value, and usage string.
// The argument p points to a []bool variable in which to store the value of the flag.
func (f *FlagSet) BoolSliceVar(p *[]bool, name string, value []bool, usage string) {
f.VarP(newBoolSliceValue(value, p), name, "", usage)
}
// BoolSliceVarP is like BoolSliceVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BoolSliceVarP(p *[]bool, name, shorthand string, value []bool, usage string) {
f.VarP(newBoolSliceValue(value, p), name, shorthand, usage)
}
// BoolSliceVar defines a []bool flag with specified name, default value, and usage string.
// The argument p points to a []bool variable in which to store the value of the flag.
func BoolSliceVar(p *[]bool, name string, value []bool, usage string) {
CommandLine.VarP(newBoolSliceValue(value, p), name, "", usage)
}
// BoolSliceVarP is like BoolSliceVar, but accepts a shorthand letter that can be used after a single dash.
func BoolSliceVarP(p *[]bool, name, shorthand string, value []bool, usage string) {
CommandLine.VarP(newBoolSliceValue(value, p), name, shorthand, usage)
}
// BoolSlice defines a []bool flag with specified name, default value, and usage string.
// The return value is the address of a []bool variable that stores the value of the flag.
func (f *FlagSet) BoolSlice(name string, value []bool, usage string) *[]bool {
p := []bool{}
f.BoolSliceVarP(&p, name, "", value, usage)
return &p
}
// BoolSliceP is like BoolSlice, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) BoolSliceP(name, shorthand string, value []bool, usage string) *[]bool {
p := []bool{}
f.BoolSliceVarP(&p, name, shorthand, value, usage)
return &p
}
// BoolSlice defines a []bool flag with specified name, default value, and usage string.
// The return value is the address of a []bool variable that stores the value of the flag.
func BoolSlice(name string, value []bool, usage string) *[]bool {
return CommandLine.BoolSliceP(name, "", value, usage)
}
// BoolSliceP is like BoolSlice, but accepts a shorthand letter that can be used after a single dash.
func BoolSliceP(name, shorthand string, value []bool, usage string) *[]bool {
return CommandLine.BoolSliceP(name, shorthand, value, usage)
}

View File

@ -11,13 +11,13 @@ func newCountValue(val int, p *int) *countValue {
} }
func (i *countValue) Set(s string) error { func (i *countValue) Set(s string) error {
// "+1" means that no specific value was passed, so increment v, err := strconv.ParseInt(s, 0, 64)
if s == "+1" { // -1 means that no specific value was passed, so increment
if v == -1 {
*i = countValue(*i + 1) *i = countValue(*i + 1)
return nil } else {
*i = countValue(v)
} }
v, err := strconv.ParseInt(s, 0, 0)
*i = countValue(v)
return err return err
} }
@ -54,7 +54,7 @@ func (f *FlagSet) CountVar(p *int, name string, usage string) {
// CountVarP is like CountVar only take a shorthand for the flag name. // CountVarP is like CountVar only take a shorthand for the flag name.
func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) { func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) {
flag := f.VarPF(newCountValue(0, p), name, shorthand, usage) flag := f.VarPF(newCountValue(0, p), name, shorthand, usage)
flag.NoOptDefVal = "+1" flag.NoOptDefVal = "-1"
} }
// CountVar like CountVar only the flag is placed on the CommandLine instead of a given flag set // CountVar like CountVar only the flag is placed on the CommandLine instead of a given flag set
@ -83,9 +83,7 @@ func (f *FlagSet) CountP(name, shorthand string, usage string) *int {
return p return p
} }
// Count defines a count flag with specified name, default value, and usage string. // Count like Count only the flag is placed on the CommandLine isntead of a given flag set
// The return value is the address of an int variable that stores the value of the flag.
// A count flag will add 1 to its value evey time it is found on the command line
func Count(name string, usage string) *int { func Count(name string, usage string) *int {
return CommandLine.CountP(name, "", usage) return CommandLine.CountP(name, "", usage)
} }

395
vendor/github.com/spf13/pflag/flag.go generated vendored
View File

@ -16,9 +16,9 @@ pflag is a drop-in replacement of Go's native flag package. If you import
pflag under the name "flag" then all code should continue to function pflag under the name "flag" then all code should continue to function
with no changes. with no changes.
import flag "github.com/spf13/pflag" import flag "github.com/ogier/pflag"
There is one exception to this: if you directly instantiate the Flag struct There is one exception to this: if you directly instantiate the Flag struct
there is one more field "Shorthand" that you will need to set. there is one more field "Shorthand" that you will need to set.
Most code never instantiates this struct directly, and instead uses Most code never instantiates this struct directly, and instead uses
functions such as String(), BoolVar(), and Var(), and is therefore functions such as String(), BoolVar(), and Var(), and is therefore
@ -134,21 +134,14 @@ type FlagSet struct {
// a custom error handler. // a custom error handler.
Usage func() Usage func()
// SortFlags is used to indicate, if user wants to have sorted flags in
// help/usage messages.
SortFlags bool
name string name string
parsed bool parsed bool
actual map[NormalizedName]*Flag actual map[NormalizedName]*Flag
orderedActual []*Flag
sortedActual []*Flag
formal map[NormalizedName]*Flag formal map[NormalizedName]*Flag
orderedFormal []*Flag
sortedFormal []*Flag
shorthands map[byte]*Flag shorthands map[byte]*Flag
args []string // arguments after flags args []string // arguments after flags
argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor output io.Writer // nil means stderr; use out() accessor
interspersed bool // allow interspersed option/non-option args interspersed bool // allow interspersed option/non-option args
@ -163,7 +156,7 @@ type Flag struct {
Value Value // value as set Value Value // value as set
DefValue string // default value (as text); for usage message DefValue string // default value (as text); for usage message
Changed bool // If the user set the value (or if left to default) Changed bool // If the user set the value (or if left to default)
NoOptDefVal string // default value (as text); if the flag is on the command line without any options NoOptDefVal string //default value (as text); if the flag is on the command line without any options
Deprecated string // If this flag is deprecated, this string is the new or now thing to use Deprecated string // If this flag is deprecated, this string is the new or now thing to use
Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text
ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use
@ -201,19 +194,11 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
// "--getUrl" which may also be translated to "geturl" and everything will work. // "--getUrl" which may also be translated to "geturl" and everything will work.
func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
f.normalizeNameFunc = n f.normalizeNameFunc = n
f.sortedFormal = f.sortedFormal[:0] for k, v := range f.formal {
for fname, flag := range f.formal { delete(f.formal, k)
nname := f.normalizeFlagName(flag.Name) nname := f.normalizeFlagName(string(k))
if fname == nname { f.formal[nname] = v
continue v.Name = string(nname)
}
flag.Name = string(nname)
delete(f.formal, fname)
f.formal[nname] = flag
if _, set := f.actual[fname]; set {
delete(f.actual, fname)
f.actual[nname] = flag
}
} }
} }
@ -244,25 +229,10 @@ func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output f.output = output
} }
// VisitAll visits the flags in lexicographical order or // VisitAll visits the flags in lexicographical order, calling fn for each.
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set. // It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) { func (f *FlagSet) VisitAll(fn func(*Flag)) {
if len(f.formal) == 0 { for _, flag := range sortFlags(f.formal) {
return
}
var flags []*Flag
if f.SortFlags {
if len(f.formal) != len(f.sortedFormal) {
f.sortedFormal = sortFlags(f.formal)
}
flags = f.sortedFormal
} else {
flags = f.orderedFormal
}
for _, flag := range flags {
fn(flag) fn(flag)
} }
} }
@ -283,39 +253,22 @@ func (f *FlagSet) HasAvailableFlags() bool {
return false return false
} }
// VisitAll visits the command-line flags in lexicographical order or // VisitAll visits the command-line flags in lexicographical order, calling
// in primordial order if f.SortFlags is false, calling fn for each. // fn for each. It visits all flags, even those not set.
// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) { func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn) CommandLine.VisitAll(fn)
} }
// Visit visits the flags in lexicographical order or // Visit visits the flags in lexicographical order, calling fn for each.
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set. // It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) { func (f *FlagSet) Visit(fn func(*Flag)) {
if len(f.actual) == 0 { for _, flag := range sortFlags(f.actual) {
return
}
var flags []*Flag
if f.SortFlags {
if len(f.actual) != len(f.sortedActual) {
f.sortedActual = sortFlags(f.actual)
}
flags = f.sortedActual
} else {
flags = f.orderedActual
}
for _, flag := range flags {
fn(flag) fn(flag)
} }
} }
// Visit visits the command-line flags in lexicographical order or // Visit visits the command-line flags in lexicographical order, calling fn
// in primordial order if f.SortFlags is false, calling fn for each. // for each. It visits only those flags that have been set.
// It visits only those flags that have been set.
func Visit(fn func(*Flag)) { func Visit(fn func(*Flag)) {
CommandLine.Visit(fn) CommandLine.Visit(fn)
} }
@ -325,22 +278,6 @@ func (f *FlagSet) Lookup(name string) *Flag {
return f.lookup(f.normalizeFlagName(name)) return f.lookup(f.normalizeFlagName(name))
} }
// ShorthandLookup returns the Flag structure of the short handed flag,
// returning nil if none exists.
// It panics, if len(name) > 1.
func (f *FlagSet) ShorthandLookup(name string) *Flag {
if name == "" {
return nil
}
if len(name) > 1 {
msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name)
fmt.Fprintf(f.out(), msg)
panic(msg)
}
c := name[0]
return f.shorthands[c]
}
// lookup returns the Flag structure of the named flag, returning nil if none exists. // lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) lookup(name NormalizedName) *Flag { func (f *FlagSet) lookup(name NormalizedName) *Flag {
return f.formal[name] return f.formal[name]
@ -382,7 +319,7 @@ func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
if flag == nil { if flag == nil {
return fmt.Errorf("flag %q does not exist", name) return fmt.Errorf("flag %q does not exist", name)
} }
if usageMessage == "" { if len(usageMessage) == 0 {
return fmt.Errorf("deprecated message for flag %q must be set", name) return fmt.Errorf("deprecated message for flag %q must be set", name)
} }
flag.Deprecated = usageMessage flag.Deprecated = usageMessage
@ -397,7 +334,7 @@ func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) erro
if flag == nil { if flag == nil {
return fmt.Errorf("flag %q does not exist", name) return fmt.Errorf("flag %q does not exist", name)
} }
if usageMessage == "" { if len(usageMessage) == 0 {
return fmt.Errorf("deprecated message for flag %q must be set", name) return fmt.Errorf("deprecated message for flag %q must be set", name)
} }
flag.ShorthandDeprecated = usageMessage flag.ShorthandDeprecated = usageMessage
@ -421,12 +358,6 @@ func Lookup(name string) *Flag {
return CommandLine.Lookup(name) return CommandLine.Lookup(name)
} }
// ShorthandLookup returns the Flag structure of the short handed flag,
// returning nil if none exists.
func ShorthandLookup(name string) *Flag {
return CommandLine.ShorthandLookup(name)
}
// Set sets the value of the named flag. // Set sets the value of the named flag.
func (f *FlagSet) Set(name, value string) error { func (f *FlagSet) Set(name, value string) error {
normalName := f.normalizeFlagName(name) normalName := f.normalizeFlagName(name)
@ -434,30 +365,17 @@ func (f *FlagSet) Set(name, value string) error {
if !ok { if !ok {
return fmt.Errorf("no such flag -%v", name) return fmt.Errorf("no such flag -%v", name)
} }
err := flag.Value.Set(value) err := flag.Value.Set(value)
if err != nil { if err != nil {
var flagName string return err
if flag.Shorthand != "" && flag.ShorthandDeprecated == "" {
flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name)
} else {
flagName = fmt.Sprintf("--%s", flag.Name)
}
return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err)
} }
if f.actual == nil {
if !flag.Changed { f.actual = make(map[NormalizedName]*Flag)
if f.actual == nil {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[normalName] = flag
f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
} }
f.actual[normalName] = flag
if flag.Deprecated != "" { flag.Changed = true
fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
} }
return nil return nil
} }
@ -564,107 +482,36 @@ func UnquoteUsage(flag *Flag) (name string, usage string) {
name = "int" name = "int"
case "uint64": case "uint64":
name = "uint" name = "uint"
case "stringSlice":
name = "strings"
case "intSlice":
name = "ints"
} }
return return
} }
// Splits the string `s` on whitespace into an initial substring up to // FlagUsages Returns a string containing the usage information for all flags in
// `i` runes in length and the remainder. Will go `slop` over `i` if // the FlagSet
// that encompasses the entire string (which allows the caller to func (f *FlagSet) FlagUsages() string {
// avoid short orphan words on the final line). x := new(bytes.Buffer)
func wrapN(i, slop int, s string) (string, string) {
if i+slop > len(s) {
return s, ""
}
w := strings.LastIndexAny(s[:i], " \t")
if w <= 0 {
return s, ""
}
return s[:w], s[w+1:]
}
// Wraps the string `s` to a maximum width `w` with leading indent
// `i`. The first line is not indented (this is assumed to be done by
// caller). Pass `w` == 0 to do no wrapping
func wrap(i, w int, s string) string {
if w == 0 {
return s
}
// space between indent i and end of line width w into which
// we should wrap the text.
wrap := w - i
var r, l string
// Not enough space for sensible wrapping. Wrap as a block on
// the next line instead.
if wrap < 24 {
i = 16
wrap = w - i
r += "\n" + strings.Repeat(" ", i)
}
// If still not enough space then don't even try to wrap.
if wrap < 24 {
return s
}
// Try to avoid short orphan words on the final line, by
// allowing wrapN to go a bit over if that would fit in the
// remainder of the line.
slop := 5
wrap = wrap - slop
// Handle first line, which is indented by the caller (or the
// special case above)
l, s = wrapN(wrap, slop, s)
r = r + l
// Now wrap the rest
for s != "" {
var t string
t, s = wrapN(wrap, slop, s)
r = r + "\n" + strings.Repeat(" ", i) + t
}
return r
}
// FlagUsagesWrapped returns a string containing the usage information
// for all flags in the FlagSet. Wrapped to `cols` columns (0 for no
// wrapping)
func (f *FlagSet) FlagUsagesWrapped(cols int) string {
buf := new(bytes.Buffer)
lines := make([]string, 0, len(f.formal)) lines := make([]string, 0, len(f.formal))
maxlen := 0 maxlen := 0
f.VisitAll(func(flag *Flag) { f.VisitAll(func(flag *Flag) {
if flag.Deprecated != "" || flag.Hidden { if len(flag.Deprecated) > 0 || flag.Hidden {
return return
} }
line := "" line := ""
if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { if len(flag.Shorthand) > 0 && len(flag.ShorthandDeprecated) == 0 {
line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name)
} else { } else {
line = fmt.Sprintf(" --%s", flag.Name) line = fmt.Sprintf(" --%s", flag.Name)
} }
varname, usage := UnquoteUsage(flag) varname, usage := UnquoteUsage(flag)
if varname != "" { if len(varname) > 0 {
line += " " + varname line += " " + varname
} }
if flag.NoOptDefVal != "" { if len(flag.NoOptDefVal) > 0 {
switch flag.Value.Type() { switch flag.Value.Type() {
case "string": case "string":
line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal)
@ -672,10 +519,6 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string {
if flag.NoOptDefVal != "true" { if flag.NoOptDefVal != "true" {
line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) line += fmt.Sprintf("[=%s]", flag.NoOptDefVal)
} }
case "count":
if flag.NoOptDefVal != "+1" {
line += fmt.Sprintf("[=%s]", flag.NoOptDefVal)
}
default: default:
line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) line += fmt.Sprintf("[=%s]", flag.NoOptDefVal)
} }
@ -691,7 +534,7 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string {
line += usage line += usage
if !flag.defaultIsZeroValue() { if !flag.defaultIsZeroValue() {
if flag.Value.Type() == "string" { if flag.Value.Type() == "string" {
line += fmt.Sprintf(" (default %q)", flag.DefValue) line += fmt.Sprintf(" (default \"%s\")", flag.DefValue)
} else { } else {
line += fmt.Sprintf(" (default %s)", flag.DefValue) line += fmt.Sprintf(" (default %s)", flag.DefValue)
} }
@ -703,17 +546,10 @@ func (f *FlagSet) FlagUsagesWrapped(cols int) string {
for _, line := range lines { for _, line := range lines {
sidx := strings.Index(line, "\x00") sidx := strings.Index(line, "\x00")
spacing := strings.Repeat(" ", maxlen-sidx) spacing := strings.Repeat(" ", maxlen-sidx)
// maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx fmt.Fprintln(x, line[:sidx], spacing, line[sidx+1:])
fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:]))
} }
return buf.String() return x.String()
}
// FlagUsages returns a string containing the usage information for all flags in
// the FlagSet
func (f *FlagSet) FlagUsages() string {
return f.FlagUsagesWrapped(0)
} }
// PrintDefaults prints to standard error the default values of all defined command-line flags. // PrintDefaults prints to standard error the default values of all defined command-line flags.
@ -799,15 +635,16 @@ func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag {
// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. // VarP is like Var, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { func (f *FlagSet) VarP(value Value, name, shorthand, usage string) {
f.VarPF(value, name, shorthand, usage) _ = f.VarPF(value, name, shorthand, usage)
} }
// AddFlag will add the flag to the FlagSet // AddFlag will add the flag to the FlagSet
func (f *FlagSet) AddFlag(flag *Flag) { func (f *FlagSet) AddFlag(flag *Flag) {
// Call normalizeFlagName function only once
normalizedFlagName := f.normalizeFlagName(flag.Name) normalizedFlagName := f.normalizeFlagName(flag.Name)
_, alreadyThere := f.formal[normalizedFlagName] _, alreadythere := f.formal[normalizedFlagName]
if alreadyThere { if alreadythere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg) fmt.Fprintln(f.out(), msg)
panic(msg) // Happens only if flags are declared with identical names panic(msg) // Happens only if flags are declared with identical names
@ -818,31 +655,28 @@ func (f *FlagSet) AddFlag(flag *Flag) {
flag.Name = string(normalizedFlagName) flag.Name = string(normalizedFlagName)
f.formal[normalizedFlagName] = flag f.formal[normalizedFlagName] = flag
f.orderedFormal = append(f.orderedFormal, flag)
if flag.Shorthand == "" { if len(flag.Shorthand) == 0 {
return return
} }
if len(flag.Shorthand) > 1 { if len(flag.Shorthand) > 1 {
msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand) fmt.Fprintf(f.out(), "%s shorthand more than ASCII character: %s\n", f.name, flag.Shorthand)
fmt.Fprintf(f.out(), msg) panic("shorthand is more than one character")
panic(msg)
} }
if f.shorthands == nil { if f.shorthands == nil {
f.shorthands = make(map[byte]*Flag) f.shorthands = make(map[byte]*Flag)
} }
c := flag.Shorthand[0] c := flag.Shorthand[0]
used, alreadyThere := f.shorthands[c] old, alreadythere := f.shorthands[c]
if alreadyThere { if alreadythere {
msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) fmt.Fprintf(f.out(), "%s shorthand reused: %q for %s already used for %s\n", f.name, c, flag.Name, old.Name)
fmt.Fprintf(f.out(), msg) panic("shorthand redefinition")
panic(msg)
} }
f.shorthands[c] = flag f.shorthands[c] = flag
} }
// AddFlagSet adds one FlagSet to another. If a flag is already present in f // AddFlagSet adds one FlagSet to another. If a flag is already present in f
// the flag from newSet will be ignored. // the flag from newSet will be ignored
func (f *FlagSet) AddFlagSet(newSet *FlagSet) { func (f *FlagSet) AddFlagSet(newSet *FlagSet) {
if newSet == nil { if newSet == nil {
return return
@ -890,18 +724,45 @@ func (f *FlagSet) usage() {
} }
} }
func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
if err := flag.Value.Set(value); err != nil {
return f.failf("invalid argument %q for %s: %v", value, origArg, err)
}
// mark as visited for Visit()
if f.actual == nil {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[f.normalizeFlagName(flag.Name)] = flag
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
}
if len(flag.ShorthandDeprecated) > 0 && containsShorthand(origArg, flag.Shorthand) {
fmt.Fprintf(os.Stderr, "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
}
return nil
}
func containsShorthand(arg, shorthand string) bool {
// filter out flags --<flag_name>
if strings.HasPrefix(arg, "-") {
return false
}
arg = strings.SplitN(arg, "=", 2)[0]
return strings.Contains(arg, shorthand)
}
func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) {
a = args a = args
name := s[2:] name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' { if len(name) == 0 || name[0] == '-' || name[0] == '=' {
err = f.failf("bad flag syntax: %s", s) err = f.failf("bad flag syntax: %s", s)
return return
} }
split := strings.SplitN(name, "=", 2) split := strings.SplitN(name, "=", 2)
name = split[0] name = split[0]
flag, exists := f.formal[f.normalizeFlagName(name)] flag, alreadythere := f.formal[f.normalizeFlagName(name)]
if !exists { if !alreadythere {
if name == "help" { // special case for nice help message. if name == "help" { // special case for nice help message.
f.usage() f.usage()
return a, ErrHelp return a, ErrHelp
@ -909,12 +770,11 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
err = f.failf("unknown flag: --%s", name) err = f.failf("unknown flag: --%s", name)
return return
} }
var value string var value string
if len(split) == 2 { if len(split) == 2 {
// '--flag=arg' // '--flag=arg'
value = split[1] value = split[1]
} else if flag.NoOptDefVal != "" { } else if len(flag.NoOptDefVal) > 0 {
// '--flag' (arg was optional) // '--flag' (arg was optional)
value = flag.NoOptDefVal value = flag.NoOptDefVal
} else if len(a) > 0 { } else if len(a) > 0 {
@ -926,74 +786,55 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
err = f.failf("flag needs an argument: %s", s) err = f.failf("flag needs an argument: %s", s)
return return
} }
err = f.setFlag(flag, value, s)
err = fn(flag, value)
if err != nil {
f.failf(err.Error())
}
return return
} }
func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) { func (f *FlagSet) parseSingleShortArg(shorthands string, args []string) (outShorts string, outArgs []string, err error) {
if strings.HasPrefix(shorthands, "test.") { if strings.HasPrefix(shorthands, "test.") {
return return
} }
outArgs = args outArgs = args
outShorts = shorthands[1:] outShorts = shorthands[1:]
c := shorthands[0] c := shorthands[0]
flag, exists := f.shorthands[c] flag, alreadythere := f.shorthands[c]
if !exists { if !alreadythere {
if c == 'h' { // special case for nice help message. if c == 'h' { // special case for nice help message.
f.usage() f.usage()
err = ErrHelp err = ErrHelp
return return
} }
//TODO continue on error
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
return return
} }
var value string var value string
if len(shorthands) > 2 && shorthands[1] == '=' { if len(shorthands) > 2 && shorthands[1] == '=' {
// '-f=arg'
value = shorthands[2:] value = shorthands[2:]
outShorts = "" outShorts = ""
} else if flag.NoOptDefVal != "" { } else if len(flag.NoOptDefVal) > 0 {
// '-f' (arg was optional)
value = flag.NoOptDefVal value = flag.NoOptDefVal
} else if len(shorthands) > 1 { } else if len(shorthands) > 1 {
// '-farg'
value = shorthands[1:] value = shorthands[1:]
outShorts = "" outShorts = ""
} else if len(args) > 0 { } else if len(args) > 0 {
// '-f arg'
value = args[0] value = args[0]
outArgs = args[1:] outArgs = args[1:]
} else { } else {
// '-f' (arg was required)
err = f.failf("flag needs an argument: %q in -%s", c, shorthands) err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return return
} }
err = f.setFlag(flag, value, shorthands)
if flag.ShorthandDeprecated != "" {
fmt.Fprintf(f.out(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated)
}
err = fn(flag, value)
if err != nil {
f.failf(err.Error())
}
return return
} }
func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) { func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) {
a = args a = args
shorthands := s[1:] shorthands := s[1:]
// "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv").
for len(shorthands) > 0 { for len(shorthands) > 0 {
shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn) shorthands, a, err = f.parseSingleShortArg(shorthands, args)
if err != nil { if err != nil {
return return
} }
@ -1002,7 +843,7 @@ func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []stri
return return
} }
func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { func (f *FlagSet) parseArgs(args []string) (err error) {
for len(args) > 0 { for len(args) > 0 {
s := args[0] s := args[0]
args = args[1:] args = args[1:]
@ -1022,9 +863,9 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) {
f.args = append(f.args, args...) f.args = append(f.args, args...)
break break
} }
args, err = f.parseLongArg(s, args, fn) args, err = f.parseLongArg(s, args)
} else { } else {
args, err = f.parseShortArg(s, args, fn) args, err = f.parseShortArg(s, args)
} }
if err != nil { if err != nil {
return return
@ -1039,43 +880,8 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) {
// The return value will be ErrHelp if -help was set but not defined. // The return value will be ErrHelp if -help was set but not defined.
func (f *FlagSet) Parse(arguments []string) error { func (f *FlagSet) Parse(arguments []string) error {
f.parsed = true f.parsed = true
if len(arguments) < 0 {
return nil
}
f.args = make([]string, 0, len(arguments)) f.args = make([]string, 0, len(arguments))
err := f.parseArgs(arguments)
set := func(flag *Flag, value string) error {
return f.Set(flag.Name, value)
}
err := f.parseArgs(arguments, set)
if err != nil {
switch f.errorHandling {
case ContinueOnError:
return err
case ExitOnError:
os.Exit(2)
case PanicOnError:
panic(err)
}
}
return nil
}
type parseFunc func(flag *Flag, value string) error
// ParseAll parses flag definitions from the argument list, which should not
// include the command name. The arguments for fn are flag and value. Must be
// called after all flags in the FlagSet are defined and before flags are
// accessed by the program. The return value will be ErrHelp if -help was set
// but not defined.
func (f *FlagSet) ParseAll(arguments []string, fn func(flag *Flag, value string) error) error {
f.parsed = true
f.args = make([]string, 0, len(arguments))
err := f.parseArgs(arguments, fn)
if err != nil { if err != nil {
switch f.errorHandling { switch f.errorHandling {
case ContinueOnError: case ContinueOnError:
@ -1101,14 +907,6 @@ func Parse() {
CommandLine.Parse(os.Args[1:]) CommandLine.Parse(os.Args[1:])
} }
// ParseAll parses the command-line flags from os.Args[1:] and called fn for each.
// The arguments for fn are flag and value. Must be called after all flags are
// defined and before flags are accessed by the program.
func ParseAll(fn func(flag *Flag, value string) error) {
// Ignore errors; CommandLine is set for ExitOnError.
CommandLine.ParseAll(os.Args[1:], fn)
}
// SetInterspersed sets whether to support interspersed option/non-option arguments. // SetInterspersed sets whether to support interspersed option/non-option arguments.
func SetInterspersed(interspersed bool) { func SetInterspersed(interspersed bool) {
CommandLine.SetInterspersed(interspersed) CommandLine.SetInterspersed(interspersed)
@ -1122,15 +920,14 @@ func Parsed() bool {
// CommandLine is the default set of command-line flags, parsed from os.Args. // CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError) var CommandLine = NewFlagSet(os.Args[0], ExitOnError)
// NewFlagSet returns a new, empty flag set with the specified name, // NewFlagSet returns a new, empty flag set with the specified name and
// error handling property and SortFlags set to true. // error handling property.
func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet {
f := &FlagSet{ f := &FlagSet{
name: name, name: name,
errorHandling: errorHandling, errorHandling: errorHandling,
argsLenAtDash: -1, argsLenAtDash: -1,
interspersed: true, interspersed: true,
SortFlags: true,
} }
return f return f
} }

View File

@ -6,10 +6,13 @@ package pflag
import ( import (
goflag "flag" goflag "flag"
"fmt"
"reflect" "reflect"
"strings" "strings"
) )
var _ = fmt.Print
// flagValueWrapper implements pflag.Value around a flag.Value. The main // flagValueWrapper implements pflag.Value around a flag.Value. The main
// difference here is the addition of the Type method that returns a string // difference here is the addition of the Type method that returns a string
// name of the type. As this is generally unknown, we approximate that with // name of the type. As this is generally unknown, we approximate that with

View File

@ -6,6 +6,8 @@ import (
"strings" "strings"
) )
var _ = strings.TrimSpace
// -- net.IP value // -- net.IP value
type ipValue net.IP type ipValue net.IP

View File

@ -1,148 +0,0 @@
package pflag
import (
"fmt"
"io"
"net"
"strings"
)
// -- ipSlice Value
type ipSliceValue struct {
value *[]net.IP
changed bool
}
func newIPSliceValue(val []net.IP, p *[]net.IP) *ipSliceValue {
ipsv := new(ipSliceValue)
ipsv.value = p
*ipsv.value = val
return ipsv
}
// Set converts, and assigns, the comma-separated IP argument string representation as the []net.IP value of this flag.
// If Set is called on a flag that already has a []net.IP assigned, the newly converted values will be appended.
func (s *ipSliceValue) Set(val string) error {
// remove all quote characters
rmQuote := strings.NewReplacer(`"`, "", `'`, "", "`", "")
// read flag arguments with CSV parser
ipStrSlice, err := readAsCSV(rmQuote.Replace(val))
if err != nil && err != io.EOF {
return err
}
// parse ip values into slice
out := make([]net.IP, 0, len(ipStrSlice))
for _, ipStr := range ipStrSlice {
ip := net.ParseIP(strings.TrimSpace(ipStr))
if ip == nil {
return fmt.Errorf("invalid string being converted to IP address: %s", ipStr)
}
out = append(out, ip)
}
if !s.changed {
*s.value = out
} else {
*s.value = append(*s.value, out...)
}
s.changed = true
return nil
}
// Type returns a string that uniquely represents this flag's type.
func (s *ipSliceValue) Type() string {
return "ipSlice"
}
// String defines a "native" format for this net.IP slice flag value.
func (s *ipSliceValue) String() string {
ipStrSlice := make([]string, len(*s.value))
for i, ip := range *s.value {
ipStrSlice[i] = ip.String()
}
out, _ := writeAsCSV(ipStrSlice)
return "[" + out + "]"
}
func ipSliceConv(val string) (interface{}, error) {
val = strings.Trim(val, "[]")
// Emtpy string would cause a slice with one (empty) entry
if len(val) == 0 {
return []net.IP{}, nil
}
ss := strings.Split(val, ",")
out := make([]net.IP, len(ss))
for i, sval := range ss {
ip := net.ParseIP(strings.TrimSpace(sval))
if ip == nil {
return nil, fmt.Errorf("invalid string being converted to IP address: %s", sval)
}
out[i] = ip
}
return out, nil
}
// GetIPSlice returns the []net.IP value of a flag with the given name
func (f *FlagSet) GetIPSlice(name string) ([]net.IP, error) {
val, err := f.getFlagType(name, "ipSlice", ipSliceConv)
if err != nil {
return []net.IP{}, err
}
return val.([]net.IP), nil
}
// IPSliceVar defines a ipSlice flag with specified name, default value, and usage string.
// The argument p points to a []net.IP variable in which to store the value of the flag.
func (f *FlagSet) IPSliceVar(p *[]net.IP, name string, value []net.IP, usage string) {
f.VarP(newIPSliceValue(value, p), name, "", usage)
}
// IPSliceVarP is like IPSliceVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) IPSliceVarP(p *[]net.IP, name, shorthand string, value []net.IP, usage string) {
f.VarP(newIPSliceValue(value, p), name, shorthand, usage)
}
// IPSliceVar defines a []net.IP flag with specified name, default value, and usage string.
// The argument p points to a []net.IP variable in which to store the value of the flag.
func IPSliceVar(p *[]net.IP, name string, value []net.IP, usage string) {
CommandLine.VarP(newIPSliceValue(value, p), name, "", usage)
}
// IPSliceVarP is like IPSliceVar, but accepts a shorthand letter that can be used after a single dash.
func IPSliceVarP(p *[]net.IP, name, shorthand string, value []net.IP, usage string) {
CommandLine.VarP(newIPSliceValue(value, p), name, shorthand, usage)
}
// IPSlice defines a []net.IP flag with specified name, default value, and usage string.
// The return value is the address of a []net.IP variable that stores the value of that flag.
func (f *FlagSet) IPSlice(name string, value []net.IP, usage string) *[]net.IP {
p := []net.IP{}
f.IPSliceVarP(&p, name, "", value, usage)
return &p
}
// IPSliceP is like IPSlice, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) IPSliceP(name, shorthand string, value []net.IP, usage string) *[]net.IP {
p := []net.IP{}
f.IPSliceVarP(&p, name, shorthand, value, usage)
return &p
}
// IPSlice defines a []net.IP flag with specified name, default value, and usage string.
// The return value is the address of a []net.IP variable that stores the value of the flag.
func IPSlice(name string, value []net.IP, usage string) *[]net.IP {
return CommandLine.IPSliceP(name, "", value, usage)
}
// IPSliceP is like IPSlice, but accepts a shorthand letter that can be used after a single dash.
func IPSliceP(name, shorthand string, value []net.IP, usage string) *[]net.IP {
return CommandLine.IPSliceP(name, shorthand, value, usage)
}

View File

@ -27,6 +27,8 @@ func (*ipNetValue) Type() string {
return "ipNet" return "ipNet"
} }
var _ = strings.TrimSpace
func newIPNetValue(val net.IPNet, p *net.IPNet) *ipNetValue { func newIPNetValue(val net.IPNet, p *net.IPNet) *ipNetValue {
*p = val *p = val
return (*ipNetValue)(p) return (*ipNetValue)(p)

View File

@ -1,5 +1,11 @@
package pflag package pflag
import (
"fmt"
)
var _ = fmt.Fprint
// -- stringArray Value // -- stringArray Value
type stringArrayValue struct { type stringArrayValue struct {
value *[]string value *[]string

View File

@ -3,9 +3,12 @@ package pflag
import ( import (
"bytes" "bytes"
"encoding/csv" "encoding/csv"
"fmt"
"strings" "strings"
) )
var _ = fmt.Fprint
// -- stringSlice Value // -- stringSlice Value
type stringSliceValue struct { type stringSliceValue struct {
value *[]string value *[]string
@ -36,7 +39,7 @@ func writeAsCSV(vals []string) (string, error) {
return "", err return "", err
} }
w.Flush() w.Flush()
return strings.TrimSuffix(b.String(), "\n"), nil return strings.TrimSuffix(b.String(), fmt.Sprintln()), nil
} }
func (s *stringSliceValue) Set(val string) error { func (s *stringSliceValue) Set(val string) error {

View File

@ -1,126 +0,0 @@
package pflag
import (
"fmt"
"strconv"
"strings"
)
// -- uintSlice Value
type uintSliceValue struct {
value *[]uint
changed bool
}
func newUintSliceValue(val []uint, p *[]uint) *uintSliceValue {
uisv := new(uintSliceValue)
uisv.value = p
*uisv.value = val
return uisv
}
func (s *uintSliceValue) Set(val string) error {
ss := strings.Split(val, ",")
out := make([]uint, len(ss))
for i, d := range ss {
u, err := strconv.ParseUint(d, 10, 0)
if err != nil {
return err
}
out[i] = uint(u)
}
if !s.changed {
*s.value = out
} else {
*s.value = append(*s.value, out...)
}
s.changed = true
return nil
}
func (s *uintSliceValue) Type() string {
return "uintSlice"
}
func (s *uintSliceValue) String() string {
out := make([]string, len(*s.value))
for i, d := range *s.value {
out[i] = fmt.Sprintf("%d", d)
}
return "[" + strings.Join(out, ",") + "]"
}
func uintSliceConv(val string) (interface{}, error) {
val = strings.Trim(val, "[]")
// Empty string would cause a slice with one (empty) entry
if len(val) == 0 {
return []uint{}, nil
}
ss := strings.Split(val, ",")
out := make([]uint, len(ss))
for i, d := range ss {
u, err := strconv.ParseUint(d, 10, 0)
if err != nil {
return nil, err
}
out[i] = uint(u)
}
return out, nil
}
// GetUintSlice returns the []uint value of a flag with the given name.
func (f *FlagSet) GetUintSlice(name string) ([]uint, error) {
val, err := f.getFlagType(name, "uintSlice", uintSliceConv)
if err != nil {
return []uint{}, err
}
return val.([]uint), nil
}
// UintSliceVar defines a uintSlice flag with specified name, default value, and usage string.
// The argument p points to a []uint variable in which to store the value of the flag.
func (f *FlagSet) UintSliceVar(p *[]uint, name string, value []uint, usage string) {
f.VarP(newUintSliceValue(value, p), name, "", usage)
}
// UintSliceVarP is like UintSliceVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) UintSliceVarP(p *[]uint, name, shorthand string, value []uint, usage string) {
f.VarP(newUintSliceValue(value, p), name, shorthand, usage)
}
// UintSliceVar defines a uint[] flag with specified name, default value, and usage string.
// The argument p points to a uint[] variable in which to store the value of the flag.
func UintSliceVar(p *[]uint, name string, value []uint, usage string) {
CommandLine.VarP(newUintSliceValue(value, p), name, "", usage)
}
// UintSliceVarP is like the UintSliceVar, but accepts a shorthand letter that can be used after a single dash.
func UintSliceVarP(p *[]uint, name, shorthand string, value []uint, usage string) {
CommandLine.VarP(newUintSliceValue(value, p), name, shorthand, usage)
}
// UintSlice defines a []uint flag with specified name, default value, and usage string.
// The return value is the address of a []uint variable that stores the value of the flag.
func (f *FlagSet) UintSlice(name string, value []uint, usage string) *[]uint {
p := []uint{}
f.UintSliceVarP(&p, name, "", value, usage)
return &p
}
// UintSliceP is like UintSlice, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) UintSliceP(name, shorthand string, value []uint, usage string) *[]uint {
p := []uint{}
f.UintSliceVarP(&p, name, shorthand, value, usage)
return &p
}
// UintSlice defines a []uint flag with specified name, default value, and usage string.
// The return value is the address of a []uint variable that stores the value of the flag.
func UintSlice(name string, value []uint, usage string) *[]uint {
return CommandLine.UintSliceP(name, "", value, usage)
}
// UintSliceP is like UintSlice, but accepts a shorthand letter that can be used after a single dash.
func UintSliceP(name, shorthand string, value []uint, usage string) *[]uint {
return CommandLine.UintSliceP(name, shorthand, value, usage)
}

View File

@ -21,7 +21,6 @@ package viper
import ( import (
"bytes" "bytes"
"encoding/csv"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -41,11 +40,6 @@ import (
var v *Viper var v *Viper
type RemoteResponse struct {
Value []byte
Error error
}
func init() { func init() {
v = New() v = New()
} }
@ -53,7 +47,6 @@ func init() {
type remoteConfigFactory interface { type remoteConfigFactory interface {
Get(rp RemoteProvider) (io.Reader, error) Get(rp RemoteProvider) (io.Reader, error)
Watch(rp RemoteProvider) (io.Reader, error) Watch(rp RemoteProvider) (io.Reader, error)
WatchChannel(rp RemoteProvider) (<-chan *RemoteResponse, chan bool)
} }
// RemoteConfig is optional, see the remote package // RemoteConfig is optional, see the remote package
@ -69,7 +62,8 @@ func (str UnsupportedConfigError) Error() string {
} }
// UnsupportedRemoteProviderError denotes encountering an unsupported remote // UnsupportedRemoteProviderError denotes encountering an unsupported remote
// provider. Currently only etcd and Consul are supported. // provider. Currently only etcd and Consul are
// supported.
type UnsupportedRemoteProviderError string type UnsupportedRemoteProviderError string
// Error returns the formatted remote provider error. // Error returns the formatted remote provider error.
@ -282,8 +276,8 @@ func (v *Viper) WatchConfig() {
}() }()
} }
// SetConfigFile explicitly defines the path, name and extension of the config file. // SetConfigFile explicitly defines the path, name and extension of the config file
// Viper will use this and not check any of the config paths. // Viper will use this and not check any of the config paths
func SetConfigFile(in string) { v.SetConfigFile(in) } func SetConfigFile(in string) { v.SetConfigFile(in) }
func (v *Viper) SetConfigFile(in string) { func (v *Viper) SetConfigFile(in string) {
if in != "" { if in != "" {
@ -292,8 +286,8 @@ func (v *Viper) SetConfigFile(in string) {
} }
// SetEnvPrefix defines a prefix that ENVIRONMENT variables will use. // SetEnvPrefix defines a prefix that ENVIRONMENT variables will use.
// E.g. if your prefix is "spf", the env registry will look for env // E.g. if your prefix is "spf", the env registry
// variables that start with "SPF_". // will look for env. variables that start with "SPF_"
func SetEnvPrefix(in string) { v.SetEnvPrefix(in) } func SetEnvPrefix(in string) { v.SetEnvPrefix(in) }
func (v *Viper) SetEnvPrefix(in string) { func (v *Viper) SetEnvPrefix(in string) {
if in != "" { if in != "" {
@ -311,11 +305,11 @@ func (v *Viper) mergeWithEnvPrefix(in string) string {
// TODO: should getEnv logic be moved into find(). Can generalize the use of // TODO: should getEnv logic be moved into find(). Can generalize the use of
// rewriting keys many things, Ex: Get('someKey') -> some_key // rewriting keys many things, Ex: Get('someKey') -> some_key
// (camel case to snake case for JSON keys perhaps) // (cammel case to snake case for JSON keys perhaps)
// getEnv is a wrapper around os.Getenv which replaces characters in the original // getEnv is a wrapper around os.Getenv which replaces characters in the original
// key. This allows env vars which have different keys than the config object // key. This allows env vars which have different keys then the config object
// keys. // keys
func (v *Viper) getEnv(key string) string { func (v *Viper) getEnv(key string) string {
if v.envKeyReplacer != nil { if v.envKeyReplacer != nil {
key = v.envKeyReplacer.Replace(key) key = v.envKeyReplacer.Replace(key)
@ -323,7 +317,7 @@ func (v *Viper) getEnv(key string) string {
return os.Getenv(key) return os.Getenv(key)
} }
// ConfigFileUsed returns the file used to populate the config registry. // ConfigFileUsed returns the file used to populate the config registry
func ConfigFileUsed() string { return v.ConfigFileUsed() } func ConfigFileUsed() string { return v.ConfigFileUsed() }
func (v *Viper) ConfigFileUsed() string { return v.configFile } func (v *Viper) ConfigFileUsed() string { return v.configFile }
@ -596,33 +590,32 @@ func (v *Viper) Get(key string) interface{} {
return nil return nil
} }
valType := val
if v.typeByDefValue { if v.typeByDefValue {
// TODO(bep) this branch isn't covered by a single test. // TODO(bep) this branch isn't covered by a single test.
valType := val
path := strings.Split(lcaseKey, v.keyDelim) path := strings.Split(lcaseKey, v.keyDelim)
defVal := v.searchMap(v.defaults, path) defVal := v.searchMap(v.defaults, path)
if defVal != nil { if defVal != nil {
valType = defVal valType = defVal
} }
switch valType.(type) {
case bool:
return cast.ToBool(val)
case string:
return cast.ToString(val)
case int64, int32, int16, int8, int:
return cast.ToInt(val)
case float64, float32:
return cast.ToFloat64(val)
case time.Time:
return cast.ToTime(val)
case time.Duration:
return cast.ToDuration(val)
case []string:
return cast.ToStringSlice(val)
}
} }
switch valType.(type) {
case bool:
return cast.ToBool(val)
case string:
return cast.ToString(val)
case int64, int32, int16, int8, int:
return cast.ToInt(val)
case float64, float32:
return cast.ToFloat64(val)
case time.Time:
return cast.ToTime(val)
case time.Duration:
return cast.ToDuration(val)
case []string:
return cast.ToStringSlice(val)
}
return val return val
} }
@ -720,15 +713,7 @@ func (v *Viper) GetSizeInBytes(key string) uint {
// UnmarshalKey takes a single key and unmarshals it into a Struct. // UnmarshalKey takes a single key and unmarshals it into a Struct.
func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) } func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) }
func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error { func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error {
err := decode(v.Get(key), defaultDecoderConfig(rawVal)) return mapstructure.Decode(v.Get(key), rawVal)
if err != nil {
return err
}
v.insensitiviseMaps()
return nil
} }
// Unmarshal unmarshals the config into a Struct. Make sure that the tags // Unmarshal unmarshals the config into a Struct. Make sure that the tags
@ -747,16 +732,13 @@ func (v *Viper) Unmarshal(rawVal interface{}) error {
} }
// defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot // defaultDecoderConfig returns default mapsstructure.DecoderConfig with suppot
// of time.Duration values & string slices // of time.Duration values
func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig { func defaultDecoderConfig(output interface{}) *mapstructure.DecoderConfig {
return &mapstructure.DecoderConfig{ return &mapstructure.DecoderConfig{
Metadata: nil, Metadata: nil,
Result: output, Result: output,
WeaklyTypedInput: true, WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc( DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
),
} }
} }
@ -817,7 +799,7 @@ func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) {
} }
// BindFlagValue binds a specific key to a FlagValue. // BindFlagValue binds a specific key to a FlagValue.
// Example (where serverCmd is a Cobra instance): // Example(where serverCmd is a Cobra instance):
// //
// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") // serverCmd.Flags().Int("port", 1138, "Port to run Application server on")
// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) // Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port"))
@ -898,9 +880,7 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") return strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -967,9 +947,7 @@ func (v *Viper) find(lcaseKey string) interface{} {
return cast.ToBool(flag.ValueString()) return cast.ToBool(flag.ValueString())
case "stringSlice": case "stringSlice":
s := strings.TrimPrefix(flag.ValueString(), "[") s := strings.TrimPrefix(flag.ValueString(), "[")
s = strings.TrimSuffix(s, "]") return strings.TrimSuffix(s, "]")
res, _ := readAsCSV(s)
return res
default: default:
return flag.ValueString() return flag.ValueString()
} }
@ -979,15 +957,6 @@ func (v *Viper) find(lcaseKey string) interface{} {
return nil return nil
} }
func readAsCSV(val string) ([]string, error) {
if val == "" {
return []string{}, nil
}
stringReader := strings.NewReader(val)
csvReader := csv.NewReader(stringReader)
return csvReader.Read()
}
// IsSet checks to see if the key has been set in any of the data locations. // IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key. // IsSet is case-insensitive for a key.
func IsSet(key string) bool { return v.IsSet(key) } func IsSet(key string) bool { return v.IsSet(key) }
@ -1124,30 +1093,24 @@ func (v *Viper) ReadInConfig() error {
return err return err
} }
config := make(map[string]interface{}) v.config = make(map[string]interface{})
err = v.unmarshalReader(bytes.NewReader(file), config) return v.unmarshalReader(bytes.NewReader(file), v.config)
if err != nil {
return err
}
v.config = config
return nil
} }
// MergeInConfig merges a new configuration with an existing config. // MergeInConfig merges a new configuration with an existing config.
func MergeInConfig() error { return v.MergeInConfig() } func MergeInConfig() error { return v.MergeInConfig() }
func (v *Viper) MergeInConfig() error { func (v *Viper) MergeInConfig() error {
jww.INFO.Println("Attempting to merge in config file") jww.INFO.Println("Attempting to merge in config file")
if !stringInSlice(v.getConfigType(), SupportedExts) {
return UnsupportedConfigError(v.getConfigType())
}
filename, err := v.getConfigFile() filename, err := v.getConfigFile()
if err != nil { if err != nil {
return err return err
} }
if !stringInSlice(v.getConfigType(), SupportedExts) {
return UnsupportedConfigError(v.getConfigType())
}
file, err := afero.ReadFile(v.fs, filename) file, err := afero.ReadFile(v.fs, filename)
if err != nil { if err != nil {
return err return err
@ -1286,11 +1249,7 @@ func (v *Viper) WatchRemoteConfig() error {
return v.watchKeyValueConfig() return v.watchKeyValueConfig()
} }
func (v *Viper) WatchRemoteConfigOnChannel() error { // Unmarshall a Reader into a map.
return v.watchKeyValueConfigOnChannel()
}
// Unmarshal a Reader into a map.
// Should probably be an unexported function. // Should probably be an unexported function.
func unmarshalReader(in io.Reader, c map[string]interface{}) error { func unmarshalReader(in io.Reader, c map[string]interface{}) error {
return v.unmarshalReader(in, c) return v.unmarshalReader(in, c)
@ -1333,23 +1292,6 @@ func (v *Viper) getRemoteConfig(provider RemoteProvider) (map[string]interface{}
return v.kvstore, err return v.kvstore, err
} }
// Retrieve the first found remote configuration.
func (v *Viper) watchKeyValueConfigOnChannel() error {
for _, rp := range v.remoteProviders {
respc, _ := RemoteConfig.WatchChannel(rp)
//Todo: Add quit channel
go func(rc <-chan *RemoteResponse) {
for {
b := <-rc
reader := bytes.NewReader(b.Value)
v.unmarshalReader(reader, v.kvstore)
}
}(respc)
return nil
}
return RemoteConfigError("No Files Found")
}
// Retrieve the first found remote configuration. // Retrieve the first found remote configuration.
func (v *Viper) watchKeyValueConfig() error { func (v *Viper) watchKeyValueConfig() error {
for _, rp := range v.remoteProviders { for _, rp := range v.remoteProviders {
@ -1549,6 +1491,7 @@ func (v *Viper) searchInPath(in string) (filename string) {
// Search all configPaths for any config file. // Search all configPaths for any config file.
// Returns the first path that exists (and is a config file). // Returns the first path that exists (and is a config file).
func (v *Viper) findConfigFile() (string, error) { func (v *Viper) findConfigFile() (string, error) {
jww.INFO.Println("Searching for config in ", v.configPaths) jww.INFO.Println("Searching for config in ", v.configPaths)
for _, cp := range v.configPaths { for _, cp := range v.configPaths {

View File

@ -132,8 +132,11 @@ const (
keyPasteEnd keyPasteEnd
) )
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} var (
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} crlf = []byte{'\r', '\n'}
pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns // bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns utf8.RuneError. // the key and the remainder of the input. Otherwise it returns utf8.RuneError.
@ -333,7 +336,7 @@ func (t *Terminal) advanceCursor(places int) {
// So, if we are stopping at the end of a line, we // So, if we are stopping at the end of a line, we
// need to write a newline so that our cursor can be // need to write a newline so that our cursor can be
// advanced to the next line. // advanced to the next line.
t.outBuf = append(t.outBuf, '\n') t.outBuf = append(t.outBuf, '\r', '\n')
} }
} }
@ -593,6 +596,35 @@ func (t *Terminal) writeLine(line []rune) {
} }
} }
// writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n.
func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
for len(buf) > 0 {
i := bytes.IndexByte(buf, '\n')
todo := len(buf)
if i >= 0 {
todo = i
}
var nn int
nn, err = w.Write(buf[:todo])
n += nn
if err != nil {
return n, err
}
buf = buf[todo:]
if i >= 0 {
if _, err = w.Write(crlf); err != nil {
return n, err
}
n += 1
buf = buf[1:]
}
}
return n, nil
}
func (t *Terminal) Write(buf []byte) (n int, err error) { func (t *Terminal) Write(buf []byte) (n int, err error) {
t.lock.Lock() t.lock.Lock()
defer t.lock.Unlock() defer t.lock.Unlock()
@ -600,7 +632,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
if t.cursorX == 0 && t.cursorY == 0 { if t.cursorX == 0 && t.cursorY == 0 {
// This is the easy case: there's nothing on the screen that we // This is the easy case: there's nothing on the screen that we
// have to move out of the way. // have to move out of the way.
return t.c.Write(buf) return writeWithCRLF(t.c, buf)
} }
// We have a prompt and possibly user input on the screen. We // We have a prompt and possibly user input on the screen. We
@ -620,7 +652,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
} }
t.outBuf = t.outBuf[:0] t.outBuf = t.outBuf[:0]
if n, err = t.c.Write(buf); err != nil { if n, err = writeWithCRLF(t.c, buf); err != nil {
return return
} }
@ -740,8 +772,6 @@ func (t *Terminal) readLine() (line string, err error) {
t.remainder = t.inBuf[:n+len(t.remainder)] t.remainder = t.inBuf[:n+len(t.remainder)]
} }
panic("unreachable") // for Go 1.0.
} }
// SetPrompt sets the prompt to be used when reading subsequent lines. // SetPrompt sets the prompt to be used when reading subsequent lines.
@ -890,3 +920,32 @@ func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
} }
return s.entries[index], true return s.entries[index], true
} }
// readPasswordLine reads from reader until it finds \n or io.EOF.
// The slice returned does not include the \n.
// readPasswordLine also ignores any \r it finds.
func readPasswordLine(reader io.Reader) ([]byte, error) {
var buf [1]byte
var ret []byte
for {
n, err := reader.Read(buf[:])
if n > 0 {
switch buf[0] {
case '\n':
return ret, nil
case '\r':
// remove \r from passwords on Windows
default:
ret = append(ret, buf[0])
}
continue
}
if err != nil {
if err == io.EOF && len(ret) > 0 {
return ret, nil
}
return ret, err
}
}
}

View File

@ -17,9 +17,10 @@
package terminal // import "golang.org/x/crypto/ssh/terminal" package terminal // import "golang.org/x/crypto/ssh/terminal"
import ( import (
"io"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/sys/unix"
) )
// State contains the state of a terminal. // State contains the state of a terminal.
@ -44,8 +45,15 @@ func MakeRaw(fd int) (*State, error) {
} }
newState := oldState.termios newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF // This attempts to replicate the behaviour documented for cfmakeraw in
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG // the termios(3) manpage.
newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newState.Oflag &^= syscall.OPOST
newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newState.Cflag &^= syscall.CSIZE | syscall.PARENB
newState.Cflag |= syscall.CS8
newState.Cc[unix.VMIN] = 1
newState.Cc[unix.VTIME] = 0
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
return nil, err return nil, err
} }
@ -67,8 +75,10 @@ func GetState(fd int) (*State, error) {
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, state *State) error { func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0); err != 0 {
return err return err
}
return nil
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.
@ -81,6 +91,13 @@ func GetSize(fd int) (width, height int, err error) {
return int(dimensions[1]), int(dimensions[0]), nil return int(dimensions[1]), int(dimensions[0]), nil
} }
// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return syscall.Read(int(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
@ -102,27 +119,5 @@ func ReadPassword(fd int) ([]byte, error) {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
}() }()
var buf [16]byte return readPasswordLine(passwordReader(fd))
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
} }

View File

@ -6,7 +6,7 @@
package terminal package terminal
import "syscall" import "golang.org/x/sys/unix"
const ioctlReadTermios = syscall.TIOCGETA const ioctlReadTermios = unix.TIOCGETA
const ioctlWriteTermios = syscall.TIOCSETA const ioctlWriteTermios = unix.TIOCSETA

View File

@ -4,8 +4,7 @@
package terminal package terminal
// These constants are declared here, rather than importing import "golang.org/x/sys/unix"
// them from the syscall package as some syscall packages, even
// on linux, for example gccgo, do not declare them. const ioctlReadTermios = unix.TCGETS
const ioctlReadTermios = 0x5401 // syscall.TCGETS const ioctlWriteTermios = unix.TCSETS
const ioctlWriteTermios = 0x5402 // syscall.TCSETS

58
vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go generated vendored Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"fmt"
"runtime"
)
type State struct{}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
return false
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

128
vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go generated vendored Normal file
View File

@ -0,0 +1,128 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build solaris
package terminal // import "golang.org/x/crypto/ssh/terminal"
import (
"golang.org/x/sys/unix"
"io"
"syscall"
)
// State contains the state of a terminal.
type State struct {
state *unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
_, err := unix.IoctlGetTermio(fd, unix.TCGETA)
return err == nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
// see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c
val, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldState := *val
newState := oldState
newState.Lflag &^= syscall.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG
newState.Iflag |= syscall.ICRNL
err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState)
if err != nil {
return nil, err
}
defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState)
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldTermios := *oldTermiosPtr
newTermios := oldTermios
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{
state: oldTermiosPtr,
}, nil
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

View File

@ -17,54 +17,7 @@
package terminal package terminal
import ( import (
"io" "golang.org/x/sys/windows"
"syscall"
"unsafe"
)
const (
enableLineInput = 2
enableEchoInput = 4
enableProcessedInput = 1
enableWindowInput = 8
enableMouseInput = 16
enableInsertMode = 32
enableQuickEditMode = 64
enableExtendedFlags = 128
enableAutoPosition = 256
enableProcessedOutput = 1
enableWrapAtEolOutput = 2
)
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
var (
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
type (
short int16
word uint16
coord struct {
x short
y short
}
smallRect struct {
left short
top short
right short
bottom short
}
consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes word
window smallRect
maximumWindowSize coord
}
) )
type State struct { type State struct {
@ -74,8 +27,8 @@ type State struct {
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool { func IsTerminal(fd int) bool {
var st uint32 var st uint32
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) err := windows.GetConsoleMode(windows.Handle(fd), &st)
return r != 0 && e == 0 return err == nil
} }
// MakeRaw put the terminal connected to the given file descriptor into raw // MakeRaw put the terminal connected to the given file descriptor into raw
@ -83,14 +36,12 @@ func IsTerminal(fd int) bool {
// restored. // restored.
func MakeRaw(fd int) (*State, error) { func MakeRaw(fd int) (*State, error) {
var st uint32 var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
if e != 0 { return nil, err
return nil, error(e)
} }
st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
if e != 0 { return nil, err
return nil, error(e)
} }
return &State{st}, nil return &State{st}, nil
} }
@ -99,9 +50,8 @@ func MakeRaw(fd int) (*State, error) {
// restore the terminal after a signal. // restore the terminal after a signal.
func GetState(fd int) (*State, error) { func GetState(fd int) (*State, error) {
var st uint32 var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
if e != 0 { return nil, err
return nil, error(e)
} }
return &State{st}, nil return &State{st}, nil
} }
@ -109,18 +59,23 @@ func GetState(fd int) (*State, error) {
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, state *State) error { func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) return windows.SetConsoleMode(windows.Handle(fd), state.mode)
return err
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) { func GetSize(fd int) (width, height int, err error) {
var info consoleScreenBufferInfo var info windows.ConsoleScreenBufferInfo
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil {
if e != 0 { return 0, 0, err
return 0, 0, error(e)
} }
return int(info.size.x), int(info.size.y), nil return int(info.Size.X), int(info.Size.Y), nil
}
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return windows.Read(windows.Handle(r), buf)
} }
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
@ -128,47 +83,20 @@ func GetSize(fd int) (width, height int, err error) {
// returned does not include the \n. // returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) { func ReadPassword(fd int) ([]byte, error) {
var st uint32 var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
if e != 0 { return nil, err
return nil, error(e)
} }
old := st old := st
st &^= (enableEchoInput) st &^= (windows.ENABLE_ECHO_INPUT)
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) st |= (windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil {
if e != 0 { return nil, err
return nil, error(e)
} }
defer func() { defer func() {
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) windows.SetConsoleMode(windows.Handle(fd), old)
}() }()
var buf [16]byte return readPasswordLine(passwordReader(fd))
var ret []byte
for {
n, err := syscall.Read(syscall.Handle(fd), buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
if n > 0 && buf[n-1] == '\r' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
} }

10
vendor/golang.org/x/sys/unix/asm.s generated vendored Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !gccgo
#include "textflag.h"
TEXT ·use(SB),NOSPLIT,$0
RET

View File

@ -1,28 +1,29 @@
// Copyright 2015 The Go Authors. All rights reserved. // Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build linux
// +build mips64 mips64le
// +build !gccgo // +build !gccgo
#include "textflag.h" #include "textflag.h"
// //
// System calls for mips64, Linux // System call support for 386, FreeBSD
// //
// Just jump to package syscall's implementation for all these functions. // Just jump to package syscall's implementation for all these functions.
// The runtime may know about them. // The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56 TEXT ·Syscall(SB),NOSPLIT,$0-32
JMP syscall·Syscall(SB) JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80 TEXT ·Syscall6(SB),NOSPLIT,$0-44
JMP syscall·Syscall6(SB) JMP syscall·Syscall6(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56 TEXT ·Syscall9(SB),NOSPLIT,$0-56
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-32
JMP syscall·RawSyscall(SB) JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80 TEXT ·RawSyscall6(SB),NOSPLIT,$0-44
JMP syscall·RawSyscall6(SB) JMP syscall·RawSyscall6(SB)

Some files were not shown because too many files have changed in this diff Show More